GraphRX" methodology: Molecular Encoder + GNN Topology Learning
| Feature | Specification |
| :--- | :--- |
| **Problem Type** | Link Prediction |
| **Classes** | 2 (Binary: safe/interact) |
| **Loss Function** | `BCEWithLogitsLoss` |
| **Architecture** | PyTorch Geometric `GATv2Conv` |
| **Training Edges** | 97,028 positive edges + negatives |
| **Complexity** | Lower |
| **Training Time** | ~15 min for 150 epochs |

```bash
# PyTorch Geometric - LINK PREDICTION
class ModelArchitecture:
    - Uses GATv2Conv (optimized C++ backend)
    - 3 layers: 128 ‚Üí 64 ‚Üí 64
    - Dot product decoder: z[src] ¬∑ z[dst] ‚Üí probability
    - Binary output: safe (0) or interaction (1)
```

**Why it's fast:**

-  Optimized library - GATv2Conv is highly optimized
- Binary classification - only 2 classes (much simpler)
-  Smaller model - 128‚Üí64‚Üí64 vs 1032‚Üí256‚Üí256‚Üí128
-  Dot product decoder - simple operation
-  Negative sampling - doesn't process all edges

**Strengths:**

-  Fast training
-  Simple deployment
-  Good for screening (safe vs unsafe)
-  Industry-standard approach

### For EACH epoch:
1. Single forward pass through entire graph (GATv2Conv optimized)
2. Sample negative edges (fast operation)
3. Compute dot products for pos/neg edges
4. Binary loss (BCEWithLogitsLoss)
5. Single backward pass

# Per epoch: 1 pass √ó optimized ops = FAST

In [7]:
import os
import json
import gc
from datetime import datetime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, f1_score, classification_report
from sklearn.model_selection import train_test_split
import torch
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv, BatchNorm
from rdkit import Chem
from rdkit.Chem import AllChem, rdFingerprintGenerator
import torch.nn.functional as F
from safetensors.torch import save_file



### 2. System Setup

In [8]:
torch.manual_seed(42)

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"DEVICE: {DEVICE}")
# Memory optimization settings for GPU, Aggressive memory management
if DEVICE.type == 'cuda':
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    
    # Force garbage collection
    gc.collect()
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print("‚úì Memory optimization enabled")

# specific style date time saving
timestamp = datetime.now().strftime("%d_%b_%H-%M")

# validating required directories:
required_directories = ['images', 'models']
for folder in required_directories:
    if not os.path.exists(folder):
        print(f"‚úò Directory `{folder}/` not found  Creating...")
        os.makedirs(folder)
    else:
        print(f"‚úì Directory `{folder}/` exists ")

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")


DEVICE: cuda
GPU: NVIDIA GeForce GTX 1650
GPU Memory: 4.00 GB
‚úì Memory optimization enabled
‚úì Directory `images/` exists 
‚úì Directory `models/` exists 


### 2.2 Custom Configurations

