In [1]:
%pip install -U -q seaborn

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import os
import sys
import torch
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Optional, Dict

In [3]:
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.max_colwidth", None)

In [4]:
def _load_pt_file(pt_path: Path | str) -> torch.Tensor:
    obj = torch.load(pt_path, map_location="cpu")
    if isinstance(obj, torch.Tensor):
        T = obj
    else:
        raise ValueError(f"Unsupported saved object in {pt_path}")

    if T.dim() != 2:
        raise ValueError(f"Expected 2D tensor [L, D], got shape {T.shape} in {p}")

    return T.contiguous()

In [5]:
_load_pt_file(Path("../smoketest/artifacts/pts/A0A077B1I6_CHIKV.pt")).shape

torch.Size([2474, 320])

In [6]:
T1 = _load_pt_file("../smoketest/artifacts/pts/A8D0M1_ADE02.pt")
T1

tensor([[ 0.1928,  0.0971, -0.2783,  ...,  0.6644,  0.0129, -0.3273],
        [-0.0482, -0.3755, -0.2580,  ...,  0.1180,  0.5264, -0.0385],
        [ 0.0179, -0.4515,  0.1118,  ...,  0.1207,  0.1228,  0.1003],
        ...,
        [ 0.1070,  0.3048,  0.0387,  ..., -0.1975, -0.0663, -0.2012],
        [-0.0979,  0.0137, -0.0529,  ...,  0.2223, -0.2307, -0.3664],
        [ 0.1540, -0.0168,  0.1488,  ...,  0.1536, -0.2626, -0.5395]])

In [7]:
T1.shape

torch.Size([80, 320])

In [8]:
def _seq_stats(T: torch.Tensor) -> Dict[str, float | int]:
    L, D = T.shape
    norms = torch.linalg.vector_norm(T, ord=2, dim=1)
    return {"L": int(L), "D": int(D), 
            "norm_min": float(norms.min()), 
            "norm_max": float(norms.max()), 
            "norms_mean": float(norms.mean()), 
            "norms_std": float(norms.std(unbiased=False))}

In [9]:
T1_stats = _seq_stats(T1)
T1_stats

{'L': 80,
 'D': 320,
 'norm_min': 6.438788890838623,
 'norm_max': 7.202156066894531,
 'norms_mean': 6.753352165222168,
 'norms_std': 0.176552414894104}

In [10]:
def cmd_summary(per_seq_dir: Path | str, out_csv: Path | str) -> None:
    rows = []
    for pt in sorted(per_seq_dir.glob("*.pt")):
        try:
            T = _load_pt_file(pt)
            s = _seq_stats(T)
            s["file"] = str(pt)
            s["id"] = pt.stem
            rows.append(s)
        except Exception as e:
            rows.append({"file": str(pt), "id": pt.stem, "error": str(e)})

    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    print(f"Wrote summary to {out_csv} with {len(df)} rows")
    print(df.head().to_string(index=False))

In [11]:
cmd_summary(Path("../smoketest/artifacts/pts"), Path("../smoketest/artifacts/analysis.csv"))

Wrote summary to ..\smoketest\artifacts\analysis.csv with 4 rows
   L   D  norm_min  norm_max  norms_mean  norms_std                                           file               id
2474 320  6.192690  8.448963    6.638195   0.295635 ..\smoketest\artifacts\pts\A0A077B1I6_CHIKV.pt A0A077B1I6_CHIKV
2474 320  6.197494  8.450518    6.635635   0.294995 ..\smoketest\artifacts\pts\A0A2R4P450_CHIKV.pt A0A2R4P450_CHIKV
  80 320  6.438789  7.202156    6.753352   0.176552     ..\smoketest\artifacts\pts\A8D0M1_ADE02.pt     A8D0M1_ADE02
  61 320  6.438868  7.112586    6.621868   0.101677     ..\smoketest\artifacts\pts\J9Z4E7_9ADEN.pt     J9Z4E7_9ADEN


In [12]:
def cmd_residue_norms(pt_file: Path | str, export_csv: Optional[Path | str] = None) -> None:
    T = _load_pt_file(pt_file)
    norms = torch.linalg.vector_norm(T, ord=2, dim=1).numpy()
    df = pd.DataFrame({"i": np.arange(len(norms), dtype=int), "l2_norm": norms})
    if export_csv:
        df.to_csv(export_csv, index=False)
        print(f"Wrote residue norms to {export_csv}  |  L={len(norms)}")
    else:
        print(df.head(10).to_string(index=False))
        print(f"   L={len(norms)}  |  mean={norms.mean():.3f}  |  std={norms.std():.3f}")

