In [None]:
import torch
import re
import random
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from typing import Optional, Dict, Tuple, List
import torch.serialization as ts
from collections import Counter

In [None]:
pd.set_option("display.max_columns", None)
pd.set_option("display.width", 180)

In [None]:
EMB_DIR = Path("../localdata/sample/embeddings")
LABELS_PT = Path("../localdata/sample/labels/labels_single.pt")
META_TSV = Path("../localdata/sample/sample_meta.tsv")

In [None]:
payload = torch.load(LABELS_PT, map_location="cpu", weights_only=False)
labels_by_id = payload["labels"]
label_ids = set(labels_by_id.keys())
emb_by_id = {p.stem: p for p in EMB_DIR.glob("*.pt")}

In [None]:
matched_ids = sorted(label_ids & set(emb_by_id.keys()))
missing_emb = sorted(label_ids - set(emb_by_id.keys()))
extra_emb = sorted(set(emb_by_id.keys()) - label_ids)

print(f"label IDs: {len(label_ids)}")
print(f"embedding files: {len(emb_by_id)}")
print(f"matched: {len(matched_ids)}")
print(f"missing embeddings for labels: {len(missing_emb)}")
print(f"extra embeddings not in labels: {len(extra_emb)}")

In [None]:
selected_emb_paths = [emb_by_id[pid] for pid in matched_ids]
selected_labels = {pid: labels_by_id[pid] for pid in matched_ids}

In [None]:
def parse_pid(fullname: str) -> str:
    m = re.match(r"^ID=([^\s])", fullname.strip())
    return m.group(1) if m else ""

def load_emb(pt_path: Path) -> torch.Tensor:
    x = torch.load(pt_path, map_location="cpu", weights_only=True)
    if not isinstance(x, torch.Tensor) or x.dim() != 2:
        raise ValueError(f"Bad embedding in {pt_path}: type={type(pt_path)}, shape={getattr(x, 'shape', None)}")
    return x

In [None]:
# show one random protein end-to-end
pid = random.choice(matched_ids)
x = torch.load(emb_by_id[pid], map_location="cpu", weights_only=True)   # (L, D)
y = selected_labels[pid]                                                 # (L, 3)

print("protein_id:", pid)
print("embedding shape:", tuple(x.shape), "dtype:", x.dtype)
print("label shape:", tuple(y.shape), "dtype:", y.dtype)

print("\nEmbedding sample (first 3 residues, first 8 dims):")
print(x[:3, :8])

print("\nLabel sample (first 12 residues) [def, unc, not]:")
print(y[:12])

print("\nClass counts for this protein:")
print({
    "def": int(y[:,0].sum().item()),
    "unc": int(y[:,1].sum().item()),
    "not": int(y[:,2].sum().item()),
})

### Current Idea
- read in metadata file/fullname column
- parse out the ID and OXX parts
- map ID ---> OXX 

In [None]:
meta_df = pd.read_csv(META_TSV, sep="\t")
taxonomic_pattern = re.compile(r"^ID=([^\s]+)\s+AC=([^\s]+)\s+OXX=([^\s]+)\s*$")

def parse_pid_oxx(fullname: str) -> Tuple[str, List[str]] | None:
    match_ = taxonomic_pattern.match(fullname.strip())
    if not match_:
        return None
    oxx = (match_.group(3).split(",") + ["", "", "", ""])[:4]
    return (match_.group(1), oxx)

parsed_fullname = meta_df["FullName"].map(parse_pid_oxx)
meta_df["protein_id"] = parsed_fullname.map(lambda x: x[0])
meta_df["family_taxid"] = parsed_fullname.map(lambda x: x[1][2] if x[1][2] else f"fallback_{x[1][3]}")
meta_df["genus_taxid"] = parsed_fullname.map(lambda x: x[1][3])

family_counts = meta_df[["protein_id", "family_taxid"]].drop_duplicates()["family_taxid"].value_counts()
display(family_counts.head(30))

In [None]:
prot_family = meta_df[["protein_id", "family_taxid"]].drop_duplicates()
families = sorted(prot_family["family_taxid"].unique())
rng = random.Random(42)
rng.shuffle(families)

n_val = max(1, int(0.2 * len(families)))
val_fams = set(families[:n_val])

train_ids = prot_family.loc[~prot_family["family_taxid"].isin(val_fams), "protein_id"].tolist()
val_ids = prot_family.loc[prot_family["family_taxid"].isin(val_fams), "protein_id"].tolist()

print(f"train proteins: {len(train_ids)}, val proteins: {len(val_ids)}")
print(f"family overlap: {len(set(prot_family.set_index('protein_id').loc[train_ids, 'family_taxid']) & set(prot_family.set_index("protein_id").loc[val_ids, 'family_taxid']))}")