In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/polypharmacy-dataset/Side_effects_unique.csv
/kaggle/input/polypharmacy-dataset/neg.csv
/kaggle/input/polypharmacy-dataset/Drugbank_ID_SMILE_all_structure links.csv
/kaggle/input/polypharmacy-dataset/pos.csv
/kaggle/input/polypharmacy-dataset/DrugBankID2SMILES.csv


In [3]:
!pip install rdkit torch_geometric

Collecting rdkit
  Downloading rdkit-2025.9.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Downloading rdkit-2025.9.1-cp311-cp311-manylinux_2_28_x86_64.whl (36.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.2/36.2 MB[0m [31m54.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m56.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric, rdkit
Successfully installed rdkit-2025.9.1 torch_geometric-2.6.1


In [4]:
# %% Cell 1 — Imports & device
import os, ast, math, time, json
from pathlib import Path
from itertools import combinations
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# RDKit for fingerprints
from rdkit import Chem
from rdkit.Chem import AllChem

# sklearn utils
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import StandardScaler, RobustScaler, normalize
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, roc_curve, f1_score, confusion_matrix

# torch + pyg
import torch
import torch.nn.functional as F
from torch import nn, optim

# try imports for torch_geometric
try:
    from torch_geometric.data import Data
    from torch_geometric.nn import SAGEConv, global_mean_pool
except Exception as e:
    raise ImportError("torch_geometric not available. Install it before running. Error: " + str(e))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [5]:
# %% Cell 2 — Paths to files (Kaggle dataset)
DATA_DIR = Path('/kaggle/input/polypharmacy-dataset')
SIDE_EFFECTS_CSV = DATA_DIR / 'Side_effects_unique.csv'
DRUG_SMILES_CSV = DATA_DIR / 'DrugBankID2SMILES.csv'
POS_CSV = DATA_DIR / 'pos.csv'
NEG_CSV = DATA_DIR / 'neg.csv'

for p in [SIDE_EFFECTS_CSV, DRUG_SMILES_CSV, POS_CSV, NEG_CSV]:
    print(p, "exists?", p.exists())


/kaggle/input/polypharmacy-dataset/Side_effects_unique.csv exists? True
/kaggle/input/polypharmacy-dataset/DrugBankID2SMILES.csv exists? True
/kaggle/input/polypharmacy-dataset/pos.csv exists? True
/kaggle/input/polypharmacy-dataset/neg.csv exists? True


In [6]:
# %% Cell 3 — Load side-effect embeddings and normalize
se_df = pd.read_csv(SIDE_EFFECTS_CSV, low_memory=False)
# first two columns are UMLS and name; rest numeric
se_ids = se_df.iloc[:, 0].astype(str).tolist()
se_names = se_df.iloc[:, 1].astype(str).tolist()
se_vectors = se_df.iloc[:, 2:].apply(pd.to_numeric, errors='coerce').fillna(0).values
print("SE matrix shape:", se_vectors.shape)

# Standardize (z-score)
se_scaler = StandardScaler()
se_vectors_scaled = se_scaler.fit_transform(se_vectors)
# store mapping
SE_MAP = {uid: vec for uid, vec in zip(se_ids, se_vectors_scaled)}
SE_NAME_MAP = {uid: name for uid, name in zip(se_ids, se_names)}


SE matrix shape: (7350, 768)


In [7]:
# 🧹 Add this at the very top of your notebook to silence RDKit warnings
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')       # hides all RDKit info/warning/deprecation logs
import warnings
warnings.filterwarnings("ignore")    # hides Python-level warnings too


In [None]:
# %% Cell 4 replacement — fingerprints + RDKit molecular descriptors + cache
from rdkit.Chem import Descriptors
from rdkit import Chem, DataStructs, RDLogger
RDLogger.DisableLog('rdApp.*')

DRUG_SMILES_CSV = Path('/kaggle/input/polypharmacy-dataset/DrugBankID2SMILES.csv')
drug_smiles_df = pd.read_csv(DRUG_SMILES_CSV)
drug_smiles_df['drugbank_id'] = drug_smiles_df['drugbank_id'].astype(str)
drug_smiles_df['smiles'] = drug_smiles_df['smiles'].fillna('').astype(str)

N_BITS = 1024

def compute_descriptors(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return np.zeros(6, dtype=np.float32)
        return np.array([
            Descriptors.MolWt(mol),
            Descriptors.TPSA(mol),
            Descriptors.MolLogP(mol),
            Descriptors.NumHDonors(mol),
            Descriptors.NumHAcceptors(mol),
            Descriptors.NumRotatableBonds(mol)
        ], dtype=np.float32)
    except Exception:
        return np.zeros(6, dtype=np.float32)

def smiles_to_ecfp_bits_one(smiles, nBits=N_BITS):
    if not isinstance(smiles, str) or smiles.strip()=="":
        return np.zeros(nBits, dtype=np.uint8), True
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return np.zeros(nBits, dtype=np.uint8), True
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=nBits)
        arr = np.zeros((nBits,), dtype=np.uint8)
        DataStructs.ConvertToNumpyArray(fp, arr)
        return arr, False
    except Exception:
        return np.zeros(nBits, dtype=np.uint8), True

# compute/cached
CACHE_FP = Path('/kaggle/working/drug_fp_desc.npz')
if CACHE_FP.exists():
    cache = np.load(CACHE_FP, allow_pickle=True)
    drug_ids = cache['drug_ids'].tolist()
    fps = cache['fps']
    descs = cache['descs']
    miss_flags = cache['miss_flags'].tolist()
else:
    drug_ids = drug_smiles_df['drugbank_id'].tolist()
    smiles_list = drug_smiles_df['smiles'].tolist()
    from joblib import Parallel, delayed
    results = Parallel(n_jobs=min(12, (os.cpu_count() or 1)), backend='loky')(
        delayed(lambda s: (smiles_to_ecfp_bits_one(s), compute_descriptors(s)))(s) for s in smiles_list
    )
    fps = np.stack([r[0][0] for r in results], axis=0)
    miss_flags = [r[0][1] for r in results]
    descs = np.stack([r[1] for r in results], axis=0)
    np.savez_compressed(CACHE_FP, drug_ids=drug_ids, fps=fps, descs=descs, miss_flags=miss_flags)

drug_fp_map = {db: fps[i] for i, db in enumerate(drug_ids)}
drug_desc_map = {db: descs[i] for i, db in enumerate(drug_ids)}
missing_smiles_flag = {db: bool(miss_flags[i]) for i, db in enumerate(drug_ids)}

print("Loaded/cached fingerprints + descriptors for", len(drug_ids))


In [9]:
# %% Cell 5 — Load pos/neg and expand hyperedges to pairwise edges (clique expansion)
pos_df = pd.read_csv(POS_CSV)
neg_df = pd.read_csv(NEG_CSV)

def parse_drug_list(s):
    # safe parse of "['DB0001','DB0002']"
    try:
        return [x.strip().strip("'\" ") for x in ast.literal_eval(s)]
    except Exception:
        # fallback heuristics
        s2 = s.strip().strip('[]')
        return [x.strip().strip("'\" ") for x in s2.split(',') if x.strip()]

def hyperedges_to_pairs(df, label):
    rows = []
    for _, r in df.iterrows():
        drugs = parse_drug_list(r['DrugBankID'])
        time_col = r.get('time', None)
        report_id = r.get('report_id', None)
        se_cui = r.get('SE_above_0.9', None)
        # create all unordered pairs (i<j)
        for a,b in combinations(sorted(set(drugs)), 2):
            rows.append({'drug_a': a, 'drug_b': b, 'label': label, 'time': time_col, 'report_id': report_id, 'se_cui': se_cui})
    return pd.DataFrame(rows)

pos_pairs = hyperedges_to_pairs(pos_df, 1)
neg_pairs = hyperedges_to_pairs(neg_df, 0)
edges_df = pd.concat([pos_pairs, neg_pairs], ignore_index=True).drop_duplicates(subset=['drug_a','drug_b','report_id','time'])
print("Edges shape:", edges_df.shape)
print(edges_df.head())


Edges shape: (2548128, 6)
    drug_a   drug_b  label    time report_id    se_cui
0  DB00273  DB00472      1  2015Q4  11809573  C0151878
1  DB00273  DB00555      1  2015Q4  11809573  C0151878
2  DB00273  DB00557      1  2015Q4  11809573  C0151878
3  DB00273  DB00564      1  2015Q4  11809573  C0151878
4  DB00273  DB01050      1  2015Q4  11809573  C0151878


In [10]:
# %% Cell 6 replacement — build node features and compress
unique_drugs = sorted(set(edges_df['drug_a']).union(set(edges_df['drug_b'])))
print("Unique drugs:", len(unique_drugs))

# build fingerprint + descriptor matrices aligned to unique_drugs
fps_mat = np.stack([drug_fp_map.get(d, np.zeros(N_BITS, dtype=np.uint8)).astype(np.float32) for d in unique_drugs])
desc_mat = np.stack([drug_desc_map.get(d, np.zeros(6, dtype=np.float32)) for d in unique_drugs])

# robust scaling for fps then SVD compress + concat desc (normed)
from sklearn.preprocessing import RobustScaler, StandardScaler
robust = RobustScaler()
fps_scaled = robust.fit_transform(fps_mat)  # (N, N_BITS)

# SVD on fingerprint to reduce to 256 dims
SVD_DIM = 256
from sklearn.decomposition import TruncatedSVD
svd = TruncatedSVD(n_components=SVD_DIM, random_state=42)
fps_svd = svd.fit_transform(fps_scaled)  # (N, 256)

# standardize descriptors and concat
desc_scaler = StandardScaler()
desc_scaled = desc_scaler.fit_transform(desc_mat)  # (N,6)
X_node = np.hstack([fps_svd, desc_scaled])  # final node features (N, 256+6 = 262)
print("Node features shape:", X_node.shape)

# overwrite X_fp_svd used downstream
X_fp_svd = X_node.astype(np.float32)
drug_to_idx = {d: i for i, d in enumerate(unique_drugs)}


Unique drugs: 12298
Node features shape: (12298, 262)


In [11]:
# %% Cell 7 replacement — scale node features after SVD (important)
from sklearn.preprocessing import StandardScaler

# X_fp_svd already computed
scaler_node = StandardScaler()
X_fp_svd = scaler_node.fit_transform(X_fp_svd)   # overwrite with scaled version
print("Node feature shape after SVD & scaling:", X_fp_svd.shape)

drug_to_idx = {d:i for i,d in enumerate(unique_drugs)}

# Save scaler for later use
import joblib
joblib.dump(scaler_node, "/kaggle/working/node_scaler.joblib")


Node feature shape after SVD & scaling: (12298, 262)


['/kaggle/working/node_scaler.joblib']

In [12]:
# %% Cell 8 replacement — rebuild edges and do time-aware / group split (robust)
import ast
from itertools import combinations
from sklearn.model_selection import GroupShuffleSplit, train_test_split
import numpy as np

# --- helper to safely parse drug lists ---
def parse_drug_list(s):
    try:
        return [x.strip().strip("'\" ") for x in ast.literal_eval(s)]
    except Exception:
        s2 = str(s).strip().strip('[]')
        if s2 == '' or s2.lower() == 'nan':
            return []
        return [x.strip().strip("'\" ") for x in s2.split(',') if x.strip()]

# --- (re)create edges_df if needed ---
if 'edges_df' not in globals():
    # read pos/neg if not loaded yet
    if 'pos_df' not in globals():
        pos_df = pd.read_csv(POS_CSV, low_memory=False)
    if 'neg_df' not in globals():
        neg_df = pd.read_csv(NEG_CSV, low_memory=False)

    def hyperedges_to_pairs(df, label):
        rows = []
        for _, r in df.iterrows():
            drugs = parse_drug_list(r.get('DrugBankID', r.get('drug_list', '[]')))
            if len(drugs) < 2:
                continue
            time_col = r.get('time', None)
            report_id = r.get('report_id', None)
            se_cui = r.get('SE_above_0.9', r.get('se_cui', None))
            for a,b in combinations(sorted(set(drugs)), 2):
                rows.append({'drug_a': a, 'drug_b': b, 'label': int(label),
                             'time': time_col, 'report_id': report_id, 'se_cui': se_cui})
        return pd.DataFrame(rows)

    pos_pairs = hyperedges_to_pairs(pos_df, 1)
    neg_pairs = hyperedges_to_pairs(neg_df, 0)
    edges_df = pd.concat([pos_pairs, neg_pairs], ignore_index=True).drop_duplicates(subset=['drug_a','drug_b','report_id','time'])
    print("Rebuilt edges_df shape:", edges_df.shape)
else:
    print("Using existing edges_df with shape:", edges_df.shape)

# --- create edge_labels from edges_df if missing ---
edge_labels = edges_df[['drug_a','drug_b','label','time','report_id','se_cui']].copy()

# --- ensure drug_to_idx and node indexing exist; if not create from present drugs ---
if 'drug_to_idx' not in globals():
    unique_drugs = sorted(set(edge_labels['drug_a']).union(set(edge_labels['drug_b'])))
    drug_to_idx = {d:i for i,d in enumerate(unique_drugs)}
    print("Created drug_to_idx for", len(unique_drugs), "drugs")

# map to integer node ids 'u' and 'v'
edge_labels['u'] = edge_labels['drug_a'].map(drug_to_idx).astype('Int64')
edge_labels['v'] = edge_labels['drug_b'].map(drug_to_idx).astype('Int64')

# drop any rows with missing mapping (if some drug not found)
before = len(edge_labels)
edge_labels = edge_labels.dropna(subset=['u','v']).reset_index(drop=True)
after = len(edge_labels)
if after < before:
    print(f"Dropped {before-after} edges due to missing drug->idx mapping")

# --- Now do time-aware split if time exists, else group split by report_id ---
def time_to_ordinal(t):
    try:
        ts = str(t)
        if 'Q' in ts:
            y,q = ts.split('Q')
            return int(y)*10 + int(q)
        return int(float(ts))
    except Exception:
        return 0

if 'time' in edge_labels.columns and edge_labels['time'].notnull().any():
    edge_labels['time_ord'] = edge_labels['time'].apply(time_to_ordinal)
    times_sorted = sorted(edge_labels['time_ord'].unique())
    cutoff_idx = max(1, int(len(times_sorted)*0.7))
    cutoff_time = times_sorted[min(cutoff_idx, len(times_sorted)-1)]
    train_mask = edge_labels['time_ord'] <= cutoff_time
    rest = ~train_mask
    rest_idx = edge_labels[rest].index
    # stratify only if both classes present in 'rest'
    stratify_param = edge_labels.loc[rest, 'label'] if edge_labels.loc[rest, 'label'].nunique() > 1 else None
    val_idx, test_idx = train_test_split(rest_idx, test_size=0.5, random_state=42, stratify=stratify_param)
    val_mask = edge_labels.index.isin(val_idx)
    test_mask = edge_labels.index.isin(test_idx)
else:
    gss = GroupShuffleSplit(n_splits=1, train_size=0.7, random_state=42)
    groups = edge_labels['report_id'].fillna('_no_report_')
    train_idx, other_idx = next(gss.split(edge_labels, edge_labels['label'], groups=groups))
    # stratify only if both classes present in other_idx
    strat = edge_labels.loc[other_idx, 'label'] if edge_labels.loc[other_idx, 'label'].nunique() > 1 else None
    val_idx, test_idx = train_test_split(other_idx, test_size=0.5, random_state=42, stratify=strat)
    train_mask = edge_labels.index.isin(train_idx)
    val_mask = edge_labels.index.isin(val_idx)
    test_mask = edge_labels.index.isin(test_idx)

edge_labels['split'] = np.where(train_mask, 'train', np.where(val_mask, 'val', 'test'))
print(edge_labels['split'].value_counts())
# expose train/val/test dfs
train_df = edge_labels[edge_labels['split']=='train'].reset_index(drop=True)
val_df = edge_labels[edge_labels['split']=='val'].reset_index(drop=True)
test_df = edge_labels[edge_labels['split']=='test'].reset_index(drop=True)
print("train/val/test sizes:", len(train_df), len(val_df), len(test_df))


Using existing edges_df with shape: (2548128, 6)
split
train    1954270
val       296929
test      296929
Name: count, dtype: int64
train/val/test sizes: 1954270 296929 296929


In [13]:
# %% Cell 9 replacement — build edge_index_t from edge_labels and create PyG Data
import torch
from torch_geometric.utils import remove_self_loops

# ensure edge_labels with 'u' and 'v' exists
assert 'edge_labels' in globals(), "edge_labels missing — run the edge-building cell first."
assert 'X_fp_svd' in globals(), "X_fp_svd missing — run SVD / node feature cell first."

# build undirected adjacency from edge_labels (use all observed edges for message passing)
u_arr = edge_labels['u'].astype(int).to_numpy()
v_arr = edge_labels['v'].astype(int).to_numpy()

# stack both directions
edge_index_np = np.vstack([np.concatenate([u_arr, v_arr]), np.concatenate([v_arr, u_arr])]).astype(np.int64)
# remove self-loops if any
edge_index_tensor = torch.tensor(edge_index_np, dtype=torch.long)
edge_index_tensor, _ = remove_self_loops(edge_index_tensor)

# create Data object
data = Data(x=torch.tensor(X_fp_svd, dtype=torch.float), edge_index=edge_index_tensor).to(device)
edge_index_t = edge_index_tensor  # keep variable name used elsewhere
print("Data created:", data)
print("Num nodes:", data.num_nodes, "Num edges (directed):", data.num_edges)


Data created: Data(x=[12298, 262], edge_index=[2, 5096256])
Num nodes: 12298 Num edges (directed): 5096256


In [19]:
# %% Cell 10 replacement — stronger encoder + classifier (3-layer, larger)
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class GraphSAGEEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels=512, n_layers=3, dropout=0.2):
        super().__init__()
        self.convs = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        self.bns.append(nn.BatchNorm1d(hidden_channels))
        for _ in range(n_layers-1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
            self.bns.append(nn.BatchNorm1d(hidden_channels))
        self.act = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
    def forward(self, x, edge_index):
        for conv, bn in zip(self.convs, self.bns):
            x = conv(x, edge_index)
            x = bn(x)
            x = self.act(x)
            x = self.dropout(x)
        return x

class EdgeClassifier(nn.Module):
    def __init__(self, node_emb_dim, hidden=256, dropout=0.2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(node_emb_dim*2, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, hidden//2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden//2, 1)
        )
    def forward(self, ha, hb):
        h = torch.cat([ha, hb], dim=1)
        return self.mlp(h).squeeze(1)

in_dim = X_fp_svd.shape[1]
encoder = encoder.to(device)
edge_clf = edge_clf.to(device)
data = data.to(device)

encoder = GraphSAGEEncoder(in_dim, hidden_channels=512, n_layers=3, dropout=0.2).to(device)
edge_clf = EdgeClassifier(node_emb_dim=512, hidden=256, dropout=0.2).to(device)

# optimizer + scheduler
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(edge_clf.parameters()), lr=5e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)


In [15]:
# %% Cell 11 replacement — balanced batches with oversampling (more positive presence)
train_df = edge_labels[edge_labels['split']=='train'].reset_index(drop=True)
val_df = edge_labels[edge_labels['split']=='val'].reset_index(drop=True)
test_df = edge_labels[edge_labels['split']=='test'].reset_index(drop=True)

train_df['u'] = train_df['u'].astype(int); train_df['v'] = train_df['v'].astype(int)
val_df['u'] = val_df['u'].astype(int); val_df['v'] = val_df['v'].astype(int)
test_df['u'] = test_df['u'].astype(int); test_df['v'] = test_df['v'].astype(int)

train_pos_idx = train_df[train_df['label']==1].index.to_numpy()
train_neg_idx = train_df[train_df['label']==0].index.to_numpy()

# oversample positives to reach a target positive ratio per batch (e.g., 40%)
def balanced_edge_batches(df, pos_idx, neg_idx, batch_size=4096, pos_ratio=0.4):
    pos_per = int(batch_size * pos_ratio)
    neg_per = batch_size - pos_per
    # create repeated pools
    while True:
        p = np.random.choice(pos_idx, size=max(len(pos_idx), pos_per*2), replace=True)
        n = np.random.choice(neg_idx, size=max(len(neg_idx), neg_per*2), replace=True)
        # iterate in chunks
        for i in range(0, max(len(p), len(n)), max(pos_per, neg_per)):
            pos_chunk = p[i:i+pos_per]
            neg_chunk = n[i:i+neg_per]
            chosen = np.concatenate([pos_chunk, neg_chunk])
            np.random.shuffle(chosen)
            sub = df.loc[chosen]
            u = torch.tensor(sub['u'].values, dtype=torch.long, device=device)
            v = torch.tensor(sub['v'].values, dtype=torch.long, device=device)
            y = torch.tensor(sub['label'].astype(np.float32).values, dtype=torch.float, device=device)
            yield u, v, y


In [38]:
# %% Cell 12 — mini-batch GNN training without torch-sparse/pyg-lib
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
from tqdm import tqdm

EPOCHS = 50
BATCH_SIZE = 2048   # reduce if OOM
best_val_auc = -float('inf')
patience, pat = 6, 0

for ep in range(1, EPOCHS + 1):
    encoder.train(); edge_clf.train()
    torch.cuda.empty_cache()
    epoch_losses = []
    pbar = tqdm(range(0, len(train_df), BATCH_SIZE), desc=f"Epoch {ep}")

    for start in pbar:
        end = min(start + BATCH_SIZE, len(train_df))
        batch = train_df.iloc[start:end]

        # subgraph of drugs in this mini-batch
        nodes = np.unique(np.concatenate([batch['u'].values, batch['v'].values]))
        idx_map = {n: i for i, n in enumerate(nodes)}
        sub_x = data.x[nodes]
        mask = np.isin(edge_index_t[0].cpu(), nodes) & np.isin(edge_index_t[1].cpu(), nodes)
        sub_edge = edge_index_t[:, mask]
        sub_edge = torch.stack([torch.tensor([idx_map[i.item()] for i in sub_edge[0]]),
                                torch.tensor([idx_map[i.item()] for i in sub_edge[1]])]).to(device)

        sub_x = sub_x.to(device)
        node_emb = encoder(sub_x, sub_edge)

        u = torch.tensor([idx_map[i] for i in batch['u'].values], device=device)
        v = torch.tensor([idx_map[i] for i in batch['v'].values], device=device)
        y = torch.tensor(batch['label'].values, dtype=torch.float, device=device)

        optimizer.zero_grad()
        logits = edge_clf(node_emb[u], node_emb[v])
        loss = F.binary_cross_entropy_with_logits(logits, y)
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())

    # ----- validation -----
    encoder.eval(); edge_clf.eval()
    with torch.no_grad():
        val_probs, val_true = [], []
        for start in range(0, len(val_df), BATCH_SIZE):
            end = min(start + BATCH_SIZE, len(val_df))
            b = val_df.iloc[start:end]
            u = torch.tensor(b['u'].values, device=device)
            v = torch.tensor(b['v'].values, device=device)
            logits = edge_clf(encoder(data.x.to(device), data.edge_index)[u],
                              encoder(data.x.to(device), data.edge_index)[v])
            probs = torch.sigmoid(logits).cpu().numpy()
            val_probs.extend(probs); val_true.extend(b['label'].values)

        val_auc = roc_auc_score(val_true, val_probs)
        val_ap = average_precision_score(val_true, val_probs)
        val_f1 = f1_score(val_true, (np.array(val_probs) >= 0.5).astype(int))
        print(f"Epoch {ep:02d} | loss {np.mean(epoch_losses):.4f} | val_auc {val_auc:.4f} | val_ap {val_ap:.4f} | val_f1 {val_f1:.4f}")

    # early stopping
    if val_auc > best_val_auc + 1e-4:
        best_val_auc, pat = val_auc, 0
        torch.save({'encoder': encoder.state_dict(), 'edge_clf': edge_clf.state_dict()},
                   "/kaggle/working/gnn_best_simple.pth")
    else:
        pat += 1
        if pat >= patience:
            print("Early stopping")
            break

ckp = torch.load("/kaggle/working/gnn_best_simple.pth", map_location=device)
encoder.load_state_dict(ckp['encoder']); edge_clf.load_state_dict(ckp['edge_clf'])
print("Loaded best model, best_val_auc=", best_val_auc)


Epoch 1: 100%|██████████| 4/4 [00:10<00:00,  2.67s/it]


Epoch 01 | loss 0.6711 | val_auc 0.4968 | val_ap 0.5031 | val_f1 0.4970


Epoch 2: 100%|██████████| 4/4 [00:10<00:00,  2.69s/it]


Epoch 02 | loss 0.6667 | val_auc 0.4983 | val_ap 0.5039 | val_f1 0.4992


Epoch 3: 100%|██████████| 4/4 [00:10<00:00,  2.74s/it]


Epoch 03 | loss 0.6617 | val_auc 0.4992 | val_ap 0.5047 | val_f1 0.5103


Epoch 4: 100%|██████████| 4/4 [00:10<00:00,  2.73s/it]


Epoch 04 | loss 0.6560 | val_auc 0.4993 | val_ap 0.5049 | val_f1 0.5152


Epoch 5: 100%|██████████| 4/4 [00:11<00:00,  2.77s/it]


Epoch 05 | loss 0.6494 | val_auc 0.4991 | val_ap 0.5046 | val_f1 0.5153


Epoch 6: 100%|██████████| 4/4 [00:10<00:00,  2.66s/it]


Epoch 06 | loss 0.6418 | val_auc 0.4988 | val_ap 0.5042 | val_f1 0.5152


Epoch 7: 100%|██████████| 4/4 [00:11<00:00,  2.80s/it]


Epoch 07 | loss 0.6337 | val_auc 0.4990 | val_ap 0.5044 | val_f1 0.5191


Epoch 8: 100%|██████████| 4/4 [00:10<00:00,  2.68s/it]


Epoch 08 | loss 0.6257 | val_auc 0.4999 | val_ap 0.5040 | val_f1 0.5068


Epoch 9: 100%|██████████| 4/4 [00:10<00:00,  2.72s/it]


Epoch 09 | loss 0.6195 | val_auc 0.5019 | val_ap 0.5026 | val_f1 0.4097


Epoch 10: 100%|██████████| 4/4 [00:10<00:00,  2.68s/it]


Epoch 10 | loss 0.6232 | val_auc 0.5009 | val_ap 0.5006 | val_f1 0.4548


Epoch 11: 100%|██████████| 4/4 [00:10<00:00,  2.73s/it]


Epoch 11 | loss 0.6177 | val_auc 0.4991 | val_ap 0.4994 | val_f1 0.5664


Epoch 12: 100%|██████████| 4/4 [00:10<00:00,  2.70s/it]


Epoch 12 | loss 0.6042 | val_auc 0.5024 | val_ap 0.5010 | val_f1 0.4619


Epoch 13: 100%|██████████| 4/4 [00:10<00:00,  2.70s/it]


Epoch 13 | loss 0.5978 | val_auc 0.5013 | val_ap 0.5010 | val_f1 0.5496


Epoch 14: 100%|██████████| 4/4 [00:10<00:00,  2.72s/it]


Epoch 14 | loss 0.5889 | val_auc 0.5030 | val_ap 0.5014 | val_f1 0.4825


Epoch 15: 100%|██████████| 4/4 [00:10<00:00,  2.74s/it]


Epoch 15 | loss 0.5816 | val_auc 0.5017 | val_ap 0.5005 | val_f1 0.5412


Epoch 16: 100%|██████████| 4/4 [00:11<00:00,  2.76s/it]


Epoch 16 | loss 0.5745 | val_auc 0.5031 | val_ap 0.5000 | val_f1 0.4719


Epoch 17: 100%|██████████| 4/4 [00:10<00:00,  2.64s/it]


Epoch 17 | loss 0.5670 | val_auc 0.5024 | val_ap 0.4987 | val_f1 0.5277


Epoch 18: 100%|██████████| 4/4 [00:11<00:00,  2.79s/it]


Epoch 18 | loss 0.5593 | val_auc 0.5035 | val_ap 0.4980 | val_f1 0.5020


Epoch 19: 100%|██████████| 4/4 [00:10<00:00,  2.66s/it]


Epoch 19 | loss 0.5519 | val_auc 0.5042 | val_ap 0.4967 | val_f1 0.4926


Epoch 20: 100%|██████████| 4/4 [00:10<00:00,  2.71s/it]


Epoch 20 | loss 0.5451 | val_auc 0.5040 | val_ap 0.4954 | val_f1 0.5109


Epoch 21: 100%|██████████| 4/4 [00:10<00:00,  2.69s/it]


Epoch 21 | loss 0.5373 | val_auc 0.5038 | val_ap 0.4952 | val_f1 0.5282


Epoch 22: 100%|██████████| 4/4 [00:10<00:00,  2.73s/it]


Epoch 22 | loss 0.5305 | val_auc 0.5037 | val_ap 0.4949 | val_f1 0.5287


Epoch 23: 100%|██████████| 4/4 [00:10<00:00,  2.71s/it]


Epoch 23 | loss 0.5238 | val_auc 0.5040 | val_ap 0.4941 | val_f1 0.4970


Epoch 24: 100%|██████████| 4/4 [00:10<00:00,  2.70s/it]


Epoch 24 | loss 0.5188 | val_auc 0.5054 | val_ap 0.4935 | val_f1 0.4520


Epoch 25: 100%|██████████| 4/4 [00:10<00:00,  2.71s/it]


Epoch 25 | loss 0.5208 | val_auc 0.5053 | val_ap 0.4926 | val_f1 0.4630


Epoch 26: 100%|██████████| 4/4 [00:10<00:00,  2.71s/it]


Epoch 26 | loss 0.5237 | val_auc 0.5024 | val_ap 0.4909 | val_f1 0.5695


Epoch 27: 100%|██████████| 4/4 [00:11<00:00,  2.77s/it]


Epoch 27 | loss 0.5266 | val_auc 0.5042 | val_ap 0.4930 | val_f1 0.5099


Epoch 28: 100%|██████████| 4/4 [00:10<00:00,  2.63s/it]


Epoch 28 | loss 0.4988 | val_auc 0.5046 | val_ap 0.4931 | val_f1 0.4967


Epoch 29: 100%|██████████| 4/4 [00:11<00:00,  2.80s/it]


Epoch 29 | loss 0.4906 | val_auc 0.5035 | val_ap 0.4919 | val_f1 0.5339


Epoch 30: 100%|██████████| 4/4 [00:10<00:00,  2.69s/it]

Epoch 30 | loss 0.4850 | val_auc 0.5049 | val_ap 0.4923 | val_f1 0.4668
Early stopping
Loaded best model, best_val_auc= 0.5054440670371392





In [None]:
!pip install torch-sparse -f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__)").html
!pip install torch-scatter -f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__)").html
!pip install torch-geometric


