# Graph Attention Networks (GAT) - Complete Professional Implementation

This notebook provides a comprehensive implementation of Graph Attention Networks (GAT) based on the paper "Graph Attention Networks" by Veliƒçkoviƒá et al.

### Key Features:
- ‚úÖ Multi-head attention mechanism
- ‚úÖ Professional training pipeline with early stopping
- ‚úÖ Comprehensive evaluation metrics
- ‚úÖ Visualization capabilities
- ‚úÖ Error handling and logging

### Architecture Overview:
GATs work on graph training_data where nodes represent entities and edges represent relationships. The attention mechanism allows nodes to focus on the most relevant neighbors, similar to transformers but adapted for graph structures.

```bash
# Custom implementation - EDGE CLASSIFICATION
class GraphAttentionLayer:
    - Manual attention computation
    - Sparse edge-based processing
    - Memory-efficient scatter operations
    - Gradient checkpointing
    
class GAT_DDI:
    - 3 GAT layers (256 ‚Üí 256 ‚Üí 128 hidden dims)
    - Edge MLP: 256 ‚Üí 128 ‚Üí 64 ‚Üí 86 classes
    - Processes: node_features ‚Üí embeddings ‚Üí edge_features ‚Üí 86-way classification
```
```bash
# For EACH batch of 7000 edges:
1. Forward pass through 3 custom GAT layers
2. Extract source/dest node embeddings
3. Concatenate edge features
4. Pass through 4-layer MLP
5. Compute CrossEntropyLoss over 86 classes
6. Backward pass with gradient checkpointing
7. Repeat for ~27 batches per epoch

# Per epoch: ~27 batches √ó complex operations = SLOW
```

## 1. Importing Required Libraries

In [1]:
# Libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.nn import GATv2Conv
from torch_geometric.loader import NeighborLoader
from torch.amp import autocast, GradScaler

from safetensors.torch import save_file, load_file
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_auc_score, f1_score
from sklearn.preprocessing import LabelEncoder
from sklearn.manifold import TSNE
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator, Descriptors
import networkx as nx
from datetime import datetime
import time
import warnings
import os
import gc
import json
from collections import Counter

In [2]:
warnings.filterwarnings('ignore')

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Memory optimization settings for GPU, Aggressive memory management
if device.type == 'cuda':
    torch.cuda.empty_cache()
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    
    # Force garbage collection
    gc.collect()
    
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
    print("Memory optimization enabled ‚úì")

# specific style date time saving
timestamp = datetime.now().strftime("%d_%b_%H-%M")

# validating required directories:
required_directories = ['images', 'models', 'results']
for folder in required_directories:
    if not os.path.exists(folder):
        print(f"‚úò Directory `{folder}/` not found \nCreating...")
        os.makedirs(folder)
    else:
        print(f"‚úì Directory `{folder}/` exists")

Using device: cuda
GPU: NVIDIA GeForce GTX 1650
GPU Memory: 4.00 GB
Memory optimization enabled ‚úì
‚úì Directory `images/` exists
‚úì Directory `models/` exists
‚úì Directory `results/` exists


## training_data Reading

In [3]:
# Load DDI training_data
print("EXPLORING DRUG-DRUG INTERACTION DATASET")

# Load DDI interactions
ddi_df = pd.read_csv('dataset/drugdata/ddis.csv')
print(f"\nüìä DDI Dataset Shape: {ddi_df.shape}")
print(f"Columns: {ddi_df.columns.tolist()}")
print(f"\nüî¨ Interaction Types Distribution:")
print(ddi_df['type'].value_counts())
print(f"\nüìã Sample DDI training_data:")
print(ddi_df.head(10))

# Load drug SMILES
smiles_df = pd.read_csv('dataset/drugdata/drug_smiles.csv')
print(f"\nüíä Drug SMILES Dataset Shape: {smiles_df.shape}")
print(f"Columns: {smiles_df.columns.tolist()}")
print(f"\nüìã Sample SMILES training_data:")
print(smiles_df.head(10))

# Get unique drugs
unique_drugs_ddi = set(ddi_df['d1'].unique()) | set(ddi_df['d2'].unique())
print(f"\nüìà Statistics:")
print(f"‚Ä¢ Total DDI pairs: {len(ddi_df)}")
print(f"‚Ä¢ Unique drugs in DDI: {len(unique_drugs_ddi)}")
print(f"‚Ä¢ Drugs with SMILES: {len(smiles_df)}")
print(f"‚Ä¢ Interaction type 0: {(ddi_df['type'] == 0).sum()}")
print(f"‚Ä¢ Interaction type 1: {(ddi_df['type'] == 1).sum()}")

# Check overlap
drugs_with_smiles = set(smiles_df['drug_id'].unique())
overlap = unique_drugs_ddi & drugs_with_smiles
print(f"‚Ä¢ Drugs with both DDI and SMILES: {len(overlap)}")
print(f"‚Ä¢ Coverage: {len(overlap)/len(unique_drugs_ddi)*100:.2f}%")

EXPLORING DRUG-DRUG INTERACTION DATASET

üìä DDI Dataset Shape: (191808, 4)
Columns: ['d1', 'd2', 'type', 'Neg samples']

üî¨ Interaction Types Distribution:
type
48    60751
46    34360
72    23779
74     9470
59     8397
      ...  
42       11
61       11
51       10
25        7
41        6
Name: count, Length: 86, dtype: int64

üìã Sample DDI training_data:
        d1       d2  type Neg samples
0  DB04571  DB00460     0   DB01579$t
1  DB00855  DB00460     0   DB01178$t
2  DB09536  DB00460     0   DB06626$t
3  DB01600  DB00460     0   DB01588$t
4  DB09000  DB00460     0   DB06196$t
5  DB11630  DB00460     0   DB00744$t
6  DB00553  DB00460     0   DB06413$t
7  DB06261  DB00460     0   DB00876$t
8  DB01878  DB00460     0   DB09267$t
9  DB00140  DB00460     0   DB01204$t