In [13]:
# Config matches the paper's "Safe Pairing" focus
CONFIG = {
    'MODEL_PATH': 'models/Gemini_AushadiNet_GATv2_128',
    'DATA_PATH': 'dataset/drugdata', 
    'INTERACTION_FILE': 'ddis.csv',
    'SMILES_FILE': 'drug_smiles.csv',
    'NODE_DIM': 128,      # Paper uses 128-bit embeddings
    'HIDDEN_DIM': 128,     # Hidden layers
    'OUTPUT_DIM': 64,     # Latent space
    'HEADS': 4,           # Multi-head attention (Critical for accuracy)
    'DROPOUT': 0.2,       # Regularization
    'LR': 0.005,
    'EPOCHS': 150,        # Extended training for convergence
    'DEVICE': torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

print(f"‚úì Customized Configurations Initialized on {DEVICE}")

‚úì Customized Configurations Initialized on cuda


### 3. Data Loading Preparation

In [9]:
# Load DDI data
print("EXPLORING DRUG-DRUG INTERACTION DATASET:\n")

# Load DDI interactions
ddi_df = pd.read_csv('dataset/drugdata/ddis.csv')
print(f"DDI Dataset Shape: {ddi_df.shape}")
print(f"Columns: {ddi_df.columns.tolist()}")
print(f"\nüî¨ Interaction Types Distribution:")
print(ddi_df['type'].value_counts())
print(f"\nüìã Sample DDI Data:")
print(ddi_df.head(10))

# Load drug SMILES
smiles_df = pd.read_csv('dataset/drugdata/drug_smiles.csv')
print(f"\nüíä Drug SMILES Dataset Shape: {smiles_df.shape}")
print(f"Columns: {smiles_df.columns.tolist()}")
print(f"\nüìã Sample SMILES Data:")
print(smiles_df.head(10))

# Get unique drugs
unique_drugs_ddi = set(ddi_df['d1'].unique()) | set(ddi_df['d2'].unique())
print(f"\nüìà Statistics:")
print(f"‚Ä¢ Total DDI pairs: {len(ddi_df)}")
print(f"‚Ä¢ Unique drugs in DDI: {len(unique_drugs_ddi)}")
print(f"‚Ä¢ Drugs with SMILES: {len(smiles_df)}")
print(f"‚Ä¢ Interaction type 0: {(ddi_df['type'] == 0).sum()}")
print(f"‚Ä¢ Interaction type 1: {(ddi_df['type'] == 1).sum()}")

# Check overlap
drugs_with_smiles = set(smiles_df['drug_id'].unique())
overlap = unique_drugs_ddi & drugs_with_smiles
print(f"‚Ä¢ Drugs with both DDI and SMILES: {len(overlap)}")
print(f"‚Ä¢ Coverage: {len(overlap)/len(unique_drugs_ddi)*100:.2f}%")

EXPLORING DRUG-DRUG INTERACTION DATASET:

DDI Dataset Shape: (191808, 4)
Columns: ['d1', 'd2', 'type', 'Neg samples']

üî¨ Interaction Types Distribution:
type
48    60751
46    34360
72    23779
74     9470
59     8397
      ...  
42       11
61       11
51       10
25        7
41        6
Name: count, Length: 86, dtype: int64

üìã Sample DDI Data:
        d1       d2  type Neg samples
0  DB04571  DB00460     0   DB01579$t
1  DB00855  DB00460     0   DB01178$t
2  DB09536  DB00460     0   DB06626$t
3  DB01600  DB00460     0   DB01588$t
4  DB09000  DB00460     0   DB06196$t
5  DB11630  DB00460     0   DB00744$t
6  DB00553  DB00460     0   DB06413$t
7  DB06261  DB00460     0   DB00876$t
8  DB01878  DB00460     0   DB09267$t
9  DB00140  DB00460     0   DB01204$t

üíä Drug SMILES Dataset Shape: (1706, 2)
Columns: ['drug_id', 'smiles']

üìã Sample SMILES Data:
   drug_id                                             smiles
0  DB04571                CC1=CC2=CC3=C(OC(=O)C=C3C)C(C)=C2O1
1  D

In [14]:
# --- 2. ROBUST DATA LOADER ---
class MulticlassDataLoader:
    def __init__(self, config):
        self.config = config
        self.drug_map = {} 
        self.label_encoder = LabelEncoder()
        # Initialize Fingerprint Generator once
        self.fp_gen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=config['NODE_DIM'])

    def get_molecular_features(self, smiles):
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None: 
                return np.zeros((self.config['NODE_DIM'],), dtype=np.float32)
            return self.fp_gen.GetFingerprintAsNumPy(mol).astype(np.float32)
        except:
            return np.zeros((self.config['NODE_DIM'],), dtype=np.float32)

    def load_data(self):
        print("üì• Loading datasets...")
        
        ddi_path = os.path.join(self.config['DATA_PATH'], self.config['INTERACTION_FILE'])
        smiles_path = os.path.join(self.config['DATA_PATH'], self.config['SMILES_FILE'])

        # 1. Load Interaction File (Robust Check)
        if not os.path.exists(ddi_path):
            raise FileNotFoundError(f"‚ùå File not found: {ddi_path}. Please ensure 'ddis.csv' is in the dataset folder.")
            
        # We assume ddis.csv is comma-separated based on previous context
        ddi_df = pd.read_csv(ddi_path)
        
        # Validation: Check if 'type' column exists for Multiclass
        if 'type' not in ddi_df.columns:
            # Fallback for TSV/No-Header files (Not recommended for Multiclass, but prevents crash)
            print("‚ö†Ô∏è Warning: 'type' column missing! Attempting to read as headerless TSV...")
            ddi_df = pd.read_csv(ddi_path, sep='\t', names=['d1', 'd2'])
            if 'type' not in ddi_df.columns:
                raise ValueError("‚ùå CRITICAL ERROR: The dataset provided does not have a 'type' column. "
                                 "You cannot perform Multiclass Classification without interaction types. "
                                 "Please use 'ddis.csv' instead of the TSV file.")

        smiles_df = pd.read_csv(smiles_path)
        
        # 2. Map Drug IDs -> Integer Indices
        all_drugs = set(ddi_df['d1']).union(set(ddi_df['d2'])).union(set(smiles_df['drug_id']))
        self.drug_map = {d: i for i, d in enumerate(all_drugs)}
        num_nodes = len(all_drugs)
        
        # 3. Create Node Features (X)
        print(f"‚öóÔ∏è Generating features for {num_nodes} drugs...")
        x = np.zeros((num_nodes, self.config['NODE_DIM']), dtype=np.float32)
        smiles_dict = dict(zip(smiles_df.drug_id, smiles_df.smiles))
        
        for drug_id, idx in self.drug_map.items():
            if drug_id in smiles_dict:
                x[idx] = self.get_molecular_features(smiles_dict[drug_id])
        x = torch.tensor(x, dtype=torch.float)

        # 4. Process Edges & Labels
        print("üè∑Ô∏è Encoding Interaction Types...")
        
        # Filter valid drugs
        valid_mask = ddi_df['d1'].isin(self.drug_map) & ddi_df['d2'].isin(self.drug_map)
        clean_df = ddi_df[valid_mask].copy()

        # Encode the 'type' column (e.g., 48 -> 0, 72 -> 1)
        clean_df['encoded_type'] = self.label_encoder.fit_transform(clean_df['type'])
        num_classes = len(self.label_encoder.classes_)
        
        # Build Edge Index
        src = [self.drug_map[d] for d in clean_df['d1']]
        dst = [self.drug_map[d] for d in clean_df['d2']]
        edge_index = torch.tensor([src, dst], dtype=torch.long)
        
        # Build Labels
        edge_attr = torch.tensor(clean_df['encoded_type'].values, dtype=torch.long)

        print(f"‚úÖ Data Ready: {len(clean_df)} interactions, {num_classes} unique interaction types.")
        
        data = Data(x=x, edge_index=edge_index, y=edge_attr)
        return data, self.drug_map, num_classes, self.label_encoder

