# Graph Attention Networks (GAT) - Complete Professional Implementation

This notebook provides a comprehensive implementation of Graph Attention Networks (GAT) based on the paper "Graph Attention Networks" by Veliƒçkoviƒá et al.

### Key Features:
- ‚úÖ Multi-head attention mechanism
- ‚úÖ Professional training pipeline with early stopping
- ‚úÖ Comprehensive evaluation metrics
- ‚úÖ Visualization capabilities
- ‚úÖ Error handling and logging

### Architecture Overview:
GATs work on graph data where nodes represent entities and edges represent relationships. The attention mechanism allows nodes to focus on the most relevant neighbors, similar to transformers but adapted for graph structures.

## 1. Importing Required Libraries

In [None]:
# Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from safetensors.torch import save_file, load_file
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_auc_score, f1_score
from sklearn.preprocessing import LabelEncoder
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator, Descriptors
import networkx as nx
from datetime import datetime
import time
import warnings
import os
import json

In [None]:
warnings.filterwarnings('ignore')


# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Memory optimization settings for GPU, Aggressive memory management
if device.type == 'cuda':
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    
    # Force garbage collection
    import gc
    gc.collect()
    
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print("Memory optimization enabled ‚úì")

# specific style date time saving
timestamp = datetime.now().strftime("%d_%b_%H-%M")

# validating required directories:
required_directories = ['images', 'models', 'results']
for folder in required_directories:
    if not os.path.exists(folder):
        print(f"Directory `{folder}/` not found ‚úò Creating...")
        os.makedirs(folder)
    else:
        print(f"Directory `{folder}/` exists ‚úì")

In [None]:
# Load DDI data
print("="*60)
print("EXPLORING DRUG-DRUG INTERACTION DATASET")
print("="*60)

# Load DDI interactions
ddi_df = pd.read_csv('dataset/drugdata/ddis.csv')
print(f"\nüìä DDI Dataset Shape: {ddi_df.shape}")
print(f"Columns: {ddi_df.columns.tolist()}")
print(f"\nüî¨ Interaction Types Distribution:")
print(ddi_df['type'].value_counts())
print(f"\nüìã Sample DDI Data:")
print(ddi_df.head(10))

# Load drug SMILES
smiles_df = pd.read_csv('dataset/drugdata/drug_smiles.csv')
print(f"\nüíä Drug SMILES Dataset Shape: {smiles_df.shape}")
print(f"Columns: {smiles_df.columns.tolist()}")
print(f"\nüìã Sample SMILES Data:")
print(smiles_df.head(10))

# Get unique drugs
unique_drugs_ddi = set(ddi_df['d1'].unique()) | set(ddi_df['d2'].unique())
print(f"\nüìà Statistics:")
print(f"‚Ä¢ Total DDI pairs: {len(ddi_df)}")
print(f"‚Ä¢ Unique drugs in DDI: {len(unique_drugs_ddi)}")
print(f"‚Ä¢ Drugs with SMILES: {len(smiles_df)}")
print(f"‚Ä¢ Interaction type 0: {(ddi_df['type'] == 0).sum()}")
print(f"‚Ä¢ Interaction type 1: {(ddi_df['type'] == 1).sum()}")

# Check overlap
drugs_with_smiles = set(smiles_df['drug_id'].unique())
overlap = unique_drugs_ddi & drugs_with_smiles
print(f"‚Ä¢ Drugs with both DDI and SMILES: {len(overlap)}")
print(f"‚Ä¢ Coverage: {len(overlap)/len(unique_drugs_ddi)*100:.2f}%")

