In [None]:
!pip install torch-geometric


In [48]:
# =============================================================
# FULL END-TO-END DDI SYSTEM (KAGGLE-READY)
# GNN + MC Dropout + Reliability Auditor + Evaluation + Saving
# Author: Gaurav
# =============================================================

# =============================================================
# 0. IMPORTS & GLOBAL CONFIG (KAGGLE SAFE)
# =============================================================
import warnings, os, json, joblib, random
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
from tqdm import tqdm
from collections import Counter, defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.cuda.amp import autocast, GradScaler

from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv, global_mean_pool, global_max_pool, LayerNorm

from rdkit import Chem
from rdkit.Chem import rdchem, AllChem
from rdkit.Chem.Scaffolds import MurckoScaffold

from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.utils.class_weight import compute_class_weight
from xgboost import XGBClassifier


# =============================================================
# EMA HELPER (MEDICAL-GRADE STABILITY) – KEEP ONLY ONCE
# =============================================================

class EMA:
    def __init__(self, model, decay=0.999):
        self.decay = decay
        self.shadow = {}
        self.backup = {}

        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = (
                    self.decay * self.shadow[name]
                    + (1.0 - self.decay) * param.data
                )

    def apply_shadow(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self, model):
        for name, param in model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}

# Performance flags
torch.backends.cudnn.benchmark = True
torch.set_float32_matmul_precision("high")

# Reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(42)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

FILE_PATH = "/kaggle/input/drugbank3/drugbank.tab"   
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

Device: cuda


In [38]:
import torch
import torch_geometric

print("Torch version:", torch.__version__)
print("PyG version:", torch_geometric.__version__)
print("CUDA available:", torch.cuda.is_available())


ERROR! Session/line number was not unique in database. History logging moved to new session 30
Torch version: 2.6.0+cu124
PyG version: 2.7.0
CUDA available: True


In [39]:
# =============================================================
# 1. LOAD & CLEAN DATASET
# =============================================================
df = pd.read_csv(FILE_PATH, sep="\t")
print("Raw samples:", len(df))
print("Raw classes:", df['Y'].nunique())

# Convert labels to 0-based
if df['Y'].min() >= 1:
    df['Y'] -= 1

# Remove rare classes
MIN_SAMPLES = 20
valid = df['Y'].value_counts()
valid = valid[valid >= MIN_SAMPLES].index
df = df[df['Y'].isin(valid)].copy()

# Relabel
label_map = {old: new for new, old in enumerate(sorted(df['Y'].unique()))}
df['Y'] = df['Y'].map(label_map)
NUM_CLASSES = df['Y'].nunique()

print("Filtered samples:", len(df))
print("Final classes:", NUM_CLASSES)

# Validate SMILES
def valid_smiles(s):
    return Chem.MolFromSmiles(str(s)) is not None

df = df[df['X1'].apply(valid_smiles) & df['X2'].apply(valid_smiles)]
df.reset_index(drop=True, inplace=True)


Raw samples: 191808
Raw classes: 86
Filtered samples: 191700
Final classes: 76


In [40]:
# =============================================================
# 2. SCAFFOLD SPLIT (NO CHEMICAL LEAKAGE)
# =============================================================
print("Performing Scaffold Split (Scientific Best Practice)...")
def get_scaffold(smiles):
    mol = Chem.MolFromSmiles(str(smiles))
    if mol is None: return None
    return MurckoScaffold.MurckoScaffoldSmiles(mol=mol)

scaffold_groups = defaultdict(list)
for idx, row in tqdm(df.iterrows(), total=len(df)):
    scaf = get_scaffold(row.X1)
    if scaf: scaffold_groups[scaf].append(idx)

scaffolds = sorted(scaffold_groups.keys(), key=lambda x: len(scaffold_groups[x]), reverse=True)
train_idx, test_idx = [], []

for scaf in scaffolds:
    if len(train_idx) < 0.8 * len(df): train_idx.extend(scaffold_groups[scaf])
    else: test_idx.extend(scaffold_groups[scaf])

train_df = df.loc[train_idx].reset_index(drop=True)
test_df  = df.loc[test_idx].reset_index(drop=True)

print("Train:", len(train_df), "Test:", len(test_df))


Performing Scaffold Split (Scientific Best Practice)...


100%|██████████| 191700/191700 [01:39<00:00, 1918.94it/s]

Train: 153395 Test: 29323





In [41]:
# =============================================================
# 3. GRAPH CONSTRUCTION
# =============================================================
def atom_features(atom):
    hyb = {
        rdchem.HybridizationType.SP: 0,
        rdchem.HybridizationType.SP2: 1,
        rdchem.HybridizationType.SP3: 2,
        rdchem.HybridizationType.SP3D: 3,
        rdchem.HybridizationType.SP3D2: 4,
    }
    return [
        atom.GetAtomicNum(), atom.GetTotalDegree(),
        hyb.get(atom.GetHybridization(), 5),
        int(atom.GetIsAromatic()), atom.GetFormalCharge(),
        int(atom.IsInRing()), atom.GetNumRadicalElectrons(),
        int(atom.GetChiralTag())
    ]

