Cell 1 — Title & env echo

In [264]:
# EGNN Drug Encoder — Part 1: Setup, Config, Data Scan
# Notebook: EGNN_Encode_Drugs.ipynb (Part 1/3)

import sys, os, platform, json, math, random, shutil
from pathlib import Path

import numpy as np
import pandas as pd

import torch

print("Environment")
print("--------------------------------------------------------------------")
print(f"Python      : {platform.python_version()}")
print(f"OS          : {platform.system()} {platform.release()} ({platform.machine()})")
print(f"PyTorch     : {torch.__version__}")
print(f"Device      : {'CUDA:'+str(torch.cuda.current_device()) if torch.cuda.is_available() else 'CPU'}")
if torch.cuda.is_available():
    print(f"GPU         : {torch.cuda.get_device_name(0)}")
    print(f"VRAM total  : {round(torch.cuda.get_device_properties(0).total_memory/1024**3,1)} GB")
print(f"TF32        : {torch.backends.cuda.matmul.allow_tf32}")
print(f"AMP dtype   : float16 (per config)")
print(f"cudnn.bench : {torch.backends.cudnn.benchmark}")
print(f"torch.compile available: {hasattr(torch, 'compile')}")


Environment
--------------------------------------------------------------------
Python      : 3.12.11
OS          : Windows 11 (AMD64)
PyTorch     : 2.7.0+cu128
Device      : CUDA:0
GPU         : NVIDIA GeForce RTX 5070 Ti
VRAM total  : 15.9 GB
TF32        : True
AMP dtype   : float16 (per config)
cudnn.bench : True
torch.compile available: True


Cell 2 — Config loader

In [265]:
import yaml

# Paths relative to this notebook (inside /preporcessing_code)
NB_DIR = Path(__file__).parent if '__file__' in globals() else Path.cwd()
CFG_PATH = NB_DIR / "egnn_config.yaml"

assert CFG_PATH.exists(), f"Config file not found: {CFG_PATH}"

with open(CFG_PATH, "r", encoding="utf-8") as f:
    CFG = yaml.safe_load(f)

def cfg_get(path, default=None):
    """Helper to fetch nested config keys with 'dot.path' style."""
    cur = CFG
    for key in path.split('.'):
        if key not in cur:
            return default
        cur = cur[key]
    return cur

# Echo key config bits
print("Loaded config from:", CFG_PATH.as_posix())
print(json.dumps({
    "data_root": cfg_get("paths.data_root"),
    "main_parquet": cfg_get("paths.main_parquet"),
    "out_embeddings_parquet": cfg_get("paths.out_embeddings_parquet"),
    "graph_cache_dir": cfg_get("paths.graph_cache_dir"),
    "device": cfg_get("hardware.device"),
    "embed_dim": cfg_get("egnn_encoder.embed_dim"),
    "n_layers": cfg_get("egnn_encoder.n_layers"),
    "pooling": cfg_get("egnn_encoder.pooling"),
}, indent=2))


Loaded config from: f:/Thesis Korbi na/dti-prediction-with-adr/preporcessing_code/egnn_config.yaml
{
  "data_root": "../Data",
  "main_parquet": "../Data/scope_onside_common_v3.parquet",
  "out_embeddings_parquet": "../Data/EGNN_drug_embeddings.parquet",
  "graph_cache_dir": "../Data/graph_cache_egnn_v1",
  "device": "cuda:0",
  "embed_dim": 256,
  "n_layers": 6,
  "pooling": "mean"
}


Cell 3 — Repro, CUDA preferences (TF32/AMP)

In [266]:
# Repro
seed = int(cfg_get("project.seed", 1337))
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# CUDA preferences from config / your environment
torch.backends.cuda.matmul.allow_tf32 = bool(cfg_get("hardware.tf32", True))
torch.backends.cudnn.benchmark = True  # per your env
DEVICE = torch.device(cfg_get("hardware.device", "cuda:0") if torch.cuda.is_available() else "cpu")
AMP_DTYPE = torch.float16  # per your env
print("Device set to:", DEVICE)


