In [None]:
!pip install torch torchvision torchaudio --upgrade
!pip install transformers datasets scikit-learn
!pip uninstall dgl -y -q
!pip install dgl -f https://data.dgl.ai/wheels/repo.html
!pip install torchdata==0.7.1

In [None]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Set DGL backend before importing
os.environ['DGLBACKEND'] = 'pytorch'

# Mock GraphBolt to prevent loading issues
class MockModule:
    def __getattr__(self, name):
        return lambda *args, **kwargs: None

sys.modules['dgl.graphbolt'] = MockModule()

# Import DGL
import dgl

# Import other libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp
import pandas as pd
import tensorflow as tf
from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score, accuracy_score, average_precision_score, precision_recall_curve, auc, confusion_matrix,  classification_report
import dgl.function as fn

# Test everything
print(f"DGL version: {dgl.__version__}")
print(f"PyTorch version: {torch.__version__}")

# Test creating a hypergraph
try:
    data_dict = {
        ('node', 'in', 'edge'): ([0, 1], [0, 0]),
        ('edge', 'con', 'node'): ([0, 0], [0, 1])
    }
    test_hyG = dgl.heterograph(data_dict)
    print("Hypergraph creation successful!")
except Exception as e:
    print(f"Error: {e}")

DGL version: 2.1.0
PyTorch version: 2.9.1+cu128
Hypergraph creation successful!


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
class MLPPredictor(nn.Module):
    """Multi-class predictor for 86 interaction types"""
    def __init__(self, h_feats, num_classes=86):
        super().__init__()
        self.W1 = nn.Linear(h_feats * 2, h_feats)
        self.W2 = nn.Linear(h_feats, num_classes)

    def apply_edges(self, edges):
        h = torch.cat([edges.src['h'], edges.dst['h']], 1)
        return {'score': self.W2(F.relu(self.W1(h)))}

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h
            g.apply_edges(self.apply_edges)
            return g.edata['score']  # Shape: [num_edges, 87]

In [None]:
def compute_loss(pos_score, pos_labels, class_weights=None):
    """Loss for 86-class classification on positives only"""
    if class_weights is not None:
        return F.cross_entropy(pos_score, pos_labels, weight=class_weights)
    return F.cross_entropy(pos_score, pos_labels)


def compute_auc(pos_score, pos_labels):
    """Compute metrics for 86 classes"""
    probs = F.softmax(pos_score, dim=1).cpu().detach().numpy()
    labels_np = pos_labels.cpu().numpy()

    labels_onehot = np.eye(86)[labels_np]

    roc_auc = roc_auc_score(labels_onehot, probs, multi_class='ovr', average='macro')
    pr_auc = average_precision_score(labels_onehot, probs, average='macro')

    return roc_auc, pr_auc

In [None]:
class HyGNN(nn.Module):
    """
    Original HyGNN with Edge→Node→Edge flow
    Adapted for Stage 2: Nodes=Drugs, Edges=Interaction Types
    Flow: Type → Drug → Type
    """
    def __init__(self, input_dim, query_dim, vertex_dim, edge_dim, dropout):
        super(HyGNN, self).__init__()
        self.dropout = dropout
        self.query_dim = query_dim

        self.in_first_layer = torch.nn.Linear(input_dim, vertex_dim)
        self.not_in_first_layer = torch.nn.Linear(vertex_dim, vertex_dim)

        # Hyperedge-level attention (Type → Drug)
        self.w6 = torch.nn.Linear(edge_dim, query_dim)    # Edge → Query
        self.w5 = torch.nn.Linear(vertex_dim, query_dim)  # Node → Key
        self.w4 = torch.nn.Linear(vertex_dim, edge_dim)   # Node → Value

        # Node-level attention (Drug → Type)
        self.w3 = torch.nn.Linear(vertex_dim, query_dim)  # Node → Query
        self.w2 = torch.nn.Linear(edge_dim, query_dim)    # Edge → Key
        self.w1 = torch.nn.Linear(edge_dim, vertex_dim)   # Edge → Value

    def red_function(self, nodes):
        attention_score = F.softmax((nodes.mailbox['Attn']), dim=1)
        aggregated = torch.sum(attention_score.unsqueeze(-1) * nodes.mailbox['v'], dim=1)
        return {'h': aggregated}

    def attention(self, edges):
        attn_score = F.leaky_relu((edges.src['k'] * edges.dst['q']).sum(-1))
        return {'Attn': attn_score / np.sqrt(self.query_dim)}

    def msg_fucntion(self, edges):
        return {'v': edges.src['v'], 'Attn': edges.data['Attn']}

    def forward(self, hyG, vfeat, efeat, first_layer, last_layer):
        if first_layer:
            feat_v = self.in_first_layer(vfeat)  # vfeat = type features [86,128]
        else:
            feat_v = self.not_in_first_layer(vfeat)
        feat_e = efeat  # efeat = drug features [1709,128]

        # Stage 1: Type → Drug (Edge → Node)
        hyG.ndata['h'] = {'edge': feat_v}  # Types are EDGES
        hyG.ndata['k'] = {'edge': self.w2(feat_v)}  # Keys from types
        hyG.ndata['v'] = {'edge': self.w1(feat_v)}  # Values from types
        hyG.ndata['q'] = {'node': self.w3(feat_e)}  # Queries from drugs
        hyG.apply_edges(self.attention, etype='con')
        hyG.update_all(self.msg_fucntion, self.red_function, etype='con')

        # Stage 2: Drug → Type (Node → Edge)
        feat_e = hyG.ndata['h']['node']  # Updated drugs
        hyG.ndata['k'] = {'node': self.w5(feat_e)}  # Keys from drugs
        hyG.ndata['v'] = {'node': self.w4(feat_e)}  # Values from drugs
        hyG.ndata['q'] = {'edge': self.w6(feat_v)}  # Queries from types
        hyG.apply_edges(self.attention, etype='in')
        hyG.update_all(self.msg_fucntion, self.red_function, etype='in')
        feat_v = hyG.ndata['h']['edge']  # Final type features

        if not last_layer:
            feat_e = F.dropout(feat_e, self.dropout)

        if last_layer:
            return feat_v, feat_e
        else:
            return [hyG, feat_v, feat_e]

