Part 1 (set up the environment, load config, define AA tables, geometry utilities (RBFs, torsions), and a PDB parser that returns clean residue rows with N/CA/C coordinates and pLDDT.)

Cell 1 — Imports & environment echo (Windows + CUDA 12.8 stack)

In [22]:
# Cell 1: Imports & environment echo
from __future__ import annotations
from pathlib import Path
import os, sys, json, math, hashlib, time
import numpy as np
import pandas as pd

import torch
from typing import List, Tuple, Dict, Optional

# Biopython PDB parser for robust, dependency-light parsing
from Bio.PDB import PDBParser, is_aa

# Env echo (matches your provided environment)
print("Python     :", sys.version.split()[0])
print("OS         :", os.name, "-", sys.platform)
print("PyTorch    :", torch.__version__)
print("CUDA avail :", torch.cuda.is_available())
if torch.cuda.is_available():
    print("Device     :", torch.cuda.get_device_name(0))
    print("TF32       :", torch.backends.cuda.matmul.allow_tf32)
    print("AMP dtype  :", torch.get_autocast_dtype('cuda'))
    print("cudnn.bench:", torch.backends.cudnn.benchmark)

Python     : 3.12.11
OS         : nt - win32
PyTorch    : 2.7.0+cu128
CUDA avail : True
Device     : NVIDIA GeForce RTX 5070 Ti
TF32       : True
AMP dtype  : torch.float16
cudnn.bench: True


Cell 2 — Load YAML config + make dirs

In [23]:
# Cell 2: Load YAML config + make dirs
import yaml

CFG_PATH = Path("./gvp_config.yaml")
assert CFG_PATH.exists(), f"Config file not found: {CFG_PATH}"

with open(CFG_PATH, "r", encoding="utf-8") as f:
    CFG = yaml.safe_load(f)

paths = CFG["paths"]
io_cfg = CFG["io"]
enc = CFG["encoder"]
mdl = CFG["model"]
pool = CFG["pooling"]
perf = CFG["performance"]
repro = CFG["repro"]

DATA_ROOT = Path(paths["data_root"]).resolve()
PDB_DIR = Path(paths["pdb_dir"]).resolve()
MAIN_PARQUET = Path(paths["main_parquet"]).resolve()
OUT_PARQUET = Path(paths["protein_embeddings_out"]).resolve()
GRAPH_DIR = Path(paths["graph_cache_dir"]).resolve()
META_DIR = Path(paths["meta_dir"]).resolve()

for d in [GRAPH_DIR, META_DIR, OUT_PARQUET.parent]:
    d.mkdir(parents=True, exist_ok=True)

print("Main parquet :", MAIN_PARQUET)
print("PDB dir      :", PDB_DIR)
print("Graph cache  :", GRAPH_DIR)
print("Meta dir     :", META_DIR)
print("Embeddings   :", OUT_PARQUET)


Main parquet : F:\Thesis Korbi na\dti-prediction-with-adr\Data\scope_onside_common_v3.parquet
PDB dir      : F:\Thesis Korbi na\dti-prediction-with-adr\AlphaFoldData
Graph cache  : F:\Thesis Korbi na\dti-prediction-with-adr\Data\graph_cache_gvp_v1
Meta dir     : F:\Thesis Korbi na\dti-prediction-with-adr\Data\gvp_meta
Embeddings   : F:\Thesis Korbi na\dti-prediction-with-adr\Data\protein_embeddings.parquet


Cell 3 — Repro settings (TF32, AMP, torch.compile flags)

In [24]:
# Cell 3: Repro & perf toggles
seed = int(repro.get("seed", 1337))
torch.manual_seed(seed)
np.random.seed(seed)

torch.backends.cuda.matmul.allow_tf32 = bool(perf.get("tf32", True))
torch.backends.cudnn.benchmark = bool(perf.get("cudnn_benchmark", True))
AMP_DTYPE = torch.float16 if perf.get("amp_dtype", "float16") == "float16" else torch.bfloat16
USE_COMPILE = bool(perf.get("torch_compile", True))
print("AMP_DTYPE   :", AMP_DTYPE)
print("torch.compile:", USE_COMPILE)

AMP_DTYPE   : torch.float16
torch.compile: True


Cell 4 — Amino acid tables & simple chemistry features

In [25]:
# Cell 4: Amino acid indices and basic chemistry scalars
AA3_TO_AA1 = {
    'ALA':'A','ARG':'R','ASN':'N','ASP':'D','CYS':'C','GLN':'Q','GLU':'E','GLY':'G',
    'HIS':'H','ILE':'I','LEU':'L','LYS':'K','MET':'M','PHE':'F','PRO':'P','SER':'S',
    'THR':'T','TRP':'W','TYR':'Y','VAL':'V'
}
AA1_LIST = list("ARNDCEQGHILKMFPSTWYV")
AA1_TO_IDX = {aa:i for i,aa in enumerate(AA1_LIST)}

# Kyte-Doolittle hydrophobicity (scaled later)
HYDRO = {
 'A':1.8,'R':-4.5,'N':-3.5,'D':-3.5,'C':2.5,'Q':-3.5,'E':-3.5,'G':-0.4,'H':-3.2,'I':4.5,
 'L':3.8,'K':-3.9,'M':1.9,'F':2.8,'P':-1.6,'S':-0.8,'T':-0.7,'W':-0.9,'Y':-1.3,'V':4.2
}
# Simplified charge at pH 7: +1 (K,R,H≈+0.1), -1 (D,E), else 0
CHARGE = {aa:(1 if aa in ['K','R'] else (0.1 if aa=='H' else (-1 if aa in ['D','E'] else 0))) for aa in AA1_LIST}
# Polarity flag: polar vs nonpolar (0/1)
POLAR = {aa:(1 if aa in ['R','N','D','Q','E','H','K','S','T','Y','C','W'] else 0) for aa in AA1_LIST}

def aa1_index(aa1: str) -> int:
    return AA1_TO_IDX.get(aa1, 20)  # 20 -> unknown

def chem_triplet(aa1: str) -> Tuple[float, float, float]:
    return float(HYDRO.get(aa1, 0.0)), float(CHARGE.get(aa1, 0.0)), float(POLAR.get(aa1, 0))

Cell 5 — Geometry helpers: normalize, RBF, angles

In [26]:
# Cell 5: Geometry helpers (normalize, RBF expansion, dihedrals)
def unit_vec(v: np.ndarray, eps: float = 1e-8) -> np.ndarray:
    n = np.linalg.norm(v) + eps
    return v / n

def rbf_expand(d: np.ndarray, num_bins: int, dmin: float, dmax: float) -> np.ndarray:
    """Gaussian RBF expansion over [dmin, dmax]."""
    centers = np.linspace(dmin, dmax, num_bins)
    gamma = 1.0 / ((centers[1] - centers[0] + 1e-8) ** 2)
    return np.exp(-gamma * (d[..., None] - centers[None, ...])**2)

def dihedral(p0, p1, p2, p3) -> float:
    """Dihedral angle (radians) for four points."""
    b0 = p1 - p0
    b1 = p2 - p1
    b2 = p3 - p2
    # Normalize b1 for stability
    b1n = b1 / (np.linalg.norm(b1) + 1e-8)
    v = b0 - (b0 @ b1n) * b1n
    w = b2 - (b2 @ b1n) * b1n
    x = np.dot(v, w)
    y = np.dot(np.cross(b1n, v), w)
    return np.arctan2(y, x)