Device set to: cuda:0


Cell 4 — Load main dataset (Parquet) & validate schema

In [267]:
MAIN_PARQUET = (NB_DIR / cfg_get("paths.main_parquet")).resolve()
assert MAIN_PARQUET.exists(), f"Main parquet not found at {MAIN_PARQUET}"

df = pd.read_parquet(MAIN_PARQUET)
required_cols = ["drug_chembl_id", "target_uniprot_id", "label", "smiles", "sequence", "molfile_3d", "rxcui"]
missing = [c for c in required_cols if c not in df.columns]
assert not missing, f"Missing required columns: {missing}"

# No missing data per your constraint, but we'll assert anyway
for c in required_cols:
    assert df[c].isna().sum() == 0, f"Column {c} has NA values"

print("Loaded main dataset:", MAIN_PARQUET.name)
print(df.info())

if bool(cfg_get("logging.echo_dataset_head", True)):
    display(df.head(5))

Loaded main dataset: scope_onside_common_v3.parquet
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 34741 entries, 0 to 34740
Data columns (total 7 columns):
 #   Column             Non-Null Count  Dtype 
---  ------             --------------  ----- 
 0   drug_chembl_id     34741 non-null  object
 1   target_uniprot_id  34741 non-null  object
 2   label              34741 non-null  int64 
 3   smiles             34741 non-null  object
 4   sequence           34741 non-null  object
 5   molfile_3d         34741 non-null  object
 6   rxcui              34741 non-null  object
dtypes: int64(1), object(6)
memory usage: 1.9+ MB
None


Unnamed: 0,drug_chembl_id,target_uniprot_id,label,smiles,sequence,molfile_3d,rxcui
0,CHEMBL1000,O15245,0,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MPTVDDILEQVGESGWFQKQAFLILCLLSAAFAPICVGIVFLGFTP...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610
1,CHEMBL1000,P08183,1,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MDLEGDRNGGAKKKNFFKLNNKSEKDKKEKKPTVSVFSMFRYSNWL...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610
2,CHEMBL1000,P35367,1,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MSLPNSSCLLEDKMCEGNKTTMASPQLMPLVVVLSTICLVTVGLNL...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610
3,CHEMBL1000,Q02763,0,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MDSLASLVLCGVSLLLSGTVEGAMDLILINSLPLVSDAETSLTCIA...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610
4,CHEMBL1000,Q12809,0,O=C(O)COCCN1CCN(C(c2ccccc2)c2ccc(Cl)cc2)CC1,MPVRRGHVAPQNTFLDTIIRKFEGQSRKFIIANARVENCAVIYCND...,\n RDKit 3D\n\n 52 54 0 0 0 0...,20610


Cell 5 — Derive unique drugs & basic stats

In [268]:
# We will embed *drugs* once; later we’ll join by drug_chembl_id/rxcui during DTI/ADR training
drug_cols = ["drug_chembl_id", "smiles", "molfile_3d", "rxcui"]
df_drugs = df[drug_cols].drop_duplicates(subset=["drug_chembl_id"]).reset_index(drop=True)

n_pairs = len(df)
n_drugs = len(df_drugs)
n_targets = df["target_uniprot_id"].nunique()

print(f"DTI pairs : {n_pairs:,}")
print(f"Unique drugs : {n_drugs:,}")
print(f"Unique proteins : {n_targets:,}")

display(df_drugs.sample(min(5, n_drugs), random_state=seed))

DTI pairs : 34,741
Unique drugs : 1,028
Unique proteins : 2,385