In [None]:
class HyGNN_Modified(nn.Module):
    """
    Modified HyGNN with Node→Edge→Node flow
    Starts with drugs (which have pre-trained features)
    Flow: Drug → Type → Drug
    """
    def __init__(self, input_dim, query_dim, vertex_dim, edge_dim, dropout):
        super(HyGNN_Modified, self).__init__()
        self.dropout = dropout
        self.query_dim = query_dim

        self.in_first_layer = torch.nn.Linear(input_dim, vertex_dim)
        self.not_in_first_layer = torch.nn.Linear(vertex_dim, vertex_dim)

        # Node-level attention (Drug → Type)
        self.w6 = torch.nn.Linear(edge_dim, query_dim)    # Edge → Query
        self.w5 = torch.nn.Linear(vertex_dim, query_dim)  # Node → Key
        self.w4 = torch.nn.Linear(vertex_dim, edge_dim)   # Node → Value

        # Hyperedge-level attention (Type → Drug)
        self.w3 = torch.nn.Linear(vertex_dim, query_dim)  # Node → Query
        self.w2 = torch.nn.Linear(edge_dim, query_dim)    # Edge → Key
        self.w1 = torch.nn.Linear(edge_dim, vertex_dim)   # Edge → Value

    def red_function(self, nodes):
        attention_score = F.softmax((nodes.mailbox['Attn']), dim=1)
        aggregated = torch.sum(attention_score.unsqueeze(-1) * nodes.mailbox['v'], dim=1)
        return {'h': aggregated}

    def attention(self, edges):
        attn_score = F.leaky_relu((edges.src['k'] * edges.dst['q']).sum(-1))
        return {'Attn': attn_score / np.sqrt(self.query_dim)}

    def msg_fucntion(self, edges):
        return {'v': edges.src['v'], 'Attn': edges.data['Attn']}

    def forward(self, hyG, vfeat, efeat, first_layer, last_layer):
        if first_layer:
            feat_e = self.in_first_layer(efeat)  # efeat = drug features [1709,128]
        else:
            feat_e = self.not_in_first_layer(efeat)
        feat_v = vfeat  # vfeat = type features [86,128]

        # Stage 1: Drug → Type (Node → Edge) - YOUR IDEA
        hyG.ndata['h'] = {'node': feat_e}  # Start with drugs
        hyG.ndata['k'] = {'node': self.w5(feat_e)}  # Keys from drugs
        hyG.ndata['v'] = {'node': self.w4(feat_e)}  # Values from drugs
        hyG.ndata['q'] = {'edge': self.w6(feat_v)}  # Queries from types
        hyG.apply_edges(self.attention, etype='in')  # Types query drugs
        hyG.update_all(self.msg_fucntion, self.red_function, etype='in')

        # Stage 2: Type → Drug (Edge → Node)
        feat_v = hyG.ndata['h']['edge']  # Updated types
        hyG.ndata['k'] = {'edge': self.w2(feat_v)}  # Keys from types
        hyG.ndata['v'] = {'edge': self.w1(feat_v)}  # Values from types
        hyG.ndata['q'] = {'node': self.w3(feat_e)}  # Queries from drugs
        hyG.apply_edges(self.attention, etype='con')  # Drugs query types
        hyG.update_all(self.msg_fucntion, self.red_function, etype='con')
        feat_e = hyG.ndata['h']['node']  # Final drug features

        if not last_layer:
            feat_v = F.dropout(feat_v, self.dropout)

        if last_layer:
            return feat_v, feat_e
        else:
            return [hyG, feat_v, feat_e]

In [None]:
class Model(nn.Module):
    """Wrapper for HyGNN - can switch between original and modified"""
    def __init__(self, drug_feature_dim, config, use_modified=False):
        super(Model, self).__init__()

        if use_modified:
            self.gat1 = HyGNN_Modified(
                drug_feature_dim,
                config['hidden_units'],
                config['hidden_units'],
                config['hidden_units'],
                config['dropout']
            )
        else:
            self.gat1 = HyGNN(
                drug_feature_dim,
                config['hidden_units'],
                config['hidden_units'],
                config['hidden_units'],
                config['dropout']
            )

    def forward(self, hyG, v_feat, e_feat, f, l):
        h = self.gat1(hyG, v_feat, e_feat, f, l)
        return h

In [None]:
def load_stage1_embeddings():
    """Load drug embeddings from Stage 1 (Chemical Network)"""
    print("Loading Stage 1 embeddings...")

    path = '/content/drive/MyDrive/MyModel/Model2-S42-K9-Experment-1/chemical_network_output.pt' # we use Model2-S42-K9-Experment-1
    checkpoint = torch.load(path, weights_only=False)

    H = checkpoint['drug_embeddings']  # [1709, 128]
    drug_to_idx = checkpoint['drug_to_id']

    print(f"  Drug embeddings shape: {H.shape}")
    print(f"  Number of drugs: {len(drug_to_idx)}")
    print(f"  Stage 1 best epoch: {checkpoint['best_epoch']}")
    print(f"  Stage 1 ROC-AUC: {checkpoint['final_performance']['roc_auc']:.4f}")

    return H, drug_to_idx


def load_train_test_data_multiclass(drug_to_idx):
    """Load training/validation/test data with interaction type labels (including negatives)"""
    print("\nLoading multi-class training data...")

    # Load positive samples with interaction types (1-86)
    train_pos = pd.read_csv('/content/drive/MyDrive/MyModel/Metabolic-Seed42/train_fixed.csv')
    val_pos = pd.read_csv('/content/drive/MyDrive/MyModel/Metabolic-Seed42/val_fixed.csv')
    test_pos = pd.read_csv('/content/drive/MyDrive/MyModel/Metabolic-Seed42/test_fixed.csv')

    def create_dgl_graph_with_types(df, num_nodes, drug_mapping):
        """Create DGL graph with interaction type labels"""
        # Get drug IDs
        if 'Drug1_ID' in df.columns and 'Drug2_ID' in df.columns:
            src_ids = df['Drug1_ID'].values
            dst_ids = df['Drug2_ID'].values
        else:
            src_ids = df.iloc[:, 0].values
            dst_ids = df.iloc[:, 1].values

        # Get interaction types/labels
        if 'Label' in df.columns:
            types = df['Label'].values 
        elif 'Y' in df.columns:
            types = df['Y'].values
        elif 'Interaction_Type' in df.columns:
            types = df['Interaction_Type'].values
        else:
            raise ValueError(f"No label column found. Available: {df.columns.tolist()}")

        types = types - 1
        # Convert to indices
        src = torch.tensor([drug_mapping[drug_id] for drug_id in src_ids], dtype=torch.long)
        dst = torch.tensor([drug_mapping[drug_id] for drug_id in dst_ids], dtype=torch.long)
        types = torch.tensor(types, dtype=torch.long)  # Keep as-is: 0-86

        # Create graph
        g = dgl.graph((src, dst), num_nodes=num_nodes)
        g.edata['type'] = types

        return g

    # Create graphs for positive samples
    train_pos_g = create_dgl_graph_with_types(train_pos, len(drug_to_idx), drug_to_idx)
    val_pos_g = create_dgl_graph_with_types(val_pos, len(drug_to_idx), drug_to_idx)
    test_pos_g = create_dgl_graph_with_types(test_pos, len(drug_to_idx), drug_to_idx)


    print(f"  Train positive samples: {train_pos_g.number_of_edges()}")
    print(f"  Validation positive samples: {val_pos_g.number_of_edges()}")
    print(f"  Test positive samples: {test_pos_g.number_of_edges()}")

    return train_pos_g, val_pos_g, test_pos_g


