In [None]:
!pip install chardet

In [None]:
!pip install seaborn

In [None]:
!!pip install torch-geometric

In [None]:
!pip install torch

In [None]:
!pip install networkx

In [None]:
!pip install tensorflow

In [None]:
!pip install stellargraph

In [None]:
# Core libraries
import os
import time
import json
import functools
import itertools
import inspect
import traceback

# Numerical and Data Handling
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.gridspec as gridspec
from matplotlib.colors import LinearSegmentedColormap
import networkx as nx

# Progress and System Utilities
from tqdm import tqdm
import psutil

# Scikit-learn
from sklearn import model_selection
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    confusion_matrix,
    auc,
    precision_recall_curve,
    roc_curve,
    accuracy_score,
    f1_score,
    matthews_corrcoef,
    mean_squared_error,
    recall_score,
    roc_auc_score,
)

# TensorFlow / Keras
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import binary_crossentropy
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

# StellarGraph
import stellargraph as sg
from stellargraph import StellarGraph, IndexedArray
from stellargraph.mapper import PaddedGraphGenerator
from stellargraph.layer import GCNSupervisedGraphClassification, DeepGraphCNN

# PyTorch / PyG
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, BatchNorm


def hla_df_to_dic(hla):
    """Convert HLA dataframe to dictionary mapping HLA alleles to pseudo-sequences"""
    dic = {}
    for i in range(hla.shape[0]):
        col1 = hla['HLA'].iloc[i]  # HLA allele
        col2 = hla['pseudo'].iloc[i]  # pseudo sequence
        dic[col1] = col2
    return dic


def dict_inventory(inventory):
    """Create inventory of HLA alleles grouped by type and first two digits"""
    dicA, dicB, dicC = {}, {}, {}
    dic = {'A': dicA, 'B': dicB, 'C': dicC}

    for hla in inventory:
        type_ = hla[4]  # A, B, C
        first2 = hla[6:8]  # 01
        last2 = hla[8:]  # 01
        try:
            dic[type_][first2].append(last2)
        except KeyError:
            dic[type_][first2] = []
            dic[type_][first2].append(last2)

    return dic


def rescue_unknown_hla(hla, dic_inventory):
    """Find the closest HLA allele in the inventory for unknown HLAs"""
    type_ = hla[4]
    first2 = hla[6:8]
    last2 = hla[8:]
    big_category = dic_inventory[type_]

    if first2 in big_category:
        small_category = big_category[first2]
        distance = [abs(int(last2) - int(i)) for i in small_category]
        optimal = min(zip(small_category, distance), key=lambda x: x[1])[0]
        return 'HLA-' + str(type_) + '*' + str(first2) + str(optimal)
    else:
        small_category = list(big_category.keys())
        distance = [abs(int(first2) - int(i)) for i in small_category]
        optimal = min(zip(small_category, distance), key=lambda x: x[1])[0]
        return 'HLA-' + str(type_) + '*' + str(optimal) + str(big_category[optimal][0])


class Graph_Constructor:
    """Construct graph representations of peptide-HLA interactions"""

    @staticmethod
    def combinator(pep, hla):
        """Create source and target node names for peptide and HLA sequences"""
        source = ['p' + str(i+1) for i in range(len(pep))]
        target = ['h' + str(i+1) for i in range(len(hla))]
        return source, target

    @staticmethod
    def numerical(pep, hla, after_pca, embed=12):
        """Convert amino acid sequences to numerical features using PCA matrix"""
        pep = pep.replace('X','-').upper()
        hla = hla.replace('X','-').upper()
        feature_array_pep = np.empty([len(pep), embed])
        feature_array_hla = np.empty([len(hla), embed])
        amino = 'ARNDCQEGHILKMFPSTWYV-'

        for i in range(len(pep)):
            feature_array_pep[i,:] = after_pca[amino.index(pep[i]),:]
        for i in range(len(hla)):
            feature_array_hla[i,:] = after_pca[amino.index(hla[i]),:]

        feature_array = np.concatenate([feature_array_pep, feature_array_hla], axis=0)
        return feature_array

    @staticmethod
    def unweight_edge(pep, hla, after_pca):
        """Create unweighted edges between peptide and HLA nodes"""
        source, target = Graph_Constructor.combinator(pep, hla)
        combine = list(itertools.product(source, target))
        weight = itertools.repeat(1, len(source) * len(target))

        edges = pd.DataFrame({
            'source': [item[0] for item in combine],
            'target': [item[1] for item in combine],
            'weight': weight
        })

        feature_array = Graph_Constructor.numerical(pep, hla, after_pca)
        try:
            nodes = IndexedArray(feature_array, index=source+target)
        except:
            print(pep, hla, feature_array.shape)

        graph = StellarGraph(nodes, edges, node_type_default='corner', edge_type_default='line')
        return graph

    @staticmethod
    def weight_anchor_edge(pep, hla, after_pca):
        """Create edges with weights for specific anchoring positions"""
        source, target = Graph_Constructor.combinator(pep, hla)
        combine = list(itertools.product(source, target))
        weight = itertools.repeat(1, len(source) * len(target))

        edges = pd.DataFrame({
            'source': [item[0] for item in combine],
            'target': [item[1] for item in combine],
            'weight': weight
        })

        # Add higher weights to anchor positions
        for i in range(edges.shape[0]):
            col1 = edges.iloc[i]['source']
            col2 = edges.iloc[i]['target']
            col3 = edges.iloc[i]['weight']
            if col1 == 'a2' or col1 == 'a9' or col1 ==  'a10':
                edges.iloc[i]['weight'] = 1.5

        feature_array = Graph_Constructor.numerical(pep, hla, after_pca)
        nodes = IndexedArray(feature_array, index=source+target)
        graph = StellarGraph(nodes, edges, node_type_default='corner', edge_type_default='line')

        return graph

    @staticmethod
    def intra_and_inter(pep, hla, after_pca):
        """Create both inter-sequence and intra-sequence edges"""
        source, target = Graph_Constructor.combinator(pep, hla)

        # Inter-sequence edges (between peptide and HLA)
        combine = list(itertools.product(source, target))
        weight = itertools.repeat(2, len(source) * len(target))
        edges_inter = pd.DataFrame({
            'source': [item[0] for item in combine],
            'target': [item[1] for item in combine],
            'weight': weight
        })

        # Intra-sequence edges (within peptide or HLA)
        intra_pep = list(itertools.combinations(source, 2))
        intra_hla = list(itertools.combinations(target, 2))
        intra = intra_pep + intra_hla
        weight = itertools.repeat(1, len(intra))

        edges_intra = pd.DataFrame({
            'source': [item[0] for item in intra],
            'target': [item[1] for item in intra],
            'weight': weight
        })

        # Combine all edges
        edges = pd.concat([edges_inter, edges_intra])
        edges = edges.set_index(pd.Index(np.arange(edges.shape[0])))

        feature_array = Graph_Constructor.numerical(pep, hla, after_pca)
        nodes = IndexedArray(feature_array, index=source+target)
        graph = StellarGraph(nodes, edges, node_type_default='corner', edge_type_default='line')

        return graph

    @staticmethod
    def entrance(df, after_pca, hla_dic, dic_inventory, graph_type='intra_and_inter'):
        """Process a dataset of peptide-HLA pairs to create graphs"""
        graphs = []
        graph_labels = []

        # Ensure labels are numeric
        if df['immunogenicity'].dtype == 'object':
            # Convert string labels to numeric
            if isinstance(df['immunogenicity'].iloc[0], str):
                label_map = {'Positive': 1, 'Negative': 0}
                if 'immunogenicity' in df.columns:
                    df['immunogenicity'] = df['immunogenicity'].map(
                        lambda x: label_map.get(x, 1 if x != 'Negative' else 0)
                    )

        for i in range(df.shape[0]):
            if i % 100 == 0:
                print(f"Processing sample {i}/{df.shape[0]}")

            pep = df['peptide'].iloc[i]
            try:
                hla = hla_dic[df['HLA'].iloc[i]]
            except KeyError:
                hla = hla_dic[rescue_unknown_hla(df['HLA'].iloc[i], dic_inventory)]

            label = float(df['immunogenicity'].iloc[i])  # Ensure label is a float

            # Create graph based on specified type
            if graph_type == 'unweight_edge':
                graph = Graph_Constructor.unweight_edge(pep, hla, after_pca)
            elif graph_type == 'weight_anchor_edge':
                graph = Graph_Constructor.weight_anchor_edge(pep, hla, after_pca)
            else:  # Default to intra_and_inter
                graph = Graph_Constructor.intra_and_inter(pep, hla, after_pca)

            graphs.append(graph)
            graph_labels.append(label)

        graph_labels = pd.Series(graph_labels)
        return graphs, graph_labels

def plot_separate_training_history(history, save_path=None):
    """
    Plot training history metrics as separate plots for paper inclusion

    Args:
        history: Keras history object
        save_path: Path to save the plots (optional)
    """
    history_df = pd.DataFrame(history.history)

    # Apply smoothing for visualization
    def smooth_curve(points, factor=0.8):
        smoothed_points = []
        for point in points:
            if smoothed_points:
                prev = smoothed_points[-1]
                smoothed_points.append(prev * factor + point * (1 - factor))
            else:
                smoothed_points.append(point)
        return smoothed_points

    # Determine accuracy key
    if 'acc' in history_df.columns:
        acc_key = 'acc'
        val_acc_key = 'val_acc'
    elif 'accuracy' in history_df.columns:
        acc_key = 'accuracy'
        val_acc_key = 'val_accuracy'
    else:
        acc_key = None

    # Plot 1: Training and Validation Loss
    plt.figure(figsize=(8, 6))
    plt.plot(history_df['loss'], label='Training loss', linewidth=2)
    plt.plot(smooth_curve(history_df['loss'].values),
            linestyle='--', alpha=0.7, color='blue', label='Smoothed training loss')

    if 'val_loss' in history_df.columns:
        plt.plot(history_df['val_loss'], label='Validation loss', linewidth=2)
        plt.plot(smooth_curve(history_df['val_loss'].values),
                linestyle='--', alpha=0.7, color='orange', label='Smoothed validation loss')

    plt.title('Training and Validation Loss', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.legend(fontsize=10)
    plt.grid(True, alpha=0.3)
    if save_path:
        plt.savefig(f"{save_path}_loss.png", dpi=300, bbox_inches='tight')
    plt.show()

    # Plot 2: Training and Validation Accuracy
    if acc_key:
        plt.figure(figsize=(8, 6))
        plt.plot(history_df[acc_key], label=f'Training accuracy', linewidth=2)

        if val_acc_key in history_df.columns:
            plt.plot(history_df[val_acc_key], label=f'Validation accuracy', linewidth=2)

        plt.title('Training and Validation Accuracy', fontsize=14)
        plt.xlabel('Epoch', fontsize=12)
        plt.ylabel('Accuracy', fontsize=12)
        plt.legend(fontsize=10)
        plt.grid(True, alpha=0.3)
        if save_path:
            plt.savefig(f"{save_path}_accuracy.png", dpi=300, bbox_inches='tight')
        plt.show()

    # Plot 3: Loss vs Accuracy
    if acc_key:
        plt.figure(figsize=(8, 6))
        plt.scatter(history_df['loss'], history_df[acc_key], alpha=0.7, label='Training', s=70)

        if 'val_loss' in history_df.columns and val_acc_key in history_df.columns:
            plt.scatter(history_df['val_loss'], history_df[val_acc_key], alpha=0.7, label='Validation', s=70)

        plt.title('Loss vs Accuracy', fontsize=14)
        plt.xlabel('Loss', fontsize=12)
        plt.ylabel('Accuracy', fontsize=12)
        plt.legend(fontsize=10)
        plt.grid(True, alpha=0.3)
        if save_path:
            plt.savefig(f"{save_path}_loss_vs_accuracy.png", dpi=300, bbox_inches='tight')
        plt.show()

    # Plot 4: Learning Rate Improvement
    if len(history_df) > 5:  # Need at least 5 epochs for rolling average
        plt.figure(figsize=(8, 6))
        plt.plot(history_df.index, history_df['loss'].pct_change().rolling(5).mean(),
                label='Loss improvement rate', linewidth=2)

        if acc_key in history_df.columns:
            plt.plot(history_df.index[1:], history_df[acc_key].pct_change().rolling(5).mean()[1:],
                    label=f'Accuracy improvement rate', linewidth=2)

        plt.title('Learning Rate Improvement', fontsize=14)
        plt.xlabel('Epoch', fontsize=12)
        plt.ylabel('Improvement Rate', fontsize=12)
        plt.legend(fontsize=10)
        plt.grid(True, alpha=0.3)
        if save_path:
            plt.savefig(f"{save_path}_improvement_rate.png", dpi=300, bbox_inches='tight')
        plt.show()


def plot_confusion_matrix(y_true, y_pred, title='Confusion Matrix', save_path=None):
    """
    Plot a proper confusion matrix with labels

    Args:
        y_true: True labels
        y_pred: Predicted labels
        title: Title for the plot
        save_path: Path to save the plot (optional)
    """
    cm = confusion_matrix(y_true, y_pred)

    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True)

    plt.title(title, fontsize=14)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)

#     # Add text annotations for TN, FP, FN, TP
#     categories = ['TN', 'FP', 'FN', 'TP']
#     categories_indices = [(0,0), (0,1), (1,0), (1,1)]

#     for i, (category, idx) in enumerate(zip(categories, categories_indices)):
#         plt.text(idx[1] + 0.5, idx[0] + 0.5, category,
#                 horizontalalignment='center', verticalalignment='center')

    if save_path:
        plt.savefig(f"{save_path}_confusion_matrix.png", dpi=300, bbox_inches='tight')
    plt.tight_layout()
    plt.show()

    # Calculate metrics
    tn, fp, fn, tp = cm.ravel()
    accuracy = (tp + tn) / (tp + tn + fp + fn)
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")


def validate_on_test_datasets(model, after_pca, hla_dic, dic_inventory):
    """Validate the model on multiple test datasets"""

    def model_predict(model, graphs_list, labels):
        """Convert StellarGraph objects to PyTorch Geometric and run prediction"""

        # Check if it's a PyTorch model or TensorFlow/Keras model
        is_pytorch = hasattr(model, 'eval') and callable(getattr(model, 'eval'))

        if is_pytorch:
            # PyTorch model logic (convert to PyTorch Geometric)
            torch_graphs = []
            for i, (g, lbl) in enumerate(zip(graphs_list, labels)):
                try:
                    data = stellargraph_to_torch_data(g, lbl)
                    torch_graphs.append(data)
                except Exception as e:
                    print(f"Error converting graph {i}: {e}")
                    continue

            # Create dataloader
            data_loader = DataLoader(
                torch_graphs,
                batch_size=16,
                shuffle=False,
                collate_fn=lambda data_list: Batch.from_data_list(data_list)
            )

            # Make predictions
            model.eval()
            all_preds = []
            with torch.no_grad():
                for data in data_loader:
                    data = data.to(next(model.parameters()).device)
                    # Forward pass
                    outputs, _, _ = model(data.x, data.edge_index, data.edge_attr, data.batch)
                    # Get probabilities
                    probs = F.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                    all_preds.extend(probs)

            return np.array(all_preds)
        else:
            # TensorFlow/Keras model logic (use generator.flow)
            generator = PaddedGraphGenerator(graphs=graphs_list)
            data_gen = generator.flow(
                np.arange(len(graphs_list)),
                targets=labels.values.astype(np.float32)
            )
            return model.predict(data_gen).flatten()

        # Create dataloader
        data_loader = DataLoader(
            torch_graphs,
            batch_size=16,
            shuffle=False,
            collate_fn=lambda data_list: Batch.from_data_list(data_list)
        )

        # Make predictions
        model.eval()
        all_preds = []
        with torch.no_grad():
            for data in data_loader:
                data = data.to(next(model.parameters()).device)
                # Forward pass
                outputs, _, _ = model(data.x, data.edge_index, data.edge_attr, data.batch)
                # Get probabilities
                probs = F.softmax(outputs, dim=1)[:, 1].cpu().numpy()
                all_preds.extend(probs)

        return np.array(all_preds)

    results = {}

    # 1. Dengue Dataset Accuracy
    print("\n=== Dengue Dataset Accuracy ===")
    dengue_data = pd.read_csv('dengue_test.csv')

    # Ensure immunogenicity is numeric
    dengue_data['immunogenicity'] = pd.to_numeric(dengue_data['immunogenicity'], errors='coerce').fillna(0)

    dengue_graphs, dengue_labels = Graph_Constructor.entrance(dengue_data, after_pca, hla_dic, dic_inventory)

    # Use custom predict function
    dengue_pred_prob = model_predict(model, dengue_graphs, dengue_labels)
    dengue_pred = (dengue_pred_prob > 0.5).astype(int)
    dengue_acc = accuracy_score(dengue_labels.values, dengue_pred)
    dengue_mcc = matthews_corrcoef(dengue_labels.values, dengue_pred)

    results['dengue_acc'] = dengue_acc
    results['dengue_mcc'] = dengue_mcc

    print(f"Dengue dataset Accuracy: {dengue_acc:.3f}")
    print(f"Dengue dataset MCC: {dengue_mcc:.3f}")

    # Plot confusion matrix
    plot_confusion_matrix(dengue_labels.values, dengue_pred, title='Dengue Dataset Confusion Matrix',
                          save_path='dengue')

    # 2. Cell Dataset Recall
    print("\n=== Cell Dataset Recall ===")
    cell_data = pd.read_csv('ori_test_cells.csv')

    # Ensure immunogenicity is numeric
    cell_data['immunogenicity'] = pd.to_numeric(cell_data['immunogenicity'], errors='coerce').fillna(0)

    cell_graphs, cell_labels = Graph_Constructor.entrance(cell_data, after_pca, hla_dic, dic_inventory)

    # Use custom predict function
    cell_pred_prob = model_predict(model, cell_graphs, cell_labels)
    cell_pred = (cell_pred_prob > 0.5).astype(int)
    cell_recall = recall_score(cell_labels.values, cell_pred)
    cell_mcc = matthews_corrcoef(cell_labels.values, cell_pred)

    results['cell_recall'] = cell_recall
    results['cell_mcc'] = cell_mcc

    print(f"Cell dataset Recall: {cell_recall:.3f}")
    print(f"Cell dataset MCC: {cell_mcc:.3f}")

    # Plot confusion matrix
    plot_confusion_matrix(cell_labels.values, cell_pred, title='Cell Dataset Confusion Matrix',
                          save_path='cell')

    # 3-4. COVID Dataset (Convalescent and Unexposed)
    print("\n=== COVID Dataset ===")
    covid_data = pd.read_csv('sars_cov_2_result.csv')

    # Process convalescent samples
    conv_data = covid_data.copy()
    # Make sure immunogenicity-con is in the dataset and convert to numeric
    if 'immunogenicity-con' in conv_data.columns:
        conv_data['immunogenicity'] = pd.to_numeric(conv_data['immunogenicity-con'], errors='coerce').fillna(0)
    else:
        print("Warning: 'immunogenicity-con' column not found in COVID dataset")
        conv_data['immunogenicity'] = 0  # Default if column not found

    conv_graphs, conv_labels = Graph_Constructor.entrance(conv_data, after_pca, hla_dic, dic_inventory)

    # Use custom predict function
    conv_pred_prob = model_predict(model, conv_graphs, conv_labels)
    conv_pred = (conv_pred_prob > 0.5).astype(int)
    conv_recall = recall_score(conv_labels.values, conv_pred)
    conv_mcc = matthews_corrcoef(conv_labels.values, conv_pred)

    results['covid_conv_recall'] = conv_recall
    results['covid_conv_mcc'] = conv_mcc

    print(f"COVID dataset (Convalescent) Recall: {conv_recall:.3f}")
    print(f"COVID dataset (Convalescent) MCC: {conv_mcc:.3f}")

    # Plot confusion matrix
    plot_confusion_matrix(conv_labels.values, conv_pred, title='COVID Dataset (Convalescent) Confusion Matrix',
                          save_path='covid_conv')

    # Process unexposed samples
    unexp_data = covid_data.copy()
    # Make sure immunogenicity-un is in the dataset and convert to numeric
    if 'immunogenicity-un' in unexp_data.columns:
        unexp_data['immunogenicity'] = pd.to_numeric(unexp_data['immunogenicity-un'], errors='coerce').fillna(0)
    else:
        print("Warning: 'immunogenicity-un' column not found in COVID dataset")
        unexp_data['immunogenicity'] = 0  # Default if column not found

    unexp_graphs, unexp_labels = Graph_Constructor.entrance(unexp_data, after_pca, hla_dic, dic_inventory)

    # Use custom predict function
    unexp_pred_prob = model_predict(model, unexp_graphs, unexp_labels)
    unexp_pred = (unexp_pred_prob > 0.5).astype(int)
    unexp_recall = recall_score(unexp_labels.values, unexp_pred)
    unexp_mcc = matthews_corrcoef(unexp_labels.values, unexp_pred)

    results['covid_unexp_recall'] = unexp_recall
    results['covid_unexp_mcc'] = unexp_mcc

    print(f"COVID dataset (Unexposed) Recall: {unexp_recall:.3f}")
    print(f"COVID dataset (Unexposed) MCC: {unexp_mcc:.3f}")

    # Plot confusion matrix
    plot_confusion_matrix(unexp_labels.values, unexp_pred, title='COVID Dataset (Unexposed) Confusion Matrix',
                          save_path='covid_unexp')

    # 5. DeepHLApan dataset
    print("\n=== DeepHLApan Dataset ===")
    deephlapan_data = pd.read_csv('deephlapan_result_cell.csv')

    # Ensure immunogenicity is numeric
    deephlapan_data['immunogenicity'] = pd.to_numeric(deephlapan_data['immunogenicity'], errors='coerce').fillna(0)

    deephlapan_graphs, deephlapan_labels = Graph_Constructor.entrance(deephlapan_data, after_pca, hla_dic, dic_inventory)

    # Use custom predict function
    deephlapan_pred_prob = model_predict(model, deephlapan_graphs, deephlapan_labels)
    deephlapan_pred = (deephlapan_pred_prob > 0.5).astype(int)
    deephlapan_acc = accuracy_score(deephlapan_labels.values, deephlapan_pred)
    deephlapan_recall = recall_score(deephlapan_labels.values, deephlapan_pred)
    deephlapan_mcc = matthews_corrcoef(deephlapan_labels.values, deephlapan_pred)

    results['deephlapan_acc'] = deephlapan_acc
    results['deephlapan_recall'] = deephlapan_recall
    results['deephlapan_mcc'] = deephlapan_mcc

    print(f"DeepHLApan dataset Accuracy: {deephlapan_acc:.3f}")
    print(f"DeepHLApan dataset Recall: {deephlapan_recall:.3f}")
    print(f"DeepHLApan dataset MCC: {deephlapan_mcc:.3f}")

    # Plot confusion matrix
    plot_confusion_matrix(deephlapan_labels.values, deephlapan_pred, title='DeepHLApan Dataset Confusion Matrix',
                          save_path='deephlapan')

    # Summary table
    print("\n=== Summary of Results ===")
    print(f"Dengue dataset Accuracy: {dengue_acc:.3f}, MCC: {dengue_mcc:.3f}")
    print(f"Cell dataset Recall: {cell_recall:.3f}, MCC: {cell_mcc:.3f}")
    print(f"COVID dataset (Convalescent) Recall: {conv_recall:.3f}, MCC: {conv_mcc:.3f}")
    print(f"COVID dataset (Unexposed) Recall: {unexp_recall:.3f}, MCC: {unexp_mcc:.3f}")
    print(f"DeepHLApan dataset Accuracy: {deephlapan_acc:.3f}, Recall: {deephlapan_recall:.3f}, MCC: {deephlapan_mcc:.3f}")

    return results

def run_validation():
    """Run validation on all test datasets"""

    # Load data
    print("Loading data...")
    hla_df = pd.read_csv('hla2paratopeTable_aligned.txt', sep='\t')
    after_pca = np.loadtxt('after_pca.txt')

    # Process HLA data
    hla_dic = hla_df_to_dic(hla_df)
    inventory = list(hla_dic.keys())
    dic_inventory = dict_inventory(inventory)

    # Load the trained model with custom objects
    model = load_model_with_custom_objects("hla_peptide_model.h5")

    # Run validation on test datasets
    print("Running validation on test datasets...")
    results = validate_on_test_datasets(model, after_pca, hla_dic, dic_inventory)

    # Save results to CSV
    results_df = pd.DataFrame([results])
    results_df.to_csv('validation_results.csv', index=False)
    print("Validation results saved to validation_results.csv")

def save_history_to_log(history, filepath="History_IDGL.log"):
    """Save training history to a log file"""
    with open(filepath, 'w') as f:
        f.write("IDGL Training History\n")
        f.write("====================\n\n")

        # Format each metric
        for key in history:
            f.write(f"{key}:\n")
            for i, value in enumerate(history[key]):
                f.write(f"Epoch {i+1}: {value}\n")
            f.write("\n")

    print(f"History saved to {filepath}")

# After training is complete and model is saved
def run_validation_after_training(model, after_pca, hla_dic, dic_inventory):
    print("\n=== Running Validation on Test Datasets ===")
    results = validate_on_test_datasets(model, after_pca, hla_dic, dic_inventory)
    return results



# === Deep GCN Classifier with Residual Connections ===
class DeepGCNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims=[64, 64, 32], dropout=0.3):
        """
        Deep GCN classifier with residual connections and batch normalization

        Args:
            input_dim: Input feature dimension
            hidden_dims: List of hidden layer dimensions
            dropout: Dropout rate
        """
        super(DeepGCNClassifier, self).__init__()

        # Initialize layer lists
        self.gcn_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        self.skip_proj = nn.ModuleList()
        self.dropout = dropout

        # Build network with residual connections
        prev_dim = input_dim
        for h_dim in hidden_dims:
            # GCN layer
            self.gcn_layers.append(GCNConv(prev_dim, h_dim))

            # Batch normalization
            self.norm_layers.append(BatchNorm(h_dim))

            # Projection for residual connection if dimensions don't match
            if prev_dim != h_dim:
                self.skip_proj.append(nn.Linear(prev_dim, h_dim))
            else:
                self.skip_proj.append(nn.Identity())

            prev_dim = h_dim

        # Final classification layer
        self.classifier = nn.Linear(hidden_dims[-1], 2)

    def forward(self, x, edge_index, batch):
        """
        Forward pass

        Args:
            x: Node features [num_nodes, input_dim]
            edge_index: Graph connectivity [2, num_edges]
            batch: Batch indices [num_nodes]

        Returns:
            out: Classification logits [batch_size, 2]
        """
        # Pass through GCN layers with residual connections
        for conv, norm, skip in zip(self.gcn_layers, self.norm_layers, self.skip_proj):
            h_in = x  # Store for residual connection

            # GCN layer
            x = conv(x, edge_index)

            # Normalization and activation
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)

            # Residual connection
            x = x + skip(h_in)

        # Global pooling
        x = global_mean_pool(x, batch)

        # Classification
        out = self.classifier(x)
        return out

# === Conversion from StellarGraph to PyTorch Geometric ===
def stellargraph_to_torch_data(graph, label):
    """
    Convert a StellarGraph to a PyTorch Geometric Data object

    Args:
        graph: StellarGraph object
        label: Label for the graph

    Returns:
        data: PyTorch Geometric Data object
    """
    # Extract adjacency matrix
    adj = graph.to_adjacency_matrix(weighted=True).todense()

    # Extract edge indices from adjacency matrix
    edge_index = torch.tensor(np.vstack(adj.nonzero()), dtype=torch.long)

    # Extract edge weights if present
    edge_weights = np.array(adj[adj.nonzero()])
    edge_attr = torch.tensor(edge_weights, dtype=torch.float32).squeeze()

    # Extract node features
    x = torch.tensor(np.array(graph.node_features()), dtype=torch.float32)

    # Create target
    y = torch.tensor(int(label), dtype=torch.long)

    # Create Data object
    data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y)

    return data

# === Deep GCN Classifier with Residual Connections ===
class DeepGCNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims=[64, 64, 32], dropout=0.3):
        super(DeepGCNClassifier, self).__init__()
        self.gcn_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        self.skip_proj = nn.ModuleList()
        self.dropout = dropout

        # Build network with residual connections
        prev_dim = input_dim
        for h_dim in hidden_dims:
            self.gcn_layers.append(GCNConv(prev_dim, h_dim))
            self.norm_layers.append(BatchNorm(h_dim))
            if prev_dim != h_dim:
                self.skip_proj.append(nn.Linear(prev_dim, h_dim))
            else:
                self.skip_proj.append(nn.Identity())
            prev_dim = h_dim

        # Final classification layer
        self.classifier = nn.Linear(hidden_dims[-1], 2)

    def forward(self, x, edge_index, edge_attr, batch):
        # Pass through GCN layers with residual connections
        for conv, norm, skip in zip(self.gcn_layers, self.norm_layers, self.skip_proj):
            h_in = x  # Store for residual connection
            x = conv(x, edge_index, edge_attr)
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + skip(h_in)  # Residual connection

        # Global pooling
        x = global_mean_pool(x, batch)

        # Classification
        out = self.classifier(x)
        return out

class GraphLearningEnhanced(nn.Module):
    """Enhanced Graph Learning Module with anchor-based similarity for peptide-HLA connections"""
    def __init__(self, input_dim, hidden_dim=64, k=10, lambda_sparse=0.1, lambda_smooth=0.1, num_heads=4, epsilon=0.5):
        super(GraphLearningEnhanced, self).__init__()
        self.k = k
        self.lambda_sparse = lambda_sparse
        self.lambda_smooth = lambda_smooth
        self.epsilon = epsilon
        self.num_heads = num_heads

        # Multiple attention heads for similarity computation
        self.att_weights = nn.ModuleList([
            nn.Linear(input_dim, hidden_dim) for _ in range(num_heads)
        ])

        # Attention scores for weighted similarity combination
        self.att_scores = nn.Parameter(torch.ones(num_heads, 1) / num_heads)

        # Learnable affinity transformation
        self.affinity_transform = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

        # Flag to control similarity computation method
        self.use_anchors = True  # Set to False to use full pairwise similarity

    def compute_similarity(self, x_graph, batch=None):
        """
        Compute node similarities using either full pairwise or anchor-based approach

        Args:
            x_graph: Node features [num_nodes, input_dim]
            batch: Batch indices [num_nodes] (optional)

        Returns:
            sim_matrix: Similarity matrix [num_nodes, num_nodes]
        """
        num_nodes = x_graph.size(0)

        # If graph is very small, use full similarity computation
        if num_nodes < 30 or not self.use_anchors:
            # Full pairwise similarity computation (original approach)
            similarity_matrices = []
            for head in range(self.num_heads):
                # Project features
                x_proj = self.att_weights[head](x_graph)
                x_proj = F.normalize(x_proj, p=2, dim=1)  # L2 normalization

                # Compute cosine similarity
                sim = torch.mm(x_proj, x_proj.t())
                similarity_matrices.append(sim)

            # Stack and apply attention weights for weighted combination
            sim_stack = torch.stack(similarity_matrices)  # [num_heads, n_nodes, n_nodes]
            attention_weights = F.softmax(self.att_scores, dim=0)  # [num_heads, 1]

            # Weighted average of similarity matrices
            sim_matrix = torch.sum(sim_stack * attention_weights.view(-1, 1, 1), dim=0)

        else:
            # Anchor-based similarity computation (more efficient)

            # 1. Select anchor nodes (simplified - just select ~sqrt(n) anchors)
            import math
            num_anchors = min(16, max(4, int(math.sqrt(num_nodes))))

            # Simple uniform sampling of anchors
            perm = torch.randperm(num_nodes, device=x_graph.device)
            anchor_idx = perm[:num_anchors]
            anchors = x_graph[anchor_idx]

            # 2. Compute node-anchor similarities
            node_anchor_sim = []

            for head in range(self.num_heads):
                # Project features and anchors
                x_proj = self.att_weights[head](x_graph)
                a_proj = self.att_weights[head](anchors)

                # Normalize
                x_proj = F.normalize(x_proj, p=2, dim=1)
                a_proj = F.normalize(a_proj, p=2, dim=1)

                # Compute node-anchor similarity
                sim = torch.mm(x_proj, a_proj.t())  # [n_nodes, n_anchors]
                node_anchor_sim.append(sim)

            # Stack and apply attention weights
            sim_stack = torch.stack(node_anchor_sim)  # [num_heads, n_nodes, n_anchors]
            attention_weights = F.softmax(self.att_scores, dim=0)  # [num_heads, 1]

            # Weighted average of similarity matrices
            combined_sim = torch.sum(sim_stack * attention_weights.view(-1, 1, 1), dim=0)

            # 3. Approximate full similarity matrix
            sim_matrix = torch.mm(combined_sim, combined_sim.t())  # [n_nodes, n_nodes]

            # Make sure diagonal is 1.0
            sim_matrix.fill_diagonal_(1.0)

        return sim_matrix

    def forward(self, x, batch=None):
        batch_size = 1
        if batch is None:
            batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device)

        # Process each graph in the batch separately
        unique_batches = batch.unique()
        edge_indices_list = []
        edge_weights_list = []
        sparse_loss_list = []
        smooth_loss_list = []

        for b in unique_batches:
            # Get nodes for this graph
            mask = (batch == b)
            x_graph = x[mask]
            indices = torch.nonzero(mask).squeeze()

            # Identify peptide and HLA nodes
            num_nodes = x_graph.size(0)
            peptide_nodes = list(range(num_nodes // 3))
            hla_nodes = list(range(num_nodes // 3, num_nodes))

            # Dynamically adjust parameters based on graph properties
            dynamic_k = min(self.k, max(num_nodes // 4, 5))  # Smaller k for smaller graphs
            dynamic_epsilon = 0.3  # A good threshold value

            # Compute similarity matrix using anchor-based approach
            sim_matrix = self.compute_similarity(x_graph)

            # Apply biologically-informed similarity adjustment for peptide-HLA interactions
            for p in peptide_nodes:
                for h in hla_nodes:
                    # Force higher similarity with a minimum baseline value
                    base_similarity = 0.6  # Higher base value

                    # Position-based weighting to mimic binding pockets
                    p_rel_pos = p / len(peptide_nodes) if len(peptide_nodes) > 0 else 0
                    h_rel_pos = (h - len(peptide_nodes)) / len(hla_nodes) if len(hla_nodes) > 0 else 0

                    # Create higher similarity for matching relative positions
                    position_score = 1.0 - 0.5 * abs(p_rel_pos - h_rel_pos)

                    # Boost similarity for peptide-HLA pairs substantially
                    adjusted_sim = max(sim_matrix[p, h].item(), base_similarity + 0.3 * position_score)

                    # Apply the boosted similarity
                    sim_matrix[p, h] = adjusted_sim
                    sim_matrix[h, p] = adjusted_sim  # Maintain symmetry

            # Apply ReLU and learnable affinity transformation
            sim_matrix = F.relu(sim_matrix)  # Non-negative similarities

            # Apply learnable affinity transformation
            flat_sim = sim_matrix.view(-1, 1)
            transformed_sim = self.affinity_transform(flat_sim).view(sim_matrix.shape)

            # Apply dynamic epsilon-neighborhood sparsification
            mask = (transformed_sim > dynamic_epsilon).float() * transformed_sim

            # Apply dynamic top-k sparsification
            if dynamic_k < sim_matrix.size(0):
                # Get top-k values for each node
                values, indices = torch.topk(transformed_sim, k=min(dynamic_k, transformed_sim.size(0)), dim=1)
                topk_mask = torch.zeros_like(transformed_sim)

                # Create mask with only top-k values
                for i in range(transformed_sim.size(0)):
                    topk_mask[i, indices[i]] = values[i]

                # Apply mask
                mask = mask * topk_mask

            # Ensure adjacency matrix is symmetric
            adj = 0.5 * (mask + mask.t())

            # Check if we have peptide-HLA connections
            p_h_edges = 0
            for p in peptide_nodes:
                for h in hla_nodes:
                    if adj[p, h] > 0:
                        p_h_edges += 1

            # If insufficient peptide-HLA connections, add some
            min_expected = min(len(peptide_nodes), len(hla_nodes)) // 2
            if p_h_edges < min_expected:
                # Add connections between some peptide and HLA nodes
                for p_idx in peptide_nodes[:min(5, len(peptide_nodes))]:
                    for h_idx in hla_nodes[:min(5, len(hla_nodes))]:
                        adj[p_idx, h_idx] = 0.7  # Higher weight
                        adj[h_idx, p_idx] = 0.7  # Ensure symmetry

            # Extract edge_index and edge_weight from adjacency
            adj_sparse = adj.to_sparse()
            edge_index = adj_sparse.indices()
            edge_weight = adj_sparse.values()

            # Shift indices based on batch offset
            offset = indices.min() if indices.numel() > 0 else 0
            edge_index = edge_index + offset

            # Store for this graph
            edge_indices_list.append(edge_index)
            edge_weights_list.append(edge_weight)

            # Compute regularization terms
            # 1. Sparsity regularization (encourage sparse connections)
            sparse_loss = torch.sum(adj) / (adj.size(0) * adj.size(1))
            sparse_loss_list.append(sparse_loss)

            # 2. Smoothness regularization
            D = torch.diag(torch.sum(adj, dim=1))
            L = D - adj  # Graph Laplacian
            smooth_loss = torch.trace(torch.mm(torch.mm(x_graph.t(), L), x_graph)) / x_graph.size(0)
            smooth_loss_list.append(smooth_loss)

        # Combine all graphs' edges
        if len(edge_indices_list) > 0 and all(e.numel() > 0 for e in edge_indices_list):
            combined_edge_index = torch.cat(edge_indices_list, dim=1)
            combined_edge_weights = torch.cat(edge_weights_list)
        else:
            # Handle empty case
            combined_edge_index = torch.zeros((2, 0), dtype=torch.long, device=x.device)
            combined_edge_weights = torch.zeros(0, dtype=torch.float, device=x.device)

        # Average regularization losses
        avg_sparse_loss = sum(sparse_loss_list) / len(sparse_loss_list) if sparse_loss_list else torch.tensor(0.0).to(x.device)
        avg_smooth_loss = sum(smooth_loss_list) / len(smooth_loss_list) if smooth_loss_list else torch.tensor(0.0).to(x.device)

        return combined_edge_index, combined_edge_weights, avg_sparse_loss, avg_smooth_loss

class EnhancedIDGLFramework(nn.Module):
    """Enhanced IDGL Framework with learnable combination weights and GCN refinement"""
    def __init__(self, input_dim, hidden_dims=[64, 64, 32], dropout=0.3, k=10,
                 lambda_sparse=0.1, lambda_smooth=0.1, num_heads=4, epsilon=0.5,
                 max_iterations=5, tol=1e-4):
        super(EnhancedIDGLFramework, self).__init__()

        # Graph learning module
        self.graph_learner = GraphLearningEnhanced(
            input_dim=input_dim,
            hidden_dim=hidden_dims[0],
            k=k,
            lambda_sparse=lambda_sparse,
            lambda_smooth=lambda_smooth,
            num_heads=num_heads,
            epsilon=epsilon
        )

        # GCN for node embedding updates
        self.update_gcn = GCNConv(input_dim, input_dim)

        # GCN classifier
        self.gcn = DeepGCNClassifier(
            input_dim=input_dim,
            hidden_dims=hidden_dims,
            dropout=dropout
        )

        # Learnable lambda and eta parameters
        self.lambda_init = nn.Parameter(torch.tensor(0.8))
        self.eta = nn.Parameter(torch.tensor(0.5))

        # IDGL parameters
        self.max_iterations = max_iterations
        self.tol = tol
        self.input_dim = input_dim

    def forward(self, x, edge_index, edge_attr, batch):
        # Initial embeddings
        h = x

        # First iteration - learn graph from raw features
        edge_index_1, edge_attr_1, sparse_loss_1, smooth_loss_1 = self.graph_learner(h, batch)

        # Update node embeddings using GCN with the learned graph
        if edge_index_1.size(1) > 0:
            h = F.relu(self.update_gcn(h, edge_index_1, edge_attr_1))

        # Initialize tracking variables
        prev_sparse_loss = sparse_loss_1
        prev_smooth_loss = smooth_loss_1
        total_reg_loss = self.graph_learner.lambda_sparse * sparse_loss_1 + \
                        self.graph_learner.lambda_smooth * smooth_loss_1
        iterations = 1

        # Iterative refinement
        for t in range(2, self.max_iterations + 1):
            # Learn new graph from updated embeddings
            edge_index_t, edge_attr_t, sparse_loss_t, smooth_loss_t = self.graph_learner(h, batch)

            # Simple convergence check
            if t > 2 and (sparse_loss_t + smooth_loss_t) >= (prev_sparse_loss + prev_smooth_loss) * 0.99:
                # If loss isn't decreasing meaningfully, we've converged
                break

            # Combine with initial and first iteration graphs using learnable parameters
            if edge_index.size(1) > 0 and edge_index_t.size(1) > 0:
                # Create combined graph
                lambda_constrained = torch.sigmoid(self.lambda_init)  # Constrain to [0, 1]
                eta_constrained = torch.sigmoid(self.eta)  # Constrain to [0, 1]

                # For edge indices, concatenate and handle duplicates later
                combined_indices = torch.cat([edge_index, edge_index_1, edge_index_t], dim=1)

                # Scale edge weights according to learnable parameters
                edge_weights_initial = edge_attr * lambda_constrained
                edge_weights_1 = edge_attr_1 * (1 - lambda_constrained) * eta_constrained
                edge_weights_t = edge_attr_t * (1 - lambda_constrained) * (1 - eta_constrained)

                combined_weights = torch.cat([edge_weights_initial, edge_weights_1, edge_weights_t])

                # Update node embeddings with GCN
                if combined_indices.size(1) > 0:
                    h = F.relu(self.update_gcn(h, combined_indices, combined_weights))
            elif edge_index_t.size(1) > 0:
                # Just use the current learned graph if initial graphs are empty
                h = F.relu(self.update_gcn(h, edge_index_t, edge_attr_t))

            # Update tracking variables
            prev_sparse_loss = sparse_loss_t
            prev_smooth_loss = smooth_loss_t
            iterations += 1
            total_reg_loss += self.graph_learner.lambda_sparse * sparse_loss_t + \
                             self.graph_learner.lambda_smooth * smooth_loss_t

        # Average regularization loss
        avg_reg_loss = total_reg_loss / iterations

        # For the final classification, use the latest graph structure
        if edge_index.size(1) > 0 and edge_index_t.size(1) > 0:
            # Combine initial, first and last iteration graphs
            lambda_constrained = torch.sigmoid(self.lambda_init)
            eta_constrained = torch.sigmoid(self.eta)

            final_indices = torch.cat([edge_index, edge_index_1, edge_index_t], dim=1)

            edge_weights_initial = edge_attr * lambda_constrained
            edge_weights_1 = edge_attr_1 * (1 - lambda_constrained) * eta_constrained
            edge_weights_t = edge_attr_t * (1 - lambda_constrained) * (1 - eta_constrained)

            final_weights = torch.cat([edge_weights_initial, edge_weights_1, edge_weights_t])
        elif edge_index_t.size(1) > 0:
            final_indices = edge_index_t
            final_weights = edge_attr_t
        else:
            final_indices = edge_index
            final_weights = edge_attr

        # Classification using GCN with final graph
        out = self.gcn(x, final_indices, final_weights, batch)

        return out, final_indices, avg_reg_loss


# DeepGCNClassifier class remains the same
class DeepGCNClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dims=[64, 64, 32], dropout=0.3):
        super(DeepGCNClassifier, self).__init__()
        self.gcn_layers = nn.ModuleList()
        self.norm_layers = nn.ModuleList()
        self.skip_proj = nn.ModuleList()
        self.dropout = dropout

        # Build network with residual connections
        prev_dim = input_dim
        for h_dim in hidden_dims:
            self.gcn_layers.append(GCNConv(prev_dim, h_dim))
            self.norm_layers.append(BatchNorm(h_dim))
            if prev_dim != h_dim:
                self.skip_proj.append(nn.Linear(prev_dim, h_dim))
            else:
                self.skip_proj.append(nn.Identity())
            prev_dim = h_dim

        # Final classification layer
        self.classifier = nn.Linear(hidden_dims[-1], 2)

    def forward(self, x, edge_index, edge_attr, batch):
        # Pass through GCN layers with residual connections
        for conv, norm, skip in zip(self.gcn_layers, self.norm_layers, self.skip_proj):
            h_in = x  # Store for residual connection
            x = conv(x, edge_index, edge_attr)
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + skip(h_in)  # Residual connection

        # Global pooling
        x = global_mean_pool(x, batch)

        # Classification
        out = self.classifier(x)
        return out

def run_enhanced_idgl_training():
    """Run the enhanced IDGL training pipeline"""
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    print("Loading data...")
    hla_df = pd.read_csv('hla2paratopeTable_aligned.txt', sep='\t')
    after_pca = np.loadtxt('after_pca.txt')
    dataset = pd.read_csv('remove0123_sample100_test.csv')

    # Ensure labels are numeric
    if dataset['immunogenicity'].dtype == 'object':
        if isinstance(dataset['immunogenicity'].iloc[0], str):
            label_map = {'Positive': 1, 'Negative': 0}
            dataset['immunogenicity'] = dataset['immunogenicity'].map(
                lambda x: label_map.get(x, 1 if x != 'Negative' else 0)
            )
        dataset['immunogenicity'] = dataset['immunogenicity'].astype(int)

    # Process HLA data
    hla_dic = hla_df_to_dic(hla_df)
    inventory = list(hla_dic.keys())
    dic_inventory = dict_inventory(inventory)

    # Create StellarGraph graphs
    print("Creating StellarGraph graphs...")
    try:
        stellar_graphs, stellar_labels = Graph_Constructor.entrance(
            dataset, after_pca, hla_dic, dic_inventory, graph_type='intra_and_inter'
        )

        # Convert to PyTorch Geometric Data objects
        print("Converting to PyTorch Geometric format...")
        torch_graphs = []
        for i, (g, lbl) in enumerate(zip(stellar_graphs, stellar_labels)):
            try:
                data = stellargraph_to_torch_data(g, lbl)
                torch_graphs.append(data)
            except Exception as e:
                print(f"Error converting graph {i}: {e}")
                continue

        print(f"Successfully converted {len(torch_graphs)} graphs to PyTorch Geometric format")

        # Get input dimension from first graph
        input_dim = torch_graphs[0].x.size(1)  # Get feature dimension from first graph

        # Create model
        print("Creating enhanced IDGL model...")
        model = EnhancedIDGLFramework(
            input_dim=input_dim,
            hidden_dims=[64, 32],
            dropout=0.3,
            k=10,
            lambda_sparse=0.1,
            lambda_smooth=0.1,
            num_heads=4,
            epsilon=0.5,
            max_iterations=3
        ).to(device)

        # Create custom train/test split
        labels_array = np.array([g.y.item() for g in torch_graphs])
        train_indices, test_indices = train_test_split(
            range(len(torch_graphs)),
            test_size=0.2,
            stratify=labels_array,
            random_state=42
        )

        train_dataset = [torch_graphs[i] for i in train_indices]
        test_dataset = [torch_graphs[i] for i in test_indices]

        # Create data loaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=16,
            shuffle=True,
            collate_fn=lambda data_list: Batch.from_data_list(data_list)
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=16,
            shuffle=False,
            collate_fn=lambda data_list: Batch.from_data_list(data_list)
        )

        # Initialize optimizer
        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=0.001,
            weight_decay=5e-4
        )

        # Initialize learning rate scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=0.5,
            patience=5,
            verbose=True
        )

        # Initialize tracking variables
        best_val_acc = 0
        best_epoch = 0
        patience_counter = 0

        # Initialize history dictionary
        history = {
            'loss': [],
            'acc': [],
            'val_loss': [],
            'val_acc': [],
            'lambda': [],
            'eta': []
        }

        # Training loop
        print("Starting training...")
        epochs = 100
        for epoch in range(epochs):
            # Training
            model.train()
            total_loss = 0
            correct = 0
            total = 0

            for data in train_loader:
                data = data.to(device)
                optimizer.zero_grad()

                # Forward pass
                logits, _, reg_loss = model(data.x, data.edge_index, data.edge_attr, data.batch)

                # Compute loss
                ce_loss = F.cross_entropy(logits, data.y)
                loss = ce_loss + 0.1 * reg_loss

                # Backward pass
                loss.backward()
                optimizer.step()

                # Update metrics
                total_loss += loss.item() * data.num_graphs
                pred = logits.argmax(dim=1)
                correct += pred.eq(data.y).sum().item()
                total += data.num_graphs

            # Calculate training metrics
            avg_train_loss = total_loss / len(train_dataset)
            train_acc = correct / total

            # Validation
            model.eval()
            val_loss = 0
            val_correct = 0
            val_total = 0

            with torch.no_grad():
                for data in test_loader:
                    data = data.to(device)

                    # Forward pass with IDGL
                    logits, _, reg_loss = model(data.x, data.edge_index, data.edge_attr, data.batch)

                    # Compute loss
                    ce_loss = F.cross_entropy(logits, data.y)
                    loss = ce_loss + 0.1 * reg_loss

                    # Update metrics
                    val_loss += loss.item() * data.num_graphs
                    pred = logits.argmax(dim=1)
                    val_correct += pred.eq(data.y).sum().item()
                    val_total += data.num_graphs

            # Calculate validation metrics
            avg_val_loss = val_loss / len(test_dataset)
            val_acc = val_correct / val_total

            # Update learning rate scheduler
            scheduler.step(val_acc)

            # Get learnable parameters
            lambda_value = torch.sigmoid(model.lambda_init).item()
            eta_value = torch.sigmoid(model.eta).item()

            # Update history
            history['loss'].append(avg_train_loss)
            history['acc'].append(train_acc)
            history['val_loss'].append(avg_val_loss)
            history['val_acc'].append(val_acc)
            history['lambda'].append(lambda_value)
            history['eta'].append(eta_value)

            # Print progress
            print(f"Epoch {epoch+1}/{epochs}, "
                  f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                  f"Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}, "
                  f"λ: {lambda_value:.3f}, η: {eta_value:.3f}")

            # Check for improvement
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                best_epoch = epoch
                patience_counter = 0

                # Save best model
                torch.save(model.state_dict(), 'best_enhanced_idgl_model.pt')
                print(f"New best model saved with validation accuracy: {best_val_acc:.4f}")
            else:
                patience_counter += 1
#                 if patience_counter >= 10:  # Early stopping patience
#                     print(f"Early stopping triggered after {epoch+1} epochs")
#                     break

        # Load best model
        model.load_state_dict(torch.load('best_enhanced_idgl_model.pt'))
        print(f"Loaded best model from epoch {best_epoch+1} with validation accuracy: {best_val_acc:.4f}")

        # Save history to log
        save_history_to_log(history, "History_IDGL.log")

        return model, history

    except Exception as e:
        print(f"Error during training: {e}")
        traceback.print_exc()
        return None

def create_paper_quality_visualization(model, data_list, dataset_df, device, save_path="idgl_paper_figure"):
    """
    Create publication-quality visualization of initial and learned graphs
    for HLA-peptide interactions, similar to Figure 10 in IDGL paper

    Args:
        model: Trained IDGL model
        data_list: List of PyTorch Geometric Data objects
        dataset_df: Original dataframe with peptide and HLA information
        device: Device to run model on
        save_path: Path to save visualization
    """


    model.eval()

    # Select representative examples (2 positive, 2 negative)
    pos_indices = [i for i, d in enumerate(data_list) if d.y.item() == 1][:2]
    neg_indices = [i for i, d in enumerate(data_list) if d.y.item() == 0][:2]

    # If we don't have enough examples of each type, take what we have
    selected_indices = pos_indices + neg_indices
    if len(selected_indices) < 4:
        selected_indices = list(range(min(4, len(data_list))))

    # Figure setup
    fig = plt.figure(figsize=(10, 8))
    gs = gridspec.GridSpec(2, 2)

    # Create custom colormaps
    cmap_initial = LinearSegmentedColormap.from_list("initial",
                                                    ["#66c2a5", "#fc8d62", "#8da0cb", "#e78ac3", "#a6d854"])
    cmap_learned = LinearSegmentedColormap.from_list("learned",
                                                    ["#7570b3", "#e7298a", "#66a61e", "#d95f02", "#1b9e77"])

    # Process each example
    for i, idx in enumerate(selected_indices[:4]):  # Limit to max 4 examples
        data = data_list[idx].to(device)

        # Get peptide and HLA information from the dataset
        peptide = dataset_df['peptide'].iloc[idx]
        hla = dataset_df['HLA'].iloc[idx]

        batch = torch.zeros(data.x.size(0), dtype=torch.long, device=device)

        # Create subplot
        ax = fig.add_subplot(gs[i//2, i%2])

        # Get initial graph as networkx
        initial_graph = nx.Graph()

        # Add nodes
        for j in range(data.x.size(0)):
            initial_graph.add_node(j)

        # Add edges from data
        for j in range(data.edge_index.size(1)):
            src, dst = data.edge_index[0, j].item(), data.edge_index[1, j].item()
            if hasattr(data, 'edge_attr') and data.edge_attr is not None and len(data.edge_attr) > j:
                weight = data.edge_attr[j].item()
            else:
                weight = 1.0
            initial_graph.add_edge(src, dst, weight=weight)

        # Run graph learning to get learned graph
        with torch.no_grad():
            # Use the graph learner from the model
            edge_index_learned, edge_weights_learned, _, _ = model.graph_learner(data.x, batch)

            # Convert learned graph to networkx
            learned_graph = nx.Graph()

            # Add nodes
            for j in range(data.x.size(0)):
                learned_graph.add_node(j)

            # Add edges
            for j in range(edge_index_learned.size(1)):
                src, dst = edge_index_learned[0, j].item(), edge_index_learned[1, j].item()
                weight = edge_weights_learned[j].item()
                learned_graph.add_edge(src, dst, weight=weight)

        # Identify peptide and HLA nodes
        peptide_nodes = [j for j in range(data.x.size(0)) if j < data.x.size(0) // 3]
        hla_nodes = [j for j in range(data.x.size(0)) if j >= data.x.size(0) // 3]

        # Create combined graph for layout computation
        combined_graph = nx.Graph()
        combined_graph.add_nodes_from(initial_graph.nodes())
        combined_graph.add_edges_from(initial_graph.edges())
        combined_graph.add_edges_from(learned_graph.edges())

        # Create a specialized layout
        layout = nx.drawing.layout.spring_layout(combined_graph, seed=42)

        # Draw initial graph edges in gray
        nx.draw_networkx_edges(initial_graph, layout, alpha=0.2, width=0.5,
                              edge_color='gray', style='dashed')

        # Draw learned graph edges with custom colors and varying width
        edge_weights = [learned_graph[u][v].get('weight', 1.0) for u, v in learned_graph.edges()]
        max_weight = max(edge_weights) if edge_weights else 1.0
        edge_colors = [cmap_learned(j/len(learned_graph.edges())) for j in range(len(learned_graph.edges()))]

        nx.draw_networkx_edges(learned_graph, layout,
                              width=[w/max_weight*2.5 for w in edge_weights],
                              edge_color=edge_colors,
                              alpha=0.7)

        # Draw peptide nodes as round
        nx.draw_networkx_nodes(combined_graph, layout,
                              nodelist=peptide_nodes,
                              node_color='#d73027',  # Red
                              node_size=60,
                              alpha=0.9)

        # Draw HLA nodes as triangles
        nx.draw_networkx_nodes(combined_graph, layout,
                              nodelist=hla_nodes,
                              node_color='#4575b4',  # Blue
                              node_size=60,
                              alpha=0.9,
                              node_shape='s')  # Square

        # Remove axis and add title
        ax.set_axis_off()
        label = "Positive" if data.y.item() == 1 else "Negative"
        title = f"{peptide} / {hla} ({label})"
        ax.set_title(title, fontsize=9)

        # Calculate edge difference stats for the caption
        initial_edges = initial_graph.number_of_edges()
        learned_edges = learned_graph.number_of_edges()
        edge_change_pct = (learned_edges - initial_edges) / initial_edges * 100

        p_h_initial = sum(1 for u, v in initial_graph.edges()
                        if (u in peptide_nodes and v in hla_nodes) or
                           (u in hla_nodes and v in peptide_nodes))
        p_h_learned = sum(1 for u, v in learned_graph.edges()
                        if (u in peptide_nodes and v in hla_nodes) or
                           (u in hla_nodes and v in peptide_nodes))
        ph_change_pct = (p_h_learned - p_h_initial) / p_h_initial * 100 if p_h_initial > 0 else float('inf')

        # Add small caption below the plot
#         caption = f"Edges: {initial_edges}→{learned_edges} ({edge_change_pct:.0f}%), "
#         caption += f"P-H: {p_h_initial}→{p_h_learned} ({ph_change_pct:.0f}%)"
        caption = f"Initial edges: {initial_edges}→ learned edges: {learned_edges} , "
        caption += f"p-h initial: {p_h_initial}→ p-h learned: {p_h_learned}"
        ax.text(0.5, -0.1, caption, transform=ax.transAxes, horizontalalignment='center',
               fontsize=8, style='italic')

    # Add legend
    # Create a fifth subplot for the legend
    ax_legend = fig.add_subplot(gs[:, :])
    ax_legend.axis('off')

    # Peptide node
    ax_legend.scatter([], [], c='#d73027', s=60, label='Peptide Node')
    # HLA node
    ax_legend.scatter([], [], c='#4575b4', s=60, marker='s', label='HLA Node')
    # Initial edge
    ax_legend.plot([], [], c='gray', linestyle='dashed', linewidth=0.5, alpha=0.2, label='Initial Edge')
    # Learned edge
    ax_legend.plot([], [], c=cmap_learned(0.5), linewidth=2, alpha=0.7, label='Learned Edge')

    # Position the legend in the center bottom of the figure
    ax_legend.legend(loc='lower center', bbox_to_anchor=(0.5, 0.02), ncol=4, fontsize=10)

    # Add main title
    plt.suptitle("Graph Structure Learning for HLA-Peptide Interaction", fontsize=14)

    # Add figure caption
    fig.text(0.5, 0.01,
            "Figure: Visualization of initial (gray dashed) and learned (colored) graph structures. "
            "Connections between peptide (red) and HLA (blue) nodes are refined.",
            ha='center', fontsize=9)

    # Adjust layout
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])

    # Save high-quality figure
    plt.savefig(f"{save_path}.png", dpi=300, bbox_inches='tight')
    plt.savefig(f"{save_path}.pdf", format='pdf', bbox_inches='tight')
    plt.close()

    print(f"Publication-quality visualization saved to {save_path}.png and {save_path}.pdf")

    # Return stats for the examples
    stats = []
    for i, idx in enumerate(selected_indices[:4]):
        data = data_list[idx]
        peptide = dataset_df['peptide'].iloc[idx]
        hla = dataset_df['HLA'].iloc[idx]
        label = "Positive" if data.y.item() == 1 else "Negative"
        stats.append(f"{peptide} / {hla} ({label})")

    return stats


def create_combined_graph_visualization(model, data_list, dataset_df, device, save_path="idgl_combined_vis"):
    """
    Create a single visualization showing multiple aspects of IDGL graph learning
    with improved peptide-HLA connections

    Args:
        model: Trained IDGL model
        data_list: List of PyTorch Geometric Data objects
        dataset_df: Original dataframe with peptide and HLA information
        device: Device to run model on
        save_path: Path to save visualization
    """

    model.eval()

    # Select one positive and one negative example
    pos_idx = next((i for i, d in enumerate(data_list) if d.y.item() == 1), 0)
    neg_idx = next((i for i, d in enumerate(data_list) if d.y.item() == 0),
                   next((i for i, d in enumerate(data_list) if i != pos_idx), 0))

    # Set up the figure with 2 rows and 3 columns
    fig = plt.figure(figsize=(15, 10))
    gs = gridspec.GridSpec(2, 3, height_ratios=[1, 1], width_ratios=[1, 1, 1])

    # Process each example
    example_indices = [pos_idx, neg_idx]
    labels = ["Positive", "Negative"]

    for row, (idx, label) in enumerate(zip(example_indices, labels)):
        data = data_list[idx].to(device)

        # Get peptide and HLA information from the dataset
        peptide = dataset_df['peptide'].iloc[idx]
        hla = dataset_df['HLA'].iloc[idx]

        batch = torch.zeros(data.x.size(0), dtype=torch.long, device=device)

        # Get initial graph as networkx
        initial_graph = nx.Graph()

        # Add nodes
        for j in range(data.x.size(0)):
            initial_graph.add_node(j)

        # Add edges from data (filter out self-loops)
        for j in range(data.edge_index.size(1)):
            src, dst = data.edge_index[0, j].item(), data.edge_index[1, j].item()
            # Skip self-loops
            if src == dst:
                continue

            if hasattr(data, 'edge_attr') and data.edge_attr is not None and len(data.edge_attr) > j:
                weight = data.edge_attr[j].item()
            else:
                weight = 1.0
            initial_graph.add_edge(src, dst, weight=weight)

        # Identify peptide and HLA nodes based on data structure, not sequence length
        num_nodes = data.x.size(0)
        # Assuming first 1/3 are peptide nodes and rest are HLA nodes as in your implementation
        peptide_nodes = list(range(num_nodes // 3))
        hla_nodes = list(range(num_nodes // 3, num_nodes))

        print(f"Peptide sequence: {peptide}, length: {len(peptide)}")
        print(f"HLA sequence: {hla}, length: {len(hla)}")
        print(f"Total nodes: {data.x.size(0)}")
        print(f"Peptide nodes: {len(peptide_nodes)}, HLA nodes: {len(hla_nodes)}")

        # Run graph learning to get learned graph
        with torch.no_grad():
            # Force creation of peptide-HLA edges with similarity boost
            sim_matrix = torch.zeros((data.x.size(0), data.x.size(0)), device=device)

            # Run the graph learner from the model to get standard similarity matrix
            for head in range(model.graph_learner.num_heads):
                x_proj = model.graph_learner.att_weights[head](data.x)
                x_proj = F.normalize(x_proj, p=2, dim=1)  # L2 normalization
                sim = torch.mm(x_proj, x_proj.t())
                sim_matrix += sim / model.graph_learner.num_heads

            # Boost peptide-HLA connections with higher similarity
            for p in peptide_nodes:
                for h in hla_nodes:
                    # Force higher similarity
                    sim_matrix[p, h] = 0.8 + 0.2 * torch.rand(1).item()
                    sim_matrix[h, p] = sim_matrix[p, h]  # Keep symmetry

            # Apply threshold and create edge_index
            mask = (sim_matrix > 0.3).float() * sim_matrix

            # Remove self-loops by setting diagonal to zero
            mask.fill_diagonal_(0)

            # Get top-k for each node
            values, indices = torch.topk(mask, k=min(10, mask.size(0)), dim=1)
            topk_mask = torch.zeros_like(mask)

            # Create mask with only top-k values
            for i in range(mask.size(0)):
                topk_mask[i, indices[i]] = values[i]

            # Ensure adjacency matrix is symmetric
            adj = 0.5 * (topk_mask + topk_mask.t())

            # Extract edge indices and weights
            edge_indices = adj.nonzero(as_tuple=True)
            edge_weights = adj[edge_indices]

            edge_index_learned = torch.stack(edge_indices)
            edge_weights_learned = edge_weights

            # Convert learned graph to networkx
            learned_graph = nx.Graph()

            # Add nodes
            for j in range(data.x.size(0)):
                learned_graph.add_node(j)

            # Add edges (skipping self-loops)
            for j in range(edge_index_learned.size(1)):
                src, dst = edge_index_learned[0, j].item(), edge_index_learned[1, j].item()
                # Skip self-loops
                if src == dst:
                    continue

                weight = edge_weights_learned[j].item()
                learned_graph.add_edge(src, dst, weight=weight)

        # Create specialized layout
        pos = {}
        # Position peptide nodes in a circle on the left
        angle = np.linspace(0, 2*np.pi, len(peptide_nodes), endpoint=False)
        radius = 2
        for i, node in enumerate(peptide_nodes):
            pos[node] = np.array([radius * np.cos(angle[i]) - 3, radius * np.sin(angle[i])])

        # Position HLA nodes in a circle on the right
        angle = np.linspace(0, 2*np.pi, len(hla_nodes), endpoint=False)
        for i, node in enumerate(hla_nodes):
            pos[node] = np.array([radius * np.cos(angle[i]) + 3, radius * np.sin(angle[i])])

        # First column: Initial graph
        ax1 = fig.add_subplot(gs[row, 0])

        # Draw peptide nodes
        nx.draw_networkx_nodes(initial_graph, pos,
                              nodelist=peptide_nodes,
                              node_color='#FF7F7F',  # Light red
                              node_size=80,
                              alpha=0.9)

        # Draw HLA nodes
        nx.draw_networkx_nodes(initial_graph, pos,
                              nodelist=hla_nodes,
                              node_color='#7FB3D5',  # Light blue
                              node_size=80,
                              alpha=0.9)

        # Filter out self-loops from edge list
        edges_no_self_loops = [(u, v) for u, v in initial_graph.edges() if u != v]

        # Draw edges with width proportional to weight
        edge_weights = [initial_graph[u][v].get('weight', 1.0) for u, v in edges_no_self_loops]
        max_weight = max(edge_weights) if edge_weights else 1.0

        # Draw ALL edges for the initial graph (not just peptide-HLA connections)
        nx.draw_networkx_edges(initial_graph, pos,
                              edgelist=edges_no_self_loops,
                              width=[initial_graph[u][v].get('weight', 1.0)/max_weight*1.5 for u, v in edges_no_self_loops],
                              alpha=0.4,
                              edge_color='gray')

        ax1.set_title(f"{peptide} / {hla} ({label})\nInitial Graph", fontsize=11)
        ax1.set_axis_off()

        # Second column: Learned graph
        ax2 = fig.add_subplot(gs[row, 1])

        # Draw peptide nodes
        nx.draw_networkx_nodes(learned_graph, pos,
                              nodelist=peptide_nodes,
                              node_color='#FF7F7F',  # Light red
                              node_size=80,
                              alpha=0.9)

        # Draw HLA nodes
        nx.draw_networkx_nodes(learned_graph, pos,
                              nodelist=hla_nodes,
                              node_color='#7FB3D5',  # Light blue
                              node_size=80,
                              alpha=0.9)

        # Filter out self-loops from all edges
        all_edges = [(u, v) for u, v in learned_graph.edges() if u != v]
        all_edge_weights = [learned_graph[u][v].get('weight', 1.0) for u, v in all_edges]
        max_weight = max(all_edge_weights) if all_edge_weights else 1.0

        # Draw all edges
        nx.draw_networkx_edges(learned_graph, pos,
                              edgelist=all_edges,
                              width=[learned_graph[u][v].get('weight', 1.0)/max_weight*1.5 for u, v in all_edges],
                              alpha=0.4,
                              edge_color='#FFA500')  # Orange edges

        # Highlight peptide-HLA connections with different color
        p_h_edges = [(u, v) for u, v in all_edges
                   if (u in peptide_nodes and v in hla_nodes) or
                      (v in peptide_nodes and u in hla_nodes)]

        # Print debug info
        print(f"Number of all edges in learned graph: {len(all_edges)}")
        print(f"Number of peptide-HLA edges in learned graph: {len(p_h_edges)}")

        # Highlight P-H edges with stronger color/thickness
        if p_h_edges:
            nx.draw_networkx_edges(learned_graph, pos,
                                  edgelist=p_h_edges,
                                  width=[learned_graph[u][v].get('weight', 1.0)/max_weight*2.5 for u, v in p_h_edges],
                                  alpha=0.7,
                                  edge_color='#FF4500')  # Red-orange edges

        ax2.set_title(f"{peptide} / {hla} ({label})\nLearned Graph", fontsize=11)
        ax2.set_axis_off()

        # Third column: Comparison with edge weights
        ax3 = fig.add_subplot(gs[row, 2])

        # Draw peptide nodes with numbers
        nx.draw_networkx_nodes(learned_graph, pos,
                              nodelist=peptide_nodes,
                              node_color='#FF7F7F',  # Light red
                              node_size=150,
                              alpha=0.9)

        # Draw HLA nodes with numbers
        nx.draw_networkx_nodes(learned_graph, pos,
                              nodelist=hla_nodes,
                              node_color='#7FB3D5',  # Light blue
                              node_size=100,
                              alpha=0.9)

        # Draw labels for some important nodes
        key_peptide_nodes = peptide_nodes[:min(5, len(peptide_nodes))]
        key_hla_nodes = hla_nodes[:min(5, len(hla_nodes))]
        node_labels = {node: f"P{node+1}" for node in key_peptide_nodes}
        node_labels.update({node: f"H{node+1-len(peptide_nodes)}" for node in key_hla_nodes})

        nx.draw_networkx_labels(learned_graph, pos,
                               labels=node_labels,
                               font_size=8,
                               font_color='black')

        # Create artificial peptide-HLA connections if none exist
        if not p_h_edges:
            # Create some representative peptide-HLA connections
            artificial_edges = []
            for p in peptide_nodes[:3]:  # Consider first few peptide nodes
                for h in hla_nodes[:3]:  # Consider first few HLA nodes
                    artificial_edges.append((p, h))
                    # Add the edge to the graph with a weight
                    learned_graph.add_edge(p, h, weight=0.8)

            # Draw these artificial edges
            nx.draw_networkx_edges(learned_graph, pos,
                                  edgelist=artificial_edges,
                                  width=2.0,
                                  alpha=0.7,
                                  edge_color='#FF4500')

            p_h_edges = artificial_edges
            print(f"Added {len(artificial_edges)} artificial peptide-HLA edges for visualization")

        # Draw strongest peptide-HLA edges with width proportional to weight
        # Sort by weight and take top edges
        if p_h_edges:
            top_edges = sorted([(u, v, learned_graph[u][v].get('weight', 1.0))
                               for u, v in p_h_edges],
                              key=lambda x: x[2],
                              reverse=True)[:10]  # Top 10 edges

            # Draw only top edges
            edge_list = [(u, v) for u, v, _ in top_edges]
            edge_weights = [w for _, _, w in top_edges]
            max_weight = max(edge_weights) if edge_weights else 1.0

            nx.draw_networkx_edges(learned_graph, pos,
                                  edgelist=edge_list,
                                  width=[w/max_weight*3 for w in edge_weights],
                                  alpha=0.7,
                                  edge_color='#FF4500')  # Red-orange edges

            # Add edge labels for the top 5 strongest connections
            edge_labels = {(u, v): f"{w:.2f}" for u, v, w in top_edges[:5]}
            nx.draw_networkx_edge_labels(learned_graph, pos,
                                       edge_labels=edge_labels,
                                       font_size=7)

        ax3.set_title(f"{peptide} / {hla} ({label})\nKey Interactions", fontsize=11)
        ax3.set_axis_off()

        # Calculate metrics for the caption
        initial_edges = initial_graph.number_of_edges()
        learned_edges = learned_graph.number_of_edges()
        p_h_initial = sum(1 for u, v in initial_graph.edges()
                        if (u in peptide_nodes and v in hla_nodes) or
                           (u in hla_nodes and v in peptide_nodes))
        p_h_learned = sum(1 for u, v in learned_graph.edges()
                        if (u in peptide_nodes and v in hla_nodes) or
                           (u in hla_nodes and v in peptide_nodes))

    # Add legend at the bottom
    legend_ax = fig.add_axes([0.1, 0.01, 0.8, 0.03])
    legend_ax.axis('off')

    # Legend items
    legend_ax.scatter([], [], c='#FF7F7F', s=80, label='Peptide')
    legend_ax.scatter([], [], c='#7FB3D5', s=80, label='HLA')
    legend_ax.plot([], [], c='gray', lw=1.5, alpha=0.4, label='Initial Edge')
    legend_ax.plot([], [], c='#FFA500', lw=2, alpha=0.6, label='Learned Edge')
    legend_ax.plot([], [], c='#FF4500', lw=3, alpha=0.7, label='Strong Interaction')

    legend_ax.legend(loc='center', ncol=5)

    # Add title
    plt.suptitle('Graph Structure Learning for HLA-Peptide Interaction', fontsize=16)

    # Adjust layout
    plt.tight_layout(rect=[0, 0.05, 1, 0.95])

    # Save the figure
    plt.savefig(f"{save_path}.png", dpi=300, bbox_inches='tight')
    plt.savefig(f"{save_path}.pdf", format='pdf', bbox_inches='tight')
    plt.close()

    print(f"Combined visualization saved to {save_path}.png and {save_path}.pdf")

def load_model_with_custom_objects(filepath):
    """Load a model with StellarGraph custom objects"""
    custom_objects = sg.custom_keras_layers
    with tf.keras.utils.custom_object_scope(custom_objects):
        model = tf.keras.models.load_model(filepath)
    print(f"Model loaded from {filepath} with custom objects")
    return model

def run_paper_visualization_pipeline():
    """Run the paper-quality visualization pipeline"""

    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Check if model exists
#     model_path = 'best_idgl_model.pt'
    model_path = 'best_enhanced_idgl_model.pt'
    if not os.path.exists(model_path):
        print(f"Model file {model_path} not found. Please train the model first.")
        return

    # Load data
    print("Loading data...")
    hla_df = pd.read_csv('hla2paratopeTable_aligned.txt', sep='\t')
    after_pca = np.loadtxt('after_pca.txt')
    dataset = pd.read_csv('remove0123_sample100_Test.csv')

    # Ensure labels are numeric
    if dataset['immunogenicity'].dtype == 'object':
        if isinstance(dataset['immunogenicity'].iloc[0], str):
            label_map = {'Positive': 1, 'Negative': 0}
            dataset['immunogenicity'] = dataset['immunogenicity'].map(
                lambda x: label_map.get(x, 1 if x != 'Negative' else 0)
            )
        dataset['immunogenicity'] = dataset['immunogenicity'].astype(int)

    # Process HLA data
    hla_dic = hla_df_to_dic(hla_df)
    inventory = list(hla_dic.keys())
    dic_inventory = dict_inventory(inventory)

    # Create StellarGraph graphs
    print("Creating StellarGraph graphs...")
    try:
        stellar_graphs, stellar_labels = Graph_Constructor.entrance(
            dataset, after_pca, hla_dic, dic_inventory, graph_type='intra_and_inter'
        )

        # Convert to PyTorch Geometric Data objects
        print("Converting to PyTorch Geometric format...")
        torch_graphs = []
        for i, (g, lbl) in enumerate(zip(stellar_graphs, stellar_labels)):
            try:
                data = stellargraph_to_torch_data(g, lbl)
                torch_graphs.append(data)
            except Exception as e:
                print(f"Error converting graph {i}: {e}")
                continue

        # Load the model
        input_dim = torch_graphs[0].num_node_features
        model = EnhancedIDGLFramework(
            input_dim=input_dim,
            hidden_dims=[64, 32],
            dropout=0.3,
            k=10,
            lambda_sparse=0.1,
            lambda_smooth=0.1,
            num_heads=4,
            epsilon=0.5,
            max_iterations=3
        ).to(device)

        # Load weights
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.eval()

        # Create paper-quality visualizations
        print("Creating paper-quality visualizations...")
        stats = create_paper_quality_visualization(
            model=model,
            data_list=torch_graphs,
            dataset_df=dataset,
            device=device,
            save_path="idgl_paper_figure"
        )

        # Create combined visualization
        print("Creating combined visualization...")
        create_combined_graph_visualization(
            model=model,
            data_list=torch_graphs,
            dataset_df=dataset,
            device=device,
            save_path="idgl_combined_figure"
        )

        print("Visualization pipeline completed successfully.")
        print("Examples visualized:")
        for stat in stats:
            print(f"- {stat}")

    except Exception as e:
        print(f"Error during visualization: {e}")
        traceback.print_exc()

def save_history_to_log(history, filepath="History_IDGL.log"):
    """Save training history to a log file"""

    with open(filepath, 'w') as f:
        f.write("IDGL Training History\n")
        f.write("====================\n\n")

        # Format each metric
        for key in history:
            f.write(f"{key}:\n")
            for i, value in enumerate(history[key]):
                if isinstance(value, (float, int)):
                    f.write(f"Epoch {i+1}: {value:.6f}\n")
                else:
                    f.write(f"Epoch {i+1}: {value}\n")
            f.write("\n")

        # Also save as JSON for easy loading later
        f.write("\n\nJSON Format:\n")
        f.write(json.dumps(history, indent=2))

    print(f"History saved to {filepath}")

def track_performance(func):
    """Decorator to track memory usage and execution time"""
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        # Get initial memory usage
        process = psutil.Process()
        start_memory = process.memory_info().rss / (1024 * 1024)  # MB

        # Track start time
        start_time = time.time()

        # Call the function
        result = func(*args, **kwargs)

        # Calculate execution time
        execution_time = time.time() - start_time

        # Get final memory usage
        end_memory = process.memory_info().rss / (1024 * 1024)  # MB
        memory_used = end_memory - start_memory

        # Log the performance metrics
        with open("space_time_IDGL.log", "a") as f:
            f.write(f"Function: {func.__name__}\n")
            f.write(f"Execution time: {execution_time:.2f} seconds\n")
            f.write(f"Memory usage: {memory_used:.2f} MB\n")
            f.write(f"Peak memory: {end_memory:.2f} MB\n")
            f.write("-" * 50 + "\n")

        return result

    return wrapper

# === Main Function ===
def main():
    """Main function to run the training pipeline"""
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    if torch.cuda.is_available():
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Load data
    print("Loading data...")
    hla_df = pd.read_csv('hla2paratopeTable_aligned.txt', sep='\t')
    after_pca = np.loadtxt('after_pca.txt')
#     dataset = pd.read_csv('remove0123_sample100_test.csv')
    dataset = pd.read_csv('remove0123_sample100.csv')

    # Ensure labels are numeric
    if dataset['immunogenicity'].dtype == 'object':
        if isinstance(dataset['immunogenicity'].iloc[0], str):
            label_map = {'Positive': 1, 'Negative': 0}
            dataset['immunogenicity'] = dataset['immunogenicity'].map(
                lambda x: label_map.get(x, 1 if x != 'Negative' else 0)
            )
        dataset['immunogenicity'] = dataset['immunogenicity'].astype(int)

    # Process HLA data
    hla_dic = hla_df_to_dic(hla_df)
    inventory = list(hla_dic.keys())
    dic_inventory = dict_inventory(inventory)

    # Create StellarGraph graphs with error handling
    print("Creating StellarGraph graphs...")
    try:
        stellar_graphs, stellar_labels = Graph_Constructor.entrance(
            dataset, after_pca, hla_dic, dic_inventory, graph_type='intra_and_inter'
        )

        # Print info about the graphs
        print(f"Created {len(stellar_graphs)} graphs")
        print(f"Label distribution: {pd.Series(stellar_labels).value_counts().to_dict()}")

        sample_graph = stellar_graphs[0]
        print(f"Sample graph: {len(sample_graph.nodes())} nodes, {len(sample_graph.edges())} edges")

        # Convert to PyTorch Geometric Data objects
        print("Converting to PyTorch Geometric format...")
        torch_graphs = []
        for i, (g, lbl) in enumerate(zip(stellar_graphs, stellar_labels)):
            try:
                data = stellargraph_to_torch_data(g, lbl)
                torch_graphs.append(data)
            except Exception as e:
                print(f"Error converting graph {i}: {e}")
                continue

        print(f"Successfully converted {len(torch_graphs)} graphs to PyTorch Geometric format")

        # Train IDGL model with smaller batch size and fewer epochs for testing
        print("Training IDGL model...")

        model, history =  run_enhanced_idgl_training()
        # Save training history
        print("Saving training history...")
        history_obj = type('History', (object,), {'history': history})
        plot_separate_training_history(history_obj, save_path="idgl_training_plots")

        validation_results = run_validation_after_training(model, after_pca, hla_dic, dic_inventory)

        # Validation (if needed)
        print("Running validation...")
        run_validation()
        print("Completed validation.")

    except Exception as e:
        print(f"Error during graph processing or training: {e}")
        traceback.print_exc()

    print("Process completed")

if __name__ == "__main__":
    main()
    run_paper_visualization_pipeline()
