Cell 1 — Title & env echo

In [57]:
# EGNN Drug Encoder — Part 1: Setup, Config, Data Scan
# Notebook: EGNN_Encode_Drugs.ipynb (Part 1/3)

import sys, os, platform, json, math, random, shutil
from pathlib import Path

import numpy as np
import pandas as pd

import torch

print("Environment")
print("--------------------------------------------------------------------")
print(f"Python      : {platform.python_version()}")
print(f"OS          : {platform.system()} {platform.release()} ({platform.machine()})")
print(f"PyTorch     : {torch.__version__}")
print(f"Device      : {'CUDA:'+str(torch.cuda.current_device()) if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
    print(f"GPU         : {torch.cuda.get_device_name(0)}")
    print(f"VRAM total  : {round(torch.cuda.get_device_properties(0).total_memory/1024**3,1)} GB")
print(f"TF32        : {torch.backends.cuda.matmul.allow_tf32}")
print(f"AMP dtype   : float16 (per config)")
print(f"cudnn.bench : {torch.backends.cudnn.benchmark}")
print(f"torch.compile available: {hasattr(torch, 'compile')}")


Environment
--------------------------------------------------------------------
Python      : 3.12.11
OS          : Windows 11 (AMD64)
PyTorch     : 2.7.0+cu128
Device      : CUDA:0
GPU         : NVIDIA GeForce RTX 5070 Ti
VRAM total  : 15.9 GB
TF32        : True
AMP dtype   : float16 (per config)
cudnn.bench : True
torch.compile available: True


Cell 2 — Config loader

In [58]:
import yaml

# Paths relative to this notebook (inside /preporcessing_code)
NB_DIR = Path(__file__).parent if '__file__' in globals() else Path.cwd()
CFG_PATH = NB_DIR / "egnn_config.yaml"

assert CFG_PATH.exists(), f"Config file not found: {CFG_PATH}"

with open(CFG_PATH, "r", encoding="utf-8") as f:
    CFG = yaml.safe_load(f)

def cfg_get(path, default=None):
    """Helper to fetch nested config keys with 'dot.path' style."""
    cur = CFG
    for key in path.split('.'):
        if key not in cur:
            return default
        cur = cur[key]
    return cur

# Echo key config bits
print("Loaded config from:", CFG_PATH.as_posix())
print(json.dumps({
    "data_root": cfg_get("paths.data_root"),
    "main_parquet": cfg_get("paths.main_parquet"),
    "out_embeddings_parquet": cfg_get("paths.out_embeddings_parquet"),
    "graph_cache_dir": cfg_get("paths.graph_cache_dir"),
    "device": cfg_get("hardware.device"),
    "embed_dim": cfg_get("egnn_encoder.embed_dim"),
    "n_layers": cfg_get("egnn_encoder.n_layers"),
    "pooling": cfg_get("egnn_encoder.pooling"),
}, indent=2))


Loaded config from: f:/Thesis Korbi na/dti-prediction-with-adr/preporcessing_code/egnn_config.yaml
{
  "data_root": "../Data",
  "main_parquet": "../Data/scope_onside_common_v3.parquet",
  "out_embeddings_parquet": "../Data/EGNN_drug_embeddings_v2.parquet",
  "graph_cache_dir": "../Data/graph_cache_egnn_v2",
  "device": "cuda:0",
  "embed_dim": 256,
  "n_layers": 6,
  "pooling": "mean"
}


Cell 3 — Repro, CUDA preferences (TF32/AMP)

In [59]:
# Repro
seed = int(cfg_get("project.seed", 1337))
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# CUDA preferences from config / your environment
torch.backends.cuda.matmul.allow_tf32 = bool(cfg_get("hardware.tf32", True))
torch.backends.cudnn.benchmark = True  # per your env
DEVICE = torch.device(cfg_get("hardware.device", "cuda:0") if torch.cuda.is_available() else "cpu")
AMP_DTYPE = torch.float16  # per your env
print("Device set to:", DEVICE)


Device set to: cuda:0


Cell 4 — Load main dataset (Parquet) & validate schema

In [60]:
MAIN_PARQUET = (NB_DIR / cfg_get("paths.main_parquet")).resolve()
assert MAIN_PARQUET.exists(), f"Main parquet not found at {MAIN_PARQUET}"

df = pd.read_parquet(MAIN_PARQUET)
required_cols = ["drug_chembl_id", "target_uniprot_id", "label", "smiles", "sequence", "molfile_3d", "rxcui"]
missing = [c for c in required_cols if c not in df.columns]
assert not missing, f"Missing required columns: {missing}"

# No missing data per your constraint, but we'll assert anyway
for c in required_cols:
    assert df[c].isna().sum() == 0, f"Column {c} has NA values"

print("Loaded main dataset:", MAIN_PARQUET.name)
print(df.info())

if bool(cfg_get("logging.echo_dataset_head", True)):
    display(df.head(5))

Loaded main dataset: scope_onside_common_v3.parquet
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 34741 entries, 0 to 34740
Data columns (total 7 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   drug_chembl_id     34741 non-null  object
 1   target_uniprot_id  34741 non-null  object
 2   label              34741 non-null  int64 
 3   smiles             34741 non-null  object
 4   sequence           34741 non-null  object
 5   molfile_3d         34741 non-null  object
 6   rxcui              34741 non-null  object
dtypes: int64(1), object(6)
memory usage: 1.9+ MB
None


Unnamed: 0,drug_chembl_id,target_uniprot_id,label,smiles,sequence,molfile_3d,rxcui
0,CHEMBL1000,O15245,0,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MPTVDDILEQVGESGWFQKQAFLILCLLSAAFAPICVGIVFLGFTP...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610
1,CHEMBL1000,P08183,1,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MDLEGDRNGGAKKKNFFKLNNKSEKDKKEKKPTVSVFSMFRYSNWL...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610
2,CHEMBL1000,P35367,1,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MSLPNSSCLLEDKMCEGNKTTMASPQLMPLVVVLSTICLVTVGLNL...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610
3,CHEMBL1000,Q02763,0,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MDSLASLVLCGVSLLLSGTVEGAMDLILINSLPLVSDAETSLTCIA...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610
4,CHEMBL1000,Q12809,0,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MPVRRGHVAPQNTFLDTIIRKFEGQSRKFIIANARVENCAVIYCND...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610


Cell 5 — Derive unique drugs & basic stats

In [61]:
# We will embed *drugs* once; later we’ll join by drug_chembl_id/rxcui during DTI/ADR training
drug_cols = ["drug_chembl_id", "smiles", "molfile_3d", "rxcui"]
df_drugs = df[drug_cols].drop_duplicates(subset=["drug_chembl_id"]).reset_index(drop=True)

n_pairs = len(df)
n_drugs = len(df_drugs)
n_targets = df["target_uniprot_id"].nunique()

print(f"DTI pairs : {n_pairs:,}")
print(f"Unique drugs : {n_drugs:,}")
print(f"Unique proteins : {n_targets:,}")

display(df_drugs.sample(min(5, n_drugs), random_state=seed))

DTI pairs : 34,741
Unique drugs : 1,028
Unique proteins : 2,385


Unnamed: 0,drug_chembl_id,smiles,molfile_3d,rxcui
126,CHEMBL1200472,Fc1ccccc1C1=NCC(=S)N(CC(F)(F)F)c2ccc(Cl)cc21,\n RDKit 3D\n\n 36 38 0 0 0 0...,35185
737,CHEMBL3545363,CC(C)(C)[C@@H]1NC(=O)O[C@@H]2CCC[C@H]2OC/C=C/C...,\n RDKit 3D\n\n104110 0 0 0 0...,1940635
277,CHEMBL128,CNS(=O)(=O)Cc1ccc2[nH]cc(CCN(C)C)c2c1,\n RDKit 3D\n\n 41 42 0 0 0 0...,37418
758,CHEMBL372795,CN[C@@H]1[C@H](O[C@H]2[C@H](O[C@H]3[C@H](O)[C@...,\n RDKit 3D\n\n 79 81 0 0 0 0...,10109
170,CHEMBL1201199,CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(=N)N)NC(=...,\n RDKit 3D\n\n171176 0 0 0 0...,42375


Cell 6 — Prepare output and cache dirs

In [62]:
# GRAPH_CACHE_DIR = (NB_DIR / cfg_get("paths.graph_cache_dir")).resolve()
# OUT_PARQUET = (NB_DIR / cfg_get("paths.out_embeddings_parquet")).resolve()

# GRAPH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# OUT_PARQUET.parent.mkdir(parents=True, exist_ok=True)

# print("Graph cache dir :", GRAPH_CACHE_DIR.as_posix())
# print("Output parquet  :", OUT_PARQUET.as_posix())

# If you want a clean slate (commented by default)
# if OUT_PARQUET.exists():
#     OUT_PARQUET.unlink()
#     print("Removed existing output parquet to start fresh.")

# Cell 6B — V2 paths (separate cache + parquet)
GRAPH_CACHE_DIR = (NB_DIR / (cfg_get("paths.graph_cache_dir"))).resolve()
OUT_PARQUET     = (NB_DIR / cfg_get("paths.out_embeddings_parquet")).with_name("EGNN_drug_embeddings_v2.parquet").resolve()

GRAPH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
OUT_PARQUET.parent.mkdir(parents=True, exist_ok=True)

print("V2 Graph cache dir :", GRAPH_CACHE_DIR.as_posix())
print("V2 Output parquet  :", OUT_PARQUET.as_posix())

# (Re)initialize empty parquet if not present
if not OUT_PARQUET.exists():
    pd.DataFrame({
        "drug_chembl_id": pd.Series(dtype="string"),
        "rxcui": pd.Series(dtype="string"),
        "embedding": pd.Series(dtype="object"),
    }).to_parquet(OUT_PARQUET, engine=cfg_get("io.parquet_engine", "pyarrow"))
    print("Initialized empty V2 embeddings parquet.")
else:
    print("V2 embeddings parquet already exists; will overwrite at the end.")


V2 Graph cache dir : F:/Thesis Korbi na/dti-prediction-with-adr/Data/graph_cache_egnn_v2
V2 Output parquet  : F:/Thesis Korbi na/dti-prediction-with-adr/Data/EGNN_drug_embeddings_v2.parquet
V2 embeddings parquet already exists; will overwrite at the end.


Cell 7 — Define output schema helper (embedding placeholder)

In [63]:
EMBED_DIM = int(cfg_get("egnn_encoder.embed_dim", 256))

def make_empty_output_df():
    return pd.DataFrame({
        "drug_chembl_id": pd.Series(dtype="string"),
        "rxcui": pd.Series(dtype="string"),
        "embedding": pd.Series(dtype="object"),   # list<float32> per row
    })

# Create file with schema if it doesn't exist (append later in batches)
if not OUT_PARQUET.exists():
    empty = make_empty_output_df()
    empty.to_parquet(OUT_PARQUET, engine=cfg_get("io.parquet_engine", "pyarrow"))
    print("Initialized empty embeddings parquet.")
else:
    print("Embeddings parquet already exists; will append batches.")

Embeddings parquet already exists; will append batches.


Cell 9 — Imports for chemistry + graph building

In [64]:
# Part 2 – Molecular Graph Builder
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem

import numpy as np
import torch
from torch_geometric.data import Data

import pickle
from tqdm.auto import tqdm

Cell 10 — Atom & bond feature helpers

In [65]:
# Cell 10 — Feature helpers (V2)
def one_hot(x, choices):
    v = [0]*len(choices)
    try:
        v[choices.index(x)] = 1
    except ValueError:
        pass
    return v

# Reasonable coverage for small molecules; extend if needed
ATOM_LIST = list(range(1, 31))  # H..Zn
HYB_LIST  = [
    Chem.rdchem.HybridizationType.SP,
    Chem.rdchem.HybridizationType.SP2,
    Chem.rdchem.HybridizationType.SP3,
    Chem.rdchem.HybridizationType.SP3D,
    Chem.rdchem.HybridizationType.SP3D2,
]
CHIR_LIST = [
    Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
    Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
]

BOND_LIST = [
    Chem.rdchem.BondType.SINGLE,
    Chem.rdchem.BondType.DOUBLE,
    Chem.rdchem.BondType.TRIPLE,
    Chem.rdchem.BondType.AROMATIC,
]
STEREO_LIST = [
    Chem.rdchem.BondStereo.STEREONONE,
    Chem.rdchem.BondStereo.STEREOZ,
    Chem.rdchem.BondStereo.STEREOE,
]

def atom_features(atom: Chem.Atom):
    feats = []
    feats += one_hot(atom.GetAtomicNum(), ATOM_LIST)
    feats += [atom.GetFormalCharge()]
    feats += [int(atom.GetIsAromatic())]
    feats += one_hot(atom.GetHybridization(), HYB_LIST)
    feats += one_hot(atom.GetChiralTag(), CHIR_LIST) if bool(cfg_get("featurization.add_chirality", True)) else []
    if bool(cfg_get("featurization.add_atomic_mass", True)):
        feats += [atom.GetMass() * 0.01]  # light scaling
    if bool(cfg_get("featurization.add_valence", True)):
        feats += [atom.GetTotalValence()]
    if bool(cfg_get("featurization.add_total_h", True)):
        feats += [atom.GetTotalNumHs(includeNeighbors=True)]
    feats += [int(atom.IsInRing())]
    if bool(cfg_get("featurization.add_ring_size", True)):
        try:
            r = atom.GetOwningMol().GetRingInfo().AtomRings()
            # smallest ring size containing this atom
            s = min((len(rng) for rng in r if atom.GetIdx() in rng), default=0)
            feats += [min(s, 8)]
        except Exception:
            feats += [0]
    return feats

def bond_features(b: Chem.Bond):
    feats = []
    feats += one_hot(b.GetBondType(), BOND_LIST)
    feats += [int(b.GetIsConjugated())]
    feats += [int(b.IsInRing())]
    feats += one_hot(b.GetStereo(), STEREO_LIST)
    return feats


Cell 11 — Parse MOL and build graph object

In [72]:
# Cell 11 — mol_to_graph (V2: bond + radius edges, partial charges) — FIXED
from rdkit.Chem import AllChem
from torch_geometric.data import Data
import torch
import numpy as np

def _bond_feat_length():
    # Matches bond_features(): one-hot over BOND_LIST, +is_conjugated, +in_ring, +one-hot over STEREO_LIST
    return len(BOND_LIST) + 1 + 1 + len(STEREO_LIST)

def mol_to_graph(mol_block: str, max_atoms: int = 120, radius: float = 4.8):
    mol = Chem.MolFromMolBlock(
        mol_block,
        removeHs=not bool(cfg_get("featurization.use_explicit_h", True)),
        sanitize=True
    )
    assert mol is not None, "Failed to parse MOL block"
    N = mol.GetNumAtoms()
    assert 0 < N <= max_atoms, f"Molecule atoms {N} out of bounds (max={max_atoms})"
    conf = mol.GetConformer()

    # Optional partial charges once per molecule
    add_q = bool(cfg_get("featurization.add_partial_charge", True))
    if add_q:
        try:
            AllChem.ComputeGasteigerCharges(mol)
        except Exception:
            add_q = False  # fallback silently

    # Node features
    x_list = []
    for a in mol.GetAtoms():
        xf = atom_features(a)
        if add_q:
            try:
                q = float(a.GetProp('_GasteigerCharge'))
                xf += [q if np.isfinite(q) else 0.0]
            except Exception:
                xf += [0.0]
        x_list.append(xf)
    x = torch.tensor(x_list, dtype=torch.float32)

    # Coordinates (float32 for stability)
    pos = torch.tensor([[*conf.GetAtomPosition(i)] for i in range(N)], dtype=torch.float32)

    # Prepare bond feature length and final edge feature length (+2 flags)
    bond_feat_len = _bond_feat_length()
    final_edge_feat_len = bond_feat_len + 2  # is_bond, is_radius

    # Bond edges
    src, dst, eattr = [], [], []
    for b in mol.GetBonds():
        u, v = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        bf = bond_features(b)  # length == bond_feat_len
        bf = bf + [1, 0]       # is_bond=1, is_radius=0
        src += [u, v]; dst += [v, u]; eattr += [bf, bf]

    # Radius edges for non-bonded neighbors
    P = pos.cpu().numpy()
    for i in range(N):
        Pi = P[i]
        for j in range(i+1, N):
            if mol.GetBondBetweenAtoms(i, j):  # already added as bond edge
                continue
            d = float(np.linalg.norm(Pi - P[j]))
            if d <= radius:
                rf = [0]*bond_feat_len + [0, 1]  # zeros for bond part, is_bond=0, is_radius=1
                src += [i, j]; dst += [j, i]; eattr += [rf, rf]

    # Build edge tensors (allow E=0)
    if len(src) == 0:
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr  = torch.empty((0, final_edge_feat_len), dtype=torch.float32)
    else:
        edge_index = torch.tensor([src, dst], dtype=torch.long)
        edge_attr  = torch.tensor(eattr, dtype=torch.float32)
        # Safety: enforce expected width
        assert edge_attr.shape[1] == final_edge_feat_len, \
            f"edge_attr width {edge_attr.shape[1]} != expected {final_edge_feat_len}"

    g = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)
    return g


Cell 12 — Graph cache writer

In [73]:
from pathlib import Path

def cache_graph(drug_id: str, mol_block: str, cache_dir: Path, max_atoms: int):
    """Build & save graph for a drug if not already cached"""
    cache_path = cache_dir / f"{drug_id}.pkl"
    if cache_path.exists():
        return cache_path

    data = mol_to_graph(mol_block, max_atoms=max_atoms)
    with open(cache_path, "wb") as f:
        pickle.dump(data, f)
    return cache_path

Cell 13 — Batch graph creation loop

In [74]:
# Cell 13 — Rebuild graphs into V2 cache
MAX_ATOMS = int(cfg_get("egnn_encoder.max_atoms", 120))
paths = []
print(f"[V2] Building & caching graphs → {GRAPH_CACHE_DIR.as_posix()}")

for _, row in tqdm(df_drugs.iterrows(), total=len(df_drugs)):
    drug_id = row["drug_chembl_id"]
    mol_block = row["molfile_3d"]
    cache_path = GRAPH_CACHE_DIR / f"{drug_id}.pkl"
    if not cache_path.exists():
        data = mol_to_graph(mol_block, max_atoms=MAX_ATOMS)
        with open(cache_path, "wb") as f:
            pickle.dump(data, f)
    paths.append(cache_path)

print(f"[V2] Cached {len(paths)} molecule graphs.")


[V2] Building & caching graphs → F:/Thesis Korbi na/dti-prediction-with-adr/Data/graph_cache_egnn_v2


  0%|          | 0/1028 [00:00<?, ?it/s]

[V2] Cached 1028 molecule graphs.


Cell 14 — Quick integrity check

In [75]:
import random, pickle
sample = random.choice(paths)
with open(sample, "rb") as f:
    g = pickle.load(f)

print("Random cached graph:", sample.name)
print(f"Nodes: {g.x.shape[0]}, Edges: {g.edge_index.shape[1]}, Edge features: {g.edge_attr.shape[1]}")
print(f"Coords shape: {g.pos.shape}")

Random cached graph: CHEMBL3646221.pkl
Nodes: 32, Edges: 468, Edge features: 11
Coords shape: torch.Size([32, 3])


In [76]:
print("""
✅  Graph cache complete.

Next:
 • Define the EGNN layer stack (equivariant message passing with coord updates)
 • Encode all cached graphs in mini-batches
 • Pool node embeddings → per-drug vectors
 • Save to Data/EGNN_drug_embeddings.parquet
""")



✅  Graph cache complete.

Next:
 • Define the EGNN layer stack (equivariant message passing with coord updates)
 • Encode all cached graphs in mini-batches
 • Pool node embeddings → per-drug vectors
 • Save to Data/EGNN_drug_embeddings.parquet



Cell 16 — Imports and utilities

In [79]:
# Part 3 – EGNN Encoder & Embedding Writer
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.nn.functional as F
from torch_geometric.loader import DataLoader

import pickle
from tqdm.auto import tqdm

In [85]:
# Cell 16.5 — RBF for distances (fixed broadcasting)
import torch

def rbf(d, D_min=0.0, D_max=6.0, D_count=16):
    # d: shape [E] (distances)
    d = d.unsqueeze(-1)  # [E, 1]
    step = max((D_max - D_min) / max(D_count - 1, 1), 1e-6)
    centers = torch.linspace(D_min, D_max, D_count, device=d.device, dtype=d.dtype)  # [D]
    centers = centers.unsqueeze(0)  # [1, D]
    gamma = 1.0 / (step ** 2)
    # result: [E, D]
    return torch.exp(-gamma * (d.clamp(D_min, D_max) - centers) ** 2)


Cell 17 — EGNN layer definition

In [81]:
# Cell 17 — EGNNLayer (V2)


class EGNNLayer(nn.Module):
    def __init__(self, in_dim, edge_dim, hidden_dim, rbf_dim=16, update_coords=False, p_drop=0.1):
        super().__init__()
        self.update_coords = update_coords
        self.rbf_dim = rbf_dim

        self.edge_mlp = nn.Sequential(
            nn.Linear(in_dim*2 + edge_dim + rbf_dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.node_mlp = nn.Sequential(
            nn.Linear(in_dim + hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Dropout(p_drop),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.norm_in  = nn.LayerNorm(in_dim)
        self.norm_msg = nn.LayerNorm(hidden_dim)
        self.norm_out = nn.LayerNorm(hidden_dim)

        if update_coords:
            self.coord_mlp = nn.Sequential(nn.Linear(hidden_dim, 1), nn.Tanh())

    def forward(self, x, pos, edge_index, edge_attr):
        # x/edge_attr follow compute dtype; pos stays fp32
        i, j = edge_index

        x = self.norm_in(x)

        # Geometry in fp32
        pos = pos.float()
        if edge_index.numel() == 0:
            # no edges -> zero messages shaped to hidden_dim
            hidden_dim = self.node_mlp[0].in_features - x.size(1)
            m_j = torch.zeros(x.size(0), hidden_dim, device=x.device, dtype=x.dtype)
            h = self.node_mlp(torch.cat([x, m_j], dim=1))
            h = self.norm_out(h)
            return h, pos

        rel  = pos[i] - pos[j]                   # [E,3] fp32
        dist = (rel.pow(2).sum(dim=1)).sqrt()    # [E]   fp32
        dist_rbf = rbf(dist.to(x.dtype), D_count=self.rbf_dim)  # [E, rbf_dim], compute dtype

        # Build edge input and assert width matches edge_mlp
        ea = edge_attr.to(x.dtype)
        e_input = torch.cat([x[i], x[j], ea, dist_rbf], dim=1)
        expected = self.edge_mlp[0].in_features
        assert e_input.shape[1] == expected, f"edge_mlp expects {expected}, got {e_input.shape[1]}"

        e_ij = self.edge_mlp(e_input).to(x.dtype)
        e_ij = self.norm_msg(e_ij)

        # Aggregate messages on destination i
        m_j = torch.zeros(x.size(0), e_ij.size(1), device=x.device, dtype=x.dtype)
        m_j.index_add_(0, i, e_ij)

        # Node update
        h = self.node_mlp(torch.cat([x, m_j], dim=1))
        h = self.norm_out(h)
        x = h  # residual path not needed since norms stabilize

        # Optional coordinate update (kept in fp32 domain)
        if self.update_coords:
            w_ij = self.coord_mlp(e_ij).to(rel.dtype)  # [E,1] fp32
            delta = rel * w_ij                         # [E,3] fp32
            coord_update = torch.zeros_like(pos)
            coord_update.index_add_(0, i, delta)
            pos = pos + coord_update / (len(j) + 1e-6)

        return x, pos


Cell 18 — EGNN Encoder (stack + pooling)

In [82]:
# Cell 18 — EGNNEncoder (V2)
class AttnPool(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.score = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.SiLU(),
            nn.Linear(hidden_dim//2, 1),
        )
    def forward(self, x):
        a = torch.softmax(self.score(x).squeeze(-1), dim=0)  # per-atom weights
        return (a.unsqueeze(-1) * x).sum(dim=0)

class EGNNEncoder(nn.Module):
    def __init__(self, node_dim, edge_dim, hidden_dim, n_layers, embed_dim,
                 update_coords=False, pooling="attention", l2_norm=True, rbf_dim=16, p_drop=0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            EGNNLayer(node_dim if l==0 else hidden_dim,
                      edge_dim,
                      hidden_dim,
                      rbf_dim=rbf_dim,
                      update_coords=update_coords,
                      p_drop=p_drop)
            for l in range(n_layers)
        ])
        self.pooling = pooling
        self.l2_norm = l2_norm
        self.attn = AttnPool(hidden_dim) if pooling == "attention" else None
        self.proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, embed_dim),
        )

    def forward(self, data):
        # Keep params FP32; unifying compute dtype to param dtype
        compute_dtype = next(self.parameters()).dtype
        x   = data.x.to(compute_dtype)
        ea  = data.edge_attr.to(compute_dtype)
        pos = data.pos.float()     # fp32 geometry
        ei  = data.edge_index

        for layer in self.layers:
            x, pos = layer(x, pos, ei, ea)

        if self.pooling == "mean":
            mol_vec = x.mean(dim=0)
        elif self.pooling == "sum":
            mol_vec = x.sum(dim=0)
        else:
            mol_vec = self.attn(x)

        mol_vec = self.proj(mol_vec)
        if self.l2_norm:
            mol_vec = F.normalize(mol_vec, p=2, dim=-1)
        return mol_vec


Cell 19 — Initialize encoder

In [None]:
# Cell 19 — Init from V2 cache probe
import random, pickle

_probe = random.choice(list(GRAPH_CACHE_DIR.glob("*.pkl")))
with open(_probe, "rb") as f:
    g_probe = pickle.load(f)

NODE_DIM = int(g_probe.x.shape[1])
EDGE_DIM = int(g_probe.edge_attr.shape[1]) if g_probe.edge_attr.numel() else 0
HIDDEN_DIM = int(cfg_get("egnn_encoder.hidden_dim", 384))
EMBED_DIM  = int(cfg_get("egnn_encoder.embed_dim", 256))
N_LAYERS   = int(cfg_get("egnn_encoder.n_layers", 8))
UPDATE_COORDS = bool(cfg_get("egnn_encoder.coord_updates", False))
POOLING    = cfg_get("egnn_encoder.pooling", "attention")
L2NORM     = bool(cfg_get("egnn_encoder.l2_normalize", True))

print(f"[V2] Inferred dims => NODE_DIM={NODE_DIM}, EDGE_DIM={EDGE_DIM}")

model = EGNNEncoder(
    node_dim=NODE_DIM,
    edge_dim=EDGE_DIM,
    hidden_dim=HIDDEN_DIM,
    n_layers=N_LAYERS,
    embed_dim=EMBED_DIM,
    update_coords=UPDATE_COORDS,
    pooling=POOLING,
    l2_norm=L2NORM,
    rbf_dim=16,
    p_drop=0.1,
).to(DEVICE)

print("Model dtype (should be fp32):", next(model.parameters()).dtype)


[V2] Inferred dims => NODE_DIM=46, EDGE_DIM=11
Model dtype (should be fp32): torch.float32


Cell 20 — Encoding loop

In [86]:
# Cell 20 — Encode all V2 graphs into V2 parquet (FP32 for best accuracy)
records = []
model.eval()

with torch.no_grad():
    for pkl_path in tqdm(list(GRAPH_CACHE_DIR.glob("*.pkl"))):
        with open(pkl_path, "rb") as f:
            g = pickle.load(f)
        g = g.to(DEVICE)

        emb = model(g).detach().cpu().numpy().astype("float32")

        drug_id = pkl_path.stem
        row = df_drugs.loc[df_drugs["drug_chembl_id"] == drug_id].iloc[0]
        records.append({
            "drug_chembl_id": drug_id,
            "rxcui": row["rxcui"],
            "embedding": emb.tolist(),
        })

df_embed = pd.DataFrame.from_records(records)
df_embed.to_parquet(OUT_PARQUET, engine=cfg_get("io.parquet_engine", "pyarrow"))
print(f"✅ [V2] Saved {len(df_embed)} embeddings → {OUT_PARQUET.name}")
display(df_embed.head(3))


  0%|          | 0/1028 [00:00<?, ?it/s]

✅ [V2] Saved 1028 embeddings → EGNN_drug_embeddings_v2.parquet


Unnamed: 0,drug_chembl_id,rxcui,embedding
0,CHEMBL1000,20610,"[0.02189382165670395, 0.016782937571406364, -0..."
1,CHEMBL1002,237159,"[0.02701766975224018, 0.033309683203697205, -0..."
2,CHEMBL1004,3642,"[0.025201082229614258, 0.02474130503833294, -0..."


Cell 21 — Save to Parquet

In [87]:
df_embed = pd.DataFrame.from_records(records)
df_embed.to_parquet(OUT_PARQUET, engine=cfg_get("io.parquet_engine", "pyarrow"))
print(f"✅ Saved {len(df_embed)} embeddings to {OUT_PARQUET.name}")
display(df_embed.head(3))

✅ Saved 1028 embeddings to EGNN_drug_embeddings_v2.parquet


Unnamed: 0,drug_chembl_id,rxcui,embedding
0,CHEMBL1000,20610,"[0.02189382165670395, 0.016782937571406364, -0..."
1,CHEMBL1002,237159,"[0.02701766975224018, 0.033309683203697205, -0..."
2,CHEMBL1004,3642,"[0.025201082229614258, 0.02474130503833294, -0..."


Cell 22 — Quick test: load and shape check

In [89]:
df_check = pd.read_parquet(OUT_PARQUET)
print("Embedding shape check:")
print(df_check["embedding"].iloc[0][:10], " ...")
print("All equal length:", len(set(len(x) for x in df_check.embedding)) == 1)

Embedding shape check:
[ 0.02189382  0.01678294 -0.03726786 -0.04854688 -0.15610014  0.02558895
  0.0249966   0.0026272  -0.01306693  0.04946777]  ...
All equal length: True


Invariance quick test

In [129]:
from torch.nn.functional import cosine_similarity
with open(random.choice(list(GRAPH_CACHE_DIR.glob("*.pkl"))), "rb") as f:
    g = pickle.load(f)
rot = g.clone()
rot.pos = torch.matmul(g.pos, torch.randn(3,3))  # random rotation
e1 = model(g.to(DEVICE)).detach().cpu()
e2 = model(rot.to(DEVICE)).detach().cpu()
print("cosine sim:", float(cosine_similarity(e1, e2, dim=0)))

cosine sim: 0.9936487674713135


Cell 22B — Quick invariance & health checks (safe)

In [114]:
# Cell 22B — Sanity checks (V2)
from torch.nn.functional import cosine_similarity
df_check = pd.read_parquet(OUT_PARQUET)
print("V2 rows:", len(df_check), "| unique lengths:", set(len(v) for v in df_check.embedding))
print("Any NaN:", any(np.isnan(v).any() for v in df_check.embedding))

# rotation invariance on one random graph
with open(random.choice(list(GRAPH_CACHE_DIR.glob("*.pkl"))), "rb") as f:
    g = pickle.load(f)
rot = g.clone()
R = torch.linalg.qr(torch.randn(3,3)).Q  # random orthogonal
rot.pos = (g.pos @ R)
with torch.no_grad():
    e1 = model(g.to(DEVICE)).detach().cpu()
    e2 = model(rot.to(DEVICE)).detach().cpu()
print("cosine(rot-invariance):", float(cosine_similarity(e1, e2, dim=0)))


V2 rows: 1028 | unique lengths: {256}
Any NaN: False
cosine(rot-invariance): 1.0
