<font size="8"> HLA-A:02*01 LOO Data Analysis </font>

Load the necessary libraries, also make sure they are in the conda environment 

In [2]:
import torch
import plotly.graph_objects as go
import re
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import (
    roc_auc_score, roc_curve, accuracy_score,
    precision_score, recall_score, f1_score, matthews_corrcoef, auc
)
import os

According to the paper "15.	Hogan, Jack, and Niall M. Adams. "On Averaging ROC Curves." Transactions on Machine Learning Research (2023)." I can use pooled averaging

In [5]:
# Create a list containing the place where the information is stored 

LOO_Linear_3072 = [r"Y:\models_gijs\MLP_using_LOO_3072_Linear_no_dropout_batch_16_epochs_150_cluster_1.pth",
                    r"Y:\models_gijs\MLP_using_LOO_3072_Linear_no_dropout_batch_16_epochs_150_cluster_2.pth",
                    r"Y:\models_gijs\MLP_using_LOO_3072_Linear_no_dropout_batch_16_epochs_150_cluster_3.pth",
                    r"Y:\models_gijs\MLP_using_LOO_3072_Linear_no_dropout_batch_16_epochs_150_cluster_4.pth",
                    r"Y:\models_gijs\MLP_using_LOO_3072_Linear_no_dropout_batch_16_epochs_150_cluster_5.pth"]

LOO_Linear_NoBias_512 = [r"Y:\models_gijs\LOO_experiments\512_MLP_Linear_NoBias\MLP_using_LOO_512_Reduce_batch_16_epochs_150_cluster_1.pth", 
                            r"Y:\models_gijs\LOO_experiments\512_MLP_Linear_NoBias\MLP_using_LOO_512_Reduce_batch_16_epochs_150_cluster_2.pth",
                            r"Y:\models_gijs\LOO_experiments\512_MLP_Linear_NoBias\MLP_using_LOO_512_Reduce_batch_16_epochs_150_cluster_3.pth",
                            r"Y:\models_gijs\LOO_experiments\512_MLP_Linear_NoBias\MLP_using_LOO_512_Reduce_batch_16_epochs_150_cluster_4.pth",
                            r"Y:\models_gijs\LOO_experiments\512_MLP_Linear_NoBias\MLP_using_LOO_512_Reduce_batch_16_epochs_150_cluster_5.pth"]

# Create appropiate names for the different clusters
LOO_name = ["Cluster_1",
            "Cluster_2",
            "Cluster_3",
            "Cluster_4",
            "Cluster_5"]

In [12]:
# Initialize an empty figure
# Load the model
fig = go.Figure()

def reshape_and_stack(data):
    return torch.stack(data).detach().reshape(epochs, -1)

all_test_predict = []
all_test_target = []

# Loop through your models and add ROC curves to the figure
i = 1
for path, model_name in zip(LOO_Linear_NoBias_512, LOO_name):
    try:
        # Load the model
        load_model = torch.load(path, map_location=torch.device("cpu"))

        metrics = load_model["metrics"]
        epochs = load_model["epoch"]
        
    except FileNotFoundError:
        print(f"File not found: {path}")
    except Exception as e:
        print(f"An error occurred while loading {path}: {e}")
    
    print(f"Cluster_{str(i)}")
    i += 1
    
    # Reshape and stack training and validation metrics
    train_predict, train_target, train_loss = [reshape_and_stack(metrics["train"][key]) for key in ["predict", "targets", "losses"]]
    val_predict, val_target, val_loss = [reshape_and_stack(metrics["val"][key]) for key in ["predict", "targets", "losses"]]

    # Reshape and stack test predictions and targets
    test_predict = torch.stack(metrics["test"]["predict"]).detach()
    test_predict = test_predict.view(-1, test_predict.size(-1))
    test_target = torch.stack(metrics["test"]["targets"]).detach()
    test_target = test_target.view(-1, test_target.size(-1))

    # Append test predictions and targets to lists
    all_test_predict.append(test_predict.numpy())
    all_test_target.append(test_target.numpy())

    target = test_target.numpy()
    predict = test_predict.numpy()
        
    # Calculate ROC curve (fpr, tpr) and AUC
    fpr, tpr, _ = roc_curve(target, predict)  # y_true are the true labels
    roc_auc = auc(fpr, tpr)
    
    # Add a Scatter plot for the current model to the figure
    fig.add_trace(go.Scatter(x=fpr, y=tpr,
                             mode='lines',
                             name=f'{model_name} (AUC={roc_auc:.2f})'))

# Concatenate test predictions and targets from all models
all_test_predict = np.concatenate(all_test_predict)
all_test_target = np.concatenate(all_test_target)

# Calculate pooled ROC curve and AUC
fpr_pooled, tpr_pooled, _ = roc_curve(all_test_target, all_test_predict)
roc_auc_pooled = auc(fpr_pooled, tpr_pooled)

# Add pooled ROC curve as a dotted line to the figure
fig.add_trace(go.Scatter(x=fpr_pooled, y=tpr_pooled,
                         mode='lines',
                         line=dict(color='black', dash='dash'),  # Set line style to dotted
                         name=f'Pooled (AUC={roc_auc_pooled:.2f})'))

# Add a diagonal line (y=x) for reference in red
fig.add_trace(go.Scatter(x=[0, 1], y=[0, 1],
                         mode='lines',
                         line=dict(color='red')))

# Customize the layout of the figure
fig.update_layout(title='ROC Curve LOO cluster, Linear',
                  xaxis=dict(title='False Positive Rate'),
                  yaxis=dict(title='True Positive Rate'),
                  legend=dict(x=0.7, y=0.2),
                  width=600,
                  height=600,
                  font_size=13,
                  template="plotly",
                  # margin=dict(l=20, r=20, t=20, b=20),  # Add margin for better appearance
                  showlegend=True,
                  legend_bordercolor='black',
                  # bordercolor='black'  # Ensure the legend is displayed
)

# Show the plot
fig.show()


Cluster_1
Cluster_2
Cluster_3
Cluster_4
Cluster_5