In [10]:
def load_metabolic_hypergraph(hyg_path, metadata_path):
    """Load metabolic hypergraph structure"""
    print("\nLoading metabolic hypergraph...")

    # Load hypergraph edges
    edges = torch.load(hyg_path, weights_only=False)
    metadata = torch.load(metadata_path, weights_only=False)

    print(f"  Hypergraph edges shape: {edges.shape}")
    print(f"  Total connections: {edges.shape[0]:,}")
    print(f"  Unique drugs: {len(torch.unique(edges[:, 0]))}")
    print(f"  Unique interaction types: {len(torch.unique(edges[:, 1]))}")

    # Create DGL heterograph
    # edges[:, 0] = drug indices (nodes)
    # edges[:, 1] = interaction type indices (hyperedges)
    data_dict = {
        ('node', 'in', 'edge'): (edges[:, 0], edges[:, 1]),
        ('edge', 'con', 'node'): (edges[:, 1], edges[:, 0])
    }

    hyG = dgl.heterograph(data_dict)

    print(f"\nHypergraph structure:")
    print(hyG)

    return hyG, metadata

In [11]:
def create_interaction_type_features(num_types=86, hidden_dim=128):
    """Create one-hot features for interaction types"""
    print(f"\nCreating interaction type features...")
    print(f"  Using one-hot encoding for {num_types} types")

    # Create sparse identity matrix
    from scipy.sparse import coo_matrix
    nl = coo_matrix((num_types, num_types))
    nl.setdiag(1)
    values = nl.data
    indices = np.vstack((nl.row, nl.col))
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    type_X = torch.sparse_coo_tensor(i, v, torch.Size((num_types, num_types)))

    print(f"  Sparse identity matrix created: {type_X.shape}")

    return type_X

In [12]:
# ================= CELL 10: TRAINING CONFIGURATION =================

# EXPERIMENT CONFIG
EXPERIMENT_CONFIG = {
    'learning_rate': 0.005,
    'hidden_units': 128,
    'dropout': 0.5,
    'weight_decay': 0.0,
    'training_seed': 42,
    'experiment_name': 'metabolic_network_original',  # or 'metabolic_network_modified'
    'use_modified_attention': False  # Set True to test modified flow
}

base_path = f'/content/drive/MyDrive/MyModel/DB/Metabolic/Model1-{EXPERIMENT_CONFIG["experiment_name"]}/'

# Create experiment directory
os.makedirs(base_path, exist_ok=True)
print(f"Experiment: {EXPERIMENT_CONFIG['experiment_name']}")
print(f"Config: {EXPERIMENT_CONFIG}")
print(f"Directory: {base_path}")

# Set random seeds
torch.manual_seed(EXPERIMENT_CONFIG['training_seed'])
np.random.seed(EXPERIMENT_CONFIG['training_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(EXPERIMENT_CONFIG['training_seed'])

Experiment: metabolic_network_original
Config: {'learning_rate': 0.005, 'hidden_units': 128, 'dropout': 0.5, 'weight_decay': 0.0, 'training_seed': 42, 'experiment_name': 'metabolic_network_original', 'use_modified_attention': False}
Directory: /content/drive/MyDrive/MyModel/DB/Metabolic/Model1-metabolic_network_original/


In [None]:
# Load Stage 1 embeddings (these become drug features)
H_stage1, drug_to_idx = load_stage1_embeddings()

# Load metabolic hypergraph
hyG, metadata = load_metabolic_hypergraph(
    '/content/drive/MyDrive/MyModel/metabolic_hypergraph.pt',
    '/content/drive/MyDrive/MyModel/metabolic_hypergraph_metadata.pt'
)

# Create interaction type features
type_X = create_interaction_type_features(num_types=86, hidden_dim=128)

# Load train/val/test data
train_pos_g, val_pos_g, test_pos_g = load_train_test_data_multiclass(drug_to_idx)

# Set initial features for hypergraph
# e_feat = drug features from Stage 1 [1709, 128]
# v_feat = interaction type features (ones vector) [86, 128]
e_feat = H_stage1
v_feat = torch.tensor(type_X.to_dense())


print("\n" + "="*60)
print("Data loading complete!")
print("="*60)
print(f"Drug features (e_feat): {e_feat.shape}")
print(f"Interaction type features (v_feat): {v_feat.shape}")
print(f"Hypergraph nodes (drugs): {hyG.num_nodes('node')}")
print(f"Hypergraph edges (types): {hyG.num_nodes('edge')}")


# ================= VERIFICATION =================

# Verify drug mapping consistency
stage1_drugs = set(drug_to_idx.keys())
stage2_drugs = set(metadata['drug_to_idx'].keys())

print(f"\n{'='*60}")
print("Drug Mapping Verification")
print(f"{'='*60}")
print(f"Stage 1 drugs: {len(stage1_drugs)}")
print(f"Stage 2 drugs: {len(stage2_drugs)}")
print(f"Shared drugs: {len(stage1_drugs & stage2_drugs)}")

if stage1_drugs == stage2_drugs:
    print("✓ Drug vocabularies are identical")

    # Check if indices match
    mismatch = False
    for drug_id in list(stage1_drugs)[:5]:
        idx1 = drug_to_idx[drug_id]
        idx2 = metadata['drug_to_idx'][drug_id]
        if idx1 != idx2:
            print(f"✗ Mismatch: {drug_id} has index {idx1} in Stage1 but {idx2} in Stage2")
            mismatch = True

    if not mismatch:
        print("✓ Drug index mappings are identical")
        print("✓ Stage 1 embeddings will correctly transfer to Stage 2")
else:
    print("✗ WARNING: Drug vocabularies differ!")

Loading Stage 1 embeddings...
  Drug embeddings shape: torch.Size([1709, 128])
  Number of drugs: 1709
  Stage 1 best epoch: 494
  Stage 1 ROC-AUC: 0.9846

Loading metabolic hypergraph...
  Hypergraph edges shape: torch.Size([13486, 2])
  Total connections: 13,486
  Unique drugs: 1709
  Unique interaction types: 86

Hypergraph structure:
Graph(num_nodes={'edge': 86, 'node': 1709},
      num_edges={('edge', 'con', 'node'): 13486, ('node', 'in', 'edge'): 13486},
      metagraph=[('edge', 'node', 'con'), ('node', 'edge', 'in')])

Creating interaction type features...
  Using one-hot encoding for 86 types
  Sparse identity matrix created: torch.Size([86, 86])

Loading multi-class training data...
  Train positive samples: 153489
  Validation positive samples: 19188
  Test positive samples: 19200

Data loading complete!
Drug features (e_feat): torch.Size([1709, 128])
Interaction type features (v_feat): torch.Size([86, 86])
Hypergraph nodes (drugs): 1709
Hypergraph edges (types): 86

Drug Ma

In [14]:
import psutil
import os
import time

def calculate_ram_usage():
    """Calculate current RAM usage in GB"""
    process = psutil.Process(os.getpid())
    ram_gb = process.memory_info().rss / (1024 ** 3)  # Convert to GB
    return ram_gb


In [None]:
# Create model and decoder
if EXPERIMENT_CONFIG['use_modified_attention']:
    input_dim = e_feat.shape[1]  # 128 for Modified
    print('Use Modified HyGNN attantion')
else:
    input_dim = type_X.shape[1]  # 86 for Original
    print('Use Original HyGNN attantion')

model = Model(
    input_dim,  
    EXPERIMENT_CONFIG,
    use_modified=EXPERIMENT_CONFIG['use_modified_attention']
)
decoder = MLPPredictor(EXPERIMENT_CONFIG['hidden_units'], num_classes=86)  )

# Optimizer
optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), decoder.parameters()),
    lr=EXPERIMENT_CONFIG['learning_rate']
)