Unnamed: 0,drug_chembl_id,smiles,molfile_3d,rxcui
126,CHEMBL1200472,Fc1ccccc1C1=NCC(=S)N(CC(F)(F)F)c2ccc(Cl)cc21,\n RDKit 3D\n\n 36 38 0 0 0 0...,35185
737,CHEMBL3545363,CC(C)(C)[C@@H]1NC(=O)O[C@@H]2CCC[C@H]2OC/C=C/C...,\n RDKit 3D\n\n104110 0 0 0 0...,1940635
277,CHEMBL128,CNS(=O)(=O)Cc1ccc2[nH]cc(CCN(C)C)c2c1,\n RDKit 3D\n\n 41 42 0 0 0 0...,37418
758,CHEMBL372795,CN[C@@H]1[C@H](O[C@H]2[C@H](O[C@H]3[C@H](O)[C@...,\n RDKit 3D\n\n 79 81 0 0 0 0...,10109
170,CHEMBL1201199,CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(=N)N)NC(=...,\n RDKit 3D\n\n171176 0 0 0 0...,42375


Cell 6 — Prepare output and cache dirs

In [269]:
GRAPH_CACHE_DIR = (NB_DIR / cfg_get("paths.graph_cache_dir")).resolve()
OUT_PARQUET = (NB_DIR / cfg_get("paths.out_embeddings_parquet")).resolve()

GRAPH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
OUT_PARQUET.parent.mkdir(parents=True, exist_ok=True)

print("Graph cache dir :", GRAPH_CACHE_DIR.as_posix())
print("Output parquet  :", OUT_PARQUET.as_posix())

# If you want a clean slate (commented by default)
# if OUT_PARQUET.exists():
#     OUT_PARQUET.unlink()
#     print("Removed existing output parquet to start fresh.")

Graph cache dir : F:/Thesis Korbi na/dti-prediction-with-adr/Data/graph_cache_egnn_v1
Output parquet  : F:/Thesis Korbi na/dti-prediction-with-adr/Data/EGNN_drug_embeddings.parquet


Cell 7 — Define output schema helper (embedding placeholder)

In [270]:
EMBED_DIM = int(cfg_get("egnn_encoder.embed_dim", 256))

def make_empty_output_df():
    return pd.DataFrame({
        "drug_chembl_id": pd.Series(dtype="string"),
        "rxcui": pd.Series(dtype="string"),
        "embedding": pd.Series(dtype="object"),   # list<float32> per row
    })

# Create file with schema if it doesn't exist (append later in batches)
if not OUT_PARQUET.exists():
    empty = make_empty_output_df()
    empty.to_parquet(OUT_PARQUET, engine=cfg_get("io.parquet_engine", "pyarrow"))
    print("Initialized empty embeddings parquet.")
else:
    print("Embeddings parquet already exists; will append batches.")

Embeddings parquet already exists; will append batches.


Cell 9 — Imports for chemistry + graph building

In [271]:
# Part 2 – Molecular Graph Builder
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem

import torch
from torch_geometric.data import Data

import pickle
from tqdm.auto import tqdm

Cell 10 — Atom & bond feature helpers

In [272]:
# Atom featurization (matches config entries)
def atom_features(atom: Chem.Atom):
    feats = [
        atom.GetAtomicNum(),
        atom.GetFormalCharge(),
        int(atom.GetIsAromatic()),
        atom.GetHybridization().real,
        int(atom.IsInRing()),
        atom.GetDegree(),
    ]
    return feats

# Bond featurization
def bond_features(bond: Chem.Bond):
    feats = [
        bond.GetBondTypeAsDouble(),   # numeric bond type
        int(bond.GetIsConjugated()),
        int(bond.IsInRing()),
    ]
    return feats

Cell 11 — Parse MOL and build graph object

In [None]:
def mol_to_graph(mol_block: str, max_atoms: int = 1000) -> Data:
    """Convert MOL block → torch_geometric.Data"""
    mol = Chem.MolFromMolBlock(mol_block, removeHs=False, sanitize=True)
    assert mol is not None, "Failed to parse MOL block"

    conf = mol.GetConformer()
    N = mol.GetNumAtoms()
    assert N <= max_atoms, f"Molecule too large: {N}>{max_atoms}"

    # --- nodes
    x = torch.tensor([atom_features(a) for a in mol.GetAtoms()], dtype=torch.float32)

    # --- edges
    src, dst, e = [], [], []
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        f = bond_features(b)
        src += [i, j]
        dst += [j, i]
        e += [f, f]
    edge_index = torch.tensor([src, dst], dtype=torch.long)
    edge_attr = torch.tensor(e, dtype=torch.float32)

    # --- coordinates
    pos = torch.tensor([list(conf.GetAtomPosition(i)) for i in range(N)], dtype=torch.float32)

    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr, pos=pos)

Cell 12 — Graph cache writer

In [274]:
from pathlib import Path

def cache_graph(drug_id: str, mol_block: str, cache_dir: Path, max_atoms: int):
    """Build & save graph for a drug if not already cached"""
    cache_path = cache_dir / f"{drug_id}.pkl"
    if cache_path.exists():
        return cache_path

    data = mol_to_graph(mol_block, max_atoms=max_atoms)
    with open(cache_path, "wb") as f:
        pickle.dump(data, f)
    return cache_path

Cell 13 — Batch graph creation loop

In [275]:
MAX_ATOMS = int(cfg_get("egnn_encoder.max_atoms", 120))
cache_dir = GRAPH_CACHE_DIR

paths = []
print(f"Building and caching graphs to {cache_dir.as_posix()} ...")

for _, row in tqdm(df_drugs.iterrows(), total=len(df_drugs)):
    drug_id = row["drug_chembl_id"]
    mol_block = row["molfile_3d"]
    path = cache_graph(drug_id, mol_block, cache_dir, MAX_ATOMS)
    paths.append(path)

print(f"Cached {len(paths)} molecule graphs.")

Building and caching graphs to F:/Thesis Korbi na/dti-prediction-with-adr/Data/graph_cache_egnn_v1 ...


100%|██████████| 1028/1028 [00:00<00:00, 15804.47it/s]

Cached 1028 molecule graphs.





Cell 14 — Quick integrity check

In [276]:
import random, pickle
sample = random.choice(paths)
with open(sample, "rb") as f:
    g = pickle.load(f)

print("Random cached graph:", sample.name)
print(f"Nodes: {g.x.shape[0]}, Edges: {g.edge_index.shape[1]}, Edge features: {g.edge_attr.shape[1]}")
print(f"Coords shape: {g.pos.shape}")

Random cached graph: CHEMBL3646221.pkl
Nodes: 32, Edges: 66, Edge features: 3
Coords shape: torch.Size([32, 3])


In [277]:
print("""
✅  Graph cache complete.

Next:
 • Define the EGNN layer stack (equivariant message passing with coord updates)
 • Encode all cached graphs in mini-batches
 • Pool node embeddings → per-drug vectors
 • Save to Data/EGNN_drug_embeddings.parquet
""")



✅  Graph cache complete.

Next:
 • Define the EGNN layer stack (equivariant message passing with coord updates)
 • Encode all cached graphs in mini-batches
 • Pool node embeddings → per-drug vectors
 • Save to Data/EGNN_drug_embeddings.parquet



Cell 16 — Imports and utilities

In [278]:
# Part 3 – EGNN Encoder & Embedding Writer
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch_geometric.loader import DataLoader

import pickle
from tqdm.auto import tqdm

Cell 17 — EGNN layer definition

In [279]:
class EGNNLayer(nn.Module):
    def __init__(self, in_dim, edge_dim, hidden_dim, update_coords=True):
        super().__init__()
        self.edge_mlp = nn.Sequential(
            nn.Linear(in_dim * 2 + edge_dim + 1, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.node_mlp = nn.Sequential(
            nn.Linear(in_dim + hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
        )
        self.update_coords = update_coords
        if update_coords:
            self.coord_mlp = nn.Sequential(
                nn.Linear(hidden_dim, 1),
                nn.Tanh(),
            )

    def forward(self, x, pos, edge_index, edge_attr):
        i, j = edge_index
        pos = pos.float()
        rel = pos[i] - pos[j]
        dist = (rel ** 2).sum(dim=1, keepdim=True).sqrt()

        # If there are no edges, skip edge MLP and use zero messages
        if edge_index.numel() == 0:
            # m_j needs hidden_dim columns for node_mlp concat
            hidden_dim = self.node_mlp[0].in_features - x.size(1)
            m_j = torch.zeros(x.size(0), hidden_dim, device=x.device, dtype=x.dtype)
            x = self.node_mlp(torch.cat([x, m_j], dim=1))
            # no coordinate update possible without edges
            return x, pos

        # Normal path (E > 0)
        dist = dist.to(x.dtype)
        edge_attr = edge_attr.to(x.dtype)

        e_input = torch.cat([x[i], x[j], edge_attr, dist], dim=1)

        # Assert expected width matches the layer’s first Linear
        expected_in = x.size(1)*2 + edge_attr.size(1) + 1
        assert self.edge_mlp[0].in_features == expected_in, \
            f"edge_mlp in_features={self.edge_mlp[0].in_features} but got e_input.size(1)={expected_in} " \
            f"(node_feat={x.size(1)}, edge_feat={edge_attr.size(1)})"

        e_ij = self.edge_mlp(e_input)
        e_ij = e_ij.to(x.dtype)

        m_j = torch.zeros(x.size(0), e_ij.size(1), device=x.device, dtype=x.dtype)
        m_j.index_add_(0, i, e_ij)

        x = self.node_mlp(torch.cat([x, m_j], dim=1))

        if self.update_coords:
            w_ij = self.coord_mlp(e_ij).to(rel.dtype)
            delta = rel * w_ij
            coord_update = torch.zeros_like(pos)
            coord_update.index_add_(0, i, delta)
            pos = pos + coord_update / (len(j) + 1e-6)

        return x, pos

Cell 18 — EGNN Encoder (stack + pooling)

In [280]:
class EGNNEncoder(nn.Module):
    def __init__(self, node_dim, edge_dim, hidden_dim, n_layers, embed_dim,
                 update_coords=True, pooling="mean", l2_norm=True):
        super().__init__()
        self.layers = nn.ModuleList([
            EGNNLayer(node_dim if l == 0 else hidden_dim,
                      edge_dim,
                      hidden_dim,
                      update_coords=update_coords)
            for l in range(n_layers)
        ])
        self.proj = nn.Linear(hidden_dim, embed_dim)
        self.pooling = pooling
        self.l2_norm = l2_norm

    def forward(self, data):
        # Parameters dictate compute dtype (fp16 under autocast)
        compute_dtype = next(self.parameters()).dtype 
        # print(f"EGNN forward with compute dtype: {compute_dtype}") -> EGNN forward with compute dtype: torch.float32

        # Node/edge feats to compute dtype; coords stay fp32
        x   = data.x.to(compute_dtype)
        ea  = data.edge_attr.to(compute_dtype)
        pos = data.pos.float()
        ei  = data.edge_index

        for layer in self.layers:
            x, pos = layer(x, pos, ei, ea)

        # Single-graph pooling (we encode one molecule per forward)
        if self.pooling == "mean":
            mol_vec = x.mean(dim=0)
        elif self.pooling == "sum":
            mol_vec = x.sum(dim=0)
        else:
            mol_vec = x.mean(dim=0)

        mol_vec = self.proj(mol_vec)
        if self.l2_norm:
            mol_vec = F.normalize(mol_vec, p=2, dim=-1)
        return mol_vec

Cell 19 — Initialize encoder

In [None]:
import pickle, random
_sample = random.choice(list(GRAPH_CACHE_DIR.glob("*.pkl")))
with open(_sample, "rb") as f:
    g_probe = pickle.load(f)

NODE_DIM = int(g_probe.x.shape[1])          # infer from cache
EDGE_DIM = int(g_probe.edge_attr.shape[1])  # infer from cache
HIDDEN_DIM = int(cfg_get("egnn_encoder.hidden_dim", 256))
EMBED_DIM  = int(cfg_get("egnn_encoder.embed_dim", 256))
N_LAYERS   = int(cfg_get("egnn_encoder.n_layers", 6))
UPDATE_COORDS = bool(cfg_get("egnn_encoder.coord_updates", True))
POOLING    = cfg_get("egnn_encoder.pooling", "mean")
L2NORM     = bool(cfg_get("egnn_encoder.l2_normalize", True))

model = EGNNEncoder(
    node_dim=NODE_DIM,
    edge_dim=EDGE_DIM,
    hidden_dim=HIDDEN_DIM,
    n_layers=N_LAYERS,
    embed_dim=EMBED_DIM,
    update_coords=UPDATE_COORDS,
    pooling=POOLING,
    l2_norm=L2NORM,
).to(DEVICE)

print("Model dtype:", next(model.parameters()).dtype)  # should be torch.float32

# print(model)


Model dtype: torch.float32


Cell 20 — Encoding loop

In [282]:
records = []
model.eval()

with torch.no_grad(), torch.cuda.amp.autocast(dtype=AMP_DTYPE):
    for pkl_path in tqdm(list(GRAPH_CACHE_DIR.glob("*.pkl"))):
        with open(pkl_path, "rb") as f:
            g = pickle.load(f)
        g = g.to(DEVICE)

        emb = model(g).detach().cpu().numpy().astype("float32")
        drug_id = pkl_path.stem
        row = df_drugs.loc[df_drugs["drug_chembl_id"] == drug_id].iloc[0]
        records.append({
            "drug_chembl_id": drug_id,
            "rxcui": row["rxcui"],
            "embedding": emb.tolist(),
        })

  with torch.no_grad(), torch.cuda.amp.autocast(dtype=AMP_DTYPE):
100%|██████████| 1028/1028 [00:04<00:00, 234.50it/s]


Cell 21 — Save to Parquet

In [283]:
df_embed = pd.DataFrame.from_records(records)
df_embed.to_parquet(OUT_PARQUET, engine=cfg_get("io.parquet_engine", "pyarrow"))
print(f"✅ Saved {len(df_embed)} embeddings to {OUT_PARQUET.name}")
display(df_embed.head(3))

✅ Saved 1028 embeddings to EGNN_drug_embeddings.parquet


Unnamed: 0,drug_chembl_id,rxcui,embedding
0,CHEMBL1000,20610,"[-0.0482025183737278, 0.06364382058382034, -0...."
1,CHEMBL1002,237159,"[-0.04840248450636864, 0.06527001410722733, -0..."
2,CHEMBL1004,3642,"[-0.048166919499635696, 0.06434477120637894, -..."


Cell 22 — Quick test: load and shape check

In [284]:
df_check = pd.read_parquet(OUT_PARQUET)
print("Embedding shape check:")
print(df_check["embedding"].iloc[0][:10], " ...")
print("All equal length:", len(set(len(x) for x in df_check.embedding)) == 1)

Embedding shape check:
[-0.04820252  0.06364382 -0.0387178   0.11986481  0.01223391  0.00276065
  0.01033238 -0.03599152 -0.01045839  0.02357433]  ...
All equal length: True


Invariance quick test

In [285]:
from torch.nn.functional import cosine_similarity
with open(random.choice(list(GRAPH_CACHE_DIR.glob("*.pkl"))), "rb") as f:
    g = pickle.load(f)
rot = g.clone()
rot.pos = torch.matmul(g.pos, torch.randn(3,3))  # random rotation
e1 = model(g.to(DEVICE)).detach().cpu()
e2 = model(rot.to(DEVICE)).detach().cpu()
print("cosine sim:", float(cosine_similarity(e1, e2, dim=0)))

cosine sim: 0.9999590516090393
