In [34]:
import os
import gc
import json
import time
from datetime import timedelta, datetime
from collections import Counter

import numpy as np
import pandas as pd

import torch
from torch import amp
import torch.nn as nn
import torch.nn.functional as F

from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader  
from torch_geometric.nn import GATv2Conv
from torch_geometric.loader import LinkNeighborLoader
from torch_geometric.data import Data

from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator, Descriptors

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    accuracy_score, f1_score, roc_auc_score,
    classification_report, confusion_matrix
)

import matplotlib.pyplot as plt
import seaborn as sns

In [35]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"● DEVICE: {DEVICE}")
# Memory optimization settings for GPU, Aggressive memory management
if DEVICE.type == 'cuda':
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

    # Force garbage collection
    gc.collect()
    print(f"- GPU: {torch.cuda.get_device_name(0)}")
    print(f"- GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print("- ✓ Memory optimization enabled")

# Force Garbage Collection
print("● Cleaning GPU Memory...")
# Delete potential ghost variables if they exist
try:
    del model, optimizer, scaler, z, loss, out
except NameError:
    pass
gc.collect()
#  Clear CUDA Cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    # Set Allocator Strategy to reduce fragmentation
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    print(f"- GPU Memory Freed. Allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
    print(f"- Reserved: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
else:
    print("-Running on CPU (No CUDA clear needed)")

● DEVICE: cuda
- GPU: NVIDIA GeForce GTX 1650
- GPU Memory: 4.00 GB
- ✓ Memory optimization enabled
● Cleaning GPU Memory...
- GPU Memory Freed. Allocated: 0.00 MB
- Reserved: 0.00 MB


In [36]:
# Notebook / plotting style
sns.set(style="whitegrid")

# Device and basic GPU hygiene
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if device.type == "cuda":
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
    gc.collect()

# Timestamp for saving artifacts
TIMESTAMP = datetime.now().strftime("%d_%b_%H-%M")

# Ensure output folders exist
for folder in ["images", "models"]:
  os.makedirs(folder, exist_ok=True)

Using device: cuda


In [37]:
# Load DDI data

ddi_data = 'dataset/drugdata/ddis.csv'
drug_smiles_data = 'dataset/drugdata/drug_smiles.csv'
if ddi_data and drug_smiles_data:
    print("✓ Data loaded successfully")
def load_raw_data(ddi_path: str = ddi_data, smiles_path: str = drug_smiles_data):
    """Load raw DDI and SMILES tables from disk."""
    ddi_df = pd.read_csv(ddi_path)
    smiles_df = pd.read_csv(smiles_path)

    print("● DDI shape:", ddi_df.shape)
    print("● SMILES shape:", smiles_df.shape)
    print("● DDI columns:", ddi_df.columns.tolist())
    print("● SMILES columns:", smiles_df.columns.tolist())

    return ddi_df, smiles_df


def quick_sanity_report(ddi_df: pd.DataFrame, smiles_df: pd.DataFrame):
    """Print a compact sanity report (non‑destructive)."""
    print("\nⓘ QUICK SANITY REPORT:\n")

    # DDI
    print("[1. DDI Data]")
    print("● rows:", len(ddi_df))
    print("● nulls per column:", ddi_df.isnull().sum().to_dict())
    if {"d1", "d2", "type"}.issubset(ddi_df.columns):
        print("● unique types:", ddi_df["type"].nunique())
        print("● type distribution (top 10):")
        print(ddi_df["type"].value_counts().head(10))

    # SMILES
    print("\n[2. SMILES Data]")
    print("● rows:", len(smiles_df))
    print("● nulls per column:", smiles_df.isnull().sum().to_dict())

    # Overlap
    ddi_drugs = set(ddi_df["d1"]).union(set(ddi_df["d2"]))
    smiles_drugs = set(smiles_df["drug_id"]) if "drug_id" in smiles_df.columns else set()
    overlap = ddi_drugs & smiles_drugs
    print(f"\nOverlap: {len(overlap)}/{len(ddi_drugs)} DDI drugs have SMILES ({100*len(overlap)/len(ddi_drugs):.2f}%)")


# Load once (reused later)
ddi_df, smiles_df = load_raw_data()
quick_sanity_report(ddi_df, smiles_df)

✓ Data loaded successfully
● DDI shape: (191808, 4)
● SMILES shape: (1706, 2)
● DDI columns: ['d1', 'd2', 'type', 'Neg samples']
● SMILES columns: ['drug_id', 'smiles']

ⓘ QUICK SANITY REPORT:

[1. DDI Data]
● rows: 191808
● nulls per column: {'d1': 0, 'd2': 0, 'type': 0, 'Neg samples': 0}
● unique types: 86
● type distribution (top 10):
type
48    60751
46    34360
72    23779
74     9470
59     8397
69     7786
19     6140
15     5413
3      5011
5      3160
Name: count, dtype: int64

[2. SMILES Data]
● rows: 1706
● nulls per column: {'drug_id': 0, 'smiles': 0}

Overlap: 1706/1706 DDI drugs have SMILES (100.00%)


In [38]:
# 3. RDKit‑based drug feature extraction

class DrugFeatureExtractor:
    """SMILES → numeric feature vector using RDKit
    Features:
    - 1024 bit Morgan fingerprint (radius=2)
    - A small set of physicochemical descriptors
    """

    def __init__(self, fp_size: int = 1024, radius: int = 2):
        self.fp_size = fp_size
        self.radius = radius
        self._mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=radius, fpSize=fp_size)

    def smiles_to_features(self, smiles: str):
        mol = Chem.MolFromSmiles(str(smiles))
        if mol is None:
            return None

        # Fingerprint
        fp = self._mfpgen.GetFingerprint(mol)
        fp_array = np.array(fp, dtype=np.float32)

        # Simple descriptors
        desc = np.array([
            Descriptors.MolWt(mol),
            Descriptors.MolLogP(mol),
            Descriptors.NumHDonors(mol),
            Descriptors.NumHAcceptors(mol),
            Descriptors.TPSA(mol),
            Descriptors.NumRotatableBonds(mol),
            Descriptors.NumAromaticRings(mol),
            Descriptors.FractionCSP3(mol),
        ], dtype=np.float32)

        return np.concatenate([fp_array, desc])

    def build_feature_table(self, smiles_df: pd.DataFrame):
        """Return: dict {drug_id -> feature_vector}, plus the final feature dim."""
        features = {}
        failed = 0

        print("\nExtracting RDKit features from SMILES…")
        for i, row in smiles_df.iterrows():
            drug_id = row["drug_id"]
            feats = self.smiles_to_features(row["smiles"])
            if feats is None:
                failed += 1
                continue
            features[drug_id] = feats
            if (i + 1) % 200 == 0:
                print(f"● {i+1}/{len(smiles_df)} processed")

        print(f"✓ Features for {len(features)} drugs ({failed} failures)")
        feat_dim = len(next(iter(features.values())))
        return features, feat_dim


feature_extractor = DrugFeatureExtractor()
drug_features, feature_dim = feature_extractor.build_feature_table(smiles_df)
print("\nFeature dim:", feature_dim)


Extracting RDKit features from SMILES…
● 200/1706 processed
● 400/1706 processed
● 600/1706 processed
● 800/1706 processed
● 1000/1706 processed
● 1200/1706 processed
● 1400/1706 processed
● 1600/1706 processed
✓ Features for 1706 drugs (0 failures)

Feature dim: 1032


In [39]:
# 4. Graph construction with positive & negative edges
class DDIGraph:
    """Utility to construct a PyG style graph for DDI.

    - Nodes: drugs for which we have RDKit features.
    - Positive edges: from `ddis.csv` (observed interactions).
    - Negative edges: sampled using the `Neg samples` column when available;
      we treat these as non interacting pairs.

    Multi task targets per edge:
    - `y_binary`: 0/1 for interaction existence.
    - `y_type`: encoded interaction type for positives, -1 for negatives.
    """

    def __init__(self, ddi_df: pd.DataFrame, drug_features: dict):
        self.ddi_df = ddi_df
        self.drug_features = drug_features
        self.drug_to_idx = {}
        self.idx_to_drug = {}
        self.type_encoder = LabelEncoder()

    def _index_drugs(self):
        unique_drugs = sorted(self.drug_features.keys())
        self.drug_to_idx = {d: i for i, d in enumerate(unique_drugs)}
        self.idx_to_drug = {i: d for d, i in self.drug_to_idx.items()}
        return unique_drugs

    def build(self):
        unique_drugs = self._index_drugs()
        n_nodes = len(unique_drugs)

        # Node feature matrix
        feat_dim = len(next(iter(self.drug_features.values())))
        x = np.zeros((n_nodes, feat_dim), dtype=np.float32)
        for d, idx in self.drug_to_idx.items():
            x[idx] = self.drug_features[d]

        # Standardize features channel‑wise
        x_mean = x.mean(axis=0, keepdims=True)
        x_std = x.std(axis=0, keepdims=True) + 1e-8
        x = (x - x_mean) / x_std

        pos_edges = []
        pos_types = []
        neg_edges = []

        # Positive edges from DDI table
        for _, row in self.ddi_df.iterrows():
            d1, d2 = row["d1"], row["d2"]
            if d1 not in self.drug_to_idx or d2 not in self.drug_to_idx:
                continue
            u, v = self.drug_to_idx[d1], self.drug_to_idx[d2]
            # bidirectional
            pos_edges.extend([(u, v), (v, u)])
            pos_types.extend([row["type"], row["type"]])

            # Negatives from `Neg samples` if available
            if "Neg samples" in row and isinstance(row["Neg samples"], str):
                neg_tokens = [t for t in row["Neg samples"].split("$t") if t]
                for neg_d in neg_tokens:
                    if neg_d in self.drug_to_idx:
                        w = self.drug_to_idx[neg_d]
                        neg_edges.extend([(w, v), (v, w)])

        print(f"● Positive directed edges: {len(pos_edges)}")
        print(f"● Negative directed edges (from dataset hints): {len(neg_edges)}")

        # Encode interaction types (for positives only)
        pos_types_encoded = self.type_encoder.fit_transform(np.array(pos_types))
        n_types = len(self.type_encoder.classes_)
        print(f"● interaction types: {n_types}")

        # Build unified edge list and labels
        edge_src = []
        edge_dst = []
        y_binary = []
        y_type = []

        # positives
        for (u, v), t_enc in zip(pos_edges, pos_types_encoded):
            edge_src.append(u)
            edge_dst.append(v)
            y_binary.append(1)
            y_type.append(t_enc)

        # negatives (type unknown / not applicable)
        for (u, v) in neg_edges:
            edge_src.append(u)
            edge_dst.append(v)
            y_binary.append(0)
            y_type.append(-1)  # ignored in type loss

        edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long)
        y_binary = torch.tensor(y_binary, dtype=torch.long)
        y_type = torch.tensor(y_type, dtype=torch.long)

        x_tensor = torch.from_numpy(x).float()

        print(f"● Total edges: {edge_index.shape[1]} (pos+neg)")

        return {
            "x": x_tensor,
            "edge_index": edge_index,
            "y_binary": y_binary,
            "y_type": y_type,
            "n_nodes": n_nodes,
            "n_types": n_types,
            "type_encoder": self.type_encoder,
            "drug_to_idx": self.drug_to_idx,
            "idx_to_drug": self.idx_to_drug,
        }


graph_builder = DDIGraph(ddi_df, drug_features)
graph = graph_builder.build()

● Positive directed edges: 383616
● Negative directed edges (from dataset hints): 197610
● interaction types: 86
● Total edges: 581226 (pos+neg)


In [40]:
# 5. Train/val/test split at edge level
def split_edges(graph, val_size=0.15, test_size=0.15, random_state=42):
    n_edges = graph["edge_index"].shape[1]
    all_idx = np.arange(n_edges)

    train_idx, temp_idx = train_test_split(
        all_idx, test_size=val_size + test_size, random_state=random_state, shuffle=True
    )
    val_rel = val_size / (val_size + test_size)
    val_idx, test_idx = train_test_split(
        temp_idx, test_size=1 - val_rel, random_state=random_state, shuffle=True
    )

    masks = {}
    for name, idx in zip(["train", "val", "test"], [train_idx, val_idx, test_idx]):
        m = torch.zeros(n_edges, dtype=torch.bool)
        m[idx] = True
        masks[f"{name}_mask"] = m
        print(f"● {name.capitalize()} edges: {m.sum().item()} ({100*m.float().mean():.1f}% of total)")

    return masks


masks = split_edges(graph)
graph.update(masks)

● Train edges: 406858 (70.0% of total)
● Val edges: 87184 (15.0% of total)
● Test edges: 87184 (15.0% of total)


In [41]:
class GATv2DDI(nn.Module):
    """GATv2 for edge level DDI prediction with mutiple heads:
    - binary (interaction yes/no)
    - type (multi class), trained only on positive edges.

    This version factorizes into an encoder (node level) and a decoder
    (edge level) so we can cache node embeddings once per epoch during
    training and reuse them for many edge mini batches.
    """

    def __init__(self, in_dim, hidden_dim, n_heads, n_types, dropout=0.3):
        super().__init__()
        self.dropout = dropout

        self.gat1 = GATv2Conv(in_dim, hidden_dim, heads=n_heads, dropout=dropout, concat=True)
        self.gat2 = GATv2Conv(hidden_dim * n_heads, hidden_dim, heads=n_heads, dropout=dropout, concat=True)
        self.gat3 = GATv2Conv(hidden_dim * n_heads, hidden_dim, heads=1, dropout=dropout, concat=False)

        node_emb_dim = hidden_dim  # output of gat3

        # Edge MLP shared trunk
        self.edge_mlp = nn.Sequential(
            nn.Linear(node_emb_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        # Two task‑specific heads
        self.binary_head = nn.Linear(hidden_dim // 2, 2)
        self.type_head = nn.Linear(hidden_dim // 2, n_types)

    def encode(self, x, edge_index):
        """Compute node embeddings for the whole graph."""
        x = x.float()
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.gat2(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = F.elu(self.gat3(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x

    def edge_forward(self, node_emb, edge_index_supervised):
        """Edge level prediction from pre-computed node embeddings."""
        src, dst = edge_index_supervised
        h_src, h_dst = node_emb[src], node_emb[dst]
        h_edge = torch.cat([h_src, h_dst], dim=-1)
        h_shared = self.edge_mlp(h_edge)
        logits_bin = self.binary_head(h_shared)
        logits_type = self.type_head(h_shared)
        return logits_bin, logits_type

    def forward(self, x, edge_index, edge_index_supervised):
        """Standard full graph forward (kept for evaluation/inference)."""
        node_emb = self.encode(x, edge_index)
        return self.edge_forward(node_emb, edge_index_supervised)


In [42]:
# 7. Edge dataset and training utilities

class EdgeDataset(Dataset):
    def __init__(self, graph, mask_name: str):
        mask = graph[mask_name]
        self.edge_index = graph["edge_index"][:, mask]
        self.y_binary = graph["y_binary"][mask]
        self.y_type = graph["y_type"][mask]

    def __len__(self):
        return self.edge_index.shape[1]

    def __getitem__(self, idx):
        return (
            self.edge_index[:, idx],
            self.y_binary[idx],
            self.y_type[idx],
        )

def compute_binary_metrics(y_true, y_prob):
    y_pred = (y_prob[:, 1] >= 0.5).astype(int)
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred)
    try:
        auc = roc_auc_score(y_true, y_prob[:, 1])
    except ValueError:
        auc = float("nan")
    return {"acc": acc, "f1": f1, "auc": auc}


def compute_type_accuracy(y_type_true, y_type_pred):
    mask = y_type_true >= 0
    if mask.sum() == 0:
        return 0.0
    return accuracy_score(y_type_true[mask], y_type_pred[mask])

In [43]:
# 6. GATv2 multi‑task model (binary + type)

class GATv2DDI(nn.Module):
    """GATv2 for edge‑level DDI prediction with two heads:
    - binary (interaction yes/no)
    - type (multi‑class), trained only on positive edges
    """

    def __init__(self, in_dim, hidden_dim, n_heads, n_types, dropout=0.3):
        super().__init__()
        self.dropout = dropout

        self.gat1 = GATv2Conv(in_dim, hidden_dim, heads=n_heads, dropout=dropout, concat=True)
        self.gat2 = GATv2Conv(hidden_dim * n_heads, hidden_dim, heads=n_heads, dropout=dropout, concat=True)
        self.gat3 = GATv2Conv(hidden_dim * n_heads, hidden_dim, heads=1, dropout=dropout, concat=False)

        node_emb_dim = hidden_dim  # output of gat3

        # Edge MLP shared trunk
        self.edge_mlp = nn.Sequential(
            nn.Linear(node_emb_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
        )

        # Two task‑specific heads
        self.binary_head = nn.Linear(hidden_dim // 2, 2)
        self.type_head = nn.Linear(hidden_dim // 2, n_types)

    def forward(self, x, edge_index, edge_index_supervised):
        # Node encoder (full graph)
        x = x.float()
        x = F.elu(self.gat1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = F.elu(self.gat2(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)

        x = F.elu(self.gat3(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)

        # Edge representation
        src, dst = edge_index_supervised
        h_src, h_dst = x[src], x[dst]
        h_edge = torch.cat([h_src, h_dst], dim=-1)

        h_shared = self.edge_mlp(h_edge)
        logits_bin = self.binary_head(h_shared)
        logits_type = self.type_head(h_shared)

        return logits_bin, logits_type

In [44]:
def plot_history(history, save_path=None):
    epochs = range(1, len(history["train_loss"]) + 1)
    fig, axes = plt.subplots(1, 3, figsize=(18, 4))

    # Loss
    axes[0].plot(epochs, history["train_loss"], label="Train")
    axes[0].plot(epochs, history["val_loss"], label="Val")
    axes[0].set_title("Loss")
    axes[0].legend()

    # Binary F1
    axes[1].plot(epochs, history["train_bin_f1"], label="Train F1")
    axes[1].plot(epochs, history["val_bin_f1"], label="Val F1")
    axes[1].set_title("Binary F1")
    axes[1].legend()

    # Binary AUC
    axes[2].plot(epochs, history["val_bin_auc"], label="Val AUC")
    axes[2].set_title("Validation AUC")
    axes[2].legend()

    plt.tight_layout()
    if save_path is not None:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
        print(f"History plot saved to {save_path}")
    plt.show()


In [None]:
# 8. Training loop (modular, with history tracking)
def train_model(graph, config):
    """Full training loop. Returns (model, history dict)"""
    x = graph["x"].to(device)
    full_edge_index = graph["edge_index"].to(device)

    model = GATv2DDI(
        in_dim=x.shape[1],
        hidden_dim=config["hidden_dim"],
        n_heads=config["n_heads"],
        n_types=graph["n_types"],
        dropout=config["dropout"],
    ).to(device)

    optimizer = Adam(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
    scaler = amp.GradScaler(enabled=(device.type == "cuda"))

    # Losses
    ce_binary = nn.CrossEntropyLoss()
    ce_type = nn.CrossEntropyLoss(ignore_index=-1)

    # Datasets / loaders
    train_ds = EdgeDataset(graph, "train_mask")
    val_ds = EdgeDataset(graph, "val_mask")

    train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False)

    history = {
        "train_bin_acc": [], "train_bin_f1": [], "train_bin_auc": [],
        "val_bin_acc": [], "val_bin_f1": [], "val_bin_auc": [],
        "val_type_acc": [],
        "train_loss": [], "val_loss": [], "training_duration":[],
    }

    # Intializing some variables
    start_time = time.perf_counter()
    print(f"Model Training started at: {time.strftime('%H:%M:%S', time.localtime())}...\n")

    best_val_f1 = -1
    best_state = None
    patience_counter = 0

    for epoch in range(1, config["epochs"] + 1):
        model.train()
        epoch_losses = []
        all_yb_true, all_yb_prob = [], []

        for edges, yb, yt in train_loader:
            # DataLoader stacks edges as [batch, 2]; transpose to [2, batch]
            edges = edges.t().contiguous().to(device)
            yb = yb.to(device)
            yt = yt.to(device)

            optimizer.zero_grad(set_to_none=True)

            # Mixed precision on GPU; FP32 on CPU
            with amp.autocast("cuda", enabled=(device.type == "cuda")):
                logits_bin, logits_type = model(x, full_edge_index, edges)
                loss_bin = ce_binary(logits_bin, yb)
                loss_type = ce_type(logits_type, yt)
                loss = loss_bin + config["lambda_type"] * loss_type

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_losses.append(loss.item())

            probs_bin = torch.softmax(logits_bin.detach(), dim=1).cpu().numpy()
            all_yb_true.append(yb.cpu().numpy())
            all_yb_prob.append(probs_bin)

        all_yb_true = np.concatenate(all_yb_true)
        all_yb_prob = np.concatenate(all_yb_prob)
        train_metrics = compute_binary_metrics(all_yb_true, all_yb_prob)

        history["train_loss"].append(float(np.mean(epoch_losses)))
        history["train_bin_acc"].append(train_metrics["acc"])
        history["train_bin_f1"].append(train_metrics["f1"])
        history["train_bin_auc"].append(train_metrics["auc"])

        # Validation
        model.eval()
        val_losses = []
        vb_true, vb_prob = [], []
        vt_true, vt_pred = [], []

        with torch.no_grad():
            for edges, yb, yt in val_loader:
                edges = edges.t().contiguous().to(device)
                yb = yb.to(device)
                yt = yt.to(device)

                logits_bin, logits_type = model(x, full_edge_index, edges)
                loss_bin = ce_binary(logits_bin, yb)
                loss_type = ce_type(logits_type, yt)
                loss = loss_bin + config["lambda_type"] * loss_type
                val_losses.append(loss.item())

                probs_bin = torch.softmax(logits_bin, dim=1).cpu().numpy()
                vb_true.append(yb.cpu().numpy())
                vb_prob.append(probs_bin)

                vt_true.append(yt.cpu().numpy())
                vt_pred.append(torch.argmax(logits_type, dim=1).cpu().numpy())

        vb_true = np.concatenate(vb_true)
        vb_prob = np.concatenate(vb_prob)
        vb_metrics = compute_binary_metrics(vb_true, vb_prob)

        vt_true = np.concatenate(vt_true)
        vt_pred = np.concatenate(vt_pred)
        vt_acc = compute_type_accuracy(vt_true, vt_pred)

        history["val_loss"].append(float(np.mean(val_losses)))
        history["val_bin_acc"].append(vb_metrics["acc"])
        history["val_bin_f1"].append(vb_metrics["f1"])
        history["val_bin_auc"].append(vb_metrics["auc"])
        history["val_type_acc"].append(vt_acc)

        plot_history(history, save_path=f"images/DDI_GATv2_history_{TIMESTAMP}.png")

        # Early stopping on validation F1
        if vb_metrics["f1"] > best_val_f1 + 1e-4:
            best_val_f1 = vb_metrics["f1"]
            best_state = {k: v.cpu() for k, v in model.state_dict().items()}
            patience_counter = 0
            improved = "*"
        else:
            patience_counter += 1
            improved = ""

        if (epoch % 1 == 0):
            print(
                f"Epoch {epoch:03d} | "
                f"TRAIN: B_acc={train_metrics['acc']:.3f} F1={train_metrics['f1']:.3f} | "
                f"VALIDATION: B_acc={vb_metrics['acc']:.3f} F1={vb_metrics['f1']:.3f} "
                f"TypeAcc={vt_acc:.3f} ROC-AUC={vb_metrics['auc']:.3f} {improved}"
            )

        if patience_counter >= config["patience"]:
            print("Early stopping triggered.")
            total_seconds = end_time - start_time
            formatted_time = str(timedelta(seconds=int(total_seconds)))
            history["training_duration"].append(formatted_time)
            print(f"{epoch} epochs training complete in {history["training_duration"]} at {time.strftime('%H:%M:%S', time.localtime())}")
            break
    end_time = time.perf_counter()

    # 3. Calculate and show final time
    total_seconds = end_time - start_time
    formatted_time = str(timedelta(seconds=int(total_seconds))) 
    history["training_duration"].append(formatted_time)
    print(f"{config["epochs"]} epochs training complete in {history["training_duration"]} at {time.strftime('%H:%M:%S', time.localtime())}")

    if best_state is not None:
        model.load_state_dict(best_state)

    return model, history

In [46]:
# 9. History plotting and evaluation on test set
def evaluate_on_test(graph, model):
    x = graph["x"].to(device)
    full_edge_index = graph["edge_index"].to(device)

    test_ds = EdgeDataset(graph, "test_mask")
    test_loader = DataLoader(test_ds, batch_size=4096, shuffle=False)

    ce_binary = nn.CrossEntropyLoss()
    ce_type = nn.CrossEntropyLoss(ignore_index=-1)

    model.eval()
    test_losses = []
    yb_true, yb_prob = [], []
    yt_true, yt_pred = [], []

    with torch.no_grad():
        for edges, yb, yt in test_loader:
            # [batch, 2] -> [2, batch]
            edges = edges.t().contiguous().to(device)
            yb = yb.to(device)
            yt = yt.to(device)

            logits_bin, logits_type = model(x, full_edge_index, edges)
            loss_bin = ce_binary(logits_bin, yb)
            loss_type = ce_type(logits_type, yt)
            loss = loss_bin + loss_type
            test_losses.append(loss.item())

            probs_bin = torch.softmax(logits_bin, dim=1).cpu().numpy()
            yb_true.append(yb.cpu().numpy())
            yb_prob.append(probs_bin)

            yt_true.append(yt.cpu().numpy())
            yt_pred.append(torch.argmax(logits_type, dim=1).cpu().numpy())

    yb_true = np.concatenate(yb_true)
    yb_prob = np.concatenate(yb_prob)
    bin_metrics = compute_binary_metrics(yb_true, yb_prob)

    yt_true = np.concatenate(yt_true)
    yt_pred = np.concatenate(yt_pred)
    type_acc = compute_type_accuracy(yt_true, yt_pred)

    print("\nModel Evaluation Summary:")
    print(f"● Binary: acc={bin_metrics['acc']:.3f}, f1={bin_metrics['f1']:.3f}, auc={bin_metrics['auc']:.3f}")
    print(f"● Type accuracy (positives only): {type_acc:.3f}")

    return {
        "bin_metrics": bin_metrics,
        "type_acc": type_acc,
        "y_binary_true": yb_true,
        "y_binary_prob": yb_prob,
        "y_type_true": yt_true,
        "y_type_pred": yt_pred,
    }

In [None]:
# 10. High‑level training run (config)

config = {
    "hidden_dim": 128,
    "n_heads": 4,
    "dropout": 0.3,
    "lr": 1e-3,
    "weight_decay": 5e-4,
    "batch_size": 2048,
    "epochs": 4, # currently 5 only for speed testing purpose. 
    "patience": 20,
    "lambda_type": 1.0,  # weight for type loss vs binary loss
}

model, history = train_model(graph, config)

plot_history(history, save_path=f"images/DDI_GATv2_history_{TIMESTAMP}.png")


Model Training started at: 20:14:32


In [None]:
test_results = evaluate_on_test(graph, model)

In [None]:
def predict_interaction(drug_a: str, drug_b: str, model, graph, threshold: float = 0.5):
    """Predict interaction (0/1), probability, and type for a pair of DrugBank IDs.

    Returns a dict with:
    - binary_label (0/1)
    - binary_prob (probability of interaction=1)
    - type_index (internal index or None)
    - type_original (original `type` value from DDI, if available)
    """
    model.eval()
    x = graph["x"].to(device)
    full_edge_index = graph["edge_index"].to(device)

    if drug_a not in graph["drug_to_idx"] or drug_b not in graph["drug_to_idx"]:
        raise ValueError("One or both drugs are unknown to the graph (no SMILES/features).")

    u = graph["drug_to_idx"][drug_a]
    v = graph["drug_to_idx"][drug_b]

    edge = torch.tensor([[u], [v]], dtype=torch.long, device=device)

    with torch.no_grad():
        logits_bin, logits_type = model(x, full_edge_index, edge)
        prob_bin = torch.softmax(logits_bin, dim=1).cpu().numpy()[0]
        yb_prob = float(prob_bin[1])
        yb_label = int(yb_prob >= threshold)

        # Type prediction (best guess; meaningful mainly if binary says interact)
        type_idx = int(torch.argmax(logits_type, dim=1).cpu().item())
        type_original = int(graph["type_encoder"].inverse_transform([type_idx])[0])

    return {
        "drug_a": drug_a,
        "drug_b": drug_b,
        "binary_label": yb_label,
        "binary_prob": yb_prob,
        "type_index": type_idx,
        "type_original": type_original,
    }


# Example usage (adjust drug IDs to ones that exist in your CSVs):
# result = predict_interaction("DB04571", "DB00460", model, graph)
# print(result)