In [37]:
# add right before for-epoch loop
import os
os.environ["OMP_NUM_THREADS"] = "1"
torch.set_num_threads(1)


In [None]:
# Improved training cell for higher accuracy
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score
from tqdm import tqdm

EPOCHS = 80
LR = 1e-3
optimizer = torch.optim.AdamW(list(encoder.parameters()) + list(edge_clf.parameters()), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
best_auc, patience, counter = 0, 8, 0

for ep in range(1, EPOCHS + 1):
    encoder.train(); edge_clf.train()
    total_loss = 0
    for u_batch, v_batch, y_batch in balanced_edge_batches(train_df, train_pos_idx, train_neg_idx, batch_size=4096, pos_ratio=0.5):
        optimizer.zero_grad()
        ha, hb = encoder(data.x.to(device), data.edge_index)[u_batch], encoder(data.x.to(device), data.edge_index)[v_batch]
        logits = edge_clf(ha, hb)
        loss = F.binary_cross_entropy_with_logits(logits, y_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    scheduler.step()
    
    # --- Validation ---
    encoder.eval(); edge_clf.eval()
    with torch.no_grad():
        emb = encoder(data.x.to(device), data.edge_index)
        vu, vv = torch.tensor(val_df['u'].values, device=device), torch.tensor(val_df['v'].values, device=device)
        vy = torch.tensor(val_df['label'].values, dtype=torch.float, device=device)
        val_probs = torch.sigmoid(edge_clf(emb[vu], emb[vv])).cpu().numpy()
        val_auc = roc_auc_score(vy.cpu().numpy(), val_probs)
        val_ap = average_precision_score(vy.cpu().numpy(), val_probs)
        val_f1 = f1_score(vy.cpu().numpy(), (val_probs >= 0.5).astype(int))
    
    print(f"Epoch {ep:03d} | loss {total_loss:.4f} | val_auc {val_auc:.4f} | val_ap {val_ap:.4f} | val_f1 {val_f1:.4f}")

    if val_auc > best_auc:
        best_auc = val_auc
        torch.save({'encoder': encoder.state_dict(), 'edge_clf': edge_clf.state_dict()}, "/kaggle/working/gnn_best_strong.pth")
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping — no improvement.")
            break

print(f"✅ Training complete. Best validation AUC = {best_auc:.4f}")


In [None]:
# %% Cell — Final Test / Evaluation / Prediction
from sklearn.metrics import roc_auc_score, average_precision_score, f1_score, classification_report
import numpy as np, torch

# 🔹 Load best trained model
ckp = torch.load("/kaggle/working/gnn_best_simple.pth", map_location=device)
encoder.load_state_dict(ckp["encoder"])
edge_clf.load_state_dict(ckp["edge_clf"])
encoder.eval(); edge_clf.eval()

# 🔹 Compute node embeddings once
with torch.no_grad():
    node_emb = encoder(data.x.to(device), data.edge_index).cpu()

# 🔹 Predict probabilities for test set
def predict_gnn(df):
    probs = []
    for _, r in df.iterrows():
        ua, vb = int(r["u"]), int(r["v"])
        ha, hb = node_emb[ua], node_emb[vb]
        p = torch.sigmoid(edge_clf(ha.to(device), hb.to(device))).cpu().item()
        probs.append(p)
    return np.array(probs)

print("Running predictions on test data...")
test_probs = predict_gnn(test_df)
test_labels = test_df["label"].values
test_preds = (test_probs >= 0.5).astype(int)

# 🔹 Evaluate
test_auc = roc_auc_score(test_labels, test_probs)
test_ap = average_precision_score(test_labels, test_probs)
test_f1 = f1_score(test_labels, test_preds)

print(f"✅ Test AUC: {test_auc:.4f} | AP: {test_ap:.4f} | F1: {test_f1:.4f}")
print(classification_report(test_labels, test_preds))

# 🔹 Example prediction for custom input
sample_input = {"report_id": "runtime-12345", "drugs": ["DB00006", "DB00341", "DB01118"]}
pairs = [(sample_input["drugs"][i], sample_input["drugs"][j])
         for i in range(len(sample_input["drugs"]))
         for j in range(i+1, len(sample_input["drugs"]))]

print("\nPredicted ADR severity (0=Low,1=High):")
for a,b in pairs:
    if a in drug_to_idx and b in drug_to_idx:
        pa = drug_to_idx[a]; pb = drug_to_idx[b]
        p = torch.sigmoid(edge_clf(node_emb[pa].to(device), node_emb[pb].to(device))).cpu().item()
        print(f"{a} + {b} → severity score: {p:.3f}")
    else:
        print(f"{a} + {b} → unknown drug ID")