type_counts = torch.bincount(train_pos_g.edata['type'], minlength=86).float()
alpha = 0.5
class_weights = 1.0 / torch.pow(type_counts.clamp(min=1.0), alpha)
class_weights = class_weights / class_weights.mean()

print(f"\nClass weights statistics:")
print(f"  Min weight: {class_weights.min():.4f}")
print(f"  Max weight: {class_weights.max():.4f}")
print(f"  Mean weight: {class_weights.mean():.4f}")


# Training variables
best_val_loss = 1e10
patience = 0
best_embeddings = None
best_epoch = 0

training_start_time = time.time()
# Get RAM usage before training
ram_before = calculate_ram_usage()
print(f" RAM usage before training: {ram_before:.2f} GB")
print("\nStarting Stage 2 (Metabolic Network) training...")
print("\nStarting Stage 2 (Metabolic Network) training...")
print(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(training_start_time))}")

print("\nStarting Stage 2 (Metabolic Network) training...")

for e in range(500):
    # Training phase
    model.train()
    decoder.train()

    # Forward pass through HyGNN
    h = model(hyG, v_feat, e_feat, True, True)
    h_drug = h[1]  # Get drug embeddings [1709, 128]

    pos_score = decoder(train_pos_g, h_drug)  # [num_edges, 87]

    pos_labels = train_pos_g.edata['type']  # [num_edges] - values 1-86

    
    loss = compute_loss(pos_score, pos_labels)

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Validation phase
    with torch.no_grad():
        model.eval()
        decoder.eval()

        pos_score_val = decoder(val_pos_g, h_drug)
        pos_labels_val = val_pos_g.edata['type']
        val_loss = compute_loss(pos_score_val, pos_labels_val, class_weights)

        # Model selection based on validation loss
        if val_loss < best_val_loss:
            best_val_loss = val_loss.item()
            best_embeddings = h_drug.clone()
            best_epoch = e
            patience = 0

            # Save models
            torch.save(decoder.state_dict(), f'{base_path}decoder_best.pth')
            torch.save(model.state_dict(), f'{base_path}model_best.pth')
            torch.save(best_embeddings, f'{base_path}best_embeddings.pt')  # ADD THIS LINE

        else:
            patience += 1

        # Early stopping
        if patience > 100:
            print(f"Early stopping at epoch {e}")
            break

    # Progress reporting every 10 epochs
    if e % 10 == 0:
        print(f'Epoch {e}: loss: {loss:.4f}, val_loss: {val_loss:.4f} (best: {best_val_loss:.4f}, patience: {patience})')

print(f"\nTraining completed!")
print(f"Best epoch: {best_epoch}")
print(f"Best validation loss: {best_val_loss:.4f}")

H = best_embeddings
E = best_epoch


training_end_time = time.time()
total_training_time = training_end_time - training_start_time
# Get RAM usage after training
ram_after = calculate_ram_usage()
ram_used = ram_after - ram_before

print(f"  - Drug embeddings: {H.shape}")
print(f"  - Best epoch: {E}")
print(f"  - Best validation loss: {best_val_loss:.4f}")

print(f"\nTiming Statistics:")
print(f"  Total training time: {total_training_time:.2f} seconds")
print(f"  Total training time: {total_training_time/60:.2f} minutes")
print(f"  Total training time: {total_training_time/3600:.2f} hours")
print(f"  Average time per epoch: {total_training_time/(e+1):.2f} seconds")
print(f"RAM usage before training: {ram_before:.2f} GB")
print(f"RAM usage after training:  {ram_after:.2f} GB")
print(f"RAM used during training:  {ram_used:.2f} GB")
print(f"  Epochs completed: {e+1}")
print(f"  End time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(training_end_time))}")


Use Original HyGNN attantion

Class weights statistics:
  Min weight: 0.0387
  Max weight: 4.2745
  Mean weight: 1.0000
 RAM usage before training: 1.39 GB

Starting Stage 2 (Metabolic Network) training...

Starting Stage 2 (Metabolic Network) training...
Start time: 2025-11-17 18:14:44