üíä Drug SMILES Dataset Shape: (1706, 2)
Columns: ['drug_id', 'smiles']

üìã Sample SMILES training_data:
   drug_id                                             smiles
0  DB04571                CC1=CC2=CC3=C(OC(=

## Data Preprocessing

## Data Sanitation & Integrity Checks

A compact, professional data-sanity utility to validate CSV files, check missing/invalid values, SMILES validity (RDKit), duplicates, and DDI/SMILES coverage. Use before running feature extraction and graph building.

In [4]:
# Data Sanity utilities


def sanity_check_dataset(ddi_path='dataset/drugdata/ddis.csv',
                         smiles_path='dataset/drugdata/drug_smiles.csv',
                         save_report=True):
    """Run quick, informative checks on DDI and SMILES csv files.

    Returns a report dict. Does NOT modify files by default.
    """
    report = {'files': {}}

    # --- DDI file checks ---
    try:
        ddi_df = pd.read_csv(ddi_path)
    except Exception as e:
        report['files']['ddi'] = {'error': str(e)}
        print(f"‚úó Failed to read {ddi_path}: {e}")
        return report

    # Basic checks
    report['files']['ddi'] = {
        'path': ddi_path,
        'shape': ddi_df.shape,
        'columns': ddi_df.columns.tolist(),
        'missing_per_column': ddi_df.isnull().sum().to_dict(),
        'duplicate_pairs': int(ddi_df.duplicated(subset=['d1','d2']).sum()) if set(['d1','d2']).issubset(ddi_df.columns) else None
    }

    # Validate required columns
    req = ['d1', 'd2', 'type']
    for c in req:
        if c not in ddi_df.columns:
            print(f"‚úó Column `{c}` missing in {ddi_path}")

    # Interaction type checks
    if 'type' in ddi_df.columns:
        report['files']['ddi']['type_unique_values'] = ddi_df['type'].unique().tolist()
        # Check numeric / unexpected strings
        report['files']['ddi']['type_non_numeric'] = int(~ddi_df['type'].apply(lambda x: isinstance(x, (int, float))).fillna(False).sum())

    # --- SMILES file checks ---
    try:
        smiles_df = pd.read_csv(smiles_path)
    except Exception as e:
        report['files']['smiles'] = {'error': str(e)}
        print(f"‚úó Failed to read {smiles_path}: {e}")
        return report

    report['files']['smiles'] = {
        'path': smiles_path,
        'shape': smiles_df.shape,
        'columns': smiles_df.columns.tolist(),
        'missing_per_column': smiles_df.isnull().sum().to_dict()
    }

    if not set(['drug_id','smiles']).issubset(smiles_df.columns):
        print('‚úó Expect `drug_id` and `smiles` columns in SMILES file')
    else:
        # validate SMILES
        valid_mask = []
        invalid_examples = []
        for idx, sm in enumerate(smiles_df['smiles'].fillna('')):
            mol = Chem.MolFromSmiles(sm)
            if mol is None:
                valid_mask.append(False)
                if len(invalid_examples) < 10:
                    invalid_examples.append((int(smiles_df.loc[idx,'drug_id']) if 'drug_id' in smiles_df.columns else idx, sm))
            else:
                valid_mask.append(True)

        n_invalid = int(sum([not v for v in valid_mask]))
        report['files']['smiles']['n_invalid_smiles'] = n_invalid
        report['files']['smiles']['invalid_examples'] = invalid_examples

    # --- Cross-file checks ---
    ddi_drugs = set(ddi_df['d1'].unique()) | set(ddi_df['d2'].unique())
    smiles_drugs = set(smiles_df['drug_id'].unique()) if 'drug_id' in smiles_df.columns else set()
    missing_smiles_for_ddi = list(ddi_drugs - smiles_drugs)
    report['cross'] = {
        'num_unique_drugs_ddi': len(ddi_drugs),
        'num_drugs_with_smiles': len(smiles_drugs),
        'num_missing_smiles_for_ddi': len(missing_smiles_for_ddi),
        'sample_missing_drugs': missing_smiles_for_ddi[:20]
    }

    # Basic distribution info for interaction types
    if 'type' in ddi_df.columns:
        report['files']['ddi']['type_counts'] = ddi_df['type'].value_counts().to_dict()

    # Save report
    if save_report:
        out_path = f'data_sanity_report_{timestamp}.json'
        with open(out_path, 'w') as f:
            json.dump(report, f, indent=2)
        print(f"‚úì Data sanity report saved to {out_path}")

    print('\n'.join([
        f"‚Ä¢ DDI: {report['files']['ddi'].get('shape')} | Missing cols: {report['files']['ddi']['missing_per_column']}",
        f"‚Ä¢ SMILES: {report['files']['smiles'].get('shape')} | Invalid SMILES: {report['files']['smiles'].get('n_invalid_smiles')}",
        f"‚Ä¢ Missing SMILES for DDI drugs: {report['cross']['num_missing_smiles_for_ddi']}"
    ]))

    return report


# Quick helper to drop rows with invalid SMILES (only when you want to permanently fix data)
def drop_invalid_smiles(smiles_df, inplace=False):
    df = smiles_df if inplace else smiles_df.copy()
    valid_rows = df['smiles'].apply(lambda s: Chem.MolFromSmiles(str(s)) is not None)
    dropped = int((~valid_rows).sum())
    df = df[valid_rows].reset_index(drop=True)
    print(f"Removed {dropped} rows with invalid SMILES")
    return df

## 3. Feature and Pattern Extraction

In [5]:
class DrugFeatureExtractor:
    """Extract molecular features from SMILES strings using RDKit"""
    
    def __init__(self):
        self.feature_names = []
    
    def smiles_to_features(self, smiles):
        """
        Convert SMILES to molecular fingerprint and descriptors
        Args: smiles: SMILES string
        Returns: Feature vector
        """
        try:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return None
            # Morgan fingerprint
            mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=1024)
            fp = fp = mfpgen.GetFingerprint(mol)
            fp_array = np.array(fp)
            
            # Molecular descriptors
            descriptors = [
                Descriptors.MolWt(mol),
                Descriptors.MolLogP(mol),
                Descriptors.NumHDonors(mol),
                Descriptors.NumHAcceptors(mol),
                Descriptors.TPSA(mol),
                Descriptors.NumRotatableBonds(mol),
                Descriptors.NumAromaticRings(mol),
                Descriptors.FractionCSP3(mol),
            ]
            
            # Combine fingerprint and descriptors
            features = np.concatenate([fp_array, descriptors])
            
            return features
            
        except Exception as e:
            print(f"Error processing SMILES: {e}")
            return None
    
    def extract_all_features(self, smiles_df):
        """
        Extract features for all drugs        
        Args: smiles_df: DataFrame with drug_id and smiles columns
        Returns: Dictionary mapping drug_id to feature vector
        """
        print("\n"+"_"*45)
        print("Extracting molecular features from SMILES...")
        print("-"*45)

        drug_features = {}
        failed = 0
        
        for idx, row in smiles_df.iterrows():
            drug_id = row['drug_id']
            smiles = row['smiles']
            
            features = self.smiles_to_features(smiles)
            
            if features is not None:
                drug_features[drug_id] = features
            else:
                failed += 1
            
            if (idx + 1) % 200 == 0:
                print(f"{idx + 1}/{len(smiles_df)} drugs processed ")
        print("-"*45)
        print(f"‚úì Extracted features for {len(drug_features)} drugs")
        if failed > 0:
            print(f"‚ö† Failed to process {failed} drugs")
        print("_"*45 + "\n")
        return drug_features

## 4. GAT Architecture Building

In [6]:
class DDIGraphBuilder:
    """Build graph structure from DDI training_data"""
    def __init__(self, ddi_df, drug_features):
        self.ddi_df = ddi_df
        self.drug_features = drug_features
        self.drug_to_idx = {}
        self.idx_to_drug = {}
        self.label_encoder = LabelEncoder()
        
    def build_graph(self):
        """
        Build graph from DDI training_data        
        Returns: Node features, edge indices, edge labels
        """
        print("Building DDI graph...")
        
        # Get unique drugs
        unique_drugs = sorted(list(self.drug_features.keys()))
        self.drug_to_idx = {drug: idx for idx, drug in enumerate(unique_drugs)}
        self.idx_to_drug = {idx: drug for drug, idx in self.drug_to_idx.items()}
        
        print(f"‚Ä¢ Number of nodes (drugs): {len(unique_drugs)}")
        
        # Create node feature matrix
        feature_dim = len(list(self.drug_features.values())[0])
        node_features = np.zeros((len(unique_drugs), feature_dim))
        
        for drug, idx in self.drug_to_idx.items():
            node_features[idx] = self.drug_features[drug]
        
        # Normalize features
        node_features = (node_features - node_features.mean(axis=0)) / (node_features.std(axis=0) + 1e-8)
        
        # Build edges from DDI training_data
        edge_list = []
        edge_labels = []
        
        for _, row in self.ddi_df.iterrows():
            d1, d2, interaction_type = row['d1'], row['d2'], row['type']
            
            if d1 in self.drug_to_idx and d2 in self.drug_to_idx:
                idx1 = self.drug_to_idx[d1]
                idx2 = self.drug_to_idx[d2]
                
                # Add bidirectional edges
                edge_list.append([idx1, idx2])
                edge_list.append([idx2, idx1])
                edge_labels.append(interaction_type)
                edge_labels.append(interaction_type)
        
        edge_index = torch.tensor(edge_list, dtype=torch.long).t()
        
        # Encode labels
        edge_labels = self.label_encoder.fit_transform(edge_labels)
        edge_labels = torch.tensor(edge_labels, dtype=torch.long)
        
        print(f"‚Ä¢ Number of edges: {len(edge_list)}")
        print(f"‚Ä¢ Number of interaction types: {len(self.label_encoder.classes_)}")
        print(f"‚Ä¢ Feature dimension: {feature_dim}")
        
        return (
            torch.FloatTensor(node_features),
            edge_index,
            edge_labels,
            len(self.label_encoder.classes_)
        )

In [7]:
class GraphAttentionLayer(nn.Module):
    """
    Graph Attention Layer Implementation 
    Args:
        in_features: Number of input features per node (F)
        out_features: Number of output features per node (F')
        n_heads: Number of attention heads (K)
        is_concat: Whether to concatenate or average multi-head results
        dropout: Dropout probability for regularization
        leaky_relu_negative_slope: Negative slope for LeakyReLU activation
    """
    
    def __init__(self, in_features: int, out_features: int, n_heads: int=8, 
                 is_concat: bool = True, dropout: float = 0.6, 
                 leaky_relu_negative_slope: float = 0.2):
        super().__init__()
        
        self.is_concat = is_concat
        self.n_heads = n_heads
        
        # Calculate dimensions per head
        if is_concat:
            assert out_features % n_heads == 0, "out_features must be divisible by n_heads when concatenating"
            self.n_hidden = out_features // n_heads
        else:
            self.n_hidden = out_features
        
        # Linear transformation: W in the paper
        self.linear = nn.Linear(in_features, self.n_hidden * n_heads, bias=False)
        
        # Attention parameters (a^T) (one per head)
        self.attn_src = nn.Parameter(torch.Tensor(1, n_heads, self.n_hidden))
        self.attn_dst = nn.Parameter(torch.Tensor(1, n_heads, self.n_hidden))
        
        # Activation and normalization
        self.activation = nn.LeakyReLU(negative_slope=leaky_relu_negative_slope)
        self.softmax = nn.Softmax(dim=1)
        self.dropout = nn.Dropout(dropout)
        
        # Initialize weights using Xavier uniform
        self._init_weights()
    
    def _init_weights(self):
        """Initialize layer weights using Xavier uniform initialization"""
        nn.init.xavier_uniform_(self.linear.weight)
        nn.init.xavier_uniform_(self.attn_src)
        nn.init.xavier_uniform_(self.attn_dst)
    
    def forward(self, h: torch.Tensor, adj_mat: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the attention layer
        Args:
            h: Node embeddings [n_nodes, in_features]
            adj_mat: Adjacency matrix [n_nodes, n_nodes, n_heads]
        Returns: Updated node embeddings [n_nodes, out_features]
        """
        n_nodes = h.shape[0]
        device = h.device
        
        # Step 1: Linear transformation g_i = W * h_i for each head
        g = self.linear(h).view(n_nodes, self.n_heads, self.n_hidden)

        # Extract edge indices from adjacency matrix (ony non-zero entries)
        adj_2d = adj_mat.squeeze(-1)
        edge_index = adj_2d.nonzero(as_tuple=False).t()

        if edge_index.shape[1] == 0:
            # No edges case
            if self.is_concat:
                return torch.zeros(n_nodes, self.n_heads * self.n_hidden, device=device)
            else:
                return torch.zeros(n_nodes, self.n_hidden, device=device)
        
        src_nodes = edge_index[0]
        dst_nodes = edge_index[1]
        
        # Get features for source and destination nodes
        g_src = g[src_nodes]  # [n_edges, n_heads, n_hidden]
        g_dst = g[dst_nodes]  # [n_edges, n_heads, n_hidden]
        
        # Compute attention scores (edge-wise, memory efficient)
        e_src = (g_src * self.attn_src).sum(dim=-1)  # [n_edges, n_heads]
        e_dst = (g_dst * self.attn_dst).sum(dim=-1)  # [n_edges, n_heads]
        e = self.activation(e_src + e_dst)  # [n_edges, n_heads]
        
        # Softmax normalization per destination node
        # Group by destination node for proper normalization
        alpha = torch.zeros_like(e)
        # Process each head separately to save memory
        for head in range(self.n_heads):
            e_head = e[:, head]
            
            # Use scatter operations instead of loops
            # Create index for scatter
            max_dst = dst_nodes.max().item() + 1
            
            # Compute max for numerical stability
            max_vals = torch.full((max_dst,), float('-inf'), device=device)
            max_vals.scatter_reduce_(0, dst_nodes, e_head, reduce='amax', include_self=False)
            
            # Subtract max and exp
            e_stable = e_head - max_vals[dst_nodes]
            exp_e = torch.exp(e_stable)
            
            # Sum exp values per destination
            sum_exp = torch.zeros(max_dst, device=device)
            sum_exp.scatter_add_(0, dst_nodes, exp_e)
            
            # Normalize
            alpha[:, head] = exp_e / (sum_exp[dst_nodes] + 1e-16)
        
        # Aggregate features using attention weights
        out = torch.zeros(n_nodes, self.n_heads, self.n_hidden, device=device)
        
        # Efficient aggregation using scatter_add
        for head in range(self.n_heads):
            # Weighted features for this head
            weighted_features = g_src[:, head, :] * alpha[:, head].unsqueeze(-1)
            
            # Aggregate to destination nodes
            out[:, head, :].scatter_add_(0, 
                                         dst_nodes.unsqueeze(-1).expand(-1, self.n_hidden),
                                         weighted_features)
        
        # Concatenate or average heads
        if self.is_concat:
            return out.reshape(n_nodes, self.n_heads * self.n_hidden)
        else:
            return out.mean(dim=1)


In [8]:
class GAT_DDI(nn.Module):
    """Memory-efficient GAT with Mixed Precision support using PyTorch Geometric's GATv2Conv"""
    
    def __init__(self, n_features, n_hidden, n_classes, n_heads=8, dropout=0.3):
        super().__init__()
        
        self.dropout = dropout
        # Replace custom GAT with optimized GATv2Conv
        self.gat1 = GATv2Conv(n_features, n_hidden, heads=n_heads, dropout=dropout, concat=True)
        self.gat2 = GATv2Conv(n_hidden * n_heads, n_hidden, heads=n_heads, dropout=dropout, concat=True)
        self.gat3 = GATv2Conv(n_hidden * n_heads, n_hidden // 2, heads=n_heads, dropout=dropout, concat=True)
        
        # Input: concatenated embeddings from src and dst nodes
        edge_input_dim = (n_hidden // 2) * n_heads * 2  # *2 because we concatenate src and dst
        
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_input_dim, n_hidden),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_hidden, n_hidden // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_hidden // 2, n_classes)
        )
        
        self.dropout_layer = nn.Dropout(dropout)
    
    def forward(self, x, edge_index, edge_index_for_prediction):
        """
        Forward pass using edge_index directly (no adjacency matrix)
        Args:
            x: Node features [n_nodes, n_features]
            edge_index: Graph connectivity [2, n_edges] - for message passing
            edge_index_for_prediction: Edges to classify [2, n_prediction_edges]
        Returns: Edge predictions [n_prediction_edges, n_classes]
        """
        # FP16 Ensure inputs are float for mixed precision
        x = x.float()
        # Layer 1
        x = self.dropout_layer(x)
        x = F.elu(self.gat1(x, edge_index))
        # Layer 2
        x = self.dropout_layer(x)
        x = F.elu(self.gat2(x, edge_index))
        # Layer 3
        x = self.dropout_layer(x)
        x = F.elu(self.gat3(x, edge_index))
        # Edge classification
        src = edge_index_for_prediction[0]
        dst = edge_index_for_prediction[1]
        edge_features = torch.cat([x[src], x[dst]], dim=1)
        
        return self.edge_mlp(edge_features)

## 5. Training Pipeline

Our training pipeline includes:
- Early stopping to prevent overfitting
- Learning rate scheduling
- Comprehensive logging
- Model checkpointing

In [None]:
def prepare_ddi_data():
    """Load and prepare DDI data"""
    print("_"*40)
    print("LOADING DRUG-DRUG INTERACTION DATA")
    print("-"*40)
    
    # Load data
    ddi_df = pd.read_csv('dataset/drugdata/ddis.csv')
    smiles_df = pd.read_csv('dataset/drugdata/drug_smiles.csv')
    
    print(f"‚úì Loaded {len(ddi_df)} DDI pairs")
    print(f"‚úì Loaded {len(smiles_df)} drug SMILES")
    print("_"*40)
    
    # Extract features
    feature_extractor = DrugFeatureExtractor()
    drug_features = feature_extractor.extract_all_features(smiles_df)
    
    # Build graph
    graph_builder = DDIGraphBuilder(ddi_df, drug_features)
    node_features, edge_index, edge_labels, n_classes = graph_builder.build_graph()
    
    # Create adjacency matrix
    n_nodes = node_features.shape[0]
    adj_mat = torch.zeros(n_nodes, n_nodes, 1)
    adj_mat[edge_index[0], edge_index[1], 0] = 1
    
    # Add self-loops
    adj_mat[range(n_nodes), range(n_nodes), 0] = 1
    
    # Split data
    n_edges = edge_index.shape[1]
    indices = np.arange(n_edges)
    
    train_idx, temp_idx = train_test_split(indices, test_size=0.3, random_state=42)
    val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, random_state=42)
    
    train_mask = torch.zeros(n_edges, dtype=torch.bool)
    val_mask = torch.zeros(n_edges, dtype=torch.bool)
    test_mask = torch.zeros(n_edges, dtype=torch.bool)
    
    train_mask[train_idx] = True
    val_mask[val_idx] = True
    test_mask[test_idx] = True
    
    print(f"\nüõà Data Split:")
    print(f"‚Ä¢ Train: {train_mask.sum()} edges ({train_mask.sum()/n_edges*100:.1f}%)")
    print(f"‚Ä¢ Val: {val_mask.sum()} edges ({val_mask.sum()/n_edges*100:.1f}%)")
    print(f"‚Ä¢ Test: {test_mask.sum()} edges ({test_mask.sum()/n_edges*100:.1f}%)")
    
    return {
        'node_features': node_features,
        'edge_index': edge_index,
        'edge_labels': edge_labels,
        'adj_mat': adj_mat,
        'train_mask': train_mask,
        'val_mask': val_mask,
        'test_mask': test_mask,
        'n_classes': n_classes,
        'label_encoder': graph_builder.label_encoder
    }

In [None]:
# Helper: Save state_dict as .safetensors + JSON metadata

def save_state_dict_as_safetensors(state_dict, path_base, metadata=None):
    """Save a state_dict (CPU tensors) as .safetensors and metadata as .json
    Args:
        state_dict: dict of tensors (preferably on CPU)
        path_base: e.g. 'models/best_GAT_DDI_01_Feb_12-00' (no extension)
        metadata: dict (optional) to save alongside as .json
    """
    safetensor_path = f"{path_base}.safetensors"
    json_path = f"{path_base}.json"
    # Ensure CPU tensors
    cpu_state = {k: v.cpu() for k, v in state_dict.items()}

    try:
        save_file(cpu_state, safetensor_path)
        print(f"‚úì Saved safetensors: {safetensor_path}")
    except Exception as e:
        print(f"‚úó Failed to save safetensors: {e}")
        raise

    if metadata is not None:
        try:
            with open(json_path, 'w') as f:
                json.dump(metadata, f, indent=2)
            print(f"‚úì Saved metadata: {json_path}")
        except Exception as e:
            print(f"‚úó Failed to save metadata JSON: {e}")


# Example usage
# save_state_dict_as_safetensors(model.state_dict(), f'models/end_GAT_DDI_{timestamp}', metadata={'config':config, 'timestamp':timestamp})

In [None]:
# Helper: Load state_dict from .safetensors or .pth and load into model

def load_state_dict_from_file(path_base, device='cpu'):
    """Load state dict from .safetensors (preferred) or fallback to .pth.

    Returns a dict suitable for model.load_state_dict()
    """
    safetensor_path = f"{path_base}.safetensors"
    pth_path = f"{path_base}.pth"

    if os.path.exists(safetensor_path):
        state = load_file(safetensor_path)
        # move to device
        state = {k: v.to(device) for k, v in state.items()}
        print(f"‚úì Loaded safetensors from {safetensor_path}")
        return state
    elif os.path.exists(pth_path):
        state = torch.load(pth_path, map_location=device)
        print(f"‚úì Loaded pth state from {pth_path}")
        return state
    else:
        raise FileNotFoundError(f"No model file found at {safetensor_path} or {pth_path}")


def load_model_from_file(model, path_base, device='cpu'):
    state = load_state_dict_from_file(path_base, device=device)
    model.load_state_dict(state)
    model.to(device)
    print(f"‚úì Model weights loaded into model and moved to {device}")

## 6. Model Evaluation Pipeline

In [None]:
class ModelEvaluator:
    """Comprehensive evaluation class for GAT models"""
    
    def __init__(self, model: GAT_DDI, device: torch.device):
        self.model = model.to(device)
        self.device = device
    
    def test(self, data):
        """Comprehensive testing with multiple metrics"""
        self.model.eval()
        with torch.no_grad():
            test_edge_index = data['edge_index'][:, data['test_mask']]
            out = self.model(data['node_features'].to(self.device), data['edge_index'].to(self.device), test_edge_index.to(self.device))
            
            pred = out.max(1)[1]
            y_true = data['edge_labels'][data['test_mask']].cpu().numpy()
            y_pred = pred.cpu().numpy()
            
            unique_labels = np.unique(np.concatenate([y_true, y_pred]))
            
            # Calculate metrics
            test_acc = accuracy_score(y_true, y_pred)
            f1 = f1_score(y_true, y_pred, average='weighted')
            
            class_report = classification_report(y_true, y_pred, 
                                                labels=unique_labels,
                                                output_dict=True,
                                                zero_division=0)
            conf_matrix = confusion_matrix(y_true, y_pred, labels=unique_labels)
            
            return {
                'accuracy': test_acc,
                'f1_score': f1, 
                'classification_report': class_report,
                'confusion_matrix': conf_matrix,
                'predictions': y_pred,
                'true_labels': y_true
            }
    
    def visualize_embeddings(self, data, save_path=f'images/GAT_visualize_embeddings_{timestamp}.png'):
        """Visualize node embeddings using t-SNE (computes node embeddings from model GAT layers)"""
        self.model.eval()
        with torch.no_grad():
            x = data['node_features'].to(self.device).float()
            edge_index = data['edge_index'].to(self.device)

            # Compute embeddings using model's GAT layers
            h = self.model.dropout_layer(x)
            h = F.elu(self.model.gat1(h, edge_index))
            h = self.model.dropout_layer(h)
            h = F.elu(self.model.gat2(h, edge_index))
            h = self.model.dropout_layer(h)
            h = F.elu(self.model.gat3(h, edge_index))

            embeddings = h.cpu().numpy()

            print("Computing t-SNE embeddings...")
            tsne = TSNE(n_components=2, random_state=42, perplexity=30)
            embeddings_2d = tsne.fit_transform(embeddings)

            # Use node degree for coloring as an informative scalar
            ei = data['edge_index'].cpu().numpy()
            degrees = np.zeros(embeddings.shape[0])
            for u in ei[0]:
                degrees[u] += 1

            plt.figure(figsize=(12, 8))
            scatter = plt.scatter(embeddings_2d[:, 0], embeddings_2d[:, 1], 
                                c=degrees, cmap='viridis', alpha=0.8, s=50)
            cbar = plt.colorbar(scatter)
            cbar.set_label('Node degree')
            plt.title('GAT Node Embeddings Visualization (t-SNE)', fontsize=16, fontweight='bold')
            plt.xlabel('t-SNE Dimension 1', fontsize=12)
            plt.ylabel('t-SNE Dimension 2', fontsize=12)
            plt.grid(True, alpha=0.3)
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.show()
            print(f"‚úì Saved embeddings visualization to {save_path}")

## 7. Model Training Visualization

In [None]:
def plot_training_history(history, save_path=f'images/DDI_training_history_{timestamp}.png'):
    """Plot training history"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Loss
    axes[0].plot(history['train_loss'], label='Train', linewidth=2)
    axes[0].plot(history['val_loss'], label='Validation', linewidth=2)
    axes[0].set_title('Loss', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(history['train_acc'], label='Train', linewidth=2)
    axes[1].plot(history['val_acc'], label='Validation', linewidth=2)
    axes[1].set_title('Accuracy', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # F1-Score
    axes[2].plot(history['train_f1'], label='Train', linewidth=2)
    axes[2].plot(history['val_f1'], label='Validation', linewidth=2)
    axes[2].set_title('F1-Score', fontsize=14, fontweight='bold')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('F1-Score')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"‚úì Saved training history to {save_path}")

## Model Training Step

In [None]:
# load data
training_data = prepare_ddi_data()

# Remove unused adjacency matrix to save VRAM
del training_data['adj_mat']
gc.collect()
torch.cuda.empty_cache()

# Move to device
for key in ['node_features', 'edge_index', 'edge_labels', 
            'train_mask', 'val_mask', 'test_mask']:
    training_data[key] =    training_data[key].to(device)

# Model configuration
config = {
    'n_features': training_data['node_features'].shape[1],
    'n_hidden': 128, # 256 OR 128
    'n_classes': training_data['n_classes'],
    'n_heads': 4, # 4 OR 8
    'dropout': 0.3
}
print(f"\n‚õØ Model Configuration:")
for key, value in config.items():
    print(f"‚Ä¢ {key}: {value}")


In [None]:
# Alias for compatibility with older cells
data = training_data

### Usage examples (Sanity check + Saving)

Quick usage examples ‚Äî run the data sanity checks before feature extraction and call the trainer with metadata saving enabled.

In [None]:
# Run sanity checks BEFORE prepare_ddi_data()
report = sanity_check_dataset()

# If you want to drop invalid smiles permanently:
# smiles_df = drop_invalid_smiles(pd.read_csv('dataset/drugdata/drug_smiles.csv'))
# smiles_df.to_csv('dataset/drugdata/drug_smiles.cleaned.csv', index=False)

# Example: include config metadata when training/saving
# trainer.train(..., save_metadata={'config': config, 'notes': 'GATv2 experiment'})

print('\n(To save models as .safetensors automatically during training, pass save_metadata to trainer.train)')

In [10]:
# Initialize model
model_config = GAT_DDI(**config)
trainer = DDITrainer(model_config, device)

# Training Configuration: Train with smaller batch size
history = trainer.train(
    training_data['node_features'], 
    training_data['edge_index'],
    training_data['edge_labels'], 
    training_data['train_mask'], 
    training_data['val_mask'],
    epochs=5, 
    lr=0.001, 
    weight_decay=5e-4, 
    patience=30,
    batch_size=256  # 800 OR 512 OR 256
)

# Plot history
plot_training_history(history)

NameError: name 'config' is not defined

In [None]:
# Make a convenience alias for backward compatibility
model = model_config

In [None]:
# Load best model and evaluate (prefer .safetensors; fallback to .pth)

# Create fresh model instance (or reuse existing model_config)
model = GAT_DDI(**config)
try:
    load_model_from_file(model, os.path.join('models', f'best_GAT_DDI_{timestamp}'), device=device)
except FileNotFoundError:
    print("‚ö† Best model file not found. Using in-memory model weights (if trained).")

# Evaluate
model.to(device)
evaluator = ModelEvaluator(model, device)

# Ensure training_data variable is used consistently
data = training_data

test_results = evaluator.test(data)
print(f"üìã Test Result Accuracy: {test_results['accuracy']:.4f}")
print("\nüñÉ Classification Report:")
print(classification_report(test_results['true_labels'], test_results['predictions']))

In [None]:
# Save final model as .safetensors (example)
final_base = os.path.join('models', f'final_GAT_DDI_{timestamp}')
try:
    state_dict_to_save = model.state_dict() if 'model' in globals() else model_config.state_dict()
    save_state_dict_as_safetensors(state_dict_to_save, final_base, metadata={'config': config, 'timestamp': timestamp})
    print(f"‚úì Final model saved to {final_base}.safetensors")
except Exception as e:
    print(f"‚úó Failed to save final model: {e}")

In [None]:
# Evaluating Model (load best saved weights if available, otherwise use in-memory model)

# Use the trained model instance
model = model_config
try:
    load_model_from_file(model, os.path.join('models', f'best_GAT_DDI_{timestamp}'), device=device)
except FileNotFoundError:
    print("‚ö† No saved best model found; using current in-memory model weights")

# Evaluate
evaluator = ModelEvaluator(model, device)
test_results = evaluator.test(training_data)

print(f"üìã Test Result Accuracy: {test_results['accuracy']:.4f}")
print("\nüñÉ Classification Report:")
print(classification_report(test_results['true_labels'], test_results['predictions']))

In [None]:
print("\nüñÉ Classification Report:")

## 8. Model Evaluation and Analysis

In [None]:
test_results = evaluator.test(training_data)
test_results

In [None]:
# Visualize node embeddings
# Uses model (should be loaded/assigned before calling)
evaluator = ModelEvaluator(model, device)
evaluator.visualize_embeddings(training_data)

In [None]:
def plot_confusion_matrix_top_classes(test_results, top_n=10, save_path=f'images/GAT_confusion_matrix_{timestamp}.png'):
    """
    Plot confusion matrix for top N most frequent classes only
    """
    y_true = test_results['true_labels']
    y_pred = test_results['predictions']
    
    # Get top N most frequent classes in test set
    unique, counts = np.unique(y_true, return_counts=True)
    top_classes = unique[np.argsort(counts)[-top_n:]][::-1]  # Top N by frequency
    
    # Filter predictions to only include top classes
    mask = np.isin(y_true, top_classes)
    y_true_filtered = y_true[mask]
    y_pred_filtered = y_pred[mask]
    
    # Create confusion matrix for top classes only
    conf_matrix = confusion_matrix(y_true_filtered, y_pred_filtered, labels=top_classes)
    
    # Create class names
    class_names = [f'Type {i}' for i in top_classes]
    
    # Plot with larger figure
    plt.figure(figsize=(12, 10))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Count'}, linewidths=0.5)
    
    plt.title(f'Confusion Matrix - Top {top_n} Interaction Types', fontsize=16, fontweight='bold')
    plt.xlabel('Predicted Class', fontsize=12)
    plt.ylabel('True Class', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"‚úì Saved confusion matrix for top {top_n} classes to {save_path}")
    print(f"  Top classes: {top_classes}")

# Plot top 10 most frequent interaction types
plot_confusion_matrix_top_classes(test_results, top_n=10)


In [None]:
def plot_normalized_confusion_matrix(test_results, top_n=15, save_path=f'images/GAT_confusion_matrix_normalized_{timestamp}.png'):
    """
    Plot normalized confusion matrix (percentages) for top N classes
    """
    y_true = test_results['true_labels']
    y_pred = test_results['predictions']
    
    # Get top N classes
    unique, counts = np.unique(y_true, return_counts=True)
    top_classes = unique[np.argsort(counts)[-top_n:]][::-1]
    
    # Filter
    mask = np.isin(y_true, top_classes)
    y_true_filtered = y_true[mask]
    y_pred_filtered = y_pred[mask]
    
    # Normalized confusion matrix
    conf_matrix = confusion_matrix(y_true_filtered, y_pred_filtered, labels=top_classes)
    conf_matrix_norm = conf_matrix.astype('float') / conf_matrix.sum(axis=1)[:, np.newaxis]
    
    class_names = [f'Type {i}' for i in top_classes]
    
    # Plot
    plt.figure(figsize=(14, 12))
    sns.heatmap(conf_matrix_norm, annot=True, fmt='.2f', cmap='RdYlGn',
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Proportion'}, linewidths=0.5,
                vmin=0, vmax=1)
    
    plt.title(f'Normalized Confusion Matrix - Top {top_n} Types', fontsize=16, fontweight='bold')
    plt.xlabel('Predicted Class', fontsize=12)
    plt.ylabel('True Class', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"‚úì Saved normalized confusion matrix to {save_path}")

# Plot normalized matrix
plot_normalized_confusion_matrix(test_results, top_n=15)


In [None]:
def plot_confusion_summary(test_results, save_path=f'images/GAT_confusion_summary_{timestamp}.png'):
    """
    Plot summary statistics instead of full confusion matrix
    """
    y_true = test_results['true_labels']
    y_pred = test_results['predictions']
    
    # Calculate per-class accuracy
    unique_classes = np.unique(y_true)
    class_accuracies = []
    class_counts = []
    
    for cls in unique_classes:
        mask = y_true == cls
        if mask.sum() > 0:
            acc = (y_pred[mask] == cls).sum() / mask.sum()
            class_accuracies.append(acc)
            class_counts.append(mask.sum())
        else:
            class_accuracies.append(0)
            class_counts.append(0)
    
    # Sort by frequency
    sorted_indices = np.argsort(class_counts)[::-1][:20]  # Top 20
    
    # Plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Per-class accuracy
    ax1.barh([f'Type {unique_classes[i]}' for i in sorted_indices],
             [class_accuracies[i] for i in sorted_indices],
             color='steelblue')
    ax1.set_xlabel('Accuracy', fontsize=12)
    ax1.set_title('Per-Class Accuracy (Top 20)', fontsize=14, fontweight='bold')
    ax1.grid(axis='x', alpha=0.3)
    
    # Class distribution
    ax2.barh([f'Type {unique_classes[i]}' for i in sorted_indices],
             [class_counts[i] for i in sorted_indices],
             color='coral')
    ax2.set_xlabel('Sample Count', fontsize=12)
    ax2.set_title('Class Distribution (Top 20)', fontsize=14, fontweight='bold')
    ax2.grid(axis='x', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"‚úì Saved confusion summary to {save_path}")

# Plot summary
plot_confusion_summary(test_results)


In [None]:
# Analyze model performance
def analyze_results(history, test_results, data, model):
    """
    Provide comprehensive analysis of model performance
    """
    print("üìà MODEL PERFORMANCE ANALYSIS")
    print("="*60)
    
    # Training metrics
    print(f"\nTraining Metrics:")
    print(f"‚Ä¢ Final Training Accuracy: {history['train_acc'][-1]:.4f}")
    print(f"‚Ä¢ Final Dev / Validation Accuracy: {history['val_acc'][-1]:.4f}")
    print(f"‚Ä¢ Best Validation Accuracy: {max(history['val_acc']):.4f}")
    print(f"‚Ä¢ Final Training F1: {history['train_f1'][-1]:.4f}")
    print(f"‚Ä¢ Final Validation F1: {history['val_f1'][-1]:.4f}")
    print(f"‚Ä¢ Best Validation F1: {max(history['val_f1']):.4f}")
    print(f"‚Ä¢ Total Epochs: {len(history['train_loss'])}")
    
    # Test metrics
    print(f"\nTest Metrics:")
    print(f"‚Ä¢ Test Accuracy: {test_results['accuracy']:.4f}")
    if 'f1_score' in test_results:
        print(f"‚Ä¢ Test F1-Score: {test_results['f1_score']:.4f}")
    
    # Per-class performance (only for classes in test set)
    print(f"\nüìä Per-Class Performance:")
    unique_labels = np.unique(test_results['true_labels'])
    
    for label in unique_labels:
        label_str = str(label)
        if label_str in test_results['classification_report']:
            report = test_results['classification_report'][label_str]
            precision = report['precision']
            recall = report['recall']
            f1 = report['f1-score']
            print(f"   ‚Ä¢ Type {label}: Precision={precision:.3f}, Recall={recall:.3f}, F1={f1:.3f}")
    
    # Model complexity
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nüîß Model Complexity:")
    print(f"   ‚Ä¢ Total Parameters: {total_params:,}")
    print(f"   ‚Ä¢ Trainable Parameters: {trainable_params:,}")
    print(f"   ‚Ä¢ Model Size: ~{total_params * 4 / 1024 / 1024:.2f} MB")
    
    # Dataset info
    print(f"\nüìä Dataset Info:")
    print(f"   ‚Ä¢ Total Drugs (Nodes): {data['node_features'].shape[0]}")
    print(f"   ‚Ä¢ Total Interactions (Edges): {data['edge_index'].shape[1]}")
    print(f"   ‚Ä¢ Feature Dimension: {data['node_features'].shape[1]}")
    print(f"   ‚Ä¢ Number of Interaction Types: {data['n_classes']}")
    print(f"   ‚Ä¢ Classes in Test Set: {len(unique_labels)}")

# Run analysis
analyze_results(history, test_results, data, model)


## Saving Model Logs and Tracks

In [None]:
# 1. Prepare the new log entry with rounded values
current_log = {
    "date": timestamp,
    "graph_details": {
        "num_nodes": data['node_features'].shape[0],  # ‚úÖ ADD: Number of drugs
        "num_edges": data['edge_index'].shape[1],  # ‚úÖ FIX: Total edges (not just first row)
        "num_interaction_types": data['n_classes'],
        "feature_dimension": data['node_features'].shape[1]
    },
    "config": config,
    "runtime_info": {
        "device": str(device),
        "total_parameters": sum(p.numel() for p in model.parameters()),
        "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad),  # ‚úÖ ADD
        "model_size_mb": round(sum(p.numel() for p in model.parameters()) * 4 / 1024 / 1024, 2),  # ‚úÖ ADD
        "training_edges": int(data['train_mask'].sum()),
        "validation_edges": int(data['val_mask'].sum()),
        "test_edges": int(data['test_mask'].sum())  # ‚úÖ ADD: Test set size
    },
    "training_info": {  # ‚úÖ ADD: Training details
        "epochs_completed": len(history['train_loss']),
        "best_epoch": history['val_f1'].index(max(history['val_f1'])) + 1,
        "learning_rate": 0.001,  # From your training config
        "batch_size": 7000,
        "weight_decay": 5e-4
    },
    "metrics": {
        # Training metrics (rounded to 3 decimals)
        "train_loss": round(history['train_loss'][-1], 3),  # ‚úÖ ROUNDED
        "train_acc": round(history['train_acc'][-1], 3),    # ‚úÖ ROUNDED
        "train_f1": round(history['train_f1'][-1], 3),      # ‚úÖ ROUNDED
        
        # Validation metrics (rounded to 3 decimals)
        "val_loss": round(history['val_loss'][-1], 3),      # ‚úÖ ROUNDED
        "val_acc": round(history['val_acc'][-1], 3),        # ‚úÖ ROUNDED
        "val_f1": round(history['val_f1'][-1], 3),          # ‚úÖ ROUNDED
        
        # Best validation metrics
        "best_val_acc": round(max(history['val_acc']), 3),  # ‚úÖ ADD
        "best_val_f1": round(max(history['val_f1']), 3),    # ‚úÖ ADD
        
        # Test metrics (if available)
        "test_acc": round(test_results['accuracy'], 3) if 'test_results' in locals() else None,  # ‚úÖ ADD
        "test_f1": round(test_results['f1_score'], 3) if 'test_results' in locals() and 'f1_score' in test_results else None  # ‚úÖ ADD
    }
}

# 2. File handling: Load existing data or create a new list
log_file = 'model_evals_log.json'

if os.path.exists(log_file):
    with open(log_file, 'r') as f:
        try:
            logs_list = json.load(f)
            if not isinstance(logs_list, list):
                logs_list = []
        except json.JSONDecodeError:
            logs_list = []
else:
    logs_list = []

# 3. Insert new log at the TOP (index 0)
logs_list.insert(0, current_log)

# 4. Save back to file with nice formatting
with open(log_file, 'w') as f:
    json.dump(logs_list, f, indent=4)

print(f"(‚úì) Log successfully saved to {log_file}!")