In [None]:
%pip install -U -q seaborn

In [None]:
import torch
from pathlib import Path
import pandas as pd
import numpy as np
# import matplotlib.pyplot as plt
# import seaborn as sns
from typing import Optional, Dict
import torch.serialization as ts

In [None]:
pd.set_option("display.max_rows", None)
pd.set_option("display.max_columns", None)
pd.set_option("display.max_colwidth", None)

In [None]:
def _load_pt_file(pt_path: Path | str) -> torch.Tensor:
    obj = torch.load(pt_path, map_location="cpu")
    if isinstance(obj, torch.Tensor):
        T = obj
    else:
        raise ValueError(f"Unsupported saved object in {pt_path}")

    if T.dim() != 2:
        raise ValueError(f"Expected 2D tensor [L, D], got shape {T.shape} in {pt_path}")

    return T.contiguous()

In [None]:
T0 = _load_pt_file(Path("../data/sample/embeddings/U3M724_CVHOC.pt"))
T1 = _load_pt_file(Path("../data/sample/embeddings/A0A1B0RMV1_9ORTO.pt"))
T2 = _load_pt_file(Path("../data/sample/embeddings/A0A2I6Q731_9RHAB.pt"))
T3 = _load_pt_file(Path("../data/sample/embeddings/K4GMQ4_9ENTO.pt"))

In [None]:
T0.shape

In [None]:
T1

In [None]:
T2

In [None]:
T3

In [None]:
T0[:,:-1].shape

In [None]:
# REMOVES LENGTH TO CALCULATE STATS
def _seq_stats(T: torch.Tensor) -> Dict[str, float | int]:
    L, D = T[:, :-1].shape
    norms = torch.linalg.vector_norm(T[:, :-1], ord=2, dim=1)
    return {"L": int(L), "D": int(D + 1), 
            "norm_min": float(norms.min()), 
            "norm_max": float(norms.max()), 
            "norms_mean": float(norms.mean()), 
            "norms_std": float(norms.std(unbiased=False))}

In [None]:
T0_stats = _seq_stats(T0)
T0_stats

In [None]:
T1_stats = _seq_stats(T1)
T1_stats

In [None]:
T2_stats = _seq_stats(T2)
T2_stats

In [None]:
T3_stats = _seq_stats(T3)
T3_stats

In [None]:
def cmd_summary(per_seq_dir: Path | str, out_csv: Path | str) -> None:
    rows = []
    for pt in sorted(per_seq_dir.glob("*.pt")):
        try:
            T = _load_pt_file(pt)
            s = _seq_stats(T)
            s["file"] = str(pt)
            s["id"] = pt.stem
            rows.append(s)
        except Exception as e:
            rows.append({"file": str(pt), "id": pt.stem, "error": str(e)})

    df = pd.DataFrame(rows)
    df.to_csv(out_csv, index=False)
    print(f"Wrote summary to {out_csv} with {len(df)} rows")
    print(df.head().to_string(index=False))

In [None]:
cmd_summary(Path("../smoketest/artifacts/pts"), Path("../smoketest/artifacts/analysis.csv"))

In [None]:
def cmd_residue_norms(pt_file: Path | str, export_csv: Optional[Path | str] = None) -> None:
    T = _load_pt_file(pt_file)
    norms = torch.linalg.vector_norm(T, ord=2, dim=1).numpy()
    df = pd.DataFrame({"i": np.arange(len(norms), dtype=int), "l2_norm": norms})
    if export_csv:
        df.to_csv(export_csv, index=False)
        print(f"Wrote residue norms to {export_csv}  |  L={len(norms)}")
    else:
        print(df.head(10).to_string(index=False))
        print(f"   L={len(norms)}  |  mean={norms.mean():.3f}  |  std={norms.std():.3f}")

In [None]:
per_seq_dir = Path("../smoketest/artifacts/pts")
for pt in sorted(per_seq_dir.glob("*.pt")):
    cmd_residue_norms(pt)#, export_csv=Path(f"./smoketest/artifacts/pts/norms/{pt.stem}.csv"))

In [None]:
# cosine similarity
def _cos(a: torch.Tensor, b: torch.Tensor) -> float:
    a, b = a.float(), b.float()
    a_n = torch.linalg.vector_norm(a)
    b_n = torch.linalg.vector_norm(b)
    if a_n == 0 or b_n == 0:
        return float("NaN")
    return float((a @ b) / (a_n * b_n))

def cos_similarity(T1: torch.Tensor, T2: torch.Tensor, i1: int, i2: int) -> None:
    if not (0 <= i1 < T1.shape[0]) or not (0 <= i2 < T2.shape[0]):
        raise IndexError(f"Indices out of range\n   T1={T1.shape[0]}, i1={i1}\n   T2={T2.shape[0]}, i2={i2}")
    sim = _cos(T1[i1], T2[i2])
    print(f"cos(T1[{i1}], T2[{i2}])={sim:.6}")

In [None]:
T1 = _load_pt_file(Path("../smoketest/artifacts/pts/A8D0M1_ADE02.pt"))
T2 = _load_pt_file(Path("../smoketest/artifacts/pts/J9Z4E7_9ADEN.pt"))
T1_rows = T1.shape[0]
T2_rows = T2.shape[0]
limit = T1_rows if T1_rows < T2_rows else T2_rows
for i in range(limit):
    for j in range(limit):
        cos_similarity(T1, T2, i, j)

In [None]:
long_seq = _load_pt_file(Path("../data/hpc_test_2/artifacts/pts/shard_003/A0A1V0FX51_COWPX.pt"))
long_seq.shape

In [None]:
bp3_long_seq = torch.load(Path("../../data/7lj4_B.pt"))
bp3_long_seq = bp3_long_seq["representations"][33]
bp3_long_seq.shape

In [None]:
# verify training data is valid
obj = torch.load(Path("../data/localsample/500_sample.pt"))

In [None]:
print(f"all embbedings: {obj["embeddings"]}")
print(f"one peptide: {obj["embeddings"][10]}")
print(f"size of each embeddings: ({len(obj["embeddings"][100])}, {len(obj["embeddings"][100][0])})")

In [None]:
print(f"all targets: {obj["targets"]}")
print(f"size of targets: ({len(obj["targets"])}, {len(obj["targets"][0])})")