Starting Stage 2 (Metabolic Network) training...
Epoch 0: loss: 4.4829, val_loss: 4.4217 (best: 4.4217, patience: 0)
Epoch 10: loss: 2.6488, val_loss: 4.6302 (best: 3.9911, patience: 7)
Epoch 20: loss: 2.4583, val_loss: 4.2259 (best: 3.9911, patience: 17)
Epoch 30: loss: 2.3091, val_loss: 4.1341 (best: 3.9911, patience: 27)
Epoch 40: loss: 2.0704, val_loss: 3.8873 (best: 3.8873, patience: 0)
Epoch 50: loss: 1.9072, val_loss: 3.5960 (best: 3.5960, patience: 0)
Epoch 60: loss: 1.7692, val_loss: 3.3514 (best: 3.3514, patience: 0)
Epoch 70: loss: 1.6186, val_loss: 3.0396 (best: 3.0396, patience: 0)
Epoch 80: loss: 1.4684, val_loss: 2.6774 (best: 2.6774, patience: 0)
Epoch 90: loss: 1.3099, val_loss: 2.3013

In [None]:
# ================= EVALUATION =================
decoder.load_state_dict(torch.load(f'{base_path}decoder_best.pth', weights_only=False))

with torch.no_grad():
    model.eval()
    decoder.eval()

    # Test on positives only
    pos_score = decoder(test_pos_g, H)
    pos_labels = test_pos_g.edata['type']

    test_auc = compute_auc(pos_score, pos_labels)

predictions = torch.argmax(pos_score, dim=1).numpy()
ground_truth = pos_labels.numpy()

print("\n" + "="*80)
print("STAGE 2 (METABOLIC NETWORK) - FINAL RESULTS ")
print("="*80)
print(f"Best Epoch: {E}")
print(f"\nTest Performance (Positives Only):")
print(f"  ROC-AUC: {test_auc[0]:.6f}")
print(f"  PR-AUC:  {test_auc[1]:.6f}")

accuracy = accuracy_score(ground_truth, predictions)
print(f"  Top-1 Accuracy: {accuracy:.6f}")

# Top-3 accuracy
top3_preds = torch.topk(pos_score, k=3, dim=1)[1].cpu().numpy()
top3_acc = np.mean([gt in pred for gt, pred in zip(ground_truth, top3_preds)])
print(f"  Top-3 Accuracy: {top3_acc:.6f}")

# Per-class metrics
print("\nPer-Class Performance:")
print(classification_report(
    ground_truth,
    predictions,
    target_names=[f"Type_{i}" for i in range(1, 87)],
    zero_division=0
))

# ================= SAVE FINAL OUTPUTS FOR FUTURE EVALUATION =================

print("\n" + "="*80)
print("SAVING TRAINING OUTPUTS")
print("="*80)

# Save embeddings and complete state
final_checkpoint = {
    'best_epoch': E,
    'drug_embeddings': H.cpu(),  # Best embeddings from training
    'drug_to_idx': drug_to_idx,
    'config': EXPERIMENT_CONFIG,
    'final_metrics': {
        'best_val_loss': best_val_loss,
        'test_roc_auc': test_auc[0],
        'test_pr_auc': test_auc[1]
    },
    'model_state_dict': model.state_dict(),
    'decoder_state_dict': decoder.state_dict()
}

checkpoint_path = f'{base_path}metabolic_network_complete_checkpoint.pt'
torch.save(final_checkpoint, checkpoint_path)

print(f"✓ Complete checkpoint saved to: {checkpoint_path}")


STAGE 2 (METABOLIC NETWORK) - FINAL RESULTS (LINE 3: INTRINSIC)
Best Epoch: 499

Test Performance (Positives Only):
  ROC-AUC: 0.994405
  PR-AUC:  0.872433
  Top-1 Accuracy: 0.857344
  Top-3 Accuracy: 0.981094

Per-Class Performance:
              precision    recall  f1-score   support

      Type_1       1.00      0.50      0.67         2
      Type_2       0.87      1.00      0.93        27
      Type_3       0.98      0.98      0.98        58
      Type_4       0.78      0.84      0.81       504
      Type_5       0.90      0.84      0.87        32
      Type_6       0.92      0.91      0.91       296
      Type_7       1.00      1.00      1.00         1
      Type_8       0.94      0.89      0.92        19
      Type_9       0.89      0.93      0.91       230
     Type_10       0.87      0.98      0.92        61
     Type_11       0.83      0.19      0.30        27
     Type_12       0.80      0.80      0.80        30
     Type_13       1.00      0.25      0.40         4
     Typ

# IMPROVED METABOLIC NETWORK EVALUATION (Using Checkpoint Directly)

*The textual description of the interaction type can be found in the output files test_predictions_top3.csv and test_predictions_with_names.csv*

In [None]:
# ================= IMPROVED METABOLIC NETWORK EVALUATION (Using Checkpoint Directly) =================

import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
import tensorflow as tf
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix,
    classification_report
)
import json
from datetime import datetime

# ======================= SETUP =======================
base_path = '/content/drive/MyDrive/MyModel/DB/Metabolic/Model1-metabolic_network_original/'
checkpoint_path = f'{base_path}metabolic_network_complete_checkpoint.pt'

print("="*80)
print("LOADING CHECKPOINT DIRECTLY")
print("="*80)

