In [None]:
# Includes:
# GraphSAGE encoder
# Drug+protein MLP fusion
# Binary classifier
# Basic training loop

In [None]:
# Define the dataset
from torch_geometric.data import Dataset, DataLoader
from torch_geometric.nn import SAGEConv, global_mean_pool
import pandas as pd, torch, os
from torch.utils.data import random_split
import torch.nn as nn, torch.nn.functional as F
from sklearn.metrics import roc_auc_score
import numpy as np
from collections import Counter

class DrugTargetDataset(Dataset):
    def __init__(self, pair_file, graph_dir, protein_csv):
        self.pairs = pd.read_csv(pair_file)
        self.graph_dir = graph_dir
        self.protein_df = pd.read_csv(protein_csv).set_index("sequence_id")

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        row = self.pairs.iloc[idx]
        drug_id = row["drug_id"]
        target_id = row["target_id"]
        label = torch.tensor([row["label"]], dtype=torch.float)

        # Load graph
        graph = torch.load(os.path.join(self.graph_dir, f"{drug_id}.pt"))

        # Get protein embedding
        protein_vec = torch.tensor(self.protein_df.loc[target_id].values, dtype=torch.float)

        return graph, protein_vec, label


In [None]:
# Define the GraphSAGE model architecture with enhanced layers and dropout
class DrugTargetModel(nn.Module):
    def __init__(self, in_channels, hidden_dim, protein_dim):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_dim)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim)
        self.conv3 = SAGEConv(hidden_dim, hidden_dim)  # extra layer for better graph learning
        
        self.dropout = nn.Dropout(p=0.3)
        self.lin1 = nn.Linear(hidden_dim + protein_dim, 128)
        self.lin2 = nn.Linear(128, 1)

    def forward(self, graph, protein_vec, batch, return_logits=False):
        x = F.relu(self.conv1(graph.x, graph.edge_index))
        x = self.dropout(x)
        x = F.relu(self.conv2(x, graph.edge_index))
        x = self.dropout(x)
        x = F.relu(self.conv3(x, graph.edge_index))
        x = global_mean_pool(x, batch)  # graph-level representation

        x = torch.cat([x, protein_vec], dim=1)
        x = self.dropout(F.relu(self.lin1(x)))
        logits = self.lin2(x)
        
        # Return logits for BCEWithLogitsLoss or probabilities for evaluation
        if return_logits:
            return logits
        return torch.sigmoid(logits)


In [None]:
# Early Stopping Class
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0.001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_score = None
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_score, model):
        if self.best_score is None:
            self.best_score = val_score
            self.save_checkpoint(model)
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                if self.restore_best_weights:
                    model.load_state_dict(self.best_weights)
                return True
        else:
            self.best_score = val_score
            self.counter = 0
            self.save_checkpoint(model)
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

# Training loop with enhanced evaluation using ROC-AUC
def train(model, loader, optimizer, criterion):
    model.train()
    total_loss = 0
    for data, protein_vec, label in loader:
        optimizer.zero_grad()
        # Use logits for BCEWithLogitsLoss
        if isinstance(criterion, nn.BCEWithLogitsLoss):
            out = model(data, protein_vec, data.batch, return_logits=True)
        else:
            out = model(data, protein_vec, data.batch)
        loss = criterion(out, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader):
    """
    Evaluate model using ROC-AUC score for better performance measurement
    on potentially imbalanced datasets
    """
    model.eval()
    all_labels = []
    all_probs = []
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, protein_vec, label in loader:
            out = model(data, protein_vec, data.batch)  # Always get probabilities for evaluation
            all_probs.extend(out.cpu().numpy())
            all_labels.extend(label.cpu().numpy())
            
            # Also calculate accuracy for reference
            preds = (out > 0.5).float()
            correct += (preds == label).sum().item()
            total += label.size(0)
    
    # Calculate ROC-AUC score and accuracy
    auc = roc_auc_score(all_labels, all_probs)
    accuracy = correct / total
    return auc, accuracy

def calculate_class_weights(dataset):
    """Calculate class weights for handling imbalanced datasets"""
    labels = [dataset[i][2].item() for i in range(len(dataset))]
    class_counts = Counter(labels)
    total_samples = len(labels)
    
    # Calculate weights inversely proportional to class frequency
    weights = {}
    for class_label, count in class_counts.items():
        weights[class_label] = total_samples / (len(class_counts) * count)
    
    print(f"Dataset class distribution: {dict(class_counts)}")
    print(f"Calculated class weights: {weights}")
    return weights


