In [None]:
# ================= CELL 1: INSTALLATION =================
!pip install torch torchvision torchaudio --upgrade
!pip install transformers datasets scikit-learn
!pip uninstall dgl -y -q
!pip install dgl -f https://data.dgl.ai/wheels/repo.html
!pip install torchdata==0.7.1

In [None]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

# Set DGL backend before importing
os.environ['DGLBACKEND'] = 'pytorch'

# Mock GraphBolt to prevent loading issues
class MockModule:
    def __getattr__(self, name):
        return lambda *args, **kwargs: None

sys.modules['dgl.graphbolt'] = MockModule()

# Import DGL
import dgl

# Import other libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp
import pandas as pd
import tensorflow as tf
from sklearn.metrics import roc_auc_score, f1_score, recall_score, precision_score, accuracy_score, average_precision_score, precision_recall_curve, auc
import dgl.function as fn

# Test everything
print(f"DGL version: {dgl.__version__}")
print(f"PyTorch version: {torch.__version__}")

# Test creating a hypergraph
try:
    data_dict = {
        ('node', 'in', 'edge'): ([0, 1], [0, 0]),
        ('edge', 'con', 'node'): ([0, 0], [0, 1])
    }
    test_hyG = dgl.heterograph(data_dict)
    print("Hypergraph creation successful!")
except Exception as e:
    print("Error: {e}")

DGL version: 2.1.0
PyTorch version: 2.9.0+cu128
Hypergraph creation successful!


In [None]:
import os
import warnings
warnings.filterwarnings('ignore')
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools
import numpy as np
import scipy.sparse as sp
import pandas as pd
import tensorflow as tf
from sklearn.metrics import roc_auc_score, f1_score, recall_score,precision_score, accuracy_score,average_precision_score,precision_recall_curve,auc
import dgl.function as fn

In [None]:
from google.colab import drive
# Mount Google Drive
drive.mount('/content/drive')# Mount Google Drive


Mounted at /content/drive


In [None]:

class DotPredictor(nn.Module):
    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h # here h is drug features and g is the pos/neg train/test graph
            g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
            return g.edata['score'][:, 0]


#We use this only
class MLPPredictor(nn.Module):
    def __init__(self, h_feats):
        super().__init__()
        self.W1 = nn.Linear(h_feats * 2, h_feats)
        self.W2 = nn.Linear(h_feats, 1)

    def apply_edges(self, edges):
        h = torch.cat([edges.src['h'], edges.dst['h']], 1)
        return {'score': self.W2(F.relu(self.W1(h))).squeeze(1)}

    def forward(self, g, h):
        with g.local_scope():
            g.ndata['h'] = h  # Assign embeddings to nodes
            g.apply_edges(self.apply_edges)  # Compute scores for all edges
            return g.edata['score'] # Return edge scores


In [None]:
#binary cross-entropy loss
def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
    return F.binary_cross_entropy_with_logits(scores, labels)


def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).numpy()
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    precision, recall, thresholds = precision_recall_curve(labels, scores)
    auc_precision_recall = auc(recall, precision)
    return roc_auc_score(labels, scores),auc(recall, precision)

In [None]:

class HyGNN(nn.Module):
    def __init__(self, input_dim, query_dim, vertex_dim, edge_dim, dropout):
        super(HyGNN, self).__init__()
        self.dropout = dropout

        self.query_dim = query_dim
        self.in_first_layer = torch.nn.Linear(input_dim, vertex_dim)
        self.not_in_first_layer = torch.nn.Linear(vertex_dim, vertex_dim)
        # Attention mechanism weights (Key, Query, Value transformations)
        self.w6 = torch.nn.Linear(edge_dim, query_dim)   # Edge â†’ Query
        self.w5 = torch.nn.Linear(vertex_dim, query_dim)  # Node â†’ Key
        self.w4 = torch.nn.Linear(vertex_dim, edge_dim)  # Node â†’ Value (for edges)

        self.w3 = torch.nn.Linear(vertex_dim, query_dim) # Node â†’ Query
        self.w2 = torch.nn.Linear(edge_dim, query_dim)  # Edge â†’ Key
        self.w1 = torch.nn.Linear(edge_dim, vertex_dim) # Edge â†’ Value (for nodes)

    def red_function(self, nodes):
        attention_score = F.softmax((nodes.mailbox['Attn']), dim=1)
        aggregated = torch.sum(attention_score.unsqueeze(-1) * nodes.mailbox['v'], dim=1)
        return {'h': aggregated}

    def attention(self, edges):
        attn_score = F.leaky_relu((edges.src['k'] * edges.dst['q']).sum(-1))
        return {'Attn': attn_score/np.sqrt(self.query_dim)}

    def msg_fucntion(self, edges):
        return {'v': edges.src['v'], 'Attn': edges.data['Attn']}



    def forward(self, hyG, vfeat, efeat, first_layer, last_layer):
            if first_layer:
                feat_e = self.in_first_layer(efeat)
            else:
                feat_e = self.not_in_first_layer(efeat)
            feat_v = vfeat
            #Hyperedge-level attention
            hyG.ndata['h'] = {'edge': feat_e}
            hyG.ndata['k'] = {'edge' : self.w2(feat_e)} # Keys from drugs
            hyG.ndata['v'] = {'edge' : self.w1(feat_e)} # Values from drugs
            hyG.ndata['q'] = {'node' : self.w3(feat_v)} # Queries from substructures
            hyG.apply_edges(self.attention, etype='con') #Computes attention scores between each drug (src) and its substructures (dst)
            hyG.update_all(self.msg_fucntion, self.red_function, etype='con')  # drug -> node

            #Node-level attention
            feat_v = hyG.ndata['h']['node']  # Updated node features
            hyG.ndata['k'] = {'node' : self.w5(feat_v)}  # Keys from substructures
            hyG.ndata['v'] = {'node' : self.w4(feat_v)}  # Values from substructures
            hyG.ndata['q'] = {'edge' : self.w6(feat_e)}  # Queries from drugs
            hyG.apply_edges(self.attention, etype='in')
            hyG.update_all(self.msg_fucntion, self.red_function, etype='in')
            feat_e = hyG.ndata['h']['edge'] # GET the updated drug features!  node->drug

            if not last_layer :
                feat_v = F.dropout(feat_v, self.dropout)
            if last_layer:
                return feat_v, feat_e
            else:
                return [hyG, feat_v, feat_e]

In [None]:
def load_data_and_create_graphs():
    """Load your actual data files and create the necessary graphs"""

    # Load metadata to get dimensions
    metadata = torch.load('/content/drive/MyDrive/MLHygnn/DB/hypergraphs/hyG_drug_drugbank_kmer_12_metadata.pt', weights_only=False)

    # Extract dimensions from metadata
    num_drugs = len(metadata['drug_to_idx'])
    num_substructures = len(metadata['node_to_idx'])

    print(f"Number of drugs: {num_drugs}")
    print(f"Number of substructures: {num_substructures}")

    # Load hypergraph data
    chemicalsub_drug = torch.load('/content/drive/MyDrive/MLHygnn/DB/hypergraphs/hyG_drug_drugbank_kmer_12.pt', weights_only=False)

    # Create hypergraph
    data_dict = {
        ('node', 'in', 'edge'): (chemicalsub_drug[:,0], chemicalsub_drug[:,1]),
        ('edge', 'con', 'node'): (chemicalsub_drug[:,1], chemicalsub_drug[:,0])
    }

    hyG = dgl.heterograph(data_dict)
    print("Hypergraph structure:")
    print(hyG)
    print("=" * 500)

    # Create drug identity matrix (sparse)
    from scipy.sparse import coo_matrix
    nl = coo_matrix((num_drugs, num_drugs))
    nl.setdiag(1)
    values = nl.data # Array of non-zero values
    indices = np.vstack((nl.row, nl.col))
    i = torch.LongTensor(indices)
    v = torch.FloatTensor(values)
    shape = nl.shape
    drug_X = torch.sparse_coo_tensor(i, v, torch.Size(shape))

    # Create node features
    hyG.ndata['h'] = {'edge': torch.tensor(drug_X).type('torch.FloatTensor'), 'node': torch.ones(num_substructures, 128)}
    e_feat = torch.tensor(drug_X).type('torch.FloatTensor')
    v_feat = torch.ones(num_substructures, 128)

    return hyG, v_feat, e_feat, drug_X, metadata