# ======================= LOAD CHECKPOINT =======================
print(f"\nLoading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, weights_only=False)

# Extract all information from checkpoint
H = checkpoint['drug_embeddings']  # Pre-computed embeddings [1709, 128]
drug_to_idx = checkpoint['drug_to_idx']
EXPERIMENT_CONFIG = checkpoint['config']
best_epoch = checkpoint['best_epoch']
best_val_loss = checkpoint['final_metrics']['best_val_loss']

print(f"✓ Checkpoint loaded successfully!")
print(f"  - Best epoch: {best_epoch}")
print(f"  - Best validation loss: {best_val_loss:.4f}")
print(f"  - Drug embeddings shape: {H.shape}")
print(f"  - Number of drugs in mapping: {len(drug_to_idx)}")

# ======================= LOAD DRUG NAMES =======================
print("\nLoading drug names from Google Drive...")

drug_names_path = '/content/drive/MyDrive/MyModel/ddi_DrugBank_DrugName_map.csv'
drug_names_df = pd.read_csv(drug_names_path)

# Create drug ID to name mapping
if 'Drug ID' in drug_names_df.columns and 'Name' in drug_names_df.columns:
    drug_id_to_name = dict(zip(drug_names_df['Drug ID'], drug_names_df['Name']))
elif 'Drug_ID' in drug_names_df.columns and 'Name' in drug_names_df.columns:
    drug_id_to_name = dict(zip(drug_names_df['Drug_ID'], drug_names_df['Name']))
else:
    print(f"Available columns: {drug_names_df.columns.tolist()}")
    drug_id_to_name = dict(zip(drug_names_df.iloc[:, 0], drug_names_df.iloc[:, 1]))

print(f"✓ Loaded {len(drug_id_to_name)} drug names")

# ======================= LOAD TEST DATA =======================
print("\nLoading test data...")

# Load test data CSV to get drug IDs
test_pos = pd.read_csv('/content/drive/MyDrive/MyModel/Metabolic-Seed42/test_fixed.csv')

print(f"✓ Test set: {len(test_pos)} samples")

# ======================= CREATE DECODER AND LOAD STATE =======================
print("\nCreating decoder and loading trained weights...")

# Create decoder
decoder = MLPPredictor(EXPERIMENT_CONFIG['hidden_units'], num_classes=86)

# Load decoder state from checkpoint (not from separate file)
decoder.load_state_dict(checkpoint['decoder_state_dict'])
decoder.eval()

print(f"✓ Decoder loaded from checkpoint")

# ======================= EVALUATION =======================
print("\nRunning evaluation...")

with torch.no_grad():
    decoder.eval()

    # Get predictions (test_pos_g should be available from previous cells)
    pos_score = decoder(test_pos_g, H)  # Shape: [num_edges, 86]
    pos_labels = test_pos_g.edata['type']  # Ground truth types [0-85]

    # Get probabilities and predictions
    probs = F.softmax(pos_score, dim=1)
    pred_probs, pred_classes = torch.max(probs, dim=1)

    # Compute metrics
    predictions = pred_classes.numpy()
    ground_truth = pos_labels.numpy()
    confidence_scores = pred_probs.numpy()
    all_probs = probs.numpy()

# ======================= COMPUTE METRICS =======================

# Accuracy
accuracy = accuracy_score(ground_truth, predictions)

# Top-3 accuracy
top3_preds = torch.topk(pos_score, k=3, dim=1)[1].cpu().numpy()
top3_acc = np.mean([gt in pred for gt, pred in zip(ground_truth, top3_preds)])

# Multi-class ROC-AUC and PR-AUC
labels_onehot = np.eye(86)[ground_truth]
roc_auc = roc_auc_score(labels_onehot, all_probs, multi_class='ovr', average='macro')
pr_auc = average_precision_score(labels_onehot, all_probs, average='macro')

# TensorFlow metrics for consistency
m1 = tf.keras.metrics.CategoricalAccuracy()
m1.update_state(labels_onehot, all_probs)

# Precision, Recall, F1 (macro average)
precision = precision_score(ground_truth, predictions, average='macro', zero_division=0)
recall = recall_score(ground_truth, predictions, average='macro', zero_division=0)
f1 = f1_score(ground_truth, predictions, average='macro', zero_division=0)

# Confusion Matrix
cm = confusion_matrix(ground_truth, predictions)

# ======================= CREATE DETAILED RESULTS TABLE =======================
print("\nCreating detailed results table...")

# Get drug pairs from test data
drug1_ids = test_pos['Drug1_ID'].values
drug2_ids = test_pos['Drug2_ID'].values

# Create detailed results
detailed_results = []
top3_results = []

for i in range(len(drug1_ids)):
    drug1_id = drug1_ids[i]
    drug2_id = drug2_ids[i]
    true_type = ground_truth[i]
    pred_type = predictions[i]
    score = confidence_scores[i]

    # Get top-3 predictions for this sample
    top3_indices = np.argsort(all_probs[i])[-3:][::-1]
    top3_scores = [all_probs[i][idx] for idx in top3_indices]

    # Main detailed results table
    detailed_results.append({
        'Drug1_ID': drug1_id,
        'Drug1_Name': drug_id_to_name.get(drug1_id, 'Unknown'),
        'Drug2_ID': drug2_id,
        'Drug2_Name': drug_id_to_name.get(drug2_id, 'Unknown'),
        'Prediction_Score': float(score),
        'Predicted_Type_Index': int(pred_type + 1),  # 1-86
        'Predicted_Type_Name': f"Type_{pred_type + 1}",
        'Interaction_Translation': 'N/A',  # Placeholder for translation
        'True_Type_Index': int(true_type + 1),  # 1-86
        'True_Type_Name': f"Type_{true_type + 1}",
        'Match': 'Yes' if pred_type == true_type else 'No'
    })

    # Top-3 predictions table
    top3_results.append({
        'Drug1_ID': drug1_id,
        'Drug1_Name': drug_id_to_name.get(drug1_id, 'Unknown'),
        'Drug2_ID': drug2_id,
        'Drug2_Name': drug_id_to_name.get(drug2_id, 'Unknown'),
        'Predicted_Type_1': int(top3_indices[0] + 1),  # Final prediction
        'Type_1_Score': float(top3_scores[0]),
        'Predicted_Type_2': int(top3_indices[1] + 1),
        'Type_2_Score': float(top3_scores[1]),
        'Predicted_Type_3': int(top3_indices[2] + 1),
        'Type_3_Score': float(top3_scores[2]),
        'True_Type_Index': int(true_type + 1),
        'True_Type_Name': f"Type_{true_type + 1}",
        'Match_Top1': 'Yes' if top3_indices[0] == true_type else 'No',
        'In_Top3': 'Yes' if true_type in top3_indices else 'No'
    })

results_df = pd.DataFrame(detailed_results)
top3_df = pd.DataFrame(top3_results)

# ======================= ADD INTERACTION TRANSLATIONS =======================
print("\nLoading interaction type descriptions...")

# Load the interaction type mapping
interaction_map_path = '/content/drive/MyDrive/MyModel/ddi_DrugBank_label_map.csv'
interaction_map_df = pd.read_csv(interaction_map_path)

# Create a dictionary mapping type ID to description
if 'ID' in interaction_map_df.columns and 'InteractionType' in interaction_map_df.columns:
    type_to_description = dict(zip(interaction_map_df['ID'], interaction_map_df['InteractionType']))
else:
    type_to_description = dict(zip(interaction_map_df.iloc[:, 0], interaction_map_df.iloc[:, 1]))

print(f"✓ Loaded {len(type_to_description)} interaction type descriptions")

print("Adding interaction translations with drug names...")

def translate_interaction(row):
    """Replace #Drug1 and #Drug2 with actual drug names in the interaction description"""
    type_idx = row['Predicted_Type_Index']

    # Get the description template
    description = type_to_description.get(type_idx, 'N/A')

    if description != 'N/A':
        # Replace #Drug1 and #Drug2 with actual drug names
        description = description.replace('#Drug1', row['Drug1_Name'])
        description = description.replace('#Drug2', row['Drug2_Name'])

    return description

# Apply translation to main results
results_df['Interaction_Translation'] = results_df.apply(translate_interaction, axis=1)

# Also add translations for top-3 predictions
def translate_top3(row, type_col):
    """Translate interaction for top-3 predictions"""
    type_idx = row[type_col]
    description = type_to_description.get(type_idx, 'N/A')

    if description != 'N/A':
        description = description.replace('#Drug1', row['Drug1_Name'])
        description = description.replace('#Drug2', row['Drug2_Name'])

    return description

top3_df['Type_1_Translation'] = top3_df.apply(lambda row: translate_top3(row, 'Predicted_Type_1'), axis=1)
top3_df['Type_2_Translation'] = top3_df.apply(lambda row: translate_top3(row, 'Predicted_Type_2'), axis=1)
top3_df['Type_3_Translation'] = top3_df.apply(lambda row: translate_top3(row, 'Predicted_Type_3'), axis=1)
top3_df['True_Translation'] = top3_df.apply(lambda row: translate_top3(row, 'True_Type_Index'), axis=1)

print("✓ Translations added successfully!")

# Save detailed results
results_df.to_csv(f'{base_path}test_predictions_with_names.csv', index=False)
print(f"✓ Detailed results saved to: {base_path}test_predictions_with_names.csv")

# Save top-3 predictions
top3_df.to_csv(f'{base_path}test_predictions_top3.csv', index=False)
print(f"✓ Top-3 predictions saved to: {base_path}test_predictions_top3.csv")

# ======================= PRINT RESULTS =======================

print("\n" + "=" * 80)
print("METABOLIC NETWORK - EVALUATION RESULTS (FROM CHECKPOINT)")
print("=" * 80)
print(f"\nExperiment: {EXPERIMENT_CONFIG['experiment_name']}")
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Checkpoint Epoch: {best_epoch}")

print(f"\n{'='*80}")
print("AGGREGATE PERFORMANCE METRICS")
print(f"{'='*80}")

print(f'Top-1 Accuracy:  {accuracy:.6f}')
print(f'Top-3 Accuracy:  {top3_acc:.6f}')
print(f'Precision (macro): {precision:.6f}')
print(f'Recall (macro):    {recall:.6f}')
print(f'F1-Score (macro):  {f1:.6f}')
print(f'ROC-AUC (macro):   {roc_auc:.6f}')
print(f'PR-AUC (macro):    {pr_auc:.6f}')

# ======================= ANALYZE PREDICTIONS BY CONFIDENCE =======================

print(f"\n{'='*80}")
print("CONFIDENCE-BASED ANALYSIS")
print(f"{'='*80}")

sample_cols = ['Drug1_Name', 'Drug2_Name', 'Predicted_Type_Name', 'True_Type_Name',
               'Prediction_Score', 'Match']

# High confidence correct predictions
high_conf_correct = results_df[
    (results_df['Match'] == 'Yes') &
    (results_df['Prediction_Score'] > 0.7)
]
print(f"High Confidence Correct (>0.7): {len(high_conf_correct):,} cases ({len(high_conf_correct)/len(results_df)*100:.2f}%)")

if len(high_conf_correct) > 0:
    print("\nSample High Confidence Correct Predictions:")
    print(high_conf_correct[sample_cols].head(10).to_string(index=False))

# High confidence errors
high_conf_errors = results_df[
    (results_df['Match'] == 'No') &
    (results_df['Prediction_Score'] > 0.7)
]
print(f"\nHigh Confidence Errors (>0.7): {len(high_conf_errors):,} cases ({len(high_conf_errors)/len(results_df)*100:.2f}%)")

if len(high_conf_errors) > 0:
    print("\nSample High Confidence Errors:")
    print(high_conf_errors[sample_cols].head(10).to_string(index=False))
    high_conf_errors.to_csv(f'{base_path}high_confidence_errors_with_names.csv', index=False)

# Low confidence predictions
low_conf = results_df[results_df['Prediction_Score'] < 0.5]
print(f"\nLow Confidence (<0.5): {len(low_conf):,} cases ({len(low_conf)/len(results_df)*100:.2f}%)")

if len(low_conf) > 0:
    print("\nSample Low Confidence Predictions:")
    print(low_conf[sample_cols].head(10).to_string(index=False))
    low_conf.to_csv(f'{base_path}low_confidence_with_names.csv', index=False)

# Uncertain predictions (moderate confidence)
uncertain = results_df[
    (results_df['Prediction_Score'] >= 0.5) &
    (results_df['Prediction_Score'] <= 0.7)
]
print(f"\nModerate Confidence (0.5-0.7): {len(uncertain):,} cases ({len(uncertain)/len(results_df)*100:.2f}%)")

# ======================= STATISTICS SUMMARY =======================

print("\n" + "=" * 80)
print("STATISTICS SUMMARY")
print("=" * 80)

print(f"\nTotal Test Samples: {len(results_df):,}")
print(f"Number of Interaction Types: 86")

print(f"\nCorrect Predictions: {results_df[results_df['Match']=='Yes'].shape[0]:,} ({accuracy*100:.2f}%)")
print(f"Incorrect Predictions: {results_df[results_df['Match']=='No'].shape[0]:,} ({(1-accuracy)*100:.2f}%)")

# Top-3 statistics
top3_correct = top3_df[top3_df['In_Top3']=='Yes'].shape[0]
print(f"\nTop-3 Statistics:")
print(f"  Correct in Top-3: {top3_correct:,} ({top3_acc*100:.2f}%)")
print(f"  Not in Top-3: {len(top3_df) - top3_correct:,} ({(1-top3_acc)*100:.2f}%)")

print(f"\nPrediction Score Statistics:")
print(f"  Mean:   {results_df['Prediction_Score'].mean():.4f}")
print(f"  Median: {results_df['Prediction_Score'].median():.4f}")
print(f"  Std:    {results_df['Prediction_Score'].std():.4f}")
print(f"  Min:    {results_df['Prediction_Score'].min():.4f}")
print(f"  Max:    {results_df['Prediction_Score'].max():.4f}")

# Score statistics by correctness
correct_scores = results_df[results_df['Match']=='Yes']['Prediction_Score']
incorrect_scores = results_df[results_df['Match']=='No']['Prediction_Score']

print(f"\nBy Correctness:")
print(f"  Correct predictions:   Mean = {correct_scores.mean():.4f}, Std = {correct_scores.std():.4f}")
print(f"  Incorrect predictions: Mean = {incorrect_scores.mean():.4f}, Std = {incorrect_scores.std():.4f}")

# ======================= SAVE SUMMARY =======================

summary_report = {
    'experiment_name': EXPERIMENT_CONFIG['experiment_name'],
    'timestamp': datetime.now().isoformat(),
    'checkpoint_info': {
        'best_epoch': int(best_epoch),
        'best_val_loss': float(best_val_loss),
        'checkpoint_path': checkpoint_path
    },
    'metrics': {
        'top1_accuracy': float(accuracy),
        'top3_accuracy': float(top3_acc),
        'precision_macro': float(precision),
        'recall_macro': float(recall),
        'f1_score_macro': float(f1),
        'roc_auc_macro': float(roc_auc),
        'pr_auc_macro': float(pr_auc)
    },
    'statistics': {
        'total_samples': len(results_df),
        'num_types': 86,
        'correct_predictions': int(np.sum(predictions == ground_truth)),
        'incorrect_predictions': int(np.sum(predictions != ground_truth)),
        'high_confidence_correct': len(high_conf_correct),
        'high_confidence_errors': len(high_conf_errors),
        'low_confidence': len(low_conf),
        'moderate_confidence': len(uncertain)
    },
    'score_statistics': {
        'overall_mean': float(results_df['Prediction_Score'].mean()),
        'overall_std': float(results_df['Prediction_Score'].std()),
        'correct_mean': float(correct_scores.mean()),
        'correct_std': float(correct_scores.std()),
        'incorrect_mean': float(incorrect_scores.mean()),
        'incorrect_std': float(incorrect_scores.std())
    }
}

with open(f'{base_path}evaluation_summary_improved.json', 'w') as f:
    json.dump(summary_report, f, indent=4)

print(f"\n\n✓ Evaluation summary saved to: {base_path}evaluation_summary_improved.json")
print("=" * 80)
print("EVALUATION COMPLETE (Using Checkpoint)")
print("=" * 80)

LOADING CHECKPOINT DIRECTLY

Loading checkpoint from: /content/drive/MyDrive/MyModel/DB/Metabolic/Model1-metabolic_network_original/metabolic_network_complete_checkpoint.pt
✓ Checkpoint loaded successfully!
  - Best epoch: 499
  - Best validation loss: 0.5535
  - Drug embeddings shape: torch.Size([1709, 128])
  - Number of drugs in mapping: 1709

Loading drug names from Google Drive...
✓ Loaded 1709 drug names

Loading test data...
✓ Test set: 19200 samples

Creating decoder and loading trained weights...
✓ Decoder loaded from checkpoint

Running evaluation...

Creating detailed results table...

Loading interaction type descriptions...
✓ Loaded 86 interaction type descriptions
Adding interaction translations with drug names...
✓ Translations added successfully!
✓ Detailed results saved to: /content/drive/MyDrive/MyModel/DB/Metabolic/Model1-metabolic_network_original/test_predictions_with_names.csv
✓ Top-3 predictions saved to: /content/drive/MyDrive/MyModel/DB/Metabolic/Model1-metabol

In [None]:
# ================= CELL 13: UPDATED EVALUATION =================
decoder.load_state_dict(torch.load(f'{base_path}decoder_best.pth', weights_only=False))

with torch.no_grad():
    model.eval()
    decoder.eval()

    # Test on positives only
    pos_score = decoder(test_pos_g, H)
    pos_labels = test_pos_g.edata['type']

    test_auc = compute_auc(pos_score, pos_labels)

predictions = torch.argmax(pos_score, dim=1).numpy()
ground_truth = pos_labels.numpy()

print("\n" + "="*80)
print("STAGE 2 (METABOLIC NETWORK) - FINAL RESULTS (LINE 3: INTRINSIC)")
print("="*80)
print(f"Best Epoch: {E}")
print(f"\nTest Performance (Positives Only):")
print(f"  ROC-AUC: {test_auc[0]:.6f}")
print(f"  PR-AUC:  {test_auc[1]:.6f}")

accuracy = accuracy_score(ground_truth, predictions)
print(f"  Top-1 Accuracy: {accuracy:.6f}")

# Top-3 accuracy
top3_preds = torch.topk(pos_score, k=3, dim=1)[1].cpu().numpy()
top3_acc = np.mean([gt in pred for gt, pred in zip(ground_truth, top3_preds)])
print(f"  Top-3 Accuracy: {top3_acc:.6f}")

# Per-class metrics
print("\nPer-Class Performance:")
print(classification_report(
    ground_truth,
    predictions,
    target_names=[f"Type_{i}" for i in range(1, 87)],
    zero_division=0
))

# ================= SAVE FINAL OUTPUTS FOR FUTURE EVALUATION =================

print("\n" + "="*80)
print("SAVING TRAINING OUTPUTS")
print("="*80)

# Save embeddings and complete state
final_checkpoint = {
    'best_epoch': E,
    'drug_embeddings': H.cpu(),  # Best embeddings from training
    'drug_to_idx': drug_to_idx,
    'config': EXPERIMENT_CONFIG,
    'final_metrics': {
        'best_val_loss': best_val_loss,
        'test_roc_auc': test_auc[0],
        'test_pr_auc': test_auc[1]
    },
    'model_state_dict': model.state_dict(),
    'decoder_state_dict': decoder.state_dict()
}

checkpoint_path = f'{base_path}metabolic_network_complete_checkpoint.pt'
torch.save(final_checkpoint, checkpoint_path)

print(f"✓ Complete checkpoint saved to: {checkpoint_path}")


STAGE 2 (METABOLIC NETWORK) - FINAL RESULTS (LINE 3: INTRINSIC)
Best Epoch: 499

Test Performance (Positives Only):
  ROC-AUC: 0.993705
  PR-AUC:  0.886490
  Top-1 Accuracy: 0.858802
  Top-3 Accuracy: 0.980938

Per-Class Performance:
              precision    recall  f1-score   support

      Type_1       1.00      0.50      0.67         2
      Type_2       0.77      0.85      0.81        27
      Type_3       0.98      1.00      0.99        58
      Type_4       0.79      0.78      0.78       504
      Type_5       0.90      0.81      0.85        32
      Type_6       0.89      0.90      0.89       296
      Type_7       0.50      1.00      0.67         1
      Type_8       0.85      0.89      0.87        19
      Type_9       0.90      0.95      0.92       230
     Type_10       0.82      0.89      0.85        61
     Type_11       0.60      0.22      0.32        27
     Type_12       0.94      0.97      0.95        30
     Type_13       1.00      0.75      0.86         4
     Typ