In [13]:
per_seq_dir = Path("../smoketest/artifacts/pts")
for pt in sorted(per_seq_dir.glob("*.pt")):
    cmd_residue_norms(pt)#, export_csv=Path(f"./smoketest/artifacts/pts/norms/{pt.stem}.csv"))

 i  l2_norm
 0 6.955391
 1 6.581041
 2 6.475143
 3 6.443740
 4 6.555764
 5 6.513677
 6 6.554464
 7 6.540605
 8 6.524702
 9 6.534371
   L=2474  |  mean=6.638  |  std=0.296
 i  l2_norm
 0 6.939563
 1 6.584851
 2 6.465514
 3 6.435401
 4 6.551120
 5 6.484616
 6 6.556869
 7 6.540776
 8 6.537248
 9 6.510118
   L=2474  |  mean=6.636  |  std=0.295
 i  l2_norm
 0 7.202156
 1 6.803749
 2 6.692681
 3 6.786354
 4 6.609060
 5 6.647263
 6 6.780688
 7 6.685127
 8 6.910825
 9 6.678438
   L=80  |  mean=6.753  |  std=0.177
 i  l2_norm
 0 7.112586
 1 6.697171
 2 6.631411
 3 6.571994
 4 6.592375
 5 6.571852
 6 6.601831
 7 6.646228
 8 6.535510
 9 6.697135
   L=61  |  mean=6.622  |  std=0.102


In [14]:
# cosine similarity
def _cos(a: torch.Tensor, b: torch.Tensor) -> float:
    a, b = a.float(), b.float()
    a_n = torch.linalg.vector_norm(a)
    b_n = torch.linalg.vector_norm(b)
    if a_n == 0 or b_n == 0:
        return float("NaN")
    return float((a @ b) / (a_n * b_n))

def cos_similarity(T1: torch.Tensor, T2: torch.Tensor, i1: int, i2: int) -> None:
    if not (0 <= i1 < T1.shape[0]) or not (0 <= i2 < T2.shape[0]):
        raise IndexError(f"Indices out of range\n   T1={T1.shape[0]}, i1={i1}\n   T2={T2.shape[0]}, i2={i2}")
    sim = _cos(T1[i1], T2[i2])
    print(f"cos(T1[{i1}], T2[{i2}])={sim:.6}")

In [15]:
T1 = _load_pt_file(Path("../smoketest/artifacts/pts/A8D0M1_ADE02.pt"))
T2 = _load_pt_file(Path("../smoketest/artifacts/pts/J9Z4E7_9ADEN.pt"))
T1_rows = T1.shape[0]
T2_rows = T2.shape[0]
limit = T1_rows if T1_rows < T2_rows else T2_rows
for i in range(limit):
    for j in range(limit):
        cos_similarity(T1, T2, i, j)

cos(T1[0], T2[0])=0.8003
cos(T1[0], T2[1])=0.458844
cos(T1[0], T2[2])=0.452129
cos(T1[0], T2[3])=0.487064
cos(T1[0], T2[4])=0.476941
cos(T1[0], T2[5])=0.462267
cos(T1[0], T2[6])=0.47371
cos(T1[0], T2[7])=0.461989
cos(T1[0], T2[8])=0.504824
cos(T1[0], T2[9])=0.428601
cos(T1[0], T2[10])=0.504682
cos(T1[0], T2[11])=0.491778
cos(T1[0], T2[12])=0.497401
cos(T1[0], T2[13])=0.489229
cos(T1[0], T2[14])=0.478003
cos(T1[0], T2[15])=0.516583
cos(T1[0], T2[16])=0.491391
cos(T1[0], T2[17])=0.538426
cos(T1[0], T2[18])=0.496446
cos(T1[0], T2[19])=0.534545
cos(T1[0], T2[20])=0.507407
cos(T1[0], T2[21])=0.493278
cos(T1[0], T2[22])=0.534773
cos(T1[0], T2[23])=0.505152
cos(T1[0], T2[24])=0.5231
cos(T1[0], T2[25])=0.503875
cos(T1[0], T2[26])=0.497416
cos(T1[0], T2[27])=0.491831
cos(T1[0], T2[28])=0.494024
cos(T1[0], T2[29])=0.480601
cos(T1[0], T2[30])=0.504853
cos(T1[0], T2[31])=0.500082
cos(T1[0], T2[32])=0.523447
cos(T1[0], T2[33])=0.480703
cos(T1[0], T2[34])=0.500881
cos(T1[0], T2[35])=0.559979
cos(T1[

In [16]:
long_seq = _load_pt_file(Path("../smoketest/artifacts/pts/A0A077B1I6_CHIKV.pt"))
long_seq.shape

torch.Size([2474, 320])