In [None]:
def load_train_test_data():
    """Load your CSV files and create DGL graphs for training"""

    # Load positive samples
    train_pos = pd.read_csv('/content/drive/MyDrive/MLHygnn/DB/Partition-Dataset/train-seed32/train.csv')
    val_pos = pd.read_csv('/content/drive/MyDrive/MLHygnn/DB/Partition-Dataset/train-seed32/val.csv')
    test_pos = pd.read_csv('/content/drive/MyDrive/MLHygnn/DB/Partition-Dataset/train-seed32/test.csv')

    # Load negative samples
    train_neg = pd.read_csv('/content/drive/MyDrive/MLHygnn/DB/Partition-Dataset/train-seed32/processed_with_negatives_fixed/train_negatives.csv')
    val_neg = pd.read_csv('/content/drive/MyDrive/MLHygnn/DB/Partition-Dataset/train-seed32/processed_with_negatives_fixed/val_negatives.csv')
    test_neg = pd.read_csv('/content/drive/MyDrive/MLHygnn/DB/Partition-Dataset/train-seed32/processed_with_negatives_fixed/test_negatives.csv')

    # Get metadata and drug mapping
    metadata = torch.load('/content/drive/MyDrive/MLHygnn/DB/hypergraphs/hyG_drug_drugbank_kmer_12_metadata.pt', weights_only=False)
    num_drugs = len(metadata['drug_to_idx'])
    drug_to_id_mapping = metadata['drug_to_idx']

    def create_dgl_graph(df, num_nodes, drug_mapping):
        """Create DGL graph from dataframe with drug pairs"""
        if 'Drug1_ID' in df.columns and 'Drug2_ID' in df.columns:
            src_ids = df['Drug1_ID'].values
            dst_ids = df['Drug2_ID'].values
        else:
            print("Available columns:", df.columns.tolist())
            src_ids = df.iloc[:, 0].values  # First column
            dst_ids = df.iloc[:, 1].values  # Second column

        # Convert DrugBank IDs to integer indices using the mapping
        src = torch.tensor([drug_mapping[drug_id] for drug_id in src_ids], dtype=torch.long)
        dst = torch.tensor([drug_mapping[drug_id] for drug_id in dst_ids], dtype=torch.long)

        return dgl.graph((src, dst), num_nodes=num_nodes)

    # Create DGL graphs
    train_pos_g = create_dgl_graph(train_pos, num_drugs, drug_to_id_mapping)
    val_pos_g = create_dgl_graph(val_pos, num_drugs, drug_to_id_mapping)
    test_pos_g = create_dgl_graph(test_pos, num_drugs, drug_to_id_mapping)

    train_neg_g = create_dgl_graph(train_neg, num_drugs, drug_to_id_mapping)
    val_neg_g = create_dgl_graph(val_neg, num_drugs, drug_to_id_mapping)
    test_neg_g = create_dgl_graph(test_neg, num_drugs, drug_to_id_mapping)

    print(f"Train positive edges: {train_pos_g.number_of_edges()}")
    print(f"Train negative edges: {train_neg_g.number_of_edges()}")
    print(f"Validation positive edges: {val_pos_g.number_of_edges()}")
    print(f"Validation negative edges: {val_neg_g.number_of_edges()}")
    print(f"Test positive edges: {test_pos_g.number_of_edges()}")
    print(f"Test negative edges: {test_neg_g.number_of_edges()}")

    return train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g