In [None]:
def prepare_ddi_data():
    """
    Load and prepare DDI data
    """
    print("_"*40)
    print("LOADING DRUG-DRUG INTERACTION DATA")
    print("-"*40)
    
    # Load data
    ddi_df = pd.read_csv('dataset/drugdata/ddis.csv')
    smiles_df = pd.read_csv('dataset/drugdata/drug_smiles.csv')
    
    print(f"‚úì Loaded {len(ddi_df)} DDI pairs")
    print(f"‚úì Loaded {len(smiles_df)} drug SMILES")
    print("_"*40)
    
    # Extract features
    feature_extractor = DrugFeatureExtractor()
    drug_features = feature_extractor.extract_all_features(smiles_df)
    
    # Build graph
    graph_builder = DDIGraphBuilder(ddi_df, drug_features)
    node_features, edge_index, edge_labels, n_classes = graph_builder.build_graph()
    
    # Create adjacency matrix
    n_nodes = node_features.shape[0]
    adj_mat = torch.zeros(n_nodes, n_nodes, 1)
    adj_mat[edge_index[0], edge_index[1], 0] = 1
    
    # Add self-loops
    adj_mat[range(n_nodes), range(n_nodes), 0] = 1
    
    # Split data
    n_edges = edge_index.shape[1]
    indices = np.arange(n_edges)
    
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)
    
    train_mask = torch.zeros(n_edges, dtype=torch.bool)
    val_mask = torch.zeros(n_edges, dtype=torch.bool)
    test_mask = torch.zeros(n_edges, dtype=torch.bool)
    
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True
    
    print(f"\nüõà Data Split:")
    print(f"‚Ä¢ Train: {train_mask.sum()} edges ({train_mask.sum()/n_edges*100:.1f}%)")
    print(f"‚Ä¢ Val: {val_mask.sum()} edges ({val_mask.sum()/n_edges*100:.1f}%)")
    print(f"‚Ä¢ Test: {test_mask.sum()} edges ({test_mask.sum()/n_edges*100:.1f}%)")
    
    return {
        'node_features': node_features,
        'edge_index': edge_index,
        'edge_labels': edge_labels,
        'adj_mat': adj_mat,
        'train_mask': train_mask,
        'val_mask': val_mask,
        'test_mask': test_mask,
        'n_classes': n_classes,
        'label_encoder': graph_builder.label_encoder
    }

## 3. Feature and Pattern Extraction

In [None]:
class DrugFeatureExtractor:
    """
    Extract molecular features from SMILES strings using RDKit
    """
    
    def __init__(self):
        self.feature_names = []
    
    def smiles_to_features(self, smiles):
        """
        Convert SMILES to molecular fingerprint and descriptors
        
        Args:
            smiles: SMILES string
            
        Returns:
            Feature vector
        """
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return None
            
            # Morgan fingerprint
            mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=1024)
            fp = fp = mfpgen.GetFingerprint(mol)
            fp_array = np.array(fp)
            
            # Molecular descriptors
            descriptors = [
                Descriptors.MolWt(mol),
                Descriptors.MolLogP(mol),
                Descriptors.NumHDonors(mol),
                Descriptors.NumHAcceptors(mol),
                Descriptors.TPSA(mol),
                Descriptors.NumRotatableBonds(mol),
                Descriptors.NumAromaticRings(mol),
                Descriptors.FractionCSP3(mol),
            ]
            
            # Combine fingerprint and descriptors
            features = np.concatenate([fp_array, descriptors])
            
            return features
            
        except Exception as e:
            print(f"Error processing SMILES: {e}")
            return None
    
    def extract_all_features(self, smiles_df):
        """
        Extract features for all drugs
        
        Args:
            smiles_df: DataFrame with drug_id and smiles columns
            
        Returns:
            Dictionary mapping drug_id to feature vector
        """
        print("\n"+"_"*45)
        print("Extracting molecular features from SMILES...")
        print("-"*45)

        drug_features = {}
        failed = 0
        
        for idx, row in smiles_df.iterrows():
            drug_id = row['drug_id']
            smiles = row['smiles']
            
            features = self.smiles_to_features(smiles)
            
            if features is not None:
                drug_features[drug_id] = features
            else:
                failed += 1
            
            if (idx + 1) % 200 == 0:
                print(f"{idx + 1}/{len(smiles_df)} drugs processed ")
        print("-"*45)
        print(f"‚úì Extracted features for {len(drug_features)} drugs")
        if failed > 0:
            print(f"‚ö† Failed to process {failed} drugs")
        print("_"*45 + "\n")
        return drug_features

## 4. GAT Architecture Building