def sincos(x: float) -> Tuple[float, float]:
    return float(np.sin(x)), float(np.cos(x))

Cell 6 — Build per-residue local frames (x,y,z) from N-CA-C

In [27]:
# Cell 6: Local backbone frame from (N, CA, C)
def local_frame(N: np.ndarray, CA: np.ndarray, C: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
    """
    Returns an orthonormal (x,y,z) frame:
      z = unit(CA - N)
      x = unit( (C - CA) - proj_{z}(C - CA) )
      y = z × x
    """
    z = unit_vec(CA - N)
    c = C - CA
    c_orth = c - (c @ z) * z
    x = unit_vec(c_orth)
    y = unit_vec(np.cross(z, x))
    return x, y, z

Cell 7 — PDB parsing (AlphaFold: pLDDT in B-factor), residue rows

In [28]:
# Cell 7: Parse AlphaFold PDB -> per-residue table with N/CA/C, pLDDT, AA, index
def parse_af2_pdb(pdb_path: Path) -> pd.DataFrame:
    """
    Returns a DataFrame with columns:
      ['res_index', 'aa1', 'N', 'CA', 'C', 'pLDDT', 'valid']
    Where N/CA/C are (3,) np.float32 arrays in Å. pLDDT from B-factor (CA atom).
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure("P", str(pdb_path))
    model = next(structure.get_models())  # AF2 usually one model

    rows = []
    res_i = 0
    for chain in model:
        for residue in chain:
            if not is_aa(residue, standard=True):
                continue
            try:
                N = residue["N"].get_coord().astype(np.float32)
                CA = residue["CA"].get_coord().astype(np.float32)
                C = residue["C"].get_coord().astype(np.float32)
            except KeyError:
                # missing backbone atom(s)
                rows.append({
                    "res_index": res_i,
                    "aa1": AA3_TO_AA1.get(residue.get_resname().upper(), "X"),
                    "N": None, "CA": None, "C": None,
                    "pLDDT": np.nan,
                    "valid": False
                })
                res_i += 1
                continue

            # AF2 stores pLDDT in B-factor; use CA B-factor as residue pLDDT
            plddt = float(residue["CA"].get_bfactor())
            aa1 = AA3_TO_AA1.get(residue.get_resname().upper(), "X")

            rows.append({
                "res_index": res_i,
                "aa1": aa1,
                "N": N, "CA": CA, "C": C,
                "pLDDT": plddt,
                "valid": True
            })
            res_i += 1

    df = pd.DataFrame(rows)
    return df

Cell 8 — Torsions (ϕ, ψ, ω) and backbone vectors per residue

In [29]:
# Cell 8: Compute torsions and backbone vector features
def compute_backbone_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    Adds:
      - phi_sin, phi_cos, psi_sin, psi_cos, omega_sin, omega_cos
      - v_N_CA (3,), v_CA_C (3,)
      - local frame axes: frame_x, frame_y, frame_z (each (3,))
    Missing neighbors -> NaNs for torsions; vectors default to unit/NaN-safe.
    """
    n = len(df)
    phi_sin = np.full(n, np.nan, np.float32); phi_cos = np.full(n, np.nan, np.float32)
    psi_sin = np.full(n, np.nan, np.float32); psi_cos = np.full(n, np.nan, np.float32)
    omg_sin = np.full(n, np.nan, np.float32); omg_cos = np.full(n, np.nan, np.float32)

    v_N_CA = np.zeros((n,3), np.float32)
    v_CA_C = np.zeros((n,3), np.float32)
    fx = np.zeros((n,3), np.float32)
    fy = np.zeros((n,3), np.float32)
    fz = np.zeros((n,3), np.float32)

    for i in range(n):
        if not df.at[i, "valid"]:
            continue
        N_i, CA_i, C_i = df.at[i, "N"], df.at[i, "CA"], df.at[i, "C"]
        v_N_CA[i] = unit_vec(CA_i - N_i)
        v_CA_C[i] = unit_vec(C_i - CA_i)
        x, y, z = local_frame(N_i, CA_i, C_i)
        fx[i], fy[i], fz[i] = x, y, z

        # torsions need neighbors
        # phi(i) = C(i-1), N(i), CA(i), C(i)
        # psi(i) = N(i), CA(i), C(i), N(i+1)
        # omega(i) = CA(i-1), C(i-1), N(i), CA(i)
        try:
            if i > 0 and df.at[i-1, "valid"]:
                Cm1 = df.at[i-1, "C"]
                Np = df.at[i, "N"]; CAp = df.at[i, "CA"]; Cp = df.at[i, "C"]
                Nm1 = df.at[i-1, "N"]; CAm1 = df.at[i-1, "CA"]; Cm1 = df.at[i-1, "C"]

                # phi
                phi = dihedral(df.at[i-1,"C"], Np, CAp, Cp)
                s, c = sincos(phi); phi_sin[i], phi_cos[i] = s, c

                # omega
                omega = dihedral(CAm1, Cm1, Np, CAp)
                s, c = sincos(omega); omg_sin[i], omg_cos[i] = s, c
            if i+1 < n and df.at[i+1, "valid"]:
                Nn = df.at[i+1, "N"]
                psi = dihedral(df.at[i,"N"], df.at[i,"CA"], df.at[i,"C"], Nn)
                s, c = sincos(psi); psi_sin[i], psi_cos[i] = s, c
        except Exception:
            # leave NaNs if any geometry missing
            pass

    out = df.copy()
    out["phi_sin"], out["phi_cos"] = phi_sin, phi_cos
    out["psi_sin"], out["psi_cos"] = psi_sin, psi_cos
    out["omega_sin"], out["omega_cos"] = omg_sin, omg_cos
    out["v_N_CA"] = list(v_N_CA)
    out["v_CA_C"] = list(v_CA_C)
    out["frame_x"] = list(fx)
    out["frame_y"] = list(fy)
    out["frame_z"] = list(fz)
    return out

Cell 9 — Positional encoding & chemistry scalars

In [30]:
# Cell 9: Positional enc (sin/cos) and chemistry triplet
def positional_enc(idx: np.ndarray, L: int, num_freqs: int = 8) -> np.ndarray:
    """
    Sin/cos positional enc over residue indices normalized to [0,1].
    Returns [len(idx), 2*num_freqs].
    """
    x = idx.astype(np.float32) / max(L - 1, 1)
    freqs = np.pi * (2 ** np.arange(num_freqs, dtype=np.float32))
    sins = np.sin(x[:, None] * freqs[None, :])
    coss = np.cos(x[:, None] * freqs[None, :])
    return np.concatenate([sins, coss], axis=-1).astype(np.float32)

def add_positional_and_chemistry(df: pd.DataFrame) -> pd.DataFrame:
    L = len(df)
    idx = df["res_index"].values.astype(np.int32)
    pos = positional_enc(idx, L, num_freqs=8)  # shape [L, 16]
    # Chemistry
    hydro = np.zeros(L, np.float32)
    charge = np.zeros(L, np.float32)
    polar  = np.zeros(L, np.float32)
    aa_idx = np.zeros(L, np.int64)
    for i, aa in enumerate(df["aa1"].values):
        hydro[i], charge[i], polar[i] = chem_triplet(aa)
        aa_idx[i] = aa1_index(aa)

    out = df.copy()
    out["pos_enc"] = list(pos)
    out["hydro"] = hydro
    out["charge"] = charge
    out["polar"] = polar
    out["aa_index"] = aa_idx
    return out

Cell 10 — Quick test on your sample PDB (A0PJK1.pdb)

In [31]:
# Cell 10: Dry run on a single AlphaFold PDB (e.g., A0PJK1.pdb)
test_pdb = PDB_DIR / "A0PJK1.pdb"
if test_pdb.exists():
    df_res = parse_af2_pdb(test_pdb)
    df_res = compute_backbone_features(df_res)
    df_res = add_positional_and_chemistry(df_res)

    print(df_res.head(3)[["res_index","aa1","pLDDT","phi_sin","phi_cos","psi_sin","psi_cos"]])
    print("Valid residues:", int(df_res["valid"].sum()), "/", len(df_res))
else:
    print("Sample PDB not found at:", test_pdb)

   res_index aa1  pLDDT   phi_sin   phi_cos   psi_sin   psi_cos
0          0   M  33.94       NaN       NaN -0.887723  0.460378
1          1   A  30.64 -0.758021 -0.652230 -0.822307  0.569044
2          2   A  33.57  0.370100  0.928992 -0.717845  0.696203
Valid residues: 596 / 596


Part 2 — Graph Construction:

Cell 11 — Seq-sep buckets, contact flag, md5 helper

In [32]:
# Cell 11: Sequence separation buckets, contact flag, MD5 helper

def seq_sep_bucket(delta: int) -> int:
    """Bucketize |i-j| into {0,1,2..4,5..8,9..16,>16} -> indices 0..5."""
    if delta == 0:     return 0
    if delta == 1:     return 1
    if 2 <= delta <= 4:  return 2
    if 5 <= delta <= 8:  return 3
    if 9 <= delta <= 16: return 4
    return 5

def contact_flag(dist_A: float, thresh_A: float = 8.0) -> int:
    return 1 if dist_A < thresh_A else 0

def file_md5(path: Path, chunk: int = 1 << 20) -> str:
    m = hashlib.md5()
    with open(path, "rb") as f:
        while True:
            b = f.read(chunk)
            if not b: break
            m.update(b)
    return m.hexdigest()

Cell 12 — k-NN ∪ radius ∪ sequence edges (Cα graph)

In [33]:
# Cell 12: Edge builder (k-NN ∪ radius ∪ sequence edges)

def build_edges_from_CA(
    CA: np.ndarray,
    knn_k: int,
    radius_A: float,
    add_seq_edges: bool = True,
    bidirectional: bool = True
) -> np.ndarray:
    """
    CA: [L,3] coordinates (float32)
    Returns edge_index [2, E] with unique pairs (i->j).
    """
    L = CA.shape[0]
    # Pairwise distances (broadcast, no scipy)
    # D[i,j] = ||CA[i] - CA[j]||
    diff = CA[:, None, :] - CA[None, :, :]
    D = np.linalg.norm(diff, axis=-1)  # [L,L]
    np.fill_diagonal(D, np.inf)

    edges = set()

    # k-NN edges
    if knn_k is not None and knn_k > 0:
        idx_knn = np.argpartition(D, kth=knn_k, axis=1)[:, :knn_k]  # [L,k]
        for i in range(L):
            for j in idx_knn[i]:
                edges.add((i, int(j)))

    # radius edges
    if radius_A is not None and radius_A > 0:
        within = np.where(D <= radius_A)
        for i, j in zip(within[0], within[1]):
            edges.add((int(i), int(j)))

    # sequence edges
    if add_seq_edges:
        for i in range(L - 1):
            edges.add((i, i + 1))
            edges.add((i + 1, i))

    # bidirectional
    if bidirectional:
        extra = []
        for (i, j) in edges:
            extra.append((j, i))
        for e in extra:
            edges.add(e)

    # remove self loops
    edges = [(i, j) for (i, j) in edges if i != j]

    # to array
    if len(edges) == 0:
        return np.empty((2, 0), dtype=np.int64)
    edge_index = np.array(edges, dtype=np.int64).T  # [2, E]
    # de-duplicate
    edge_index = np.unique(edge_index, axis=1)
    return edge_index

Cell 13 — Edge features (RBF distance, direction vector, seq-sep, contact)

In [34]:
# Cell 13: Compute edge features

def compute_edge_features(
    CA: np.ndarray,
    edge_index: np.ndarray,
    rbf_bins: int,
    rbf_dmin: float,
    rbf_dmax: float,
    use_seq_sep_buckets: bool = True,
    add_contact: bool = True
) -> Dict[str, np.ndarray]:
    """
    Returns dict with:
      edge_scalar: [E, S_e]
      edge_vector: [E, V_e, 3]
      Also returns auxiliaries: raw distance [E,], seq_sep_idx [E,] if requested.
    """
    src, dst = edge_index
    E = edge_index.shape[1]

    # distances + direction vectors
    vec_ij = CA[dst] - CA[src]                  # [E,3]
    dist = np.linalg.norm(vec_ij, axis=1)      # [E]
    u_ij = (vec_ij / (dist[:, None] + 1e-8)).astype(np.float32)  # unit vectors

    # RBF expansion
    rbf = rbf_expand(dist.astype(np.float32), rbf_bins, rbf_dmin, rbf_dmax)  # [E, rbf_bins]

    # sequence separation buckets
    if use_seq_sep_buckets:
        sep = np.abs(dst - src)
        sep_idx = np.array([seq_sep_bucket(int(d)) for d in sep], dtype=np.int64)
        # one-hot (6 buckets)
        sep_oh = np.zeros((E, 6), dtype=np.float32)
        sep_oh[np.arange(E), sep_idx] = 1.0
    else:
        sep_idx = None
        sep_oh = np.zeros((E, 0), dtype=np.float32)

    # contact flag
    if add_contact:
        contact = np.array([contact_flag(float(d)) for d in dist], dtype=np.int64)
        contact = contact.reshape(-1, 1).astype(np.float32)
    else:
        contact = np.zeros((E, 0), dtype=np.float32)

    # Pack edge scalars/vectors
    edge_scalar = np.concatenate([rbf.astype(np.float32), sep_oh, contact], axis=1).astype(np.float32)  # [E, S_e]
    edge_vector = u_ij[:, None, :].astype(np.float32)  # [E, 1, 3]  (V_e=1 so far)

    return {
        "edge_scalar": edge_scalar,
        "edge_vector": edge_vector,
        "dist": dist.astype(np.float32),
        "seq_sep_idx": sep_idx
    }

Cell 14 — Node feature packing (scalar/vector blocks)

In [35]:
# Cell 14: Assemble node scalar/vector feature blocks (Tier B)

def stack_node_features(
    df: pd.DataFrame,
    use_local_frame: bool = True,
    use_backbone_dirs: bool = True,
    use_torsions: bool = True,
    use_positional_enc: bool = True,
    use_pLDDT: bool = True,
    use_basic_chemistry: bool = True
) -> Dict[str, np.ndarray]:
    """
    Returns:
      node_scalar: [L, S_n] float32
      node_vector: [L, V_n, 3] float32
      aa_index:   [L] int64
      plddt:      [L] float32
      valid_mask: [L] bool
    """
    L = len(df)
    scalars = []

    # pLDDT
    if use_pLDDT:
        plddt = df["pLDDT"].fillna(0.0).values.astype(np.float32)
        scalars.append(plddt[:, None])
    else:
        plddt = np.zeros(L, np.float32)

    # torsions (sin/cos)
    if use_torsions:
        for nm in ["phi_sin","phi_cos","psi_sin","psi_cos","omega_sin","omega_cos"]:
            v = df[nm].fillna(0.0).values.astype(np.float32)
            scalars.append(v[:, None])

    # positional encoding
    if use_positional_enc:
        pos = np.stack(df["pos_enc"].values, axis=0).astype(np.float32)  # [L,16]
        scalars.append(pos)

    # chemistry triplet
    if use_basic_chemistry:
        hydro = df["hydro"].values.astype(np.float32)
        charge = df["charge"].values.astype(np.float32)
        polar  = df["polar"].values.astype(np.float32)
        scalars.append(np.stack([hydro, charge, polar], axis=1))  # [L,3]

    # concat scalar block
    node_scalar = np.concatenate(scalars, axis=1).astype(np.float32) if len(scalars) else np.zeros((L,0), np.float32)

    # vector block
    vectors = []
    if use_local_frame:
        fx = np.stack(df["frame_x"].values, axis=0).astype(np.float32)  # [L,3]
        fy = np.stack(df["frame_y"].values, axis=0).astype(np.float32)
        fz = np.stack(df["frame_z"].values, axis=0).astype(np.float32)
        vectors += [fx, fy, fz]
    if use_backbone_dirs:
        v1 = np.stack(df["v_N_CA"].values, axis=0).astype(np.float32)
        v2 = np.stack(df["v_CA_C"].values, axis=0).astype(np.float32)
        vectors += [v1, v2]
    if len(vectors):
        node_vector = np.stack(vectors, axis=1).astype(np.float32)  # [L, V_n, 3]
    else:
        node_vector = np.zeros((L, 0, 3), dtype=np.float32)

    aa_index = df["aa_index"].values.astype(np.int64)
    valid_mask = df["valid"].values.astype(bool)

    return {
        "node_scalar": node_scalar,
        "node_vector": node_vector,
        "aa_index": aa_index,
        "plddt": plddt,
        "valid_mask": valid_mask
    }

Cell 15 — Graph pack & NPZ cache writer

In [36]:
# Cell 15: Pack graph + metadata and save to NPZ

def pack_and_save_graph_npz(
    uniprot_id: str,
    df: pd.DataFrame,
    edge_index: np.ndarray,
    edge_feats: Dict[str, np.ndarray],
    node_feats: Dict[str, np.ndarray],
    pdb_path: Path,
    cfg: dict,
    out_dir: Path
) -> Path:
    """
    Writes {uniprot_id}.graph.npz with:
      node_scalar [L,S_n], node_vector [L,V_n,3], aa_index [L], valid_mask [L], plddt [L]
      edge_index [2,E], edge_scalar [E,S_e], edge_vector [E,V_e,3]
      CA coords [L,3] (for debugging) and metadata json
    """
    L = len(df)
    CA = np.stack(df["CA"].values, axis=0).astype(np.float32)

    meta = {
        "uniprot_id": uniprot_id,
        "L": int(L),
        "E": int(edge_index.shape[1]),
        "node_scalar_shape": list(node_feats["node_scalar"].shape),
        "node_vector_shape": list(node_feats["node_vector"].shape),
        "edge_scalar_shape": list(edge_feats["edge_scalar"].shape),
        "edge_vector_shape": list(edge_feats["edge_vector"].shape),
        "knn_k": int(cfg["encoder"]["knn_k"]),
        "radius_A": float(cfg["encoder"]["radius_cutoff_A"]),
        "add_sequence_edges": bool(cfg["encoder"]["add_sequence_edges"]),
        "bidirectional_edges": bool(cfg["encoder"]["bidirectional_edges"]),
        "rbf_bins": int(cfg["encoder"]["rbf_bins"]),
        "rbf_range": [float(cfg["encoder"]["rbf_dmin"]), float(cfg["encoder"]["rbf_dmax"])],
        "encoder_version": cfg["repro"]["encoder_version"],
        "cache_version": int(cfg["repro"]["cache_version"]),
        "pdb_md5": file_md5(pdb_path),
        "pdb_path": str(pdb_path.resolve()),
        "timestamp": time.strftime("%Y-%m-%dT%H:%M:%S%z", time.localtime()),
    }

    out_path = out_dir / f"{uniprot_id}.graph.npz"
    np.savez_compressed(
        out_path,
        node_scalar=node_feats["node_scalar"],
        node_vector=node_feats["node_vector"],
        aa_index=node_feats["aa_index"],
        valid_mask=node_feats["valid_mask"],
        plddt=node_feats["plddt"],
        edge_index=edge_index.astype(np.int64),
        edge_scalar=edge_feats["edge_scalar"],
        edge_vector=edge_feats["edge_vector"],
        CA=CA,
        meta=json.dumps(meta).encode("utf-8")
    )
    return out_path

Cell 16 — One-protein graph builder (end-to-end) + dry run

In [37]:
# Cell 16: Build & cache graph for a UniProt ID (end-to-end)

def build_and_cache_graph_for_uniprot(uniprot_id: str) -> Dict[str, object]:
    pdb_path = PDB_DIR / f"{uniprot_id}.pdb"
    assert pdb_path.exists(), f"PDB not found for {uniprot_id} at {pdb_path}"

    # 1) Parse & per-residue features
    df = parse_af2_pdb(pdb_path)
    df = compute_backbone_features(df)
    df = add_positional_and_chemistry(df)

    # Ensure all residues valid for Tier B (your dataset claims no missing)
    assert df["valid"].all(), "Encountered invalid residues; your data promised no missing."

    # 2) Build edges on Cα
    CA = np.stack(df["CA"].values, axis=0).astype(np.float32)
    edge_index = build_edges_from_CA(
        CA=CA,
        knn_k=int(enc["knn_k"]),
        radius_A=float(enc["radius_cutoff_A"]),
        add_seq_edges=bool(enc["add_sequence_edges"]),
        bidirectional=bool(enc["bidirectional_edges"])
    )

    # 3) Edge features
    edge_feats = compute_edge_features(
        CA=CA,
        edge_index=edge_index,
        rbf_bins=int(enc["rbf_bins"]),
        rbf_dmin=float(enc["rbf_dmin"]),
        rbf_dmax=float(enc["rbf_dmax"]),
        use_seq_sep_buckets=bool(enc["use_seq_sep_buckets"]),
        add_contact=bool(enc["add_contact_flag"])
    )

    # 4) Node features
    node_feats = stack_node_features(
        df=df,
        use_local_frame=bool(enc["use_local_frame"]),
        use_backbone_dirs=bool(enc["use_backbone_dirs"]),
        use_torsions=bool(enc["use_torsions"]),
        use_positional_enc=bool(enc["use_positional_enc"]),
        use_pLDDT=bool(enc["use_pLDDT"]),
        use_basic_chemistry=bool(enc["use_basic_chemistry"])
    )

    # 5) Save NPZ
    out_path = pack_and_save_graph_npz(
        uniprot_id=uniprot_id,
        df=df,
        edge_index=edge_index,
        edge_feats=edge_feats,
        node_feats=node_feats,
        pdb_path=pdb_path,
        cfg=CFG,
        out_dir=GRAPH_DIR
    )

    return {
        "graph_path": out_path,
        "L": len(df),
        "E": int(edge_index.shape[1]),
        "node_scalar_shape": node_feats["node_scalar"].shape,
        "node_vector_shape": node_feats["node_vector"].shape,
        "edge_scalar_shape": edge_feats["edge_scalar"].shape,
        "edge_vector_shape": edge_feats["edge_vector"].shape,
    }

# --- Dry run on A0PJK1 ---
dry = build_and_cache_graph_for_uniprot("A0PJK1")
print("Saved:", dry["graph_path"])
print("L:", dry["L"], "E:", dry["E"])
print("node_scalar:", dry["node_scalar_shape"], "node_vector:", dry["node_vector_shape"])
print("edge_scalar:", dry["edge_scalar_shape"], "edge_vector:", dry["edge_vector_shape"])

Saved: F:\Thesis Korbi na\dti-prediction-with-adr\Data\graph_cache_gvp_v1\A0PJK1.graph.npz
L: 596 E: 26116
node_scalar: (596, 26) node_vector: (596, 5, 3)
edge_scalar: (26116, 39) edge_vector: (26116, 1, 3)


Cell 17 — Scalar/Vector norms, LayerNorm, GVP block

In [38]:
# Cell 17: SV utils, LayerNorm, and GVP block

import torch.nn as nn
import torch.nn.functional as F

class LayerNormSV(nn.Module):
    """LayerNorm on scalar stream only; vector stream gets per-vector L2 re-scale."""
    def __init__(self, s_dim: int, v_dim: int, eps: float = 1e-5):
        super().__init__()
        self.ln = nn.LayerNorm(s_dim, eps=eps)
        self.v_dim = v_dim

    def forward(self, s: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # s: [N, S], v: [N, V, 3]
        s = self.ln(s)
        if v.shape[1] > 0:
            # normalize vector magnitudes softly (keep direction)
            vn = torch.linalg.norm(v, dim=-1, keepdim=True).clamp_min(1e-6)
            v = v / vn
        return s, v

class GVP(nn.Module):
    """
    Geometric Vector Perceptron: maps (s_in, v_in) -> (s_out, v_out)
    Uses gating from scalars to modulate vectors. All linear maps are equivariant
    (vectors are mixed only with vectors).
    """
    def __init__(self, s_in: int, v_in: int, s_out: int, v_out: int, dropout: float = 0.1):
        super().__init__()
        self.s_in, self.v_in, self.s_out, self.v_out = s_in, v_in, s_out, v_out

        # Scalar pathway
        self.ws = nn.Linear(s_in + (v_in if v_in > 0 else 0), s_out)

        # Vector pathway: mix vectors with vectors only
        self.wv = nn.Linear(v_in, v_out) if v_in > 0 and v_out > 0 else None

        # Gates to modulate vector magnitudes using scalars
        self.v_gate = nn.Linear(s_out, v_out) if v_out > 0 else None

        self.drop = nn.Dropout(dropout)

    def forward(self, s: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # s: [N, S_in], v: [N, V_in, 3]
        if self.v_in > 0:
            v_norm = torch.linalg.norm(v, dim=-1)  # [N, V_in]
            s_cat = torch.cat([s, v_norm], dim=-1)  # concat norms into scalar stream
        else:
            s_cat = s

        s_out = self.ws(s_cat)
        s_out = F.silu(s_out)
        s_out = self.drop(s_out)

        if self.wv is not None:
            # Mix vector channels linearly, preserving 3D
            # Rearrange to [N, 3, V_in] -> [N, 3, V_out] -> back to [N, V_out, 3]
            W = self.wv.weight   # [V_out, V_in]
            b = self.wv.bias     # [V_out] or None
            v_out = torch.einsum('bvc,ov->boc', v, W)  # (N,V_in,3) x (V_out,V_in) -> (N,V_out,3)
            if b is not None:
                v_out = v_out + b.view(1, -1, 1)       # broadcast over xyz

            # Gate vectors with scalars
            g = torch.sigmoid(self.v_gate(s_out)).unsqueeze(-1)  # [N, V_out, 1]
            v_out = v_out * g
            v_out = self.drop(v_out)
        else:
            v_out = v.new_zeros((v.shape[0], self.v_out, 3))

        return s_out, v_out

Cell 18 — Message passing layer (GVPConv)

In [39]:
# Cell 18: GVPConv layer with edge conditioning

class GVPConv(nn.Module):
    """
    One GVP message-passing layer:
      - Edge encoder turns (edge_scalar, edge_vector) into edge messages.
      - Messages combine src node -> dst using edge conditioning.
      - Node update via a GVP.
    """
    def __init__(
        self,
        s_node: int, v_node: int,
        s_edge: int, v_edge: int,
        s_hidden: int, v_hidden: int,
        dropout: float = 0.1
    ):
        super().__init__()
        # Edge encoder: (s_e, v_e) -> (s_h, v_h)
        self.edge_gvp = GVP(s_edge, v_edge, s_hidden, v_hidden, dropout=dropout)
        self.edge_norm = LayerNormSV(s_hidden, v_hidden)

        # Message combiner: merge src node with edge message -> (s_h, v_h)
        self.msg_gvp = GVP(s_node + s_hidden, v_node + v_hidden, s_hidden, v_hidden, dropout=dropout)
        self.msg_norm = LayerNormSV(s_hidden, v_hidden)

        # Node updater: (s_node + s_h, v_node + v_h) -> (s_node, v_node)
        self.up_gvp = GVP(s_node + s_hidden, v_node + v_hidden, s_node, v_node, dropout=dropout)
        self.up_norm = LayerNormSV(s_node, v_node)

    def forward(
        self,
        s: torch.Tensor, v: torch.Tensor,
        edge_index: torch.Tensor,
        e_s: torch.Tensor, e_v: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        """
        s: [N, S_node], v: [N, V_node, 3]
        edge_index: [2, E] (src, dst)
        e_s: [E, S_edge], e_v: [E, V_edge, 3]
        """
        src, dst = edge_index

        # Encode edges
        e_s_h, e_v_h = self.edge_gvp(e_s, e_v)
        e_s_h, e_v_h = self.edge_norm(e_s_h, e_v_h)

        # Build messages from src nodes + edge encodings
        msg_s_in = torch.cat([s[src], e_s_h], dim=-1)              # [E, S_node + S_h]
        msg_v_in = torch.cat([v[src], e_v_h], dim=1)               # [E, V_node + V_h, 3]
        m_s, m_v = self.msg_gvp(msg_s_in, msg_v_in)
        m_s, m_v = self.msg_norm(m_s, m_v)

        # --- DTYPE FIX for AMP: match message dtypes to accumulators ---
        m_s = m_s.to(s.dtype)
        m_v = m_v.to(v.dtype)

        # Aggregate to destinations
        N = s.size(0)
        S_h = m_s.size(-1)
        V_h = m_v.size(1)

        agg_s = torch.zeros((N, S_h), device=s.device, dtype=s.dtype)
        agg_v = torch.zeros((N, V_h, 3), device=v.device, dtype=v.dtype)

        agg_s.index_add_(0, dst, m_s)
        agg_v.index_add_(0, dst, m_v)


        # Update nodes (residual)
        up_s_in = torch.cat([s, agg_s], dim=-1)
        up_v_in = torch.cat([v, agg_v], dim=1)
        s_new, v_new = self.up_gvp(up_s_in, up_v_in)
        s_new, v_new = self.up_norm(s_new, v_new)

        return s_new, v_new

Cell 19 — Input embeddings & initial projections

In [40]:
# Cell 19: AA embedding and initial scalar/vector projections

class InputSVProjector(nn.Module):
    """
    Projects raw node scalar/vector features to model's working dims.
    Also embeds amino acid indices and concatenates into scalar stream.
    """
    def __init__(self, s_in: int, v_in: int, s_node: int, v_node: int, aa_vocab: int = 21, aa_emb: int = 16, dropout: float = 0.1):
        super().__init__()
        self.aa_emb = nn.Embedding(aa_vocab, aa_emb)
        self.pre_s = nn.Linear(s_in + aa_emb, s_node)
        self.pre_v = nn.Linear(v_in, v_node) if v_in > 0 and v_node > 0 else None
        self.drop = nn.Dropout(dropout)

    def forward(self, node_scalar: torch.Tensor, node_vector: torch.Tensor, aa_index: torch.Tensor):
        # node_scalar: [N, S_in], node_vector: [N, V_in, 3], aa_index: [N]
        aa_e = self.aa_emb(aa_index)                         # [N, aa_emb]
        s_in = torch.cat([node_scalar, aa_e], dim=-1)
        s = F.silu(self.pre_s(s_in))
        s = self.drop(s)

        if self.pre_v is not None and node_vector.size(1) > 0:
            # (N, V_in, 3) -> linear on channel dim
            # reshape to (N,3,V_in) x (V_in,V_node) -> (N,3,V_node) -> (N,V_node,3)
            W = self.pre_v.weight   # [V_out, V_in]
            b = self.pre_v.bias     # [V_out] or None
            v = torch.einsum('bvc,ov->boc', node_vector, W)  # -> (N,V_out,3)
            if b is not None:
                v = v + b.view(1, -1, 1)
            v = self.drop(v)
        else:
            v = node_vector.new_zeros((node_vector.size(0), 0, 3))

        return s, v


Cell 20 — Attention pooling with pLDDT gating + mean-pool concat

In [41]:
# Cell 20: Attention pooling with optional pLDDT gating and mean-pool concat

class AttnPool(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, gate_plddt: bool = True):
        super().__init__()
        self.gate_plddt = gate_plddt
        self.attn = nn.Linear(in_dim, 1)
        self.proj = nn.Linear(in_dim, out_dim)

        # gate maps pLDDT (0..100) -> [0,1] via affine+sigmoid
        self.gate_a = nn.Parameter(torch.tensor(0.05))  # scale
        self.gate_b = nn.Parameter(torch.tensor(-3.0))  # bias

    def forward(self, h: torch.Tensor, plddt: torch.Tensor) -> torch.Tensor:
        """
        h: [L, H], plddt: [L]
        returns pooled [out_dim]
        """
        a = self.attn(h).squeeze(-1)                    # [L]
        if self.gate_plddt:
            g = torch.sigmoid(self.gate_a * (plddt/100.0) + self.gate_b)  # [L]
            a = a + torch.log(g.clamp_min(1e-6))        # log-space gating
        w = torch.softmax(a, dim=0)                     # [L]
        pooled = (w.unsqueeze(-1) * h).sum(dim=0)       # [H]
        return self.proj(pooled)

class GlobalReadout(nn.Module):
    """
    Final global embedding = concat(attention_pool, mean_pool) -> linear to target dim.
    """
    def __init__(self, in_dim: int, out_dim: int, use_attention: bool = True, gate_plddt: bool = True, concat_mean: bool = True):
        super().__init__()
        self.use_attention = use_attention
        self.concat_mean = concat_mean
        if use_attention:
            self.attn = AttnPool(in_dim, in_dim, gate_plddt=gate_plddt)
        concat_dim = in_dim + (in_dim if use_attention and concat_mean else 0)
        self.proj = nn.Linear(concat_dim if use_attention else in_dim, out_dim)

    def forward(self, h: torch.Tensor, plddt: torch.Tensor) -> torch.Tensor:
        """
        h: [L, H]
        """
        parts = []
        if self.use_attention:
            ha = self.attn(h, plddt)     # [H]
            parts.append(ha)
            if self.concat_mean:
                parts.append(h.mean(dim=0))
            hcat = torch.cat(parts, dim=-1)
            return self.proj(hcat)
        else:
            return self.proj(h.mean(dim=0))


Cell 21 — Full ProteinEncoder (10× GVPConv + residue & global heads)

In [42]:
# Cell 21: ProteinEncoder assembly

class ProteinGVPEncoder(nn.Module):
    def __init__(self, cfg: dict):
        super().__init__()
        enc = cfg["encoder"]; mdl = cfg["model"]; pool = cfg["pooling"]

        self.s_node = mdl["node_scalar_dim"]
        self.v_node = mdl["node_vector_dim"]
        self.s_edge = enc["rbf_bins"] + (6 if enc["use_seq_sep_buckets"] else 0) + (1 if enc["add_contact_flag"] else 0)
        self.v_edge = 1  # from edge direction unit vector

        self.res_dim = mdl["residue_embed_dim"]
        self.glob_dim = mdl["global_embed_dim"]
        self.num_layers = mdl["num_layers"]
        self.dropout = mdl.get("scalar_dropout", 0.1)

        # Input projector (raw -> working dims)
        # Infer raw dims from your graph cache design:
        # node_scalar raw = 26 (from your print), node_vector raw = 5
        self.input_proj = InputSVProjector(
            s_in=26, v_in=5, s_node=self.s_node, v_node=self.v_node, aa_vocab=21, aa_emb=16, dropout=self.dropout
        )

        # Stack of GVPConv layers
        layers = []
        for _ in range(self.num_layers):
            layers.append(GVPConv(
                s_node=self.s_node, v_node=self.v_node,
                s_edge=self.s_edge, v_edge=self.v_edge,
                s_hidden=self.s_node, v_hidden=self.v_node,
                dropout=self.dropout
            ))
        self.layers = nn.ModuleList(layers)

        # Residue head: collapse (s,v) to residue embedding
        self.res_scalar_head = nn.Linear(self.s_node, self.res_dim)
        self.res_vector_head = nn.Linear(self.v_node, self.res_dim) if self.v_node > 0 else None

        # Global readout
        self.readout = GlobalReadout(
            in_dim=self.res_dim,
            out_dim=self.glob_dim,
            use_attention=bool(pool["use_attention_pool"]),
            gate_plddt=bool(pool["plddt_gate_attention"]),
            concat_mean=bool(pool["concat_mean_pool"])
        )

    def forward(self, batch: dict) -> dict:
        """
        batch keys:
          node_scalar [L,S_n_raw], node_vector [L,V_n_raw,3], aa_index [L],
          edge_index [2,E], edge_scalar [E,S_e_raw], edge_vector [E,V_e_raw,3],
          plddt [L]
        """
        ns = batch["node_scalar"]; nv = batch["node_vector"]; aa = batch["aa_index"]
        ei = batch["edge_index"]; es = batch["edge_scalar"]; ev = batch["edge_vector"]
        plddt = batch["plddt"]

        # Project inputs
        s, v = self.input_proj(ns, nv, aa)

        # Edge linear pre-map to working dims (if needed, we already set s_edge/v_edge consistent)
        e_s = es
        e_v = ev

        # GVPConv stack
        for layer in self.layers:
            s, v = layer(s, v, ei, e_s, e_v)

        # Residue embeddings from s (+ vector magnitude summary)
        r_s = self.res_scalar_head(s)                         # [L, res_dim]
        if self.res_vector_head is not None and v.size(1) > 0:
            v_mag = torch.linalg.norm(v, dim=-1)             # [L, V]
            r_v = self.res_vector_head(v_mag)                # [L, res_dim]
            r = F.silu(r_s + r_v)
        else:
            r = F.silu(r_s)

        # Global embedding
        g = self.readout(r, plddt)                           # [glob_dim]

        return {"residue_emb": r, "global_emb": g}


Cell 22 — Inference helper: load .npz → tensors → encode one protein

In [43]:
# Cell 22: Encode one cached graph (.npz) to residue/global embeddings

def load_graph_npz(path: Path, device: torch.device) -> dict:
    data = np.load(path, allow_pickle=True)
    ns = torch.from_numpy(data["node_scalar"]).to(device=device, dtype=torch.float32)
    nv = torch.from_numpy(data["node_vector"]).to(device=device, dtype=torch.float32)
    aa = torch.from_numpy(data["aa_index"]).to(device=device, dtype=torch.long)
    vi = torch.from_numpy(data["valid_mask"]).to(device=device, dtype=torch.bool)  # not used directly now
    plddt = torch.from_numpy(data["plddt"]).to(device=device, dtype=torch.float32)
    ei = torch.from_numpy(data["edge_index"]).to(device=device, dtype=torch.long)
    es = torch.from_numpy(data["edge_scalar"]).to(device=device, dtype=torch.float32)
    ev = torch.from_numpy(data["edge_vector"]).to(device=device, dtype=torch.float32)
    return {
        "node_scalar": ns, "node_vector": nv, "aa_index": aa,
        "edge_index": ei, "edge_scalar": es, "edge_vector": ev,
        "plddt": plddt, "valid_mask": vi
    }

# Build model
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = ProteinGVPEncoder(CFG).to(DEVICE)

# Optional: compile for speed
# if USE_COMPILE and hasattr(torch, "compile"):
#     model = torch.compile(model)

# AMP autocast dtype from config
amp_dtype = AMP_DTYPE

# --- Dry run on A0PJK1 ---
graph_npz = GRAPH_DIR / "A0PJK1.graph.npz"
assert graph_npz.exists(), f"Missing graph: {graph_npz}"

batch = load_graph_npz(graph_npz, DEVICE)
model.eval()
with torch.no_grad():
    with torch.autocast(device_type="cuda", dtype=amp_dtype) if DEVICE.type == "cuda" else torch.no_grad():
        out = model(batch)
re, ge = out["residue_emb"], out["global_emb"]
print("Residue emb shape:", tuple(re.shape))
print("Global  emb shape:", tuple(ge.shape))


Residue emb shape: (596, 256)
Global  emb shape: (1024,)


Part 4 — Batch-encode all proteins and write the final Parquet

Cell 23 — Load main Parquet & list unique proteins

In [44]:
# Cell 23: List unique UniProt IDs from main dataset
MAIN_DF = pd.read_parquet(MAIN_PARQUET)
uni_prots = pd.unique(MAIN_DF["target_uniprot_id"]).tolist()
print("Unique proteins:", len(uni_prots))

Unique proteins: 2385


Cell 24 — Helper: ensure graph exists, then encode one protein

In [45]:
# Cell 24: Ensure graph exists -> load -> encode -> return row dict
from tqdm import tqdm

def ensure_graph(uniprot_id: str) -> Path:
    npz_path = GRAPH_DIR / f"{uniprot_id}.graph.npz"
    if not npz_path.exists():
        # Build it now (uses PDB in PDB_DIR)
        _ = build_and_cache_graph_for_uniprot(uniprot_id)
    return npz_path

@torch.no_grad()
def encode_one(uniprot_id: str) -> Optional[dict]:
    npz_path = ensure_graph(uniprot_id)
    if not npz_path.exists():
        print(f"[WARN] Graph missing for {uniprot_id}, skipping.")
        return None

    batch = load_graph_npz(npz_path, DEVICE)
    model.eval()
    # autocast for speed; you already have AMP dtype set
    if DEVICE.type == "cuda":
        ctx = torch.autocast(device_type="cuda", dtype=amp_dtype)
    else:
        from contextlib import nullcontext
        ctx = nullcontext()

    with ctx:
        out = model(batch)
    r = out["residue_emb"]           # [L, 256]
    g = out["global_emb"]            # [1024]

    # Metadata
    L = int(r.size(0))
    mean_plddt = float(batch["plddt"].mean().item())
    # Read md5 & meta from npz
    meta = json.loads(np.load(npz_path, allow_pickle=True)["meta"].item().decode("utf-8"))
    pdb_md5 = meta["pdb_md5"]

    return {
        "uniprot_id": uniprot_id,
        "length": L,
        "mean_pLDDT": round(mean_plddt, 2),
        "embedding_dim": int(g.numel()),
        "encoder_version": CFG["repro"]["encoder_version"],
        "pdb_md5": pdb_md5,
        "embedding": g.detach().float().cpu().tolist(),  # store as list in parquet
    }


Cell 25 — Run all, write Parquet + meta files

In [120]:
# Cell 25: Batch encode and write outputs
rows = []
fail = []
for up in tqdm(uni_prots, desc="Encoding proteins"):
    try:
        row = encode_one(up)
        if row is not None:
            rows.append(row)
    except Exception as e:
        fail.append((up, str(e)))

df_out = pd.DataFrame(rows)
print("Encoded:", len(df_out), "Failed:", len(fail))

# Write final embeddings parquet
OUT_PARQUET.parent.mkdir(parents=True, exist_ok=True)
df_out.to_parquet(OUT_PARQUET, index=False)
print("Wrote:", OUT_PARQUET)

# Save meta/build_config.json
build_cfg_path = META_DIR / "build_config.json"
with open(build_cfg_path, "w", encoding="utf-8") as f:
    json.dump(CFG, f, indent=2)
print("Saved:", build_cfg_path)

# Save PDB checksums (for reproducibility)
chk_rows = []
for up in df_out["uniprot_id"].tolist():
    gpath = GRAPH_DIR / f"{up}.graph.npz"
    meta = json.loads(np.load(gpath, allow_pickle=True)["meta"].item().decode("utf-8"))
    chk_rows.append({"uniprot_id": up, "pdb_md5": meta["pdb_md5"], "graph_npz": str(gpath)})
chk_df = pd.DataFrame(chk_rows)
chk_path = META_DIR / "pdb_checksums.csv"
chk_df.to_csv(chk_path, index=False)
print("Saved:", chk_path)

# (Optional) brief stats
if len(df_out):
    print("Mean embedding norm:", float(np.mean([np.linalg.norm(np.array(x)) for x in df_out["embedding"]])))
if fail:
    print("Failures (first 5):", fail[:5])


Encoding proteins: 100%|██████████| 2385/2385 [21:41<00:00,  1.83it/s] 


Encoded: 2382 Failed: 3
Wrote: F:\Thesis Korbi na\dti-prediction-with-adr\Data\protein_embeddings.parquet
Saved: F:\Thesis Korbi na\dti-prediction-with-adr\Data\gvp_meta\build_config.json
Saved: F:\Thesis Korbi na\dti-prediction-with-adr\Data\gvp_meta\pdb_checksums.csv
Mean embedding norm: 7.591435501926689
Failures (first 5): [('Q6LAP9', 'kth(=36) out of bounds (17)'), ('Q9UE13', 'kth(=36) out of bounds (35)'), ('O43519', 'kth(=36) out of bounds (23)')]


Identify missing proteins (no graph / no GVP embedding)

In [46]:
# Cell 37: Which UniProt IDs are missing from embeddings?
main_df = pd.read_parquet(MAIN_PARQUET)
all_up = pd.unique(main_df["target_uniprot_id"]).tolist()

emb_parquet = OUT_PARQUET
if emb_parquet.exists():
    emb_df = pd.read_parquet(emb_parquet)
    have_up = set(emb_df["uniprot_id"].tolist())
else:
    emb_df = pd.DataFrame(columns=["uniprot_id"])
    have_up = set()

missing_up = [u for u in all_up if u not in have_up]
print("Missing proteins:", len(missing_up))
print(missing_up[:10])


Missing proteins: 3
['Q6LAP9', 'Q9UE13', 'O43519']


Load sequences for the missing UniProt IDs

In [47]:
# Cell 38: Fetch sequences for missing proteins
seq_map = (main_df[["target_uniprot_id","sequence"]]
           .drop_duplicates("target_uniprot_id")
           .set_index("target_uniprot_id")["sequence"]
           .to_dict())

missing_with_seq = [(u, seq_map.get(u, None)) for u in missing_up]
missing_with_seq = [(u, s) for (u, s) in missing_with_seq if isinstance(s, str) and len(s) > 0]

print("Missing with sequences available:", len(missing_with_seq))
print(missing_with_seq[:3])


Missing with sequences available: 3
[('Q6LAP9', 'MWLRAFILATLSASAAW'), ('Q9UE13', 'NASPSELRDLLSEFNVLKQVNHPHVIKLYGACSQD'), ('O43519', 'GEGDVRCRGAASAVAAAAAAARQ')]


ESM2 loader (pick the size you already use) + sequence embedding helper

In [49]:
# Cell 39: ESM2 model + helpers
import torch
import numpy as np

try:
    import esm  # fair-esm
except ImportError as e:
    raise RuntimeError("Please `pip install fair-esm` in your environment before running this cell.") from e

ESM_MODEL_NAME = "esm2_t33_650M_UR50D"  # or "esm2_t12_35M_UR50D" if you want speed
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D() if ESM_MODEL_NAME=="esm2_t33_650M_UR50D" else esm.pretrained.esm2_t12_35M_UR50D()
batch_converter = alphabet.get_batch_converter()
model = model.eval().to(DEVICE)

# Autocast on GPU
def esm_embed_sequence(uniprot_id: str, seq: str) -> np.ndarray:
    """
    Returns a single vector: mean-pooled per-residue ESM2 representation (excluding BOS/EOS).
    Shape depends on model: t33 -> 1280 dims; t12 -> 480 dims.
    """
    data = [(uniprot_id, seq)]
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_tokens = batch_tokens.to(DEVICE)
    with torch.no_grad():
        with torch.autocast(device_type="cuda", dtype=AMP_DTYPE) if DEVICE.type=="cuda" else torch.no_grad():
            out = model(batch_tokens, repr_layers=[model.num_layers], return_contacts=False)
    token_reprs = out["representations"][model.num_layers][0]            # [L+2, D]
    # slice off BOS/EOS
    reps = token_reprs[1:1+len(seq)]
    emb = reps.mean(dim=0).float().cpu().numpy()                         # [D]
    return emb


Downloading: "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt" to C:\Users\Fahmid/.cache\torch\hub\checkpoints\esm2_t33_650M_UR50D.pt
Downloading: "https://dl.fbaipublicfiles.com/fair-esm/regression/esm2_t33_650M_UR50D-contact-regression.pt" to C:\Users\Fahmid/.cache\torch\hub\checkpoints\esm2_t33_650M_UR50D-contact-regression.pt


Learn a linear map ESM→GVP (fits once on proteins you already have)

In [50]:
# Cell 40: Fit linear projection from ESM space to GVP 1024-D space
# Build a small calibration set from proteins where we have both:
cal_ids = emb_df["uniprot_id"].tolist()
np.random.seed(1337)
np.random.shuffle(cal_ids)
cal_ids = cal_ids[:min(200, len(cal_ids))]  # up to 200 for speed; increase if you want

if not len(cal_ids):
    raise RuntimeError("No existing GVP embeddings found to fit ESM→GVP map. Encode some proteins first.")

# Gather pairs (ESM, GVP)
X_list, Y_list = [], []
for up in cal_ids:
    seq = seq_map.get(up, None)
    if not isinstance(seq, str) or not len(seq):
        continue
    try:
        x = esm_embed_sequence(up, seq)          # [D_esm]
        y = emb_df.loc[emb_df["uniprot_id"]==up, "embedding"].iloc[0]
        y = np.array(y, dtype=np.float32)        # [1024]
        if x.ndim==1 and y.ndim==1:
            X_list.append(x)
            Y_list.append(y)
    except Exception as e:
        # skip problematic sequences if any
        pass

X = np.stack(X_list, axis=0).astype(np.float32)   # [N, D_esm]
Y = np.stack(Y_list, axis=0).astype(np.float32)   # [N, 1024]

print("Calibration pairs:", X.shape, Y.shape)

# Solve least squares: X W = Y -> W = (X^T X)^-1 X^T Y
# Use torch for GPU-accelerated solve (then move back to CPU for storage).
Xt = torch.from_numpy(X).to(DEVICE)
Yt = torch.from_numpy(Y).to(DEVICE)
with torch.no_grad():
    # regularized least squares (ridge) for stability
    lam = 1e-5
    A = Xt.T @ Xt + lam * torch.eye(Xt.shape[1], device=DEVICE, dtype=Xt.dtype)
    B = Xt.T @ Yt
    W = torch.linalg.solve(A, B)  # [D_esm, 1024]
W_cpu = W.float().cpu().numpy()
print("W shape:", W_cpu.shape)


Calibration pairs: (200, 1280) (200, 1024)
W shape: (1280, 1024)


Embed the missing proteins and write them into the Parquet

In [51]:
# Cell 41: Apply ESM → project to 1024 → append to protein_embeddings.parquet
rows = []
for (up, seq) in missing_with_seq:
    try:
        x = esm_embed_sequence(up, seq)                 # [D_esm]
        gvp_1024 = (torch.from_numpy(x).to(torch.float32) @ torch.from_numpy(W_cpu)).numpy()  # [1024]
        # metadata
        L = len(seq)
        mean_plddt = float("nan")  # no structure; mark NaN
        row = {
            "uniprot_id": up,
            "length": int(L),
            "mean_pLDDT": mean_plddt,
            "embedding_dim": int(gvp_1024.shape[0]),
            "encoder_version": CFG["repro"]["encoder_version"] + "+seq_fallback_" + ESM_MODEL_NAME,
            "pdb_md5": "NA_sequence_only",
            "embedding": gvp_1024.astype(np.float32).tolist(),
            "source": "sequence_only",
        }
        rows.append(row)
    except Exception as e:
        print(f"[SKIP] {up} failed with: {e}")

add_df = pd.DataFrame(rows)
print("Sequence-only embeddings:", len(add_df))

# Merge into existing parquet (append)
if emb_parquet.exists() and len(add_df):
    base = pd.read_parquet(emb_parquet)
    # de-dup if re-running
    base = base[~base["uniprot_id"].isin(add_df["uniprot_id"])]
    out = pd.concat([base, add_df], ignore_index=True)
    out.to_parquet(emb_parquet, index=False)
    print("Updated:", emb_parquet)
elif len(add_df):
    add_df.to_parquet(emb_parquet, index=False)
    print("Wrote new:", emb_parquet)
else:
    print("Nothing to append.")


Sequence-only embeddings: 3
Updated: F:\Thesis Korbi na\dti-prediction-with-adr\Data\protein_embeddings.parquet