In [None]:
class Model(nn.Module):
    def __init__(self, drug_feature_dim, config):
        super(Model, self).__init__()
        self.gat1 = HyGNN(
            drug_feature_dim,
            config['hidden_units'],     # query_dim
            config['hidden_units'],     # vertex_dim
            config['hidden_units'],     # edge_dim
            config['dropout']           # dropout
        )

    def forward(self, hyG, v_feat, e_feat, f, l):
        h = self.gat1(hyG, v_feat, e_feat, f, l)
        return h


In [None]:
import psutil
import os
import time
def calculate_ram_usage():
    """Calculate current RAM usage in GB"""
    process = psutil.Process(os.getpid())
    ram_gb = process.memory_info().rss / (1024 ** 3)  # Convert to GB
    return ram_gb


In [None]:
EXPERIMENT_CONFIG = {
    'learning_rate': 0.005,
    'hidden_units': 128,
    'dropout': 0.5,
    'weight_decay': 0.0,
    'training_seed': 42,
    'experiment_name': 'fixed_conservative'
}

base_path = f'/content/drive/MyDrive/MLHygnn/DB/Stage1_Chemical_NetworkModel1-ForTime/{EXPERIMENT_CONFIG["experiment_name"]}/'

# Create experiment-specific directory
os.makedirs(base_path, exist_ok=True)
print(f"Experiment: {EXPERIMENT_CONFIG['experiment_name']}")
print(f"Config: {EXPERIMENT_CONFIG}")
print(f"Directory: {base_path}")

# Set random seeds for reproducibility
torch.manual_seed(EXPERIMENT_CONFIG['training_seed'])
np.random.seed(EXPERIMENT_CONFIG['training_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(EXPERIMENT_CONFIG['training_seed'])

# Load data 
hyG, v_feat, e_feat, drug_X, metadata = load_data_and_create_graphs()
train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g = load_train_test_data()

# Create model and decoder
model = Model(drug_X.shape[1], EXPERIMENT_CONFIG)
decoder = MLPPredictor(EXPERIMENT_CONFIG['hidden_units'])


optimizer = torch.optim.Adam(
    itertools.chain(model.parameters(), decoder.parameters()),
    lr=EXPERIMENT_CONFIG['learning_rate']
    
)


# Training variables
best_val_loss = 1e10
patience = 0
best_embeddings = None
best_epoch = 0

training_start_time = time.time()
# Get RAM usage before training
ram_before = calculate_ram_usage()
print(f"ðŸ“Š RAM usage before training: {ram_before:.2f} GB")

print("\nStarting Stage 2 (Metabolic Network) training...")
print(f"Start time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(training_start_time))}")

for e in range(500):
    # Training phase
    model.train()
    decoder.train()
    h = model(hyG, v_feat, e_feat, True, True)
    h_drug = h[1]  # Get drug embeddings
    pos_score = decoder(train_pos_g, h_drug)
    neg_score = decoder(train_neg_g, h_drug)
    loss = compute_loss(pos_score, neg_score)

    # Simple backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Validation phase
    with torch.no_grad():
        model.eval()
        decoder.eval()
       
        pos_score_val = decoder(val_pos_g, h_drug)
        neg_score_val = decoder(val_neg_g, h_drug)
        val_loss = compute_loss(pos_score_val, neg_score_val)

        # Simple model selection
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_embeddings = h_drug.clone()  # Save training embeddings
            best_epoch = e
            patience = 0

            # Save models
            torch.save(decoder.state_dict(), f'{base_path}decoder_best.pth')
            torch.save(model.state_dict(), f'{base_path}model_best.pth')

        else:
            patience += 1

       

        # Simple early stopping
        if patience > 200:
            print(f"Early stopping at epoch {e}")
            break


    # Progress reporting
    if e % 10 == 0:
        print(f'Epoch {e}, train loss: {loss:.4f}, val loss: {val_loss:.4f} (best: {best_val_loss:.4f}, patience: {patience})')


training_end_time = time.time()
total_training_time = training_end_time - training_start_time
# Get RAM usage after training
ram_after = calculate_ram_usage()
ram_used = ram_after - ram_before


print(f"\n" + "="*80)
print("TRAINING COMPLETED")
print("="*80)
print(f"Best epoch: {best_epoch}")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"\nTiming Statistics:")
print(f"  Total training time: {total_training_time:.2f} seconds")
print(f"  Total training time: {total_training_time/60:.2f} minutes")
print(f"  Total training time: {total_training_time/3600:.2f} hours")
print(f"  Average time per epoch: {total_training_time/(e+1):.2f} seconds")
print(f"RAM usage before training: {ram_before:.2f} GB")
print(f"RAM usage after training:  {ram_after:.2f} GB")
print(f"RAM used during training:  {ram_used:.2f} GB")
print(f"  Epochs completed: {e+1}")
print(f"  End time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(training_end_time))}")

# Use best embeddings for final evaluation
H = best_embeddings
E = best_epoch

Experiment: fixed_conservative
Config: {'learning_rate': 0.005, 'hidden_units': 128, 'dropout': 0.5, 'weight_decay': 0.0, 'training_seed': 42, 'experiment_name': 'fixed_conservative'}
Directory: /content/drive/MyDrive/MLHygnn/DB/Stage1_Chemical_NetworkModel1-ForTime/fixed_conservative/
Number of drugs: 1709
Number of substructures: 43655
Hypergraph structure:
Graph(num_nodes={'edge': 1709, 'node': 43655},
      num_edges={('edge', 'con', 'node'): 91615, ('node', 'in', 'edge'): 91615},
      metagraph=[('edge', 'node', 'con'), ('node', 'edge', 'in')])
Train positive edges: 153501
Train negative edges: 153501
Validation positive edges: 19187
Validation negative edges: 19187
Test positive edges: 19189
Test negative edges: 19189
ðŸ“Š RAM usage before training: 1.42 GB

Starting Stage 2 (Metabolic Network) training...
Start time: 2025-10-21 15:25:18
Epoch 0, train loss: 0.6939, val loss: 0.6947 (best: 0.6947, patience: 0)
Epoch 10, train loss: 0.4350, val loss: 0.4623 (best: 0.4623, patienc

In [None]:

# Evaluation
decoder.load_state_dict(torch.load(f'{base_path}decoder_best.pth'))
with torch.no_grad():
    model.eval()

    pos_score = decoder(test_pos_g, H)
    neg_score = decoder(test_neg_g, H)
    test_acc = compute_auc(pos_score, neg_score)

scores = torch.cat([pos_score, neg_score])
labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])

m1 = tf.keras.metrics.BinaryAccuracy()
m1.update_state(labels, scores)

#Compute Precision  = True Positives / (True Positives + False Positives)
sig_scores = F.sigmoid(scores)
m2 = tf.keras.metrics.Precision()
m2.update_state(labels, sig_scores)
M2 = m2.result().numpy()

#Recall = True Positives / (True Positives + False Negatives)
m3 = tf.keras.metrics.Recall()
m3.update_state(labels, sig_scores)
M3 = m3.result().numpy()

#F1-Score = Harmonic mean of Precision and Recall
F1 = 2*(M2*M3)/(M2+M3)
print('Best Epoch: {}, Accuracy: {:.4f}, Precision: {:.4f}, Recall: {:.4f}, F1-score {:.4f}, ROC-AUC {:.4f}, PR-AUC {:.4f}'.format(
    E, m1.result().numpy(), M2, M3, F1, test_acc[0], test_acc[1]))

# Save the final drug embeddings for the second network
print(f"Final drug embeddings shape: {H.shape}")
print("These embeddings will be used as input to the second (metabolic) network")

# Save embeddings and metadata for second network
torch.save({
    'drug_embeddings': H,
    'drug_to_id': metadata['drug_to_idx'],  # Changed key name
    'best_epoch': E,
    'final_performance': {
        'accuracy': m1.result().numpy(),
        'precision': M2,
        'recall': M3,
        'f1': F1,
        'roc_auc': test_acc[0],
        'pr_auc': test_acc[1]
    }
}, f'{base_path}chemical_network_output.pt')

Best Epoch: 469, Accuracy: 0.9313, Precision: 0.9196, Recall: 0.9469, F1-score 0.9330, ROC-AUC 0.9842, PR-AUC 0.9841
Final drug embeddings shape: torch.Size([1709, 128])
These embeddings will be used as input to the second (metabolic) network


# The data required by the final module is loaded for a comprehensive view of the results.

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
import tensorflow as tf
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score, confusion_matrix,
    classification_report
)

# ======================= SETUP =======================
EXPERIMENT_CONFIG = {
    'learning_rate': 0.005,
    'hidden_units': 128,
    'dropout': 0.5,
    'weight_decay': 0.0,
    'training_seed': 42,
    'experiment_name': 'fixed_conservative'
}

base_path = f'/content/drive/MyDrive/MLHygnn/DB/Model-K12/Stage1_Chemical_NetworkModel1/{EXPERIMENT_CONFIG["experiment_name"]}/'

# ======================= LOAD DATA =======================
# Load hypergraph and test graphs
hyG, v_feat, e_feat, drug_X, metadata = load_data_and_create_graphs()
train_pos_g, train_neg_g, val_pos_g, val_neg_g, test_pos_g, test_neg_g = load_train_test_data()

# ======================= CREATE MODEL INSTANCES =======================
model = Model(drug_X.shape[1], EXPERIMENT_CONFIG)
decoder = MLPPredictor(EXPERIMENT_CONFIG['hidden_units'])

# ======================= LOAD TRAINED WEIGHTS =======================
# Load the saved checkpoint
checkpoint = torch.load(f'{base_path}chemical_network_output.pt', weights_only=False)

H = checkpoint['drug_embeddings']  # The best drug embeddings [1709, 128]
E = checkpoint['best_epoch']       # Best epoch number

print(f"Loaded embeddings from epoch {E}")
print(f"Embedding shape: {H.shape}")

# Load model weights

model.load_state_dict(torch.load(f'{base_path}model_best.pth', weights_only=False))
decoder.load_state_dict(torch.load(f'{base_path}decoder_best.pth', weights_only=False))

# ======================= EVALUATION =======================
with torch.no_grad():
    model.eval()
    decoder.eval()

    # Get predictions
    pos_score = decoder(test_pos_g, H)
    neg_score = decoder(test_neg_g, H)

    # Compute AUC metrics
    test_acc = compute_auc(pos_score, neg_score)

# Prepare data
scores = torch.cat([pos_score, neg_score])
labels = torch.cat([
    torch.ones(pos_score.shape[0]),
    torch.zeros(neg_score.shape[0])
])

# Convert to probabilities
sig_scores = F.sigmoid(scores)
predictions = (sig_scores > 0.5).long()

# ======================= AGGREGATE METRICS =======================


print("=" * 80)
print("AGGREGATE PERFORMANCE METRICS")
print("=" * 80)

# Using TensorFlow metrics
m1 = tf.keras.metrics.BinaryAccuracy()
m1.update_state(labels, scores)

m2 = tf.keras.metrics.Precision()
m2.update_state(labels, sig_scores)
M2 = m2.result().numpy()

m3 = tf.keras.metrics.Recall()
m3.update_state(labels, sig_scores)
M3 = m3.result().numpy()

F1 = 2 * (M2 * M3) / (M2 + M3)

print(f'Best Epoch: {E}')
print(f'Accuracy:   {m1.result().numpy():.4f}')
print(f'Precision:  {M2:.4f}')
print(f'Recall:     {M3:.4f}')
print(f'F1-Score:   {F1:.4f}')
print(f'ROC-AUC:    {test_acc[0]:.4f}')
print(f'PR-AUC:     {test_acc[1]:.4f}')

# Confusion Matrix
cm = confusion_matrix(labels.numpy(), predictions.numpy())
print("\nConfusion Matrix:")
print(f"                Predicted Negative    Predicted Positive")
print(f"Actual Negative        {cm[0,0]:6d}              {cm[0,1]:6d}")
print(f"Actual Positive        {cm[1,0]:6d}              {cm[1,1]:6d}")


print("\n" + "=" * 80)
print("INDIVIDUAL PREDICTIONS")
print("=" * 80)



# Get drug IDs (handle different column names)
if 'Drug1_ID' in test_pos.columns:
    pos_drug1 = test_pos['Drug1_ID'].tolist()
    pos_drug2 = test_pos['Drug2_ID'].tolist()
    neg_drug1 = test_neg['Drug1_ID'].tolist()
    neg_drug2 = test_neg['Drug2_ID'].tolist()
else:
    pos_drug1 = test_pos.iloc[:, 0].tolist()
    pos_drug2 = test_pos.iloc[:, 1].tolist()
    neg_drug1 = test_neg.iloc[:, 0].tolist()
    neg_drug2 = test_neg.iloc[:, 1].tolist()

# Create results dataframe
#positive pairs THEN negative pairs
results_df = pd.DataFrame({
    'Drug1_ID': pos_drug1 + neg_drug1,
    'Drug2_ID': pos_drug2 + neg_drug2,
    'True_Label': labels.numpy(),
    'Predicted_Label': predictions.numpy(),
    'Prediction_Score': sig_scores.numpy(),
    'Correct': (predictions.numpy() == labels.numpy()).astype(int)
})

# Save all predictions
results_df.to_csv(f'{base_path}test_predictions_detailed.csv', index=False)
print(f"\nAll predictions saved to: {base_path}test_predictions_detailed.csv")


print("\n" + "-" * 80)
print("SAMPLE PREDICTIONS (First 20)")
print("-" * 80)
print(results_df.head(20).to_string(index=False))


print("\n" + "=" * 80)
print("ERROR ANALYSIS")
print("=" * 80)

# False Positives (Type I Error)
false_positives = results_df[
    (results_df['True_Label'] == 0) & (results_df['Predicted_Label'] == 1)
]
print(f"\nFalse Positives: {len(false_positives)} cases")
print("(Predicted interaction, but drugs DON'T actually interact)")
print("-" * 80)
if len(false_positives) > 0:
    print(false_positives.head(10).to_string(index=False))
    false_positives.to_csv(f'{base_path}false_positives.csv', index=False)
    print(f"\nAll false positives saved to: {base_path}false_positives.csv")

# False Negatives (Type II Error)
false_negatives = results_df[
    (results_df['True_Label'] == 1) & (results_df['Predicted_Label'] == 0)
]
print(f"\n\nFalse Negatives: {len(false_negatives)} cases")
print("(Predicted NO interaction, but drugs DO actually interact)")
print("-" * 80)
if len(false_negatives) > 0:
    print(false_negatives.head(10).to_string(index=False))
    false_negatives.to_csv(f'{base_path}false_negatives.csv', index=False)
    print(f"\nAll false negatives saved to: {base_path}false_negatives.csv")


print("\n" + "=" * 80)
print("PREDICTION CONFIDENCE ANALYSIS")
print("=" * 80)

# High confidence correct predictions
high_conf_correct = results_df[
    (results_df['Correct'] == 1) &
    ((results_df['Prediction_Score'] > 0.9) | (results_df['Prediction_Score'] < 0.1))
]
print(f"\nHigh Confidence Correct: {len(high_conf_correct)} cases")

# Low confidence predictions (uncertain)
uncertain = results_df[
    (results_df['Prediction_Score'] > 0.4) &
    (results_df['Prediction_Score'] < 0.6)
]
print(f"Uncertain Predictions (0.4-0.6): {len(uncertain)} cases")
if len(uncertain) > 0:
    print("\nSample Uncertain Predictions:")
    print(uncertain.head(10).to_string(index=False))
    uncertain.to_csv(f'{base_path}uncertain_predictions.csv', index=False)

# ======================= STATISTICS SUMMARY =======================

print("\n" + "=" * 80)
print("STATISTICS SUMMARY")
print("=" * 80)

print(f"\nTotal Test Samples: {len(results_df)}")
print(f"  - Positive (interact): {int(labels.sum())}")
print(f"  - Negative (no interact): {len(labels) - int(labels.sum())}")

print(f"\nCorrect Predictions: {results_df['Correct'].sum()} ({results_df['Correct'].mean()*100:.2f}%)")
print(f"Incorrect Predictions: {len(results_df) - results_df['Correct'].sum()} ({(1-results_df['Correct'].mean())*100:.2f}%)")

print(f"\nPrediction Score Statistics:")
print(f"  Mean: {results_df['Prediction_Score'].mean():.4f}")
print(f"  Std:  {results_df['Prediction_Score'].std():.4f}")
print(f"  Min:  {results_df['Prediction_Score'].min():.4f}")
print(f"  Max:  {results_df['Prediction_Score'].max():.4f}")


summary_report = {
    'Best_Epoch': E,
    'Accuracy': float(m1.result().numpy()),
    'Precision': float(M2),
    'Recall': float(M3),
    'F1_Score': float(F1),
    'ROC_AUC': float(test_acc[0]),
    'PR_AUC': float(test_acc[1]),
    'Total_Samples': len(results_df),
    'True_Positives': int(cm[1,1]),
    'True_Negatives': int(cm[0,0]),
    'False_Positives': int(cm[0,1]),
    'False_Negatives': int(cm[1,0]),
    'High_Confidence_Correct': len(high_conf_correct),
    'Uncertain_Predictions': len(uncertain)
}

# Save as JSON
import json
with open(f'{base_path}evaluation_summary.json', 'w') as f:
    json.dump(summary_report, f, indent=4)

print(f"\n\nEvaluation summary saved to: {base_path}evaluation_summary.json")
print("=" * 80)

Number of drugs: 1709
Number of substructures: 43655
Hypergraph structure:
Graph(num_nodes={'edge': 1709, 'node': 43655},
      num_edges={('edge', 'con', 'node'): 91615, ('node', 'in', 'edge'): 91615},
      metagraph=[('edge', 'node', 'con'), ('node', 'edge', 'in')])
Train positive edges: 153501
Train negative edges: 153501
Validation positive edges: 19187
Validation negative edges: 19187
Test positive edges: 19189
Test negative edges: 19189
Loaded embeddings from epoch 469
Embedding shape: torch.Size([1709, 128])
AGGREGATE PERFORMANCE METRICS
Best Epoch: 469
Accuracy:   0.9313
Precision:  0.9196
Recall:     0.9469
F1-Score:   0.9330
ROC-AUC:    0.9842
PR-AUC:     0.9841

Confusion Matrix:
                Predicted Negative    Predicted Positive
Actual Negative         17600                1589
Actual Positive          1019               18170

INDIVIDUAL PREDICTIONS

All predictions saved to: /content/drive/MyDrive/MLHygnn/DB/Model-K12/Stage1_Chemical_NetworkModel1/fixed_conservativ