In [None]:
class DDIGraphBuilder:
    """
    Build graph structure from DDI data
    """
    
    def __init__(self, ddi_df, drug_features):
        self.ddi_df = ddi_df
        self.drug_features = drug_features
        self.drug_to_idx = {}
        self.idx_to_drug = {}
        self.label_encoder = LabelEncoder()
        
    def build_graph(self):
        """
        Build graph from DDI data
        
        Returns:
            Node features, edge indices, edge labels
        """
        print("_"*40)
        print("Building DDI graph...")
        
        # Get unique drugs
        unique_drugs = sorted(list(self.drug_features.keys()))
        self.drug_to_idx = {drug: idx for idx, drug in enumerate(unique_drugs)}
        self.idx_to_drug = {idx: drug for drug, idx in self.drug_to_idx.items()}
        
        print(f"‚Ä¢ Number of nodes (drugs): {len(unique_drugs)}")
        
        # Create node feature matrix
        feature_dim = len(list(self.drug_features.values())[0])
        node_features = np.zeros((len(unique_drugs), feature_dim))
        
        for drug, idx in self.drug_to_idx.items():
            node_features[idx] = self.drug_features[drug]
        
        # Normalize features
        node_features = (node_features - node_features.mean(axis=0)) / (node_features.std(axis=0) + 1e-8)
        
        # Build edges from DDI data
        edge_list = []
        edge_labels = []
        
        for _, row in self.ddi_df.iterrows():
            d1, d2, interaction_type = row['d1'], row['d2'], row['type']
            
            if d1 in self.drug_to_idx and d2 in self.drug_to_idx:
                idx1 = self.drug_to_idx[d1]
                idx2 = self.drug_to_idx[d2]
                
                # Add bidirectional edges
                edge_list.append([idx1, idx2])
                edge_list.append([idx2, idx1])
                edge_labels.append(interaction_type)
                edge_labels.append(interaction_type)
        
        edge_index = torch.tensor(edge_list, dtype=torch.long).t()
        
        # Encode labels
        edge_labels = self.label_encoder.fit_transform(edge_labels)
        edge_labels = torch.tensor(edge_labels, dtype=torch.long)
        
        print(f"‚Ä¢ Number of edges: {len(edge_list)}")
        print(f"‚Ä¢ Number of interaction types: {len(self.label_encoder.classes_)}")
        print(f"‚Ä¢ Feature dimension: {feature_dim}")
        
        return (
            torch.FloatTensor(node_features),
            edge_index,
            edge_labels,
            len(self.label_encoder.classes_)
        )