# Run Loader
loader = MulticlassDataLoader(CONFIG)
data, drug_map, num_classes, label_encoder = loader.load_data()

üì• Loading datasets...
‚öóÔ∏è Generating features for 1706 drugs...
üè∑Ô∏è Encoding Interaction Types...
‚úÖ Data Ready: 191808 interactions, 86 unique interaction types.


### Model Architecture

In [15]:
class AushadhiNetMulticlass(torch.nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_classes, heads=4, dropout=0.2):
        super().__init__()
        
        # --- ENCODER (The Graph Brain) ---
        # Learns "Who is connected to whom" and "What are they made of"
        self.conv1 = GATv2Conv(in_dim, hidden_dim, heads=heads, dropout=dropout, concat=True)
        self.bn1 = BatchNorm(hidden_dim * heads)
        
        self.conv2 = GATv2Conv(hidden_dim * heads, hidden_dim, heads=heads, dropout=dropout, concat=True)
        self.bn2 = BatchNorm(hidden_dim * heads)
        
        self.conv3 = GATv2Conv(hidden_dim * heads, out_dim, heads=1, dropout=dropout, concat=False)
        self.skip = torch.nn.Linear(in_dim, out_dim)

        # --- DECODER (The Classifier Head) ---
        # Instead of dot product, we Concatenate embeddings -> MLP -> Softmax
        # Input dim is out_dim * 2 because we concat Drug A and Drug B
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(out_dim * 2, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(hidden_dim, hidden_dim // 2),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim // 2, num_classes) # Outputs logits for 86 classes
        )

    def encode(self, x, edge_index):
        # 1. Graph Processing
        identity = self.skip(x)
        x = F.elu(self.bn1(self.conv1(x, edge_index)))
        x = F.elu(self.bn2(self.conv2(x, edge_index)))
        x = self.conv3(x, edge_index)
        return x + identity

    def decode(self, z, edge_index):
        # 2. Extract Embeddings for Source and Dest nodes
        src_emb = z[edge_index[0]]
        dst_emb = z[edge_index[1]]
        
        # 3. Concatenate Features (Drug A || Drug B)
        edge_feat = torch.cat([src_emb, dst_emb], dim=1)
        
        # 4. Classify Interaction Type
        return self.classifier(edge_feat)

model = AushadhiNetMulticlass(
    in_dim=CONFIG['NODE_DIM'],
    hidden_dim=CONFIG['HIDDEN_DIM'],
    out_dim=CONFIG['OUTPUT_DIM'],
    num_classes=num_classes,
    heads=CONFIG['HEADS']
).to(CONFIG['DEVICE'])

print("üß† Architecture Upgraded: MLP Decoder for Multiclass Prediction")
print(model)

üß† Architecture Upgraded: MLP Decoder for Multiclass Prediction
AushadhiNetMulticlass(
  (conv1): GATv2Conv(128, 128, heads=4)
  (bn1): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): GATv2Conv(512, 128, heads=4)
  (bn2): BatchNorm(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): GATv2Conv(512, 64, heads=1)
  (skip): Linear(in_features=128, out_features=64, bias=True)
  (classifier): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=128, out_features=64, bias=True)
    (4): ReLU()
    (5): Linear(in_features=64, out_features=86, bias=True)
  )
)


  return t.to(


### Training Pipeline

In [16]:
# Setup Optimizer & Loss
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG['LR'], weight_decay=1e-4)
criterion = torch.nn.CrossEntropyLoss() # Multiclass Loss


In [17]:

# Split Data (Train on known interactions to learn types)
# We split the *edges* themselves
num_edges = data.edge_index.size(1)
perm = torch.randperm(num_edges)

train_size = int(0.8 * num_edges)
val_size = int(0.1 * num_edges)

train_idx = perm[:train_size]
val_idx = perm[train_size:train_size + val_size]
test_idx = perm[train_size + val_size:]

print("üöÄ Starting Multiclass Training Protocol...")
history = {'loss': [], 'acc': []}

for epoch in range(1, CONFIG['EPOCHS'] + 1):
    model.train()
    optimizer.zero_grad()
    
    # 1. Get Node Embeddings (using whole graph structure)
    z = model.encode(data.x.to(CONFIG['DEVICE']), data.edge_index.to(CONFIG['DEVICE']))
    
    # 2. Predict Classes for Training Edges
    # We only train on the edges in the training set
    train_edges = data.edge_index[:, train_idx].to(CONFIG['DEVICE'])
    train_labels = data.y[train_idx].to(CONFIG['DEVICE'])
    
    out = model.decode(z, train_edges)
    
    # 3. Compute Loss
    loss = criterion(out, train_labels)
    loss.backward()
    optimizer.step()
    
    # 4. Validation
    if epoch % 10 == 0:
        model.eval()
        with torch.no_grad():
            val_edges = data.edge_index[:, val_idx].to(CONFIG['DEVICE'])
            val_labels = data.y[val_idx].to(CONFIG['DEVICE'])
            
            val_out = model.decode(z, val_edges)
            # Get predicted class (argmax)
            preds = val_out.argmax(dim=1)
            
            acc = accuracy_score(val_labels.cpu(), preds.cpu())
            history['loss'].append(loss.item())
            history['acc'].append(acc)
            
            print(f"Epoch {epoch:03d} | Loss: {loss:.4f} | Val Accuracy: {acc:.2%}")

üöÄ Starting Multiclass Training Protocol...
Epoch 010 | Loss: 2.6559 | Val Accuracy: 34.14%
Epoch 020 | Loss: 2.5450 | Val Accuracy: 31.85%
Epoch 030 | Loss: 2.3811 | Val Accuracy: 37.46%
Epoch 040 | Loss: 2.2295 | Val Accuracy: 40.05%
Epoch 050 | Loss: 1.9998 | Val Accuracy: 44.51%
Epoch 060 | Loss: 1.8692 | Val Accuracy: 46.99%
Epoch 070 | Loss: 1.7110 | Val Accuracy: 50.65%
Epoch 080 | Loss: 1.5667 | Val Accuracy: 56.51%
Epoch 090 | Loss: 1.4931 | Val Accuracy: 56.92%
Epoch 100 | Loss: 1.4527 | Val Accuracy: 57.82%
Epoch 110 | Loss: 1.2802 | Val Accuracy: 63.66%
Epoch 120 | Loss: 1.1785 | Val Accuracy: 65.66%
Epoch 130 | Loss: 1.1167 | Val Accuracy: 67.15%
Epoch 140 | Loss: 1.0545 | Val Accuracy: 69.01%
Epoch 150 | Loss: 0.9840 | Val Accuracy: 71.45%


In [None]:
# 1. Final Evaluation
model.eval()
z = model.encode(data.x.to(CONFIG['DEVICE']), data.edge_index.to(CONFIG['DEVICE']))
test_edges = data.edge_index[:, test_idx].to(CONFIG['DEVICE'])
test_labels = data.y[test_idx].to(CONFIG['DEVICE'])

logits = model.decode(z, test_edges)
preds = logits.argmax(dim=1)

print("\nüèÜ Final Test Performance:")
print(f"Accuracy: {accuracy_score(test_labels.cpu(), preds.cpu()):.4f}")
# Macro F1 is better for imbalanced datasets (like medical data)
print(f"Macro F1 Score: {f1_score(test_labels.cpu(), preds.cpu(), average='macro'):.4f}")

# 2. Save Weights (Safetensors)
tensor_path = f"{CONFIG['MODEL_PATH']}.safetensors"
save_file(model.state_dict(), tensor_path)
print(f"üîí Weights saved: {tensor_path}")

# 3. Save Config & Class Mapping
json_config = CONFIG.copy()
json_config['DEVICE'] = str(json_config['DEVICE'])
json_config['NUM_CLASSES'] = num_classes

metadata = {
    'config': json_config,
    'drug_map': drug_map,
    'class_mapping': label_encoder.classes_.tolist() # Stores [48, 72, 46...]
}

json_path = f"{CONFIG['MODEL_PATH']}_config.json"
with open(json_path, 'w') as f:
    json.dump(metadata, f, indent=4)

print(f"üìú Metadata & Class Mappings saved: {json_path}")
print("‚úÖ AushadhiNet Multiclass Upgrade Complete.")


üèÜ Final Test Performance:
Accuracy: 0.7117
Macro F1 Score: 0.2824
üîí Weights saved: models/Gemini_AushadiNet_GATv2_128.safetensors
üìú Metadata & Class Mappings saved: models/Gemini_AushadiNet_GATv2_128_config.json
‚úÖ AushadhiNet Multiclass Upgrade Complete.


In [26]:
accuracy_score(test_labels.cpu(), preds.cpu())

0.7117088937545616