def bond_features(bond):
    bt = {
        rdchem.BondType.SINGLE: 0,
        rdchem.BondType.DOUBLE: 1,
        rdchem.BondType.TRIPLE: 2,
        rdchem.BondType.AROMATIC: 3,
    }
    return [bt.get(bond.GetBondType(), 4), int(bond.GetIsConjugated()), int(bond.IsInRing())]

def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(str(smiles))
    if mol is None: return None
    x = torch.tensor([atom_features(a) for a in mol.GetAtoms()], dtype=torch.float)
    ei, ea = [], []
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        ei += [[i, j], [j, i]]
        f = bond_features(b)
        ea += [f, f]
    if len(ei) == 0:
        ei = torch.empty((2, 0), dtype=torch.long)
        ea = torch.empty((0, 3), dtype=torch.float)
    else:
        ei = torch.tensor(ei).t().contiguous()
        ea = torch.tensor(ea, dtype=torch.float)
    return Data(x=x, edge_index=ei, edge_attr=ea)

In [42]:
# =============================================================
# 4. DATASET & LOADERS 
# =============================================================

MICRO_BATCH = 32        
ACCUM_STEPS = 4        

class DDIDataset(torch.utils.data.Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        r = self.df.iloc[idx]
        g1 = smiles_to_graph(r.X1)
        g2 = smiles_to_graph(r.X2)
        if g1 is None or g2 is None:
            return None
        return g1, g2, torch.tensor(r.Y, dtype=torch.long)

def collate(batch):
    batch = [b for b in batch if b is not None]
    if not batch:
        return None
    g1, g2, y = zip(*batch)
    return (
        Batch.from_data_list(g1),
        Batch.from_data_list(g2),
        torch.stack(y)
    )

train_loader = DataLoader(
    DDIDataset(train_df),
    batch_size=MICRO_BATCH,
    shuffle=True,
    collate_fn=collate,
    pin_memory=True,
    num_workers=2
)

test_loader = DataLoader(
    DDIDataset(test_df),
    batch_size=MICRO_BATCH,
    shuffle=False,
    collate_fn=collate,
    pin_memory=True,
    num_workers=2
)


In [43]:
# =============================================================
# 5. GNN MODEL 
# =============================================================

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = GATv2Conv(8, 128, heads=4, edge_dim=3)
        self.n1 = LayerNorm(512)
        self.c2 = GATv2Conv(512, 128, heads=2, edge_dim=3)
        self.n2 = LayerNorm(256)
        self.c3 = GATv2Conv(256, 128, heads=1, edge_dim=3)
        self.n3 = LayerNorm(128)
        self.dp = nn.Dropout(0.25)

    def forward(self, g):
        x, ei, ea, b = g.x, g.edge_index, g.edge_attr, g.batch
        x = self.dp(F.elu(self.n1(self.c1(x, ei, ea))))
        x = self.dp(F.elu(self.n2(self.c2(x, ei, ea))))
        x = F.elu(self.n3(self.c3(x, ei, ea)))
        return torch.cat(
            [global_mean_pool(x, b), global_max_pool(x, b)],
            dim=1
        )

class SiameseDDI(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc = Encoder()
        self.fc = nn.Sequential(
            nn.Linear(512, 256),
            nn.LayerNorm(256),
            nn.ELU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ELU(),
            nn.Linear(128, NUM_CLASSES)
        )

    def forward(self, g1, g2):
        return self.fc(
            torch.cat([self.enc(g1), self.enc(g2)], dim=1)
        )


model = SiameseDDI().to(DEVICE)
ema = EMA(model, decay=0.999)

print("Running on SINGLE GPU (PyG safe)")

Running on SINGLE GPU (PyG safe)


In [49]:
# =============================================================
# 6. TRAINING SETUP (FINAL – MEDICAL GRADE)
# =============================================================

# ---- CLASS WEIGHTS (FIXED MISMATCH SAFE) ----
train_classes = np.unique(train_df.Y)

computed_weights = compute_class_weight(
    class_weight="balanced",
    classes=train_classes,
    y=train_df.Y
)

full_weights = np.ones(NUM_CLASSES)
full_weights[train_classes] = computed_weights
weights = torch.tensor(np.sqrt(full_weights), dtype=torch.float).to(DEVICE)

class FocalLoss(nn.Module):
    def forward(self, x, y):
        ce = F.cross_entropy(x, y, weight=weights, reduction="none")
        pt = torch.exp(-ce)
        return (((1 - pt) ** 2) * ce).mean()

criterion = FocalLoss()

BASE_LR = 3e-4       
optimizer = AdamW(
    model.parameters(),
    lr=BASE_LR,
    weight_decay=1e-4
)

# ---- LR SCHEDULER ----
scheduler = ReduceLROnPlateau(
    optimizer,
    mode="max",
    patience=3,
    factor=0.5
)

# ---- AMP ----
scaler = GradScaler(enabled=(DEVICE.type == "cuda"))

# ---- WARMUP ----
WARMUP_EPOCHS = 2


In [53]:
# =============================================================
# 7. TRAINING LOOP (FINAL - WITH PROPER EMA VALIDATION)
# =============================================================

best_f1, wait = 0, 0
EPOCHS = 30

for ep in range(EPOCHS):
    # ---- LR WARMUP ----
    if ep < WARMUP_EPOCHS:
        lr_scale = (ep + 1) / WARMUP_EPOCHS
        for pg in optimizer.param_groups:
            pg["lr"] = BASE_LR * lr_scale

    model.train()
    optimizer.zero_grad()
    total_loss = 0

    for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {ep+1}")):
        if batch is None: continue

        g1, g2, y = batch
        g1, g2, y = g1.to(DEVICE), g2.to(DEVICE), y.to(DEVICE)

        with autocast(enabled=(DEVICE.type == "cuda")):
            out = model(g1, g2)
            loss = criterion(out, y) / ACCUM_STEPS

        scaler.scale(loss).backward()
        total_loss += loss.item()

        if (step + 1) % ACCUM_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            ema.update(model) 
    # ---------- VALIDATION (USING EMA WEIGHTS) ----------
    ema.apply_shadow(model) 
    
    model.eval()
    yt, yp = [], []

    with torch.no_grad():
        for batch in test_loader:
            if batch is None: continue
            g1, g2, y = batch
            out = model(g1.to(DEVICE), g2.to(DEVICE))
            yt.extend(y.tolist())
            yp.extend(out.argmax(dim=1).cpu().tolist())

    ema.restore(model)

    acc = accuracy_score(yt, yp)
    f1  = f1_score(yt, yp, average="weighted")

    print(f"Epoch {ep+1:02d} | Loss {total_loss:.2f} | Acc {acc:.4f} | F1 {f1:.4f}")

    scheduler.step(f1)

    if f1 > best_f1:
        best_f1 = f1
        wait = 0
        ema.apply_shadow(model)
        torch.save(model.state_dict(), "best_ddi_model.pt")
        ema.restore(model)
    else:
        wait += 1
        if wait >= 6:
            print("Early stopping triggered.")
            break

Epoch 1: 100%|██████████| 4794/4794 [03:42<00:00, 21.54it/s]


Epoch 01 | Loss 1034.26 | Acc 0.1550 | F1 0.1549


Epoch 2: 100%|██████████| 4794/4794 [03:47<00:00, 21.09it/s]


Epoch 02 | Loss 833.95 | Acc 0.1921 | F1 0.1893


Epoch 3: 100%|██████████| 4794/4794 [03:47<00:00, 21.09it/s]


Epoch 03 | Loss 635.79 | Acc 0.2415 | F1 0.2427


Epoch 4: 100%|██████████| 4794/4794 [03:47<00:00, 21.07it/s]


Epoch 04 | Loss 500.72 | Acc 0.2845 | F1 0.2876


Epoch 5: 100%|██████████| 4794/4794 [03:47<00:00, 21.06it/s]


Epoch 05 | Loss 415.00 | Acc 0.3243 | F1 0.3271


Epoch 6: 100%|██████████| 4794/4794 [03:47<00:00, 21.10it/s]


Epoch 06 | Loss 354.53 | Acc 0.3542 | F1 0.3580


Epoch 7: 100%|██████████| 4794/4794 [03:47<00:00, 21.07it/s]


Epoch 07 | Loss 305.91 | Acc 0.3795 | F1 0.3823


Epoch 8: 100%|██████████| 4794/4794 [03:42<00:00, 21.53it/s]


Epoch 08 | Loss 270.81 | Acc 0.3985 | F1 0.4013


Epoch 9: 100%|██████████| 4794/4794 [03:42<00:00, 21.54it/s]


Epoch 09 | Loss 245.70 | Acc 0.4181 | F1 0.4210


Epoch 10: 100%|██████████| 4794/4794 [03:42<00:00, 21.58it/s]


Epoch 10 | Loss 225.83 | Acc 0.4336 | F1 0.4364


Epoch 11: 100%|██████████| 4794/4794 [03:42<00:00, 21.57it/s]


Epoch 11 | Loss 210.67 | Acc 0.4415 | F1 0.4437


Epoch 12: 100%|██████████| 4794/4794 [03:41<00:00, 21.65it/s]


Epoch 12 | Loss 192.77 | Acc 0.4544 | F1 0.4568


Epoch 13: 100%|██████████| 4794/4794 [03:41<00:00, 21.60it/s]


Epoch 13 | Loss 179.90 | Acc 0.4676 | F1 0.4699


Epoch 14: 100%|██████████| 4794/4794 [03:42<00:00, 21.57it/s]


Epoch 14 | Loss 167.10 | Acc 0.4772 | F1 0.4798


Epoch 15: 100%|██████████| 4794/4794 [03:42<00:00, 21.58it/s]


Epoch 15 | Loss 157.55 | Acc 0.4861 | F1 0.4890


Epoch 16: 100%|██████████| 4794/4794 [03:41<00:00, 21.60it/s]


Epoch 16 | Loss 152.51 | Acc 0.4955 | F1 0.4990


Epoch 17: 100%|██████████| 4794/4794 [03:41<00:00, 21.65it/s]


Epoch 17 | Loss 143.08 | Acc 0.5056 | F1 0.5097


Epoch 18: 100%|██████████| 4794/4794 [03:42<00:00, 21.55it/s]


Epoch 18 | Loss 138.36 | Acc 0.5108 | F1 0.5148


Epoch 19: 100%|██████████| 4794/4794 [03:42<00:00, 21.57it/s]


Epoch 19 | Loss 128.84 | Acc 0.5208 | F1 0.5249


Epoch 20: 100%|██████████| 4794/4794 [03:41<00:00, 21.64it/s]


Epoch 20 | Loss 124.05 | Acc 0.5272 | F1 0.5314


Epoch 21: 100%|██████████| 4794/4794 [03:41<00:00, 21.69it/s]


Epoch 21 | Loss 119.50 | Acc 0.5317 | F1 0.5358


Epoch 22: 100%|██████████| 4794/4794 [03:42<00:00, 21.57it/s]


Epoch 22 | Loss 114.77 | Acc 0.5361 | F1 0.5396


Epoch 23: 100%|██████████| 4794/4794 [03:41<00:00, 21.64it/s]


Epoch 23 | Loss 111.32 | Acc 0.5425 | F1 0.5461


Epoch 24: 100%|██████████| 4794/4794 [03:43<00:00, 21.50it/s]


Epoch 24 | Loss 110.42 | Acc 0.5464 | F1 0.5496


Epoch 25: 100%|██████████| 4794/4794 [03:42<00:00, 21.53it/s]


Epoch 25 | Loss 102.96 | Acc 0.5536 | F1 0.5572


Epoch 26: 100%|██████████| 4794/4794 [03:42<00:00, 21.56it/s]


Epoch 26 | Loss 103.61 | Acc 0.5543 | F1 0.5574


Epoch 27: 100%|██████████| 4794/4794 [03:44<00:00, 21.32it/s]


Epoch 27 | Loss 95.40 | Acc 0.5577 | F1 0.5620


Epoch 28: 100%|██████████| 4794/4794 [03:48<00:00, 21.02it/s]


Epoch 28 | Loss 94.15 | Acc 0.5666 | F1 0.5705


Epoch 29: 100%|██████████| 4794/4794 [03:47<00:00, 21.08it/s]


Epoch 29 | Loss 94.06 | Acc 0.5665 | F1 0.5699


Epoch 30: 100%|██████████| 4794/4794 [03:47<00:00, 21.04it/s]


Epoch 30 | Loss 90.71 | Acc 0.5692 | F1 0.5730


In [57]:
# =============================================================
# 8. STANDARD EVALUATION (SYMMETRY ENFORCED VIA TTA)
# =============================================================
print("Loading Best Model for Evaluation with TTA...")
model.load_state_dict(torch.load("best_ddi_model.pt", map_location=DEVICE))
model.eval()

y_true, y_pred = [], []

with torch.no_grad():
    for batch in test_loader:
        if batch is None: continue
        g1, g2, y = batch
        g1, g2 = g1.to(DEVICE), g2.to(DEVICE)

        # Forward Pass 1 (Normal)
        out1 = model(g1, g2)
        
        # Forward Pass 2 (Swapped/Flipped) -> Yeh hai magic trick
        out2 = model(g2, g1) 

        # Average the predictions (Enforcing Symmetry)
        # Softmax pehle lagana zaroori hai average karne se pehle
        probs1 = F.softmax(out1, dim=1)
        probs2 = F.softmax(out2, dim=1)
        
        avg_probs = (probs1 + probs2) / 2
        
        y_true.extend(y.tolist())
        y_pred.extend(avg_probs.argmax(dim=1).cpu().tolist())

print("Final Accuracy (Symmetrized):", accuracy_score(y_true, y_pred))
print("Final Weighted F1 (Symmetrized):", f1_score(y_true, y_pred, average="weighted"))

Loading Best Model for Evaluation with TTA...
Final Accuracy (Symmetrized): 0.49384442246700544
Final Weighted F1 (Symmetrized): 0.5016178289846266


In [58]:
# =============================================================
# 9. MC DROPOUT UNCERTAINTY
# =============================================================
def enable_mc_dropout(m):
    for x in m.modules():
        if isinstance(x, nn.Dropout): x.train()

def mc_dropout_predict(model, g1, g2, T=5):
    model.eval()
    enable_mc_dropout(model) # Force dropout on
    probs = []
    with torch.no_grad():
        for _ in range(T):
            # T forward passes with different dropout masks
            probs.append(F.softmax(model(g1, g2), dim=1).unsqueeze(0))
    
    probs = torch.cat(probs, dim=0)
    mean_probs = probs.mean(dim=0)
    # Entropy calculation (Uncertainty Metric)
    entropy = -torch.sum(mean_probs * torch.log(mean_probs + 1e-9), dim=1)
    return mean_probs, entropy

In [63]:
# =============================================================
# 10.1 RELIABILITY AUDITOR
# =============================================================
import numpy as np
from xgboost import XGBClassifier
from rdkit.Chem import AllChem

def fp_batch(smiles):
    """
    Generate Morgan fingerprints for a list of SMILES.
    Output shape: (N, 2048)
    """
    fps = []
    for s in smiles:
        mol = Chem.MolFromSmiles(str(s))
        if mol is None:
            fps.append(np.zeros(2048))
        else:
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=2, nBits=2048)
            fps.append(np.array(fp))
    return np.array(fps)


# -------------------------------------------------------------
# 10.2 Prepare Auditor Features (DrugA || DrugB)
# -------------------------------------------------------------
print("Generating fingerprints for Auditor...")

X_train_fp = np.concatenate(
    [fp_batch(train_df.X1), fp_batch(train_df.X2)], axis=1
)  # shape: (N, 4096)

y_train_fp = train_df.Y.values

X_test_fp = np.concatenate(
    [fp_batch(test_df.X1), fp_batch(test_df.X2)], axis=1
)


# -------------------------------------------------------------
# 10.3 HANDLE MISSING CLASSES (CORRECT WAY: LABEL REMAPPING)
# -------------------------------------------------------------
# Auditor is trained ONLY on classes present in training data
audit_classes = sorted(np.unique(y_train_fp))
print(f"Auditor training on {len(audit_classes)} / {NUM_CLASSES} classes")

# Create forward & inverse maps
audit_map = {c: i for i, c in enumerate(audit_classes)}
inv_audit_map = {i: c for c, i in audit_map.items()}

# Remap labels for auditor training
y_train_audit = np.array([audit_map[y] for y in y_train_fp])


# -------------------------------------------------------------
# 10.4 Train XGBoost Auditor
# -------------------------------------------------------------
print("Training Reliability Auditor (XGBoost)...")

auditor = XGBClassifier(
    n_estimators=100,
    max_depth=6,
    learning_rate=0.1,
    objective="multi:softprob",
    num_class=len(audit_classes),
    n_jobs=-1,
    tree_method="hist",
    random_state=42
)

auditor.fit(X_train_fp, y_train_audit)

print("Auditor trained successfully (clean & valid).")


# -------------------------------------------------------------
# 10.5 Auditor Prediction Helper (for Reliability Loop)
# -------------------------------------------------------------
def auditor_predict(fp_batch_slice):
    """
    Predict original DDI class IDs using the trained auditor.
    """
    raw_preds = auditor.predict(fp_batch_slice)
    return np.array([inv_audit_map[p] for p in raw_preds])

Generating fingerprints for Auditor...
Auditor training on 74 / 76 classes
Training Reliability Auditor (XGBoost)...
Auditor trained successfully (clean & valid).


In [65]:
# =============================================================
# 11. RELIABILITY-AWARE AUDITING LOOP 
# =============================================================

y_true_a, y_pred_a, reliab = [], [], []
ptr = 0

# Normalized class frequency from TRAIN set (risk prior)
class_freq = train_df.Y.value_counts(normalize=True)

print("Running Reliability Audit...")

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Auditing"):
        if batch is None:
            continue

        g1, g2, y = batch
        g1, g2 = g1.to(DEVICE), g2.to(DEVICE)

        # -----------------------------------------------------
        # 1. GNN Prediction with MC Dropout (Uncertainty-aware)
        # -----------------------------------------------------
        probs, entropy = mc_dropout_predict(model, g1, g2)
        preds = probs.argmax(dim=1).cpu().numpy()

        # -----------------------------------------------------
        # 2. Auditor Prediction (USING REMAPPED HELPER)
        # -----------------------------------------------------
        bs = len(y)

        # Important: test_loader is NOT shuffled
        cnt_preds = auditor_predict(X_test_fp[ptr:ptr + bs])

        # -----------------------------------------------------
        # 3. Reliability Score Computation
        # -----------------------------------------------------
        for i in range(bs):
            # Base confidence from GNN
            s = probs[i, preds[i]].item()

            # (a) Auditor disagreement penalty
            if preds[i] != cnt_preds[i]:
                s *= 0.6

            # (b) Rare-class penalty (clinical risk prior)
            freq = class_freq.get(preds[i], 0.0)
            s *= np.exp(-0.5 * (1.0 - freq))

            # (c) Uncertainty penalty (entropy)
            s *= np.exp(-entropy[i].item())

            reliab.append(s)

        # -----------------------------------------------------
        # 4. Store outputs
        # -----------------------------------------------------
        y_true_a.extend(y.tolist())
        y_pred_a.extend(preds.tolist())

        ptr += bs


Running Reliability Audit...


Auditing: 100%|██████████| 917/917 [01:06<00:00, 13.87it/s]


In [66]:
# =============================================================
# 12. COVERAGE vs ACCURACY (RELIABILITY-BASED REJECTION CURVE)
# =============================================================

import numpy as np
from sklearn.metrics import accuracy_score, f1_score

# Convert lists to arrays once (efficiency + clarity)
y_true_arr = np.array(y_true_a)
y_pred_arr = np.array(y_pred_a)
reliab_arr = np.array(reliab)

# Sort samples by reliability (ascending: least reliable first)
order = np.argsort(reliab_arr)

# Rejection rates: 0% → 50%
rates = np.linspace(0.0, 0.5, 11)

print("\nReject% | Coverage% | Accuracy | Weighted F1")
print("-" * 50)

for r in rates:
    # Number of samples to reject
    k = int(len(order) * r)

    # Keep the most reliable samples
    keep_idx = order[k:]

    # Safety stop: too few samples left
    if len(keep_idx) < 200:
        break

    acc = accuracy_score(
        y_true_arr[keep_idx],
        y_pred_arr[keep_idx]
    )

    f1 = f1_score(
        y_true_arr[keep_idx],
        y_pred_arr[keep_idx],
        average="weighted"
    )

    coverage = len(keep_idx) / len(order) * 100.0

    print(
        f"{int(r*100):>3}%    | "
        f"{coverage:6.1f}%    | "
        f"{acc:.4f}   | "
        f"{f1:.4f}"
    )



Reject% | Coverage% | Accuracy | Weighted F1
--------------------------------------------------
  0%    |  100.0%    | 0.5732   | 0.5748
  5%    |   95.0%    | 0.5962   | 0.5978
 10%    |   90.0%    | 0.6179   | 0.6195
 15%    |   85.0%    | 0.6398   | 0.6408
 20%    |   80.0%    | 0.6610   | 0.6616
 25%    |   75.0%    | 0.6793   | 0.6792
 30%    |   70.0%    | 0.6979   | 0.6975
 35%    |   65.0%    | 0.7165   | 0.7157
 40%    |   60.0%    | 0.7352   | 0.7338
 45%    |   55.0%    | 0.7505   | 0.7486
 50%    |   50.0%    | 0.7670   | 0.7648


In [70]:
# =============================================================
# 13. SAVE EVERYTHING (FINAL FIX - KEY SANITIZATION)
# =============================================================

import json
import joblib
import torch
import numpy as np

# -------------------------------------------------------------
# 1. Save Models
# -------------------------------------------------------------
torch.save(model.state_dict(), "best_ddi_model.pt")
joblib.dump(auditor, "auditor_xgboost.pkl")
print("✅ Models Saved.")

# -------------------------------------------------------------
# 2. Prepare Metadata with SANITIZED KEYS
# -------------------------------------------------------------


def clean_dict(d):
    """Converts keys to str and values to python types"""
    new_d = {}
    for k, v in d.items():
        # Key fix: Force string
        k_clean = str(k) 
        
        # Value fix: Handle numpy types
        if isinstance(v, (np.integer, np.int64)):
            v_clean = int(v)
        elif isinstance(v, (np.floating, np.float64, np.float32)):
            v_clean = float(v)
        elif isinstance(v, np.ndarray):
            v_clean = v.tolist()
        else:
            v_clean = v
            
        new_d[k_clean] = v_clean
    return new_d

# Manually clean each dictionary
metadata = {
    "label_map": clean_dict(label_map),
    "class_freq": clean_dict(class_freq.to_dict()),
    "audit_classes": [int(x) for x in audit_classes], 
    "audit_map": clean_dict(audit_map),
    "inv_audit_map": clean_dict(inv_audit_map)
}

# -------------------------------------------------------------
# 3. Save JSON 
# -------------------------------------------------------------
with open("ddi_metadata.json", "w") as f:
    json.dump(metadata, f, indent=2)

print("Metadata Saved Successfully .")


# -------------------------------------------------------------
# 4. Final Confirmation & Download
# -------------------------------------------------------------
print("\nALL FILES SAVED:")
print("- best_ddi_model.pt")
print("- auditor_xgboost.pkl")
print("- ddi_metadata.json")

from IPython.display import FileLink
display(FileLink("best_ddi_model.pt"))
display(FileLink("auditor_xgboost.pkl"))
display(FileLink("ddi_metadata.json"))

✅ Models Saved.
Metadata Saved Successfully .

ALL FILES SAVED:
- best_ddi_model.pt
- auditor_xgboost.pkl
- ddi_metadata.json


In [71]:
# =============================================================
# 14. AUTO DOWNLOAD TRAINED ASSETS
# =============================================================

import os
from IPython.display import FileLink, display

print("Download trained assets:")

files_to_download = [
    "best_ddi_model.pt",
    "auditor_xgboost.pkl",
    "ddi_metadata.json"
]

for file in files_to_download:
    if os.path.exists(file):
        display(FileLink(file))
    else:
        print(f" File not found: {file}")



Download trained assets:


In [72]:
import os, json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import (
    accuracy_score,
    f1_score,
    classification_report,
    confusion_matrix
)
from IPython.display import FileLink, display

os.makedirs("exports", exist_ok=True)

# =============================================================
# 15.1 SAVE FINAL METRICS REPORT
# =============================================================
final_metrics = {
    "accuracy": accuracy_score(y_true_a, y_pred_a),
    "weighted_f1": f1_score(y_true_a, y_pred_a, average="weighted"),
    "total_samples": len(y_true_a)
}

with open("exports/final_metrics.json", "w") as f:
    json.dump(final_metrics, f, indent=2)

# =============================================================
# 15.2 SAVE CLASSIFICATION REPORT
# =============================================================
report = classification_report(
    y_true_a, y_pred_a, output_dict=True
)
pd.DataFrame(report).transpose().to_csv(
    "exports/classification_report.csv"
)

# =============================================================
# 15.3 CONFUSION MATRIX (PNG)
# =============================================================
cm = confusion_matrix(y_true_a, y_pred_a)

plt.figure(figsize=(10, 8))
plt.imshow(cm, cmap="Blues")
plt.title("Confusion Matrix (DDI)")
plt.xlabel("Predicted")
plt.ylabel("True")
plt.colorbar()
plt.tight_layout()
plt.savefig("exports/confusion_matrix.png", dpi=300)
plt.close()

# =============================================================
# 15.4 COVERAGE vs ACCURACY TABLE (CSV)
# =============================================================
coverage_rows = []

y_true_arr = np.array(y_true_a)
y_pred_arr = np.array(y_pred_a)
reliab_arr = np.array(reliab)
order = np.argsort(reliab_arr)

rates = np.linspace(0, 0.5, 11)

for r in rates:
    k = int(len(order) * r)
    keep = order[k:]
    if len(keep) < 200:
        break

    coverage_rows.append({
        "reject_percent": int(r * 100),
        "coverage_percent": len(keep) / len(order) * 100,
        "accuracy": accuracy_score(y_true_arr[keep], y_pred_arr[keep]),
        "weighted_f1": f1_score(y_true_arr[keep], y_pred_arr[keep], average="weighted")
    })

pd.DataFrame(coverage_rows).to_csv(
    "exports/coverage_vs_accuracy.csv", index=False
)

# =============================================================
# 15.5 COVERAGE vs F1 GRAPH (PNG)
# =============================================================
plt.figure(figsize=(7, 5))
plt.plot(
    [r["coverage_percent"] for r in coverage_rows],
    [r["weighted_f1"] for r in coverage_rows],
    marker="o"
)
plt.xlabel("Coverage (%)")
plt.ylabel("Weighted F1")
plt.title("Coverage vs Weighted F1")
plt.grid(True)
plt.tight_layout()
plt.savefig("exports/coverage_vs_f1.png", dpi=300)
plt.close()

# =============================================================
# 15.6 RELIABILITY SCORE DISTRIBUTION
# =============================================================
plt.figure(figsize=(7, 5))
plt.hist(reliab_arr, bins=50)
plt.xlabel("Reliability Score")
plt.ylabel("Count")
plt.title("Reliability Score Distribution")
plt.tight_layout()
plt.savefig("exports/reliability_distribution.png", dpi=300)
plt.close()

# =============================================================
# 15.7 SAVE RAW PREDICTIONS (CSV)
# =============================================================
pd.DataFrame({
    "y_true": y_true_a,
    "y_pred": y_pred_a,
    "reliability": reliab
}).to_csv("exports/raw_predictions.csv", index=False)

# =============================================================
# 15.8 LIST & DOWNLOAD ALL FILES
# =============================================================
print("\n DOWNLOAD ALL EXPORTS:")

all_files = [
    "best_ddi_model.pt",
    "auditor_xgboost.pkl",
    "ddi_metadata.json",
    "exports/final_metrics.json",
    "exports/classification_report.csv",
    "exports/confusion_matrix.png",
    "exports/coverage_vs_accuracy.csv",
    "exports/coverage_vs_f1.png",
    "exports/reliability_distribution.png",
    "exports/raw_predictions.csv"
]

for f in all_files:
    if os.path.exists(f):
        display(FileLink(f))
    else:
        print(f" Missing: {f}")



 DOWNLOAD ALL EXPORTS:


In [74]:
# =============================================================
# FINAL COMPREHENSIVE EVALUATION REPORT
# =============================================================
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import (
    accuracy_score, f1_score, precision_score, recall_score,
    top_k_accuracy_score, classification_report
)
from tqdm import tqdm

print("GENERATING FINAL MEDICAL-GRADE REPORT...\n")

# 1. SETUP
# -------------------------------------------------------------
model.eval()
y_true = []
y_pred_hard = []  # Hard class prediction (Argmax)
y_probs_all = []  # Full probability vectors
y_reliab = []     # Reliability scores

# Ensure Auditor Data exists (Safety Check)
try:
    if 'X_test_fp' not in locals(): raise NameError
except:
    print("Auditor features missing. Generating on fly (might take 2 mins)...")
    X_test_fp = np.concatenate([fp_batch(test_df.X1), fp_batch(test_df.X2)], axis=1)

ptr = 0
class_freq = train_df.Y.value_counts(normalize=True)

# 2. EVALUATION LOOP (With Symmetry & Reliability)
# -------------------------------------------------------------
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        if batch is None: continue
        g1, g2, y = batch
        g1, g2 = g1.to(DEVICE), g2.to(DEVICE)

        # A. Symmetric Prediction (TTA)
        out1 = model(g1, g2)
        out2 = model(g2, g1)
        
        # Softmax & Average
        p1 = F.softmax(out1, dim=1)
        p2 = F.softmax(out2, dim=1)
        avg_probs = (p1 + p2) / 2
        
        # B. Get Predictions
        preds = avg_probs.argmax(dim=1).cpu().numpy()
        probs_np = avg_probs.cpu().numpy()
        
        # C. Auditor Check (Reliability)
        bs = len(y)
        # Handle Auditor Label Mapping
        raw_audit_preds = auditor.predict(X_test_fp[ptr:ptr+bs])
        audit_preds = [inv_audit_map.get(p, -1) for p in raw_audit_preds]
        
        # Calculate Reliability Score per Sample
        batch_reliab = []
        for i in range(bs):
            conf = probs_np[i, preds[i]]
            # 1. Disagreement Penalty
            if preds[i] != audit_preds[i]: conf *= 0.6
            # 2. Rare Class Penalty
            conf *= np.exp(-0.5 * (1 - class_freq.get(preds[i], 0)))
            batch_reliab.append(conf)
            
        # Store Data
        y_true.extend(y.tolist())
        y_pred_hard.extend(preds.tolist())
        y_probs_all.extend(probs_np)
        y_reliab.extend(batch_reliab)
        ptr += bs

# Convert to Numpy
y_true = np.array(y_true)
y_pred_hard = np.array(y_pred_hard)
y_probs_all = np.array(y_probs_all)
y_reliab = np.array(y_reliab)

print("\n" + "="*60)
print("FINAL DDI MODEL REPORT CARD")
print("="*60)

# 3. BASE METRICS
# -------------------------------------------------------------
acc = accuracy_score(y_true, y_pred_hard)
f1_w = f1_score(y_true, y_pred_hard, average='weighted')
f1_m = f1_score(y_true, y_pred_hard, average='macro')
prec = precision_score(y_true, y_pred_hard, average='weighted')
rec = recall_score(y_true, y_pred_hard, average='weighted')

print(f"\n1. BASE PERFORMANCE (All Samples):")
print(f"   - Accuracy:          {acc:.4f}  (Base capability)")
print(f"   - Weighted F1:       {f1_w:.4f}  (Real-world utility)")
print(f"   - Macro F1:          {f1_m:.4f}  (Rare class handling)")
print(f"   - Precision:         {prec:.4f}")
print(f"   - Recall:            {rec:.4f}")

# 4. TOP-K METRICS (Medical Context)
# -------------------------------------------------------------
top3 = top_k_accuracy_score(y_true, y_probs_all, k=3, labels=np.arange(NUM_CLASSES))
top5 = top_k_accuracy_score(y_true, y_probs_all, k=5, labels=np.arange(NUM_CLASSES))

print(f"\n2. DOCTOR ASSISTANT METRICS:")
print(f"   - Top-3 Accuracy:    {top3:.4f}  (Correct answer in top 3 suggestions)")
print(f"   - Top-5 Accuracy:    {top5:.4f}  (Correct answer in top 5 suggestions)")

# 5. RISK-COVERAGE CURVE (The Auditor's Job)
# -------------------------------------------------------------
print(f"\n3. RELIABILITY AUDIT (Rejection Curve):")
print("   Reject% | Coverage | Accuracy | Risk Reduction")
print("   " + "-"*45)

sorted_indices = np.argsort(y_reliab) # Low reliability first
thresholds = [0, 0.1, 0.2, 0.3, 0.4, 0.5]

for r in thresholds:
    # Keep top (1-r)% reliability
    if r == 0:
        keep_mask = np.ones_like(y_true, dtype=bool)
    else:
        cutoff = int(len(y_true) * r)
        keep_indices = sorted_indices[cutoff:] # Remove bottom r%
        keep_mask = np.zeros_like(y_true, dtype=bool)
        keep_mask[keep_indices] = True
    
    subset_acc = accuracy_score(y_true[keep_mask], y_pred_hard[keep_mask])
    print(f"   {int(r*100):>3}%    | {int((1-r)*100):>3}%    | {subset_acc:.4f}   | +{(subset_acc-acc)*100:.1f}%")

# 6. CLASS WISE PERFORMANCE (Short Summary)
# -------------------------------------------------------------
print(f"\n4. CLASS DIAGNOSTICS:")
report = classification_report(y_true, y_pred_hard, output_dict=True)
df_rep = pd.DataFrame(report).transpose()
# Filter out 'accuracy', 'macro avg', etc.
df_classes = df_rep.iloc[:-3].sort_values(by='f1-score', ascending=False)

best_3 = df_classes.head(3).index.tolist()
worst_3 = df_classes[df_classes['support'] > 50].tail(3).index.tolist() # Min support to be fair

print(f"   - Best Classes (IDs):  {best_3} (F1 ~ {df_classes.head(3)['f1-score'].mean():.2f})")
print(f"   - Hardest Classes:     {worst_3} (F1 ~ {df_classes.tail(3)['f1-score'].mean():.2f})")

print("\n" + "="*60)
print("END OF REPORT")
print("="*60)

GENERATING FINAL MEDICAL-GRADE REPORT...



Evaluating: 100%|██████████| 917/917 [00:49<00:00, 18.49it/s]



FINAL DDI MODEL REPORT CARD

1. BASE PERFORMANCE (All Samples):
   - Accuracy:          0.4938  (Base capability)
   - Weighted F1:       0.5016  (Real-world utility)
   - Macro F1:          0.3553  (Rare class handling)
   - Precision:         0.5736
   - Recall:            0.4938

2. DOCTOR ASSISTANT METRICS:
   - Top-3 Accuracy:    0.8426  (Correct answer in top 3 suggestions)
   - Top-5 Accuracy:    0.9245  (Correct answer in top 5 suggestions)

3. RELIABILITY AUDIT (Rejection Curve):
   Reject% | Coverage | Accuracy | Risk Reduction
   ---------------------------------------------
     0%    | 100%    | 0.4938   | +0.0%
    10%    |  90%    | 0.5356   | +4.2%
    20%    |  80%    | 0.5835   | +9.0%
    30%    |  70%    | 0.6397   | +14.6%
    40%    |  60%    | 0.7080   | +21.4%
    50%    |  50%    | 0.7758   | +28.2%

4. CLASS DIAGNOSTICS:
   - Best Classes (IDs):  ['20', '63', '49'] (F1 ~ 0.98)
   - Hardest Classes:     ['55', '28', '2'] (F1 ~ 0.00)

END OF REPORT