In [None]:
class GraphAttentionLayer(nn.Module):
    """
    Single Graph Attention Layer Implementation
    
    Args:
        in_features: Number of input features per node (F)
        out_features: Number of output features per node (F')
        n_heads: Number of attention heads (K)
        is_concat: Whether to concatenate or average multi-head results
        dropout: Dropout probability for regularization
        leaky_relu_negative_slope: Negative slope for LeakyReLU activation
    """
    
    def __init__(self, in_features: int, out_features: int, n_heads: int=8, 
                 is_concat: bool = True, dropout: float = 0.6, 
                 leaky_relu_negative_slope: float = 0.2):
        super().__init__()
        
        self.is_concat = is_concat
        self.n_heads = n_heads
        
        # Calculate dimensions per head
        if is_concat:
            assert out_features % n_heads == 0, "out_features must be divisible by n_heads when concatenating"
            self.n_hidden = out_features // n_heads
        else:
            self.n_hidden = out_features
        
        # Linear transformation: W in the paper
        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
        
        # Attention parameters (a^T) (one per head)
        self.attn_src = nn.Parameter(torch.Tensor(1, n_heads, self.n_hidden))
        self.attn_dst = nn.Parameter(torch.Tensor(1, n_heads, self.n_hidden))
        
        # Activation and normalization
        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights using Xavier uniform
        self._init_weights()
    
    def _init_weights(self):
        """Initialize layer weights using Xavier uniform initialization"""
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.xavier_uniform_(self.attn_src)
        nn.init.xavier_uniform_(self.attn_dst)
    
    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the attention layer
        
        Args:
            h: Node embeddings [n_nodes, in_features]
            adj_mat: Adjacency matrix [n_nodes, n_nodes, n_heads]
            
        Returns:
            Updated node embeddings [n_nodes, out_features]
        """
        n_nodes = h.shape[0]
        device = h.device
        
        # Step 1: Linear transformation g_i = W * h_i for each head
        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)

        # Extract edge indices from adjacency matrix (ony non-zero entries)
        adj_2d = adj_mat.squeeze(-1)
        edge_index = adj_2d.nonzero(as_tuple=False).t()

        if edge_index.shape[1] == 0:
            # No edges case
            if self.is_concat:
                return torch.zeros(n_nodes, self.n_heads * self.n_hidden, device=device)
            else:
                return torch.zeros(n_nodes, self.n_hidden, device=device)
        
        src_nodes = edge_index[0]
        dst_nodes = edge_index[1]
        
        # Get features for source and destination nodes
        g_src = g[src_nodes]  # [n_edges, n_heads, n_hidden]
        g_dst = g[dst_nodes]  # [n_edges, n_heads, n_hidden]
        
        # Compute attention scores (edge-wise, memory efficient)
        e_src = (g_src * self.attn_src).sum(dim=-1)  # [n_edges, n_heads]
        e_dst = (g_dst * self.attn_dst).sum(dim=-1)  # [n_edges, n_heads]
        e = self.activation(e_src + e_dst)  # [n_edges, n_heads]
        
        # Softmax normalization per destination node
        # Group by destination node for proper normalization
        alpha = torch.zeros_like(e)
        # Process each head separately to save memory
        for head in range(self.n_heads):
            e_head = e[:, head]
            
            # Use scatter operations instead of loops
            # Create index for scatter
            max_dst = dst_nodes.max().item() + 1
            
            # Compute max for numerical stability
            max_vals = torch.full((max_dst,), float('-inf'), device=device)
            max_vals.scatter_reduce_(0, dst_nodes, e_head, reduce='amax', include_self=False)
            
            # Subtract max and exp
            e_stable = e_head - max_vals[dst_nodes]
            exp_e = torch.exp(e_stable)
            
            # Sum exp values per destination
            sum_exp = torch.zeros(max_dst, device=device)
            sum_exp.scatter_add_(0, dst_nodes, exp_e)
            
            # Normalize
            alpha[:, head] = exp_e / (sum_exp[dst_nodes] + 1e-16)
        
        # Aggregate features using attention weights
        out = torch.zeros(n_nodes, self.n_heads, self.n_hidden, device=device)
        
        # Efficient aggregation using scatter_add
        for head in range(self.n_heads):
            # Weighted features for this head
            weighted_features = g_src[:, head, :] * alpha[:, head].unsqueeze(-1)
            
            # Aggregate to destination nodes
            out[:, head, :].scatter_add_(0, 
                                         dst_nodes.unsqueeze(-1).expand(-1, self.n_hidden),
                                         weighted_features)
        
        # Concatenate or average heads
        if self.is_concat:
            return out.reshape(n_nodes, self.n_heads * self.n_hidden)
        else:
            return out.mean(dim=1)


In [None]:
class GAT_DDI(nn.Module):
    """
    Memory-efficient GAT model with gradient checkpointing
    """
    
    def __init__(self, n_features, n_hidden, n_classes, n_heads=8, dropout=0.6):
        super().__init__()
        
        self.dropout = dropout
        
        # GAT layers
        self.gat1 = GraphAttentionLayer(n_features, n_hidden, n_heads, True, dropout)
        self.gat2 = GraphAttentionLayer(n_hidden, n_hidden, n_heads, True, dropout)
        self.gat3 = GraphAttentionLayer(n_hidden, n_hidden // 2, n_heads, True, dropout)
        
        # Edge prediction layers
        self.edge_mlp = nn.Sequential(
            nn.Linear(n_hidden, n_hidden // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_hidden // 2, n_hidden // 4),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_hidden // 4, n_classes)
        )
        
        self.dropout_layer = nn.Dropout(dropout)
        
        # Enable gradient checkpointing for memory efficiency
        self.use_checkpointing = True
    
    def _gat_forward(self, x, adj_mat, gat_layer):
        """Helper for gradient checkpointing"""
        return gat_layer(x, adj_mat)
    
    def forward(self, x, adj_mat, edge_index):
        """
        Forward pass with gradient checkpointing
        """
        # Layer 1
        x = self.dropout_layer(x)
        if self.use_checkpointing and self.training:
            x = torch.utils.checkpoint.checkpoint(self._gat_forward, x, adj_mat, self.gat1, use_reentrant=False)
        else:
            x = self.gat1(x, adj_mat)
        x = F.elu(x)
        
        # Layer 2
        x = self.dropout_layer(x)
        if self.use_checkpointing and self.training:
            x = torch.utils.checkpoint.checkpoint(self._gat_forward, x, adj_mat, self.gat2, use_reentrant=False)
        else:
            x = self.gat2(x, adj_mat)
        x = F.elu(x)
        
        # Layer 3
        x = self.dropout_layer(x)
        if self.use_checkpointing and self.training:
            x = torch.utils.checkpoint.checkpoint(self._gat_forward, x, adj_mat, self.gat3, use_reentrant=False)
        else:
            x = self.gat3(x, adj_mat)
        x = F.elu(x)
        
        # Edge prediction
        src_nodes = edge_index[0]
        dst_nodes = edge_index[1]
        
        edge_features = torch.cat([x[src_nodes], x[dst_nodes]], dim=1)
        edge_pred = self.edge_mlp(edge_features)
        
        return edge_pred

## 5. Training Pipeline

Our training pipeline includes:
- Early stopping to prevent overfitting
- Learning rate scheduling
- Comprehensive logging
- Model checkpointing

In [None]:
class DDITrainer:
    """
    Professional trainer for DDI prediction
    """
    
    def __init__(self, model, device):
        self.model = model.to(device)
        self.device = device
        self.history = {
            'train_loss': [], 'val_loss': [],
            'train_acc': [], 'val_acc': [],
            'train_f1': [], 'val_f1': []
        }
    
    def train_epoch(self, x, adj_mat, edge_index, edge_labels, train_mask, optimizer, criterion):
        """Train for one epoch with mini-batch processing"""
        self.model.train()
        
        # Get training edges
        train_indices = torch.where(train_mask)[0]
        n_train = len(train_indices)
        
        # Process in mini-batches to save memory
        batch_size = 7000  # Adjust based on GPU memory
        n_batches = (n_train + batch_size - 1) // batch_size
        
        total_loss = 0
        all_preds = []
        all_labels = []
        
        # Shuffle training indices
        perm = torch.randperm(n_train, device=self.device)
        train_indices = train_indices[perm]
        
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = min((i + 1) * batch_size, n_train)
            batch_indices = train_indices[start_idx:end_idx]
            
            optimizer.zero_grad()
            
            # Create batch edge index
            batch_edge_index = edge_index[:, batch_indices]
            batch_labels = edge_labels[batch_indices]
            
            # Forward pass
            out = self.model(x, adj_mat, batch_edge_index)
            loss = criterion(out, batch_labels)
            
            # Backward pass
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            # Accumulate metrics
            total_loss += loss.item() * len(batch_indices)
            pred = out.max(1)[1]
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(batch_labels.cpu().numpy())
            
            # Clear cache periodically
            if (i + 1) % 5 == 0:
                torch.cuda.empty_cache()
        
        # Calculate metrics
        avg_loss = total_loss / n_train
        acc = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average='weighted')
        
        return avg_loss, acc, f1

    
    def validate(self, x, adj_mat, edge_index, edge_labels, val_mask, criterion):
        """Validate with mini-batch processing"""
        self.model.eval()
        
        val_indices = torch.where(val_mask)[0]
        n_val = len(val_indices)
        
        batch_size = 7000
        n_batches = (n_val + batch_size - 1) // batch_size
        
        total_loss = 0
        all_preds = []
        all_labels = []
        
        with torch.no_grad():
            for i in range(n_batches):
                start_idx = i * batch_size
                end_idx = min((i + 1) * batch_size, n_val)
                batch_indices = val_indices[start_idx:end_idx]
                
                batch_edge_index = edge_index[:, batch_indices]
                batch_labels = edge_labels[batch_indices]
                
                out = self.model(x, adj_mat, batch_edge_index)
                loss = criterion(out, batch_labels)
                
                total_loss += loss.item() * len(batch_indices)
                pred = out.max(1)[1]
                all_preds.extend(pred.cpu().numpy())
                all_labels.extend(batch_labels.cpu().numpy())
        
        avg_loss = total_loss / n_val
        acc = accuracy_score(all_labels, all_preds)
        f1 = f1_score(all_labels, all_preds, average='weighted')
        
        return avg_loss, acc, f1
    
    def train(self, x, adj_mat, edge_index, edge_labels, train_mask, val_mask,
              epochs=200, lr=0.005, weight_decay=5e-4, patience=30):
        """
        Complete training loop
        """
        print(f"\n‚ôªTraining GAT-DDI Model")
        print(f"‚Ä¢ Device: {self.device}")
        print(f"‚Ä¢ Parameters: {sum(p.numel() for p in self.model.parameters()):,}")
        print(f"‚Ä¢ Training edges: {train_mask.sum()}")
        print(f"‚Ä¢ Validation edges: {val_mask.sum()}")
        
        optimizer = optim.Adam(self.model.parameters(), lr=lr, weight_decay=weight_decay)
        criterion = nn.CrossEntropyLoss()
        
        best_val_acc = 0
        best_val_f1 = 0
        patience_counter = 0
        start_time = time.time()
        
        print("\n" + "_"*80)
        print("EPOCH | TRAIN LOSS | TRAIN ACC | TRAIN F1 | VAL LOSS | VAL ACC | VAL F1")
        print("-"*80)
        
        for epoch in range(epochs):
            # Training
            train_loss, train_acc, train_f1 = self.train_epoch(
                x, adj_mat, edge_index, edge_labels, train_mask, optimizer, criterion
            )
            
            # Validation
            val_loss, val_acc, val_f1 = self.validate(
                x, adj_mat, edge_index, edge_labels, val_mask, criterion
            )
            
            # Store metrics
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_loss)
            self.history['train_acc'].append(train_acc)
            self.history['val_acc'].append(val_acc)
            self.history['train_f1'].append(train_f1)
            self.history['val_f1'].append(val_f1)
            
            # Early stopping
            if val_f1 > best_val_f1:
                best_val_acc = val_acc
                best_val_f1 = val_f1
                patience_counter = 0
                torch.save(self.model.state_dict(), f'models/best_GAT_DDI_{timestamp}.pth')
            else:
                patience_counter += 1
            
            # Print progress
            if (epoch + 1) % 10 == 0:
                print(f"{epoch+1:5d} | {train_loss:10.4f} | {train_acc:9.4f} | "
                      f"{train_f1:8.4f} | {val_loss:8.4f} | {val_acc:7.4f} | {val_f1:6.4f}")
            
            # Early stopping
            if patience_counter >= patience:
                print(f"\n‚äò Early stopping at epoch {epoch+1}")
                break
        
        training_time = time.time() - start_time
        print("-"*75)
        print(f"(‚úì) Training completed in {training_time:.2f} seconds")
        print(f"‚òÖ Best validation accuracy: {best_val_acc:.4f}")
        print(f"‚òÖ Best validation F1-score: {best_val_f1:.4f}")
        print("_"*75)
        
        return self.history

## 6.Model Evaluation

In [None]:
def evaluate_model(model, data, device):
    """Comprehensive model evaluation"""
    print("\n" + "="*60)
    print("MODEL EVALUATION")
    print("="*60)
    
    model.eval()
    with torch.no_grad():
        out = model(data['node_features'], data['adj_mat'], data['edge_index'])
        pred = out[data['test_mask']].max(1)[1]
        y_true = data['edge_labels'][data['test_mask']].cpu().numpy()
        y_pred = pred.cpu().numpy()
        
        # Metrics
        acc = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average='weighted')
        
        print(f"\nüéØ Test Accuracy: {acc:.4f}")
        print(f"üéØ Test F1-Score: {f1:.4f}")
        
        print(f"\nüìä Classification Report:")
        print(classification_report(y_true, y_pred, 
                                   target_names=[f'Type {i}' for i in range(data['n_classes'])],
                                   zero_division=0))
        
        return {'accuracy': acc, 'f1_score': f1, 'predictions': y_pred, 'true_labels': y_true}

## 7. Plotting and Visualization

In [None]:
def plot_training_history(history, save_path=f'images/DDI_training_history_{timestamp}.png'):
    """Plot training history"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train', linewidth=2)
    axes[0].plot(history['val_loss'], label='Validation', linewidth=2)
    axes[0].set_title('Loss', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train', linewidth=2)
    axes[1].plot(history['val_acc'], label='Validation', linewidth=2)
    axes[1].set_title('Accuracy', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # F1-Score
    axes[2].plot(history['train_f1'], label='Train', linewidth=2)
    axes[2].plot(history['val_f1'], label='Validation', linewidth=2)
    axes[2].set_title('F1-Score', fontsize=14, fontweight='bold')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('F1-Score')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"‚úì Saved training history to {save_path}")

## Model Training Step

In [None]:
# Prepare data
data = prepare_ddi_data()

# Move to device
for key in ['node_features', 'edge_index', 'edge_labels', 'adj_mat', 
            'train_mask', 'val_mask', 'test_mask']:
    data[key] = data[key].to(device)

# Model configuration
config = {
    'n_features': data['node_features'].shape[1],
    'n_hidden': 256,
    'n_classes': data['n_classes'],
    'n_heads': 8,
    'dropout': 0.3
}
print('_'*50)
print(f"\n‚õØ Model Configuration:")
for key, value in config.items():
    print(f"‚Ä¢ {key}: {value}")
print('_'*50)
# Initialize model
model = GAT_DDI(**config)
trainer = DDITrainer(model, device)

# Train
history = trainer.train(
    data['node_features'], data['adj_mat'], data['edge_index'],
    data['edge_labels'], data['train_mask'], data['val_mask'],
    epochs=50, lr=0.001, weight_decay=5e-4, patience=30
)


In [None]:
# Plot history
plot_training_history(history)

In [None]:
# Visualize learned embeddings
print("Generating embedding visualization...")
evaluator.visualize_embeddings(data)

In [None]:
# Load best model and evaluate
model.load_state_dict(torch.load(f'models/best_GAT_DDI_{timestamp}.pth'))
results = evaluate_model(model, data, device)

# Save results
torch.save({
    'config': config,
    'history': history,
    'results': results,
    'timestamp': timestamp
}, f'results/GAT_DDI_results_{timestamp}.pth')

# save_file(model.state_dict(), f'models/best_GAT_DDI_{timestamp}.safetensors')

print(f"\n‚úÖ All results saved!")
print(f"   ‚Ä¢ Model: models/best_GAT_DDI_{timestamp}.pth")
print(f"   ‚Ä¢ Results: results/GAT_DDI_results_{timestamp}.pth")

## 8. Saving Logs and Tracks

In [None]:
# 1. Prepare the new log entry
current_log = {
    "date": timestamp,
    "graph_details": {
        "Number of edges": len(data['edge_index'][0]), # Assuming edge_index format
        "Number of interaction types": data['n_classes'],
        "Feature dimension": data['node_features'].shape[1]
    },
    "config": config,
    "runtime_info": {
        "device": str(device),
        "parameters": sum(p.numel() for p in model.parameters()),
        "training_edges": int(data['train_mask'].sum()),
        "validation_edges": int(data['val_mask'].sum())
    },
    "metrics": {
        "train_loss": history['train_loss'][-1],
        "train_acc": history['train_acc'][-1],
        "train_f1": history.get('train_f1', [-1])[-1], # .get() prevents errors if missing
        "val_loss": history['val_loss'][-1],
        "val_acc": history['val_acc'][-1],
        "val_f1": history.get('val_f1', [-1])[-1]
    }
}

# 2. File handling: Load existing data or create a new list
log_file = 'model_evals_log.json'

if os.path.exists(log_file):
    with open(log_file, 'r') as f:
        try:
            logs_list = json.load(f)
            if not isinstance(logs_list, list):
                logs_list = []
        except json.JSONDecodeError:
            logs_list = []
else:
    logs_list = []

# 3. Insert new log at the TOP (index 0)
logs_list.insert(0, new_log)

# 4. Save back to file
with open(log_file, 'w') as f:
    json.dump(logs_list, f, indent=4)

print(f"üöÄ Log successfully prepended to {log_file}!")


## Custom testing

In [None]:
import torch
from safetensors.torch import load_file

# 1. Initialize the model architecture with the same config
# (Ensure GAT_DDI and config are defined in your current session)
test_model = GAT_DDI(**config)

# 2. Load weights using safetensors
weights_path = f'models/best_GAT_DDI_{timestamp}.safetensors'
state_dict = load_file(weights_path)
test_model.load_state_dict(state_dict)

# 3. Prepare for inference
test_model.to(device)
test_model.eval()  # CRITICAL: Sets layers like Dropout to evaluation mode

# 4. Test with input (Forward Pass)
with torch.no_grad():  # Disables gradient calculation for faster inference
    # Using your existing data object as a test input
    test_output = test_model(
        data['node_features'], 
        data['adj_mat'], 
        data['edge_index']
    )

# 5. Quick Verification
print(f"‚úÖ Model loaded from: {weights_path}")
print(f"üì° Output Shape: {test_output.shape}")
print(f"üéØ Sample Prediction: {torch.softmax(test_output[0], dim=0)}")


## 8. Model Analysis and Insights

In [None]:
def plot_confusion_matrix(conf_matrix, class_names, save_path=f'images/GAT_confusion_matrix_{timestamp}.png'):
    """
    Plot confusion matrix with better styling
    """
    plt.figure(figsize=(10, 8))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'})
    plt.title('Confusion Matrix', fontsize=16, fontweight='bold')
    plt.xlabel('Predicted Class', fontsize=12)
    plt.ylabel('True Class', fontsize=12)
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

# Plot confusion matrix
class_names = [f'Class {i}' for i in range(dataset.num_classes)]
plot_confusion_matrix(test_results['confusion_matrix'], class_names)

In [None]:
class ModelEvaluator:
    """
    Comprehensive evaluation class for GAT models
    """
    
    def __init__(self, model: GAT, device: torch.device):
        self.model = model.to(device)
        self.device = device
    
    def test(self, data):
        """
        Comprehensive testing with multiple metrics
        """
        self.model.eval()
        with torch.no_grad():
            out = self.model(data.x, data.adj_mat)
            pred = out[data.test_mask].max(1)[1]
            y_true = data.y[data.test_mask].cpu().numpy()
            y_pred = pred.cpu().numpy()
            
            # Calculate metrics
            test_acc = accuracy_score(y_true, y_pred)
            class_report = classification_report(y_true, y_pred, output_dict=True)
            conf_matrix = confusion_matrix(y_true, y_pred)
            
            return {
                'accuracy': test_acc,
                'classification_report': class_report,
                'confusion_matrix': conf_matrix,
                'predictions': y_pred,
                'true_labels': y_true
            }
    
    def visualize_embeddings(self, data, save_path=f'images/GAT_visualize_embeddings_{timestamp}.png'):
        """
        Visualize node embeddings using t-SNE
        """
        self.model.eval()
        with torch.no_grad():
            # Get embeddings from first layer
            x = self.model.dropout_layer(data.x)
            embeddings = self.model.attention1(x, data.adj_mat)
            embeddings = embeddings.cpu().numpy()
            
            # Apply t-SNE
            print("üîÑ Computing t-SNE embeddings...")
            tsne = TSNE(n_components=2, random_state=42, perplexity=30)
            embeddings_2d = tsne.fit_transform(embeddings)
            
            # Plot
            plt.figure(figsize=(12, 8))
            scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                                c=data.y.cpu().numpy(), cmap='tab10', alpha=0.7, s=50)
            plt.colorbar(scatter, label='Node Class')
            plt.title('GAT Node Embeddings Visualization (t-SNE)', fontsize=16, fontweight='bold')
            plt.xlabel('t-SNE Dimension 1', fontsize=12)
            plt.ylabel('t-SNE Dimension 2', fontsize=12)
            plt.grid(True, alpha=0.3)
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.show()

# Load best model and evaluate
model.load_state_dict(torch.load(f'best_GAT_model_{timestamp}.pth'))

# weights = load_file(f'models/best_GAT_DDI_{timestamp}.safetensors')
# model.load_state_dict(weights)

evaluator = ModelEvaluator(model, device)

test_results = evaluator.test(data)

print(f"üéØ Test Result Accuracy: {test_results['accuracy']:.4f}")
print("\nüìä Classification Report:")
print(classification_report(test_results['true_labels'], test_results['predictions']))

In [None]:
# Analyze model performance
def analyze_results(history, test_results):
    """
    Provide comprehensive analysis of model performance
    """
    print("üìà MODEL PERFORMANCE ANALYSIS")
    print("="*60)
    
    # Training metrics
    print(f"\nTraining Metrics:")
    print(f"‚Ä¢ Final Training Accuracy: {history['train_accuracies'][-1]:.4f}")
    print(f"‚Ä¢ Final Validation Accuracy: {history['val_accuracies'][-1]:.4f}")
    print(f"‚Ä¢ Best Validation Accuracy: {history['best_val_acc']:.4f}")
    print(f"‚Ä¢ Training Time: {history['training_time']:.2f} seconds")
    print(f"‚Ä¢ Total Epochs: {len(history['train_losses'])}")
    
    # Test metrics
    print(f"Test Metrics:")
    print(f"‚Ä¢ Test Accuracy: {test_results['accuracy']:.4f}")
    
    # Per-class performance
    print(f"\nüìä Per-Class Performance:")
    for i, (precision, recall, f1) in enumerate(zip(
        [test_results['classification_report'][str(i)]['precision'] for i in range(dataset.num_classes)],
        [test_results['classification_report'][str(i)]['recall'] for i in range(dataset.num_classes)],
        [test_results['classification_report'][str(i)]['f1-score'] for i in range(dataset.num_classes)]
    )):
        print(f"   ‚Ä¢ Class {i}: Precision={precision:.3f}, Recall={recall:.3f}, F1={f1:.3f}")
    
    # Model complexity
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nüîß Model Complexity:")
    print(f"   ‚Ä¢ Total Parameters: {total_params:,}")
    print(f"   ‚Ä¢ Trainable Parameters: {trainable_params:,}")
    print(f"   ‚Ä¢ Model Size: ~{total_params * 4 / 1024 / 1024:.2f} MB")

# Run analysis
analyze_results(history, test_results)

## 9. Save Results and Model

In [None]:
# Save comprehensive results
results_summary = {
    'model_config': config,
    'training_history': history,
    'test_results': {
        'accuracy': test_results['accuracy'],
        'classification_report': test_results['classification_report']
    },
    'dataset_info': {
        'name': 'Cora',
        'num_nodes': data.x.size(0),
        'num_edges': data.edge_index.size(1),
        'num_features': data.x.size(1),
        'num_classes': dataset.num_classes
    }
}

torch.save(results_summary, 'models/gat_complete_results.pth')
print("\n Results saved to 'gat_complete_results.pth'üóπ")
print("Best model saved to 'best_gat_model.pth' üóπ")

print("_"*50)
print(f"üèÜ Final Test Accuracy: {test_results['accuracy']:.4f}")
print(f"‚è±Ô∏è Total Training Time: {history['training_time']:.2f} seconds")
print(f"üéØ Best Validation Accuracy: {history['best_val_acc']:.4f}")