In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/polypharmacy-dataset/Side_effects_unique.csv
/kaggle/input/polypharmacy-dataset/neg.csv
/kaggle/input/polypharmacy-dataset/Drugbank_ID_SMILE_all_structure links.csv
/kaggle/input/polypharmacy-dataset/pos.csv
/kaggle/input/polypharmacy-dataset/DrugBankID2SMILES.csv


In [2]:
!pip install rdkit torch_geometric

Collecting rdkit
  Downloading rdkit-2025.9.1-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.1 kB)
Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
Downloading rdkit-2025.9.1-cp311-cp311-manylinux_2_28_x86_64.whl (36.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m36.2/36.2 MB[0m [31m56.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[?25hDownloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m57.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric, rdkit
Successfully installed rdkit-2025.9.1 torch_geometric-2.6.1


In [3]:
# %% Cell 1 — Imports & device
import os, ast, math, time, json
from pathlib import Path
from itertools import combinations
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# RDKit for fingerprints
from rdkit import Chem
from rdkit.Chem import AllChem

# sklearn utils
from sklearn.decomposition import TruncatedSVD
from sklearn.preprocessing import StandardScaler, RobustScaler, normalize
from sklearn.model_selection import GroupShuffleSplit, train_test_split
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve, roc_curve, f1_score, confusion_matrix

# torch + pyg
import torch
import torch.nn.functional as F
from torch import nn, optim

# try imports for torch_geometric
try:
    from torch_geometric.data import Data
    from torch_geometric.nn import SAGEConv, global_mean_pool
except Exception as e:
    raise ImportError("torch_geometric not available. Install it before running. Error: " + str(e))

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [4]:
# %% Cell 2 — Paths to files (Kaggle dataset)
DATA_DIR = Path('/kaggle/input/polypharmacy-dataset')
SIDE_EFFECTS_CSV = DATA_DIR / 'Side_effects_unique.csv'
DRUG_SMILES_CSV = DATA_DIR / 'DrugBankID2SMILES.csv'
POS_CSV = DATA_DIR / 'pos.csv'
NEG_CSV = DATA_DIR / 'neg.csv'

for p in [SIDE_EFFECTS_CSV, DRUG_SMILES_CSV, POS_CSV, NEG_CSV]:
    print(p, "exists?", p.exists())


/kaggle/input/polypharmacy-dataset/Side_effects_unique.csv exists? True
/kaggle/input/polypharmacy-dataset/DrugBankID2SMILES.csv exists? True
/kaggle/input/polypharmacy-dataset/pos.csv exists? True
/kaggle/input/polypharmacy-dataset/neg.csv exists? True


In [5]:
# %% Cell 3 — Load side-effect embeddings and normalize
se_df = pd.read_csv(SIDE_EFFECTS_CSV, low_memory=False)
# first two columns are UMLS and name; rest numeric
se_ids = se_df.iloc[:, 0].astype(str).tolist()
se_names = se_df.iloc[:, 1].astype(str).tolist()
se_vectors = se_df.iloc[:, 2:].apply(pd.to_numeric, errors='coerce').fillna(0).values
print("SE matrix shape:", se_vectors.shape)

# Standardize (z-score)
se_scaler = StandardScaler()
se_vectors_scaled = se_scaler.fit_transform(se_vectors)
# store mapping
SE_MAP = {uid: vec for uid, vec in zip(se_ids, se_vectors_scaled)}
SE_NAME_MAP = {uid: name for uid, name in zip(se_ids, se_names)}


SE matrix shape: (7350, 768)


In [6]:
# 🧹 Add this at the very top of your notebook to silence RDKit warnings
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')       # hides all RDKit info/warning/deprecation logs
import warnings
warnings.filterwarnings("ignore")    # hides Python-level warnings too


In [None]:
# %% Cell 4 replacement — fingerprints + RDKit molecular descriptors + cache
from rdkit.Chem import Descriptors
from rdkit import Chem, DataStructs, RDLogger
RDLogger.DisableLog('rdApp.*')

DRUG_SMILES_CSV = Path('/kaggle/input/polypharmacy-dataset/DrugBankID2SMILES.csv')
drug_smiles_df = pd.read_csv(DRUG_SMILES_CSV)
drug_smiles_df['drugbank_id'] = drug_smiles_df['drugbank_id'].astype(str)
drug_smiles_df['smiles'] = drug_smiles_df['smiles'].fillna('').astype(str)

N_BITS = 1024

def compute_descriptors(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return np.zeros(6, dtype=np.float32)
        return np.array([
            Descriptors.MolWt(mol),
            Descriptors.TPSA(mol),
            Descriptors.MolLogP(mol),
            Descriptors.NumHDonors(mol),
            Descriptors.NumHAcceptors(mol),
            Descriptors.NumRotatableBonds(mol)
        ], dtype=np.float32)
    except Exception:
        return np.zeros(6, dtype=np.float32)

def smiles_to_ecfp_bits_one(smiles, nBits=N_BITS):
    if not isinstance(smiles, str) or smiles.strip()=="":
        return np.zeros(nBits, dtype=np.uint8), True
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return np.zeros(nBits, dtype=np.uint8), True
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=nBits)
        arr = np.zeros((nBits,), dtype=np.uint8)
        DataStructs.ConvertToNumpyArray(fp, arr)
        return arr, False
    except Exception:
        return np.zeros(nBits, dtype=np.uint8), True

# compute/cached
CACHE_FP = Path('/kaggle/working/drug_fp_desc.npz')
if CACHE_FP.exists():
    cache = np.load(CACHE_FP, allow_pickle=True)
    drug_ids = cache['drug_ids'].tolist()
    fps = cache['fps']
    descs = cache['descs']
    miss_flags = cache['miss_flags'].tolist()
else:
    drug_ids = drug_smiles_df['drugbank_id'].tolist()
    smiles_list = drug_smiles_df['smiles'].tolist()
    from joblib import Parallel, delayed
    results = Parallel(n_jobs=min(12, (os.cpu_count() or 1)), backend='loky')(
        delayed(lambda s: (smiles_to_ecfp_bits_one(s), compute_descriptors(s)))(s) for s in smiles_list
    )
    fps = np.stack([r[0][0] for r in results], axis=0)
    miss_flags = [r[0][1] for r in results]
    descs = np.stack([r[1] for r in results], axis=0)
    np.savez_compressed(CACHE_FP, drug_ids=drug_ids, fps=fps, descs=descs, miss_flags=miss_flags)

drug_fp_map = {db: fps[i] for i, db in enumerate(drug_ids)}
drug_desc_map = {db: descs[i] for i, db in enumerate(drug_ids)}
missing_smiles_flag = {db: bool(miss_flags[i]) for i, db in enumerate(drug_ids)}

print("Loaded/cached fingerprints + descriptors for", len(drug_ids))


In [8]:
# %% Cell 5 — Load pos/neg and expand hyperedges to pairwise edges (clique expansion)
pos_df = pd.read_csv(POS_CSV)
neg_df = pd.read_csv(NEG_CSV)

def parse_drug_list(s):
    # safe parse of "['DB0001','DB0002']"
    try:
        return [x.strip().strip("'\" ") for x in ast.literal_eval(s)]
    except Exception:
        # fallback heuristics
        s2 = s.strip().strip('[]')
        return [x.strip().strip("'\" ") for x in s2.split(',') if x.strip()]

def hyperedges_to_pairs(df, label):
    rows = []
    for _, r in df.iterrows():
        drugs = parse_drug_list(r['DrugBankID'])
        time_col = r.get('time', None)
        report_id = r.get('report_id', None)
        se_cui = r.get('SE_above_0.9', None)
        # create all unordered pairs (i<j)
        for a,b in combinations(sorted(set(drugs)), 2):
            rows.append({'drug_a': a, 'drug_b': b, 'label': label, 'time': time_col, 'report_id': report_id, 'se_cui': se_cui})
    return pd.DataFrame(rows)

pos_pairs = hyperedges_to_pairs(pos_df, 1)
neg_pairs = hyperedges_to_pairs(neg_df, 0)
edges_df = pd.concat([pos_pairs, neg_pairs], ignore_index=True).drop_duplicates(subset=['drug_a','drug_b','report_id','time'])
print("Edges shape:", edges_df.shape)
print(edges_df.head())


Edges shape: (2548128, 6)
    drug_a   drug_b  label    time report_id    se_cui
0  DB00273  DB00472      1  2015Q4  11809573  C0151878
1  DB00273  DB00555      1  2015Q4  11809573  C0151878
2  DB00273  DB00557      1  2015Q4  11809573  C0151878
3  DB00273  DB00564      1  2015Q4  11809573  C0151878
4  DB00273  DB01050      1  2015Q4  11809573  C0151878


In [9]:
# %% Cell 6 replacement — build node features and compress
unique_drugs = sorted(set(edges_df['drug_a']).union(set(edges_df['drug_b'])))
print("Unique drugs:", len(unique_drugs))

# build fingerprint + descriptor matrices aligned to unique_drugs
fps_mat = np.stack([drug_fp_map.get(d, np.zeros(N_BITS, dtype=np.uint8)).astype(np.float32) for d in unique_drugs])
desc_mat = np.stack([drug_desc_map.get(d, np.zeros(6, dtype=np.float32)) for d in unique_drugs])

# robust scaling for fps then SVD compress + concat desc (normed)
from sklearn.preprocessing import RobustScaler, StandardScaler
robust = RobustScaler()
fps_scaled = robust.fit_transform(fps_mat)  # (N, N_BITS)

# SVD on fingerprint to reduce to 256 dims
SVD_DIM = 256
from sklearn.decomposition import TruncatedSVD
svd = TruncatedSVD(n_components=SVD_DIM, random_state=42)
fps_svd = svd.fit_transform(fps_scaled)  # (N, 256)

# standardize descriptors and concat
desc_scaler = StandardScaler()
desc_scaled = desc_scaler.fit_transform(desc_mat)  # (N,6)
X_node = np.hstack([fps_svd, desc_scaled])  # final node features (N, 256+6 = 262)
print("Node features shape:", X_node.shape)

# overwrite X_fp_svd used downstream
X_fp_svd = X_node.astype(np.float32)
drug_to_idx = {d: i for i, d in enumerate(unique_drugs)}


Unique drugs: 12298
Node features shape: (12298, 262)


In [10]:
# %% Cell 7 replacement — scale node features after SVD (important)
from sklearn.preprocessing import StandardScaler

# X_fp_svd already computed
scaler_node = StandardScaler()
X_fp_svd = scaler_node.fit_transform(X_fp_svd)   # overwrite with scaled version
print("Node feature shape after SVD & scaling:", X_fp_svd.shape)

drug_to_idx = {d:i for i,d in enumerate(unique_drugs)}

# Save scaler for later use
import joblib
joblib.dump(scaler_node, "/kaggle/working/node_scaler.joblib")


Node feature shape after SVD & scaling: (12298, 262)


['/kaggle/working/node_scaler.joblib']

In [11]:
# %% Cell 8 replacement — rebuild edges and do time-aware / group split (robust)
import ast
from itertools import combinations
from sklearn.model_selection import GroupShuffleSplit, train_test_split
import numpy as np

# --- helper to safely parse drug lists ---
def parse_drug_list(s):
    try:
        return [x.strip().strip("'\" ") for x in ast.literal_eval(s)]
    except Exception:
        s2 = str(s).strip().strip('[]')
        if s2 == '' or s2.lower() == 'nan':
            return []
        return [x.strip().strip("'\" ") for x in s2.split(',') if x.strip()]

# --- (re)create edges_df if needed ---
if 'edges_df' not in globals():
    # read pos/neg if not loaded yet
    if 'pos_df' not in globals():
        pos_df = pd.read_csv(POS_CSV, low_memory=False)
    if 'neg_df' not in globals():
        neg_df = pd.read_csv(NEG_CSV, low_memory=False)

    def hyperedges_to_pairs(df, label):
        rows = []
        for _, r in df.iterrows():
            drugs = parse_drug_list(r.get('DrugBankID', r.get('drug_list', '[]')))
            if len(drugs) < 2:
                continue
            time_col = r.get('time', None)
            report_id = r.get('report_id', None)
            se_cui = r.get('SE_above_0.9', r.get('se_cui', None))
            for a,b in combinations(sorted(set(drugs)), 2):
                rows.append({'drug_a': a, 'drug_b': b, 'label': int(label),
                             'time': time_col, 'report_id': report_id, 'se_cui': se_cui})
        return pd.DataFrame(rows)

    pos_pairs = hyperedges_to_pairs(pos_df, 1)
    neg_pairs = hyperedges_to_pairs(neg_df, 0)
    edges_df = pd.concat([pos_pairs, neg_pairs], ignore_index=True).drop_duplicates(subset=['drug_a','drug_b','report_id','time'])
    print("Rebuilt edges_df shape:", edges_df.shape)
else:
    print("Using existing edges_df with shape:", edges_df.shape)

# --- create edge_labels from edges_df if missing ---
edge_labels = edges_df[['drug_a','drug_b','label','time','report_id','se_cui']].copy()

# --- ensure drug_to_idx and node indexing exist; if not create from present drugs ---
if 'drug_to_idx' not in globals():
    unique_drugs = sorted(set(edge_labels['drug_a']).union(set(edge_labels['drug_b'])))
    drug_to_idx = {d:i for i,d in enumerate(unique_drugs)}
    print("Created drug_to_idx for", len(unique_drugs), "drugs")

# map to integer node ids 'u' and 'v'
edge_labels['u'] = edge_labels['drug_a'].map(drug_to_idx).astype('Int64')
edge_labels['v'] = edge_labels['drug_b'].map(drug_to_idx).astype('Int64')

# drop any rows with missing mapping (if some drug not found)
before = len(edge_labels)
edge_labels = edge_labels.dropna(subset=['u','v']).reset_index(drop=True)
after = len(edge_labels)
if after < before:
    print(f"Dropped {before-after} edges due to missing drug->idx mapping")

# --- Now do time-aware split if time exists, else group split by report_id ---
def time_to_ordinal(t):
    try:
        ts = str(t)
        if 'Q' in ts:
            y,q = ts.split('Q')
            return int(y)*10 + int(q)
        return int(float(ts))
    except Exception:
        return 0

if 'time' in edge_labels.columns and edge_labels['time'].notnull().any():
    edge_labels['time_ord'] = edge_labels['time'].apply(time_to_ordinal)
    times_sorted = sorted(edge_labels['time_ord'].unique())
    cutoff_idx = max(1, int(len(times_sorted)*0.7))
    cutoff_time = times_sorted[min(cutoff_idx, len(times_sorted)-1)]
    train_mask = edge_labels['time_ord'] <= cutoff_time
    rest = ~train_mask
    rest_idx = edge_labels[rest].index
    # stratify only if both classes present in 'rest'
    stratify_param = edge_labels.loc[rest, 'label'] if edge_labels.loc[rest, 'label'].nunique() > 1 else None
    val_idx, test_idx = train_test_split(rest_idx, test_size=0.5, random_state=42, stratify=stratify_param)
    val_mask = edge_labels.index.isin(val_idx)
    test_mask = edge_labels.index.isin(test_idx)
else:
    gss = GroupShuffleSplit(n_splits=1, train_size=0.7, random_state=42)
    groups = edge_labels['report_id'].fillna('_no_report_')
    train_idx, other_idx = next(gss.split(edge_labels, edge_labels['label'], groups=groups))
    # stratify only if both classes present in other_idx
    strat = edge_labels.loc[other_idx, 'label'] if edge_labels.loc[other_idx, 'label'].nunique() > 1 else None
    val_idx, test_idx = train_test_split(other_idx, test_size=0.5, random_state=42, stratify=strat)
    train_mask = edge_labels.index.isin(train_idx)
    val_mask = edge_labels.index.isin(val_idx)
    test_mask = edge_labels.index.isin(test_idx)

edge_labels['split'] = np.where(train_mask, 'train', np.where(val_mask, 'val', 'test'))
print(edge_labels['split'].value_counts())
# expose train/val/test dfs
train_df = edge_labels[edge_labels['split']=='train'].reset_index(drop=True)
val_df = edge_labels[edge_labels['split']=='val'].reset_index(drop=True)
test_df = edge_labels[edge_labels['split']=='test'].reset_index(drop=True)
print("train/val/test sizes:", len(train_df), len(val_df), len(test_df))


Using existing edges_df with shape: (2548128, 6)
split
train    1954270
val       296929
test      296929
Name: count, dtype: int64
train/val/test sizes: 1954270 296929 296929


In [20]:
# %% REPLACE Cell 9 — build edge_index_t from ONLY TRAINING POSITIVE EDGES
import torch
from torch_geometric.utils import remove_self_loops

# Build graph using ONLY POSITIVE TRAINING edges (no data leakage)
train_pos_edges = train_df[train_df['label']==1][['u','v']].values

# Stack both directions for undirected graph
edge_index_np = np.vstack([
    np.concatenate([train_pos_edges[:,0], train_pos_edges[:,1]]), 
    np.concatenate([train_pos_edges[:,1], train_pos_edges[:,0]])
]).astype(np.int64)

# Remove self-loops
edge_index_tensor = torch.tensor(edge_index_np, dtype=torch.long)
edge_index_tensor, _ = remove_self_loops(edge_index_tensor)

# Create Data object with correct graph
data = Data(x=torch.tensor(X_fp_svd, dtype=torch.float), edge_index=edge_index_tensor).to(device)
edge_index_t = edge_index_tensor

print("✅ CORRECTED: Graph built with only positive training edges")
print("Data created:", data)
print("Num nodes:", data.num_nodes, "Num edges (directed, positive only):", data.num_edges)

✅ CORRECTED: Graph built with only positive training edges
Data created: Data(x=[12298, 262], edge_index=[2, 1574422])
Num nodes: 12298 Num edges (directed, positive only): 1574422


In [21]:
# %% REPLACE Cell 10 — Improved faster encoder + classifier
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class ImprovedGraphSAGE(nn.Module):
    def __init__(self, in_channels, hidden_channels=256, n_layers=2, dropout=0.3):
        super().__init__()
        self.convs = nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(n_layers-1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.dropout = nn.Dropout(dropout)
        self.act = nn.ReLU()
    
    def forward(self, x, edge_index):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = self.act(x)
            x = self.dropout(x)
        return x

class ImprovedEdgeClassifier(nn.Module):
    def __init__(self, node_emb_dim, hidden=128, dropout=0.3):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(node_emb_dim * 2, hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden, 1)
        )
    
    def forward(self, ha, hb):
        h = torch.cat([ha, hb], dim=1)
        return self.mlp(h).squeeze(1)

# Initialize improved models
in_dim = X_fp_svd.shape[1]
encoder = ImprovedGraphSAGE(in_dim, hidden_channels=256, n_layers=2).to(device)
edge_clf = ImprovedEdgeClassifier(256).to(device)

print("✅ CORRECTED: Using improved faster architecture")

✅ CORRECTED: Using improved faster architecture


In [15]:
# %% Cell 11 replacement — balanced batches with oversampling (more positive presence)
train_df = edge_labels[edge_labels['split']=='train'].reset_index(drop=True)
val_df = edge_labels[edge_labels['split']=='val'].reset_index(drop=True)
test_df = edge_labels[edge_labels['split']=='test'].reset_index(drop=True)

train_df['u'] = train_df['u'].astype(int); train_df['v'] = train_df['v'].astype(int)
val_df['u'] = val_df['u'].astype(int); val_df['v'] = val_df['v'].astype(int)
test_df['u'] = test_df['u'].astype(int); test_df['v'] = test_df['v'].astype(int)

train_pos_idx = train_df[train_df['label']==1].index.to_numpy()
train_neg_idx = train_df[train_df['label']==0].index.to_numpy()

# oversample positives to reach a target positive ratio per batch (e.g., 40%)
def balanced_edge_batches(df, pos_idx, neg_idx, batch_size=4096, pos_ratio=0.4):
    pos_per = int(batch_size * pos_ratio)
    neg_per = batch_size - pos_per
    # create repeated pools
    while True:
        p = np.random.choice(pos_idx, size=max(len(pos_idx), pos_per*2), replace=True)
        n = np.random.choice(neg_idx, size=max(len(neg_idx), neg_per*2), replace=True)
        # iterate in chunks
        for i in range(0, max(len(p), len(n)), max(pos_per, neg_per)):
            pos_chunk = p[i:i+pos_per]
            neg_chunk = n[i:i+neg_per]
            chosen = np.concatenate([pos_chunk, neg_chunk])
            np.random.shuffle(chosen)
            sub = df.loc[chosen]
            u = torch.tensor(sub['u'].values, dtype=torch.long, device=device)
            v = torch.tensor(sub['v'].values, dtype=torch.long, device=device)
            y = torch.tensor(sub['label'].astype(np.float32).values, dtype=torch.float, device=device)
            yield u, v, y


In [22]:
# %% REPLACE Cell 12 — Optimized training (30 min target)
import torch
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, average_precision_score
import time
import numpy as np

# Configuration
EPOCHS = 25  # Reduced for speed
BATCH_SIZE = 8192
LR = 0.001

# Pre-compute indices for faster sampling
train_pos_idx = train_df[train_df['label']==1].index.values
train_neg_idx = train_df[train_df['label']==0].index.values

# Optimizer
optimizer = torch.optim.Adam(
    list(encoder.parameters()) + list(edge_clf.parameters()), 
    lr=LR, weight_decay=1e-5
)

# Class weighting for imbalance
pos_weight = torch.tensor([len(train_neg_idx) / len(train_pos_idx)], device=device)

# Mixed precision for speed
scaler = torch.cuda.amp.GradScaler()

best_val_auc = 0
patience_counter = 0
patience = 5

print("🚀 Starting optimized training (target: 30 minutes)...")
start_time = time.time()

for epoch in range(1, EPOCHS + 1):
    encoder.train()
    edge_clf.train()
    
    epoch_loss = 0
    num_batches = 0
    
    # Fast edge sampling (no subgraph building)
    for _ in range(0, len(train_df) // BATCH_SIZE):
        # Sample balanced batch
        pos_batch = train_df.loc[np.random.choice(train_pos_idx, BATCH_SIZE//2, replace=True)]
        neg_batch = train_df.loc[np.random.choice(train_neg_idx, BATCH_SIZE//2, replace=True)]
        batch = pd.concat([pos_batch, neg_batch]).sample(frac=1).reset_index(drop=True)
        
        u = torch.tensor(batch['u'].values, device=device)
        v = torch.tensor(batch['v'].values, device=device)
        y = torch.tensor(batch['label'].values, dtype=torch.float, device=device)
        
        optimizer.zero_grad()
        
        # Mixed precision forward
        with torch.cuda.amp.autocast():
            # Single forward pass on full graph (FAST)
            node_emb = encoder(data.x, data.edge_index)
            logits = edge_clf(node_emb[u], node_emb[v])
            loss = F.binary_cross_entropy_with_logits(logits, y, pos_weight=pos_weight)
        
        # Scaled backward
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        epoch_loss += loss.item()
        num_batches += 1
    
    # Validation (every epoch for monitoring)
    encoder.eval()
    edge_clf.eval()
    with torch.no_grad():
        # Use subset for faster validation
        val_sample = val_df.sample(min(50000, len(val_df)))
        u_val = torch.tensor(val_sample['u'].values, device=device)
        v_val = torch.tensor(val_sample['v'].values, device=device)
        y_val = torch.tensor(val_sample['label'].values, device=device)
        
        node_emb = encoder(data.x, data.edge_index)
        logits = edge_clf(node_emb[u_val], node_emb[v_val])
        probs = torch.sigmoid(logits).cpu().numpy()
        
        val_auc = roc_auc_score(y_val.cpu().numpy(), probs)
        val_ap = average_precision_score(y_val.cpu().numpy(), probs)
        
        elapsed = (time.time() - start_time) / 60
        print(f"Epoch {epoch:02d} | Loss: {epoch_loss/num_batches:.4f} | Val AUC: {val_auc:.4f} | AP: {val_ap:.4f} | Time: {elapsed:.1f}m")
        
        # Early stopping
        if val_auc > best_val_auc + 0.001:  # Small improvement threshold
            best_val_auc = val_auc
            patience_counter = 0
            torch.save({
                'encoder': encoder.state_dict(),
                'edge_clf': edge_clf.state_dict()
            }, "/kaggle/working/gnn_optimized.pth")
            print(f"💾 Saved new best model (AUC: {val_auc:.4f})")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"🛑 Early stopping at epoch {epoch}")
                break

# Load best model
checkpoint = torch.load("/kaggle/working/gnn_optimized.pth", map_location=device)
encoder.load_state_dict(checkpoint['encoder'])
edge_clf.load_state_dict(checkpoint['edge_clf'])

total_time = (time.time() - start_time) / 60
print(f"✅ Training completed in {total_time:.1f} minutes")
print(f"🏆 Best Validation AUC: {best_val_auc:.4f}")

🚀 Starting optimized training (target: 30 minutes)...
Epoch 01 | Loss: 0.6774 | Val AUC: 0.6455 | AP: 0.5071 | Time: 0.4m
💾 Saved new best model (AUC: 0.6455)
Epoch 02 | Loss: 0.6647 | Val AUC: 0.6506 | AP: 0.5115 | Time: 0.9m
💾 Saved new best model (AUC: 0.6506)
Epoch 03 | Loss: 0.6638 | Val AUC: 0.6500 | AP: 0.5113 | Time: 1.3m
Epoch 04 | Loss: 0.6630 | Val AUC: 0.6487 | AP: 0.5098 | Time: 1.7m
Epoch 05 | Loss: 0.6623 | Val AUC: 0.6515 | AP: 0.5095 | Time: 2.1m
Epoch 06 | Loss: 0.6617 | Val AUC: 0.6516 | AP: 0.5117 | Time: 2.6m
💾 Saved new best model (AUC: 0.6516)
Epoch 07 | Loss: 0.6616 | Val AUC: 0.6529 | AP: 0.5116 | Time: 3.0m
💾 Saved new best model (AUC: 0.6529)
Epoch 08 | Loss: 0.6617 | Val AUC: 0.6525 | AP: 0.5193 | Time: 3.4m
Epoch 09 | Loss: 0.6612 | Val AUC: 0.6502 | AP: 0.5103 | Time: 3.8m
Epoch 10 | Loss: 0.6611 | Val AUC: 0.6561 | AP: 0.5157 | Time: 4.3m
💾 Saved new best model (AUC: 0.6561)
Epoch 11 | Loss: 0.6616 | Val AUC: 0.6566 | AP: 0.5143 | Time: 4.7m
Epoch 12 | Lo

In [None]:
!pip install torch-sparse -f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__)").html
!pip install torch-scatter -f https://data.pyg.org/whl/torch-$(python -c "import torch; print(torch.__version__)").html
!pip install torch-geometric


In [None]:
# add right before for-epoch loop
import os
os.environ["OMP_NUM_THREADS"] = "1"
torch.set_num_threads(1)


In [30]:
# %% Load model and predict - SIMPLE VERSION
import torch
from torch_geometric.nn import SAGEConv
import numpy as np

# Define model architecture (same as training)
class ImprovedGraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=256, n_layers=2, dropout=0.3):
        super().__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(n_layers-1):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.dropout = torch.nn.Dropout(dropout)
        self.act = torch.nn.ReLU()
    
    def forward(self, x, edge_index):
        for conv in self.convs:
            x = conv(x, edge_index)
            x = self.act(x)
            x = self.dropout(x)
        return x

class ImprovedEdgeClassifier(torch.nn.Module):
    def __init__(self, node_emb_dim, hidden=128, dropout=0.3):
        super().__init__()
        self.mlp = torch.nn.Sequential(
            torch.nn.Linear(node_emb_dim * 2, hidden),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden, 1)
        )
    
    def forward(self, ha, hb):
        h = torch.cat([ha, hb], dim=1)
        return self.mlp(h).squeeze(1)

# Load model
in_dim = X_fp_svd.shape[1]
encoder = ImprovedGraphSAGE(in_dim).to(device)
edge_clf = ImprovedEdgeClassifier(256).to(device)

checkpoint = torch.load("/kaggle/working/gnn_optimized.pth", map_location=device)
encoder.load_state_dict(checkpoint['encoder'])
edge_clf.load_state_dict(checkpoint['edge_clf'])
encoder.eval()
edge_clf.eval()

print("Model loaded")

Model loaded


In [32]:
# %% Use the loaded model to predict on test data
def predict_with_loaded_model():
    """Use ONLY the loaded gnn_optimized.pth model to predict"""
    
    encoder.eval()
    edge_clf.eval()
    
    print("🧪 Using gnn_optimized.pth to predict on test data...")
    
    # Take random samples from test set
    test_samples = test_df.sample(n=50, random_state=42)
    
    predictions = []
    actual_labels = []
    
    with torch.no_grad():
        node_emb = encoder(data.x, data.edge_index)
        
        for idx, row in test_samples.iterrows():
            u_idx = row['u']
            v_idx = row['v']
            true_label = row['label']
            
            # Get node embeddings
            ha = node_emb[u_idx]
            hb = node_emb[v_idx]
            
            # Predict using the loaded model
            logit = edge_clf(ha.unsqueeze(0), hb.unsqueeze(0))
            pred_prob = torch.sigmoid(logit).item()
            pred_label = 1 if pred_prob >= 0.5 else 0
            
            predictions.append(pred_label)
            actual_labels.append(true_label)
            
            # Print result
            drug_a = row['drug_a']
            drug_b = row['drug_b']
            print(f"{drug_a} + {drug_b} | True: {true_label} | Pred: {pred_label} | Prob: {pred_prob:.3f}")
    
    # Calculate accuracy
    accuracy = (np.array(predictions) == np.array(actual_labels)).mean()
    print(f"\n📊 Accuracy: {accuracy:.4f}")
    
    return predictions, actual_labels

# Run prediction
predictions, actual_labels = predict_with_loaded_model()

🧪 Using gnn_optimized.pth to predict on test data...
DB01212 + DB13321 | True: 0 | Pred: 0 | Prob: 0.000
DB00864 + DB06216 | True: 0 | Pred: 1 | Prob: 0.685
DB00404 + DB01171 | True: 0 | Pred: 1 | Prob: 0.679
DB01224 + DB01601 | True: 1 | Pred: 1 | Prob: 0.705
DB00196 + DB01327 | True: 0 | Pred: 1 | Prob: 0.655
DB00404 + DB00421 | True: 1 | Pred: 1 | Prob: 0.659
DB00182 + DB03049 | True: 0 | Pred: 0 | Prob: 0.000
DB00186 + DB00996 | True: 0 | Pred: 1 | Prob: 0.654
DB08930 + DB14646 | True: 1 | Pred: 1 | Prob: 0.613
DB00706 + DB00727 | True: 1 | Pred: 1 | Prob: 0.663
DB00264 + DB01056 | True: 0 | Pred: 0 | Prob: 0.000
DB00945 + DB04861 | True: 0 | Pred: 1 | Prob: 0.637
DB00497 + DB06273 | True: 1 | Pred: 1 | Prob: 0.680
DB00181 + DB00987 | True: 1 | Pred: 1 | Prob: 0.706
DB00878 + DB01914 | True: 0 | Pred: 1 | Prob: 0.655
DB00213 + DB02108 | True: 0 | Pred: 0 | Prob: 0.000
DB04224 + DB13961 | True: 1 | Pred: 0 | Prob: 0.000
DB00945 + DB06643 | True: 1 | Pred: 1 | Prob: 0.632
DB00333 + D