In [None]:
# Install required packages
!pip install torch torch-geometric scikit-learn rdkit pandas -q

[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m63.7/63.7 kB[0m [31m3.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m1.3/1.3 MB[0m [31m35.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m36.2/36.2 MB[0m [31m49.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from tqdm import tqdm
import os

# RDKit for SMILES fingerprints
from rdkit import Chem
from rdkit.Chem import AllChem, MACCSkeys
from rdkit import DataStructs


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [None]:

DATA_PATH = '/content/drive/MyDrive/GAT/data/'
DATA_PATH_SMILES = '/content/drive/MyDrive/GAT/'
print("Loading data from Google Drive...")

# Load CSV files
train_df = pd.read_csv(DATA_PATH + 'train_positive.csv')
val_df = pd.read_csv(DATA_PATH + 'val_positive.csv')
test_df = pd.read_csv(DATA_PATH + 'test_positive.csv')

print(f"\nTrain data preview:")
print(train_df.head())
print(f"\nColumns: {train_df.columns.tolist()}")

Loading data from Google Drive...

Train data preview:
  Drug1_ID Drug2_ID  Label
0  DB01097  DB05219     47
1  DB00547  DB00784     49
2  DB00623  DB01365     61
3  DB00328  DB09027     73
4  DB00742  DB00955     57

Columns: ['Drug1_ID', 'Drug2_ID', 'Label']


In [None]:

drug_smiles_df = pd.read_csv(DATA_PATH_SMILES + 'Drugs_with_Smiles.csv')

print(f"\nDrug SMILES loaded: {len(drug_smiles_df)} drugs")
print(drug_smiles_df.head())


Drug SMILES loaded: 1709 drugs
  DrugBank_ID                                             SMILES
0     DB00006  CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@...
1     DB00014  CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H...
2     DB00027  CC(C)C[C@@H](NC(=O)CNC(=O)[C@@H](NC=O)C(C)C)C(...
3     DB00035  NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)...
4     DB00080  CCCCCCCCCC(=O)N[C@@H](CC1=CNC2=C1C=CC=C2)C(=O)...


## We use Morgan Fingerprint (2, 512)

In [1]:
# fingerprint type
FINGERPRINT_TYPE = 'morgan'
N_BITS = 512


def extract_fingerprint(smiles, fp_type='morgan', radius=2, n_bits=512):

    try:
        mol = Chem.MolFromSmiles(smiles)

        if mol is None:
            print(f"Warning: Invalid SMILES: {smiles}")
            # Return zero vector for invalid SMILES
            return np.zeros(n_bits if fp_type == 'morgan' else 166)

        if fp_type == 'morgan':
            # Morgan fingerprint (ECFP-like)
            fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=radius, nBits=n_bits)
            arr = np.zeros((n_bits,))
        elif fp_type == 'maccs':
            # MACCS keys (166-bit)
            fp = MACCSkeys.GenMACCSKeys(mol)
            arr = np.zeros((166,))
        else:
            raise ValueError("fp_type must be 'morgan' or 'maccs'")

        DataStructs.ConvertToNumpyArray(fp, arr)
        return arr

    except Exception as e:
        print(f"Error processing SMILES {smiles}: {e}")
        return np.zeros(n_bits if fp_type == 'morgan' else 166)




print(f"\nüî¨ Extracting {FINGERPRINT_TYPE} fingerprints...")

# Extract fingerprints for all drugs
drug_id_to_smiles = dict(zip(drug_smiles_df['DrugBank_ID'], drug_smiles_df['SMILES']))

# Get all unique drugs from train/val/test
all_drug_ids = set()
for df in [train_df, val_df, test_df]:
    all_drug_ids.update(df['Drug1_ID'].values)
    all_drug_ids.update(df['Drug2_ID'].values)

all_drug_ids = sorted(list(all_drug_ids))
drug_to_idx = {drug_id: idx for idx, drug_id in enumerate(all_drug_ids)}
idx_to_drug = {idx: drug_id for drug_id, idx in drug_to_idx.items()}

print(f"Total unique drugs: {len(all_drug_ids)}")

# Extract fingerprints
fingerprints = []
for drug_id in tqdm(all_drug_ids, desc="Extracting fingerprints"):
    smiles = drug_id_to_smiles.get(drug_id, None)
    if smiles is None:
        print(f"Warning: No SMILES found for drug {drug_id}")
        fp = np.zeros(N_BITS if FINGERPRINT_TYPE == 'morgan' else 166)
    else:
        fp = extract_fingerprint(smiles, fp_type=FINGERPRINT_TYPE, n_bits=N_BITS)
    fingerprints.append(fp)

# Convert to tensor
drug_features = torch.FloatTensor(np.array(fingerprints))

print(f"\nDrug features extracted:")
print(f"   - Shape: {drug_features.shape}")
print(f"   - Type: {FINGERPRINT_TYPE} fingerprints")
print(f"   - This is a FAIR baseline (not using our chemical embeddings!)")


üî¨ Extracting morgan fingerprints...


NameError: name 'drug_smiles_df' is not defined

In [None]:
# Convert drug pairs to indices and get types
def df_to_tensors(df, drug_to_idx):
    pairs = []
    types = []
    for _, row in df.iterrows():
        drug1_idx = drug_to_idx[row['Drug1_ID']]
        drug2_idx = drug_to_idx[row['Drug2_ID']]
        interaction_type = int(row['Label'])

        pairs.append([drug1_idx, drug2_idx])
        types.append(interaction_type)

    return torch.tensor(pairs, dtype=torch.long), torch.tensor(types, dtype=torch.long)

train_pairs, train_types = df_to_tensors(train_df, drug_to_idx)
val_pairs, val_types = df_to_tensors(val_df, drug_to_idx)
test_pairs, test_types = df_to_tensors(test_df, drug_to_idx)

# ADD THIS CONVERSION:
print("\nConverting labels to 0-based indexing for PyTorch...")
print(f"Original label range: {train_types.min().item()} to {train_types.max().item()}")

print(f"Original label range: {train_types.min()} to {train_types.max()}")

# If labels are 1-86, convert to 0-85
if train_types.min() == 1:
    train_types = train_types - 1
    val_types = val_types - 1
    test_types = test_types - 1
    print(f"‚úÖ Converted labels from 1-86 to 0-85")
else:
    print("‚úÖ Labels already 0-based")

print(f"Final label range: {train_types.min()} to {train_types.max()}")

print(f"Converted label range: {train_types.min().item()} to {train_types.max().item()}")
print(f"Number of classes: {len(torch.unique(train_types))} (0-{train_types.max().item()})")

# Update num_types to match
num_types = len(torch.unique(train_types))
print(f"Final num_types: {num_types}")

# Get dimensions
num_drugs = len(drug_to_idx)
num_types = len(torch.unique(train_types))
feature_dim = drug_features.shape[1]

print(f"\nDataset Statistics:")
print(f"   - Drugs: {num_drugs}")
print(f"   - Feature dimension: {feature_dim}")
num_types = len(torch.unique(train_types))  # Should be 86
print(f"   - Interaction types: {num_types} (0-{num_types-1})")
print(f"   - Train pairs: {len(train_pairs)}")
print(f"   - Val pairs: {len(val_pairs)}")
print(f"   - Test pairs: {len(test_pairs)}")


Converting labels to 0-based indexing for PyTorch...
Original label range: 1 to 86
Original label range: 1 to 86
‚úÖ Converted labels from 1-86 to 0-85
Final label range: 0 to 85
Converted label range: 0 to 85
Number of classes: 86 (0-85)
Final num_types: 86

Dataset Statistics:
   - Drugs: 1709
   - Feature dimension: 512
   - Interaction types: 86 (0-85)
   - Train pairs: 153489
   - Val pairs: 19188
   - Test pairs: 19200


In [None]:
def build_graph(train_pairs, num_drugs):
    edges = set()

    for pair in train_pairs:
        drug_i = int(pair[0])
        drug_j = int(pair[1])

        # Add both directions (undirected graph)
        edges.add((drug_i, drug_j))
        edges.add((drug_j, drug_i))

    # Convert to tensor
    edge_list = list(edges)
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()

    print(f"\nGraph Statistics:")
    print(f"   - Nodes: {num_drugs}")
    print(f"   - Edges: {edge_index.shape[1]}")
    print(f"   - NO TYPE INFORMATION in graph!")

    return edge_index

# Build the graph
edge_index = build_graph(train_pairs, num_drugs)


Graph Statistics:
   - Nodes: 1709
   - Edges: 306978
   - NO TYPE INFORMATION in graph!


In [None]:
class GAT_Encoder(nn.Module):
    """SINGLE LAYER GAT Encoder - to match our HGNN"""
    def __init__(self, in_features, hidden_dim, out_dim, num_heads=4, dropout=0.3):
        super(GAT_Encoder, self).__init__()

        # SINGLE LAYER only - matching our HGNN
        self.gat1 = GATConv(in_features, out_dim, heads=num_heads, dropout=dropout, concat=True)

        # Projection to get the right output dimension
        self.projection = nn.Linear(out_dim * num_heads, out_dim)

        self.dropout = nn.Dropout(dropout)
        self.activation = nn.ELU()

    def forward(self, x, edge_index):
        # Single GAT layer
        x = self.gat1(x, edge_index)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.projection(x)  # Project to desired dimension

        return x

class MLP_Decoder(nn.Module):
    """Simplified MLP Decoder - closer to our MLPPredictor"""
    def __init__(self, embedding_dim, hidden_dim, num_types, dropout=0.3):
        super(MLP_Decoder, self).__init__()

        # Make sure num_types is 87
        print(f"MLP Decoder initialized with {num_types} classes")

        input_dim = embedding_dim * 2  # Concatenated pair

        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_types)  # Should output 87 classes
        )

    def forward(self, drug_i_emb, drug_j_emb):
        pair_emb = torch.cat([drug_i_emb, drug_j_emb], dim=1)
        logits = self.mlp(pair_emb)
        return logits


class GAT_MLP_Model(nn.Module):
    """Complete GAT + MLP Model"""
    def __init__(self, in_features, hidden_dim, embedding_dim, num_types, num_heads=4, dropout=0.3):
        super(GAT_MLP_Model, self).__init__()

        self.encoder = GAT_Encoder(in_features, hidden_dim, embedding_dim, num_heads, dropout)
        self.decoder = MLP_Decoder(embedding_dim, hidden_dim, num_types, dropout)

    def forward(self, x, edge_index, drug_i_idx, drug_j_idx):
        # Encode all drugs
        drug_embeddings = self.encoder(x, edge_index)

        # Get pair embeddings
        drug_i_emb = drug_embeddings[drug_i_idx]
        drug_j_emb = drug_embeddings[drug_j_idx]

        # Predict type
        logits = self.decoder(drug_i_emb, drug_j_emb)

        return logits

print("Model classes defined")

Model classes defined


In [None]:
# Matching our HGNN config
HIDDEN_DIM = 128          # Same as our 'hidden_units': 128
EMBEDDING_DIM = 128       # Same as hidden for consistency
NUM_HEADS =2              # Standard for GAT
DROPOUT = 0.5             # Same as our 'dropout': 0.5
BATCH_SIZE = 128          # Standard batch size
LEARNING_RATE = 0.005     # Same as our 'learning_rate': 0.005
WEIGHT_DECAY = 0.001       # Same as our 'weight_decay': 0.0
NUM_EPOCHS = 200          # Same as our epochs
SEED = 42                 # Same as our 'training_seed': 42

# Set random seed for reproducibility
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

print("Hyperparameters (matched to HGNN metabolic):")
print(f"   - Hidden dim: {HIDDEN_DIM}")
print(f"   - Embedding dim: {EMBEDDING_DIM}")
print(f"   - Dropout: {DROPOUT}")
print(f"   - Learning rate: {LEARNING_RATE}")
print(f"   - Weight decay: {WEIGHT_DECAY}")
print(f"   - Seed: {SEED}")

model = GAT_MLP_Model(
    in_features=512,
    hidden_dim=HIDDEN_DIM,
    embedding_dim=EMBEDDING_DIM,
    num_types=num_types,  # Should be 87 now
    num_heads=NUM_HEADS,
    dropout=DROPOUT
).to(device)

# Move data to device
drug_features = drug_features.to(device)
edge_index = edge_index.to(device)

# Optimizer and loss
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY
)
criterion = nn.CrossEntropyLoss()

print(f"\nModel initialized on {device}")
print(f"   - Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Debug: Test one forward pass
print("Testing one forward pass...")
model.eval()
with torch.no_grad():
    test_i = train_pairs[:2, 0].to(device)
    test_j = train_pairs[:2, 1].to(device)
    test_types_debug = train_types[:2].to(device)

    logits = model(drug_features, edge_index, test_i, test_j)
    print(f"Logits shape: {logits.shape}")  # Should be [2, 86]
    print(f"Predictions: {torch.argmax(logits, dim=1)}")
    print(f"True labels: {test_types_debug}")
    print(f"Loss: {criterion(logits, test_types_debug):.4f}")
print("‚úÖ Forward pass successful!")

# Check data sizes
print(f"\nüìä Data sizes:")
print(f"Train pairs: {train_pairs.shape}")
print(f"Train types: {train_types.shape}")
print(f"Drug features: {drug_features.shape}")
print(f"Edge index: {edge_index.shape}")

Testing one forward pass...
Logits shape: torch.Size([2, 86])
Predictions: tensor([61, 61], device='cuda:0')
True labels: tensor([46, 48], device='cuda:0')
Loss: 4.4905
‚úÖ Forward pass successful!

üìä Data sizes:
Train pairs: torch.Size([153489, 2])
Train types: torch.Size([153489])
Drug features: torch.Size([1709, 512])
Edge index: torch.Size([2, 306978])


In [None]:
import time
import psutil
import os

def calculate_ram_usage():
    """Calculate RAM usage in GB"""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / (1024 ** 3)  # Convert to GB

In [None]:
def train_epoch(model, optimizer, criterion, pairs, types, class_weights):
    """Train for one epoch with class weights"""
    model.train()

    # Shuffle data
    indices = torch.randperm(len(pairs))
    pairs = pairs[indices]
    types = types[indices]

    total_loss = 0
    all_preds = []
    all_labels = []

    # Batch training
    num_batches = (len(pairs) + BATCH_SIZE - 1) // BATCH_SIZE

    for i in range(num_batches):
        start_idx = i * BATCH_SIZE
        end_idx = min((i + 1) * BATCH_SIZE, len(pairs))

        batch_pairs = pairs[start_idx:end_idx]
        batch_types = types[start_idx:end_idx]

        batch_i = batch_pairs[:, 0].to(device)
        batch_j = batch_pairs[:, 1].to(device)
        batch_types = batch_types.to(device)

        optimizer.zero_grad()

        # Forward
        logits = model(drug_features, edge_index, batch_i, batch_j)

        # Apply class weights to loss
        loss = F.cross_entropy(logits, batch_types, weight=class_weights)

        # Backward
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Track predictions
        preds = torch.argmax(logits, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(batch_types.cpu().numpy())

    avg_loss = total_loss / num_batches
    accuracy = accuracy_score(all_labels, all_preds)

    return avg_loss, accuracy


def evaluate(model, pairs, types, criterion, class_weights):
    """Evaluate model with class weights"""
    model.eval()

    all_preds = []
    all_labels = []
    total_loss = 0

    with torch.no_grad():
        num_batches = (len(pairs) + BATCH_SIZE - 1) // BATCH_SIZE

        for i in range(num_batches):
            start_idx = i * BATCH_SIZE
            end_idx = min((i + 1) * BATCH_SIZE, len(pairs))

            batch_pairs = pairs[start_idx:end_idx]
            batch_types = types[start_idx:end_idx]

            batch_i = batch_pairs[:, 0].to(device)
            batch_j = batch_pairs[:, 1].to(device)
            batch_types = batch_types.to(device)

            logits = model(drug_features, edge_index, batch_i, batch_j)

            # Apply class weights to loss
            loss = F.cross_entropy(logits, batch_types, weight=class_weights)
            total_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(batch_types.cpu().numpy())

    # Compute metrics
    accuracy = accuracy_score(all_labels, all_preds)
    f1_micro = f1_score(all_labels, all_preds, average='micro')
    f1_macro = f1_score(all_labels, all_preds, average='macro')
    precision = precision_score(all_labels, all_preds, average='micro')
    recall = recall_score(all_labels, all_preds, average='micro')

    avg_loss = total_loss / num_batches

    metrics = {
        'accuracy': accuracy,
        'f1_micro': f1_micro,
        'f1_macro': f1_macro,
        'precision': precision,
        'recall': recall
    }

    return metrics, avg_loss

print("Training functions ready with class weights")

Training functions ready with class weights


In [None]:
print("\n" + "="*80)
print("FULL TRAINING GAT+MLP (200 EPOCHS) - GPU ACCELERATED")
print("="*80)

# Calculate class weights (EXACTLY like HGNN)
type_counts = torch.bincount(train_types, minlength=86).float()
alpha = 0.3
class_weights = 1.0 / torch.pow(type_counts.clamp(min=1.0), alpha)
class_weights = class_weights / class_weights.mean()
class_weights = class_weights.to(device)

print(f"\nClass weights statistics (alpha={alpha}):")
print(f"  Min weight: {class_weights.min():.4f}")
print(f"  Max weight: {class_weights.max():.4f}")
print(f"  Mean weight: {class_weights.mean():.4f}")
print(f"  Sample counts - Min: {type_counts.min():.0f}, Max: {type_counts.max():.0f}")

ram_before = calculate_ram_usage()
print(f"\nRAM usage before training: {ram_before:.2f} GB")
print(f"Device: {device} (GPU)")
print(f"Training data: {len(train_pairs):,} pairs")

best_val_loss = 1e10
patience = 100
patience_counter = 0
best_epoch = 0
start_time = time.time()

train_losses = []
val_losses = []

for epoch in range(NUM_EPOCHS):
    epoch_start_time = time.time()

    # Training
    train_loss, train_acc = train_epoch(model, optimizer, criterion, train_pairs, train_types, class_weights)

    # Validation
    val_metrics, val_loss = evaluate(model, val_pairs, val_types, criterion, class_weights)

    train_losses.append(train_loss)
    val_losses.append(val_loss)

    epoch_time = time.time() - epoch_start_time

    # Early stopping based on VALIDATION LOSS
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        best_epoch = epoch
        torch.save(model.state_dict(), '/content/best_gat_mlp_model.pth')
        print(f"New best: Epoch {epoch} - Val Loss: {val_loss:.4f} (Time: {epoch_time:.1f}s)")
    else:
        patience_counter += 1

    # Print every 10 epochs
    if epoch % 10 == 0:
        print(f"Epoch {epoch}: loss: {train_loss:.4f}, val_loss: {val_loss:.4f} (best: {best_val_loss:.4f}, patience: {patience_counter})")

    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch}")
        break

print(f"\nBest epoch: {best_epoch}, Best val loss: {best_val_loss:.4f}")


FULL TRAINING GAT+MLP (200 EPOCHS) - GPU ACCELERATED

Class weights statistics (alpha=0.3):
  Min weight: 0.1566
  Max weight: 2.6340
  Mean weight: 1.0000
  Sample counts - Min: 4, Max: 48746

RAM usage before training: 1.63 GB
Device: cuda (GPU)
Training data: 153,489 pairs
New best: Epoch 0 - Val Loss: 1.7177 (Time: 27.0s)
Epoch 0: loss: 2.4016, val_loss: 1.7177 (best: 1.7177, patience: 0)
New best: Epoch 1 - Val Loss: 1.5461 (Time: 26.7s)
New best: Epoch 2 - Val Loss: 1.5451 (Time: 26.9s)
New best: Epoch 3 - Val Loss: 1.4399 (Time: 26.5s)
New best: Epoch 6 - Val Loss: 1.3788 (Time: 26.5s)
New best: Epoch 8 - Val Loss: 1.3438 (Time: 26.5s)
Epoch 10: loss: 1.7732, val_loss: 1.3907 (best: 1.3438, patience: 2)
New best: Epoch 15 - Val Loss: 1.3098 (Time: 26.7s)
Epoch 20: loss: 1.7475, val_loss: 1.3489 (best: 1.3098, patience: 5)
New best: Epoch 21 - Val Loss: 1.3028 (Time: 26.5s)
Epoch 30: loss: 1.7565, val_loss: 1.3337 (best: 1.3028, patience: 9)
New best: Epoch 31 - Val Loss: 1.2979

In [None]:
# Save the best model to our DATA_PATH
best_model_save_path = DATA_PATH + 'best_gat_mlp_model.pth'

# Save complete model checkpoint
torch.save({
    'epoch': best_epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'val_loss': best_val_loss,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'hyperparameters': {
        'HIDDEN_DIM': HIDDEN_DIM,
        'EMBEDDING_DIM': EMBEDDING_DIM,
        'NUM_HEADS': NUM_HEADS,
        'DROPOUT': DROPOUT,
        'LEARNING_RATE': LEARNING_RATE,
        'WEIGHT_DECAY': WEIGHT_DECAY,
        'BATCH_SIZE': BATCH_SIZE,
    }
}, best_model_save_path)

print("‚úÖ Best model saved successfully!")
print(f"üìÅ Location: {best_model_save_path}")
print(f"üìä Best epoch: {best_epoch}")
print(f"üìä Best val loss: {best_val_loss:.4f}")


# ============================================================================
# Run this cell AFTER training finishes to create checkpoint at current epoch
# ============================================================================

# Create checkpoint directory
CHECKPOINT_DIR = DATA_PATH + 'checkpoints/'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Save checkpoint at current epoch (e.g., epoch 200)
current_epoch = NUM_EPOCHS  # This will be 200 if you trained for 200 epochs
checkpoint_save_path = CHECKPOINT_DIR + f'checkpoint_epoch_{current_epoch}.pth'

torch.save({
    'epoch': current_epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'val_losses': val_losses,
    'best_val_loss': best_val_loss,
    'class_weights': class_weights,
    'hyperparameters': {
        'HIDDEN_DIM': HIDDEN_DIM,
        'EMBEDDING_DIM': EMBEDDING_DIM,
        'NUM_HEADS': NUM_HEADS,
        'DROPOUT': DROPOUT,
        'LEARNING_RATE': LEARNING_RATE,
        'WEIGHT_DECAY': WEIGHT_DECAY,
        'BATCH_SIZE': BATCH_SIZE,
        'NUM_EPOCHS': NUM_EPOCHS,
        'SEED': SEED,
    }
}, checkpoint_save_path)

print("‚úÖ Checkpoint saved successfully!")
print(f"üìÅ Location: {checkpoint_save_path}")
print(f"üìä Epoch: {current_epoch}")
print(f"\nüí° You can now continue training from epoch {current_epoch} to 500")



‚úÖ Best model saved successfully!
üìÅ Location: /content/drive/MyDrive/GAT/data/best_gat_mlp_model.pth
üìä Best epoch: 79
üìä Best val loss: 1.2339
‚úÖ Checkpoint saved successfully!
üìÅ Location: /content/drive/MyDrive/GAT/data/checkpoints/checkpoint_epoch_200.pth
üìä Epoch: 200

üí° You can now continue training from epoch 200 to 500


## 9. Evaluate on Test Set

In [None]:
# Load best model
model.load_state_dict(torch.load('/content/best_gat_mlp_model.pth'))

# Evaluate on test set
test_metrics, test_loss = evaluate(model, test_pairs, test_types, criterion, class_weights)

print("\n" + "=" * 60)
print("üìä FINAL TEST SET RESULTS (GAT + MLP)")
print("=" * 60)
print(f"Test Loss:  {test_loss:.4f}")
print(f"Accuracy:   {test_metrics['accuracy']:.4f}")
print(f"F1 (Micro): {test_metrics['f1_micro']:.4f}")
print(f"F1 (Macro): {test_metrics['f1_macro']:.4f}")
print(f"Precision:  {test_metrics['precision']:.4f}")
print(f"Recall:     {test_metrics['recall']:.4f}")
print("=" * 60)


üìä FINAL TEST SET RESULTS (GAT + MLP)
Test Loss:  1.2492
Accuracy:   0.6444
F1 (Micro): 0.6444
F1 (Macro): 0.3840
Precision:  0.6444
Recall:     0.6444