In [None]:
# Enhanced training script with all advanced features
if __name__ == "__main__":
    # Load dataset
    dataset = DrugTargetDataset(
        "data/step6_training_pairs.csv",
        "data/graphs/",
        "data/step4_protein_onehot.csv"
    )
    
    print(f"📊 Dataset loaded: {len(dataset)} samples")
    
    # Calculate class weights for handling imbalance
    class_weights = calculate_class_weights(dataset)
    
    # Create weighted loss function
    pos_weight = torch.tensor([class_weights.get(1.0, 1.0) / class_weights.get(0.0, 1.0)])
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    print(f"🎯 Using weighted BCE loss with pos_weight: {pos_weight.item():.3f}")

    # Split dataset
    train_len = int(0.8 * len(dataset))
    train_set, test_set = random_split(dataset, [train_len, len(dataset) - train_len])
    train_loader = DataLoader(train_set, batch_size=16, shuffle=True)
    test_loader = DataLoader(test_set, batch_size=16)
    
    print(f"📚 Training samples: {len(train_set)}")
    print(f"🧪 Test samples: {len(test_set)}")

    # Initialize model and training components
    model = DrugTargetModel(in_channels=6, hidden_dim=128, protein_dim=20)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    
    # Initialize early stopping
    early_stopping = EarlyStopping(patience=5, min_delta=0.001, restore_best_weights=True)
    
    # Training tracking
    best_auc = 0.0
    training_history = {
        'epoch': [],
        'train_loss': [],
        'test_auc': [],
        'test_accuracy': []
    }

    print("\n🚀 Starting enhanced training...")
    print("📈 Metrics: Loss (↓), ROC-AUC (↑ 1.0=perfect), Accuracy (↑)")
    print("🛑 Early stopping: patience=5, min_delta=0.001")
    print("=" * 80)
    
    for epoch in range(1, 51):  # Increased max epochs
        # Training phase
        train_loss = train(model, train_loader, optimizer, criterion)
        
        # Evaluation phase
        test_auc, test_accuracy = evaluate(model, test_loader)
        
        # Store history
        training_history['epoch'].append(epoch)
        training_history['train_loss'].append(train_loss)
        training_history['test_auc'].append(test_auc)
        training_history['test_accuracy'].append(test_accuracy)
        
        # Print progress
        print(f"Epoch {epoch:02d} | Loss: {train_loss:.4f} | AUC: {test_auc:.4f} | Acc: {test_accuracy:.4f}", end="")
        
        # Check for best model
        if test_auc > best_auc:
            best_auc = test_auc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'best_auc': best_auc,
                'training_history': training_history
            }, 'best_drug_target_model.pth')
            print(" 🌟 NEW BEST!")
        else:
            print()
        
        # Early stopping check
        if early_stopping(test_auc, model):
            print(f"\n🛑 Early stopping triggered at epoch {epoch}")
            print(f"🔄 Restored best weights from epoch with AUC: {early_stopping.best_score:.4f}")
            break
    
    print("=" * 80)
    print(f"✅ Training completed!")
    print(f"🏆 Best ROC-AUC achieved: {best_auc:.4f}")
    print(f"💾 Best model saved as: 'best_drug_target_model.pth'")
    
    # Final evaluation on test set with best model
    final_auc, final_accuracy = evaluate(model, test_loader)
    print(f"🎯 Final test performance:")
    print(f"   ROC-AUC: {final_auc:.4f}")
    print(f"   Accuracy: {final_accuracy:.4f}")
    
    # Save training history
    import json
    with open('training_history.json', 'w') as f:
        json.dump(training_history, f, indent=2)
    print(f"📊 Training history saved as: 'training_history.json'")


In [None]:
# Optional: Visualize training progress (uncomment to use)
"""
import matplotlib.pyplot as plt
import json

# Load training history
with open('training_history.json', 'r') as f:
    history = json.load(f)

# Create plots
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))

# Plot 1: Training Loss
ax1.plot(history['epoch'], history['train_loss'], 'b-', linewidth=2)
ax1.set_title('Training Loss Over Time')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True, alpha=0.3)

# Plot 2: ROC-AUC
ax2.plot(history['epoch'], history['test_auc'], 'g-', linewidth=2)
ax2.axhline(y=0.5, color='r', linestyle='--', alpha=0.5, label='Random (0.5)')
ax2.set_title('ROC-AUC Score Over Time')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('AUC')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Accuracy
ax3.plot(history['epoch'], history['test_accuracy'], 'orange', linewidth=2)
ax3.set_title('Test Accuracy Over Time')
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Accuracy')
ax3.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('training_progress.png', dpi=300, bbox_inches='tight')
plt.show()

print("📈 Training visualization saved as 'training_progress.png'")
"""