# Evaluate model on Binding Affinity test set

Code to compute evaluate models on test set. It generates the following files
for each evaluated model:
- test_results.tsv: File containing labels and predictions of the model
- test_metrics.tsv: File with global metrics:
    - Accuracy
    - Precision
    - Recall
    - Average prob of true negative labels (non-binders)
    - Average prob of true positive labels (binders)
    - AUROC

In [None]:
import os
import sys
import glob
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, precision_recall_curve, auc, precision_score, recall_score, confusion_matrix

from nimbus.predictors import pHLABindingPredictor, pHLAPseudoseqBindingPredictor
from nimbus.data_processing import pHLADataset, pHLAPseudoseqDataset
from nimbus.utils import LoggerFactory

logger = LoggerFactory.get_logger('explore_pHLApredictor_nb', 'INFO')

In [None]:
DATA_DIR = '../data'
RAW_DATA = os.path.join(DATA_DIR, 'raw')
PROCESSED_DATA = os.path.join(DATA_DIR, 'processed')
HLA_FP_DIR = os.path.join(PROCESSED_DATA, 'hla_fingerprints')
hla_fp_36_data_file = os.path.join(HLA_FP_DIR, 'hla_index_netMHCpan_pseudoseq_res_representation.csv')
hla_fp_36_file = os.path.join(HLA_FP_DIR, 'hla_fingerprint_netMHCpan_pseudoseq_res_representation.npy')
hla_fp_400_data_file = os.path.join(HLA_FP_DIR, 'hla_af_patch_info_patch_r18_pt400.csv')
hla_fp_400_file = os.path.join(HLA_FP_DIR, 'hla_af_patch_emb_patch_r18_pt400.npy')
hla_pseudoseq_file = os.path.join(RAW_DATA, 'pHLA_binding', 'NetMHCpan_train', 'MHC_pseudo_fixed.dat')
test_netmhcpan_data_file = os.path.join(PROCESSED_DATA, 'pHLA_binding', 'NetMHCpan_dataset', 'test_set_peptides_data_MaxLenPep15_hla_ABC.csv.gz')

BASELINES_DIR = os.path.join('../checkpoints', 'baselines')
NETMHCPAN_BASELINE_DIR = os.path.join(BASELINES_DIR, 'netmhcpan41')
MIXMHCPRED22_BASELINE_DIR = os.path.join(BASELINES_DIR, 'mixmhcpred22')
MIXMHCPRED3_BASELINE_DIR = os.path.join(BASELINES_DIR, 'mixmhcpred3')
merged_netmhcpan_cd8_benchmark_file = os.path.join(
    NETMHCPAN_BASELINE_DIR, 'CD8_benchmark_filtered_outs.csv.gz'
)
merged_mixmhcpred22_cd8_benchmark_file = os.path.join(
    MIXMHCPRED22_BASELINE_DIR, 'CD8_benchmark_filtered_outs.csv.gz'
)
merged_mixmhcpred3_cd8_benchmark_file = os.path.join(
    MIXMHCPRED3_BASELINE_DIR, 'CD8_benchmark_filtered_outs.csv.gz'
)

CHECKPOINTS_DIR = '../checkpoints/csv_logger'
v_num = 0  # Version number
experiments_dict = {
    # 'pHLA_balance': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_balance', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 400,
    #     'hla_representation_type': 'surface_fp',
    # },
    # 'pHLA_balance_hla_pseudoseq': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_balance_hla_pseudoseq', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 36,
    #     'hla_representation_type': 'surface_fp',
    # },
    'pHLA_imbalance': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 400,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    # 'pHLA_balance_FILIP128': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_balance_FILIP128', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 400,
    #     'hla_representation_type': 'surface_fp',
    # },
    'pHLA_imbalance_hla_pseudoseq_ManSplits0123_4': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits0123_4', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_ManSplits0124_3': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits0124_3', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_ManSplits0134_2': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits0134_2', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_ManSplits0234_1': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits0234_1', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    'pHLA_imbalance_hla_pseudoseq_ManSplits1234_0': {
        'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_ManSplits1234_0', f'version_{v_num}', 'checkpoints','ep*'))[0],
        'hla_fp_size': 36,
        'hla_representation_type': 'surface_fp',
    },
    # 'pseudoseq_pHLA_imbalance_ManSplits0123_4': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pseudoseq_pHLA_imbalance_ManSplits0123_4', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 0,
    #     'hla_representation_type': 'pseudoseq',
    # },
    # 'pHLA_imbalance_newHLAFP_ManSplits0123_4': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_newHLAFP_ManSplits0123_4', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 400,
    #     'hla_representation_type': 'surface_fp',
    # },
    # 'pHLA_imbalance_EL_hla_pseudoseq_splitTrainTest': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_EL_hla_pseudoseq_splitTrainTest', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 36,
    #     'hla_representation_type': 'surface_fp',
    # },
    # 'pHLA_imbalance_EL_splitTrainTest': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_EL_splitTrainTest', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 400,
    #     'hla_representation_type': 'surface_fp',
    # },
    # 'pHLA_imbalance_hla_pseudoseq_AllBA_TestAsVal': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_AllBA_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 36,
    #     'hla_representation_type': 'surface_fp',
    # },
    # 'pHLA_imbalance_hla_pseudoseq_AllEL_TestAsVal': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_hla_pseudoseq_AllEL_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 36,
    #     'hla_representation_type': 'surface_fp',
    # },
    # 'pHLA_imbalance_newHLAFP_AllBA_TestAsVal': { # TODO test
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_newHLAFP_AllBA_TestAsVal', f'version_0', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 400,
    #     'hla_representation_type': 'surface_fp',
    # },
    # 'pHLA_imbalance_newHLAFP_AllBA_TestAsVal_v1': { # TODO test
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_newHLAFP_AllBA_TestAsVal', f'version_1', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 400,
    #     'hla_representation_type': 'surface_fp',
    # },
    # 'pHLA_imbalance_newHLAFP_AllEL_TestAsVal': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pHLA_imbalance_newHLAFP_AllEL_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 400,
    #     'hla_representation_type': 'surface_fp',
    # },
    # 'pseudoseq_pHLA_imbalance_AllBA_TestAsVal': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pseudoseq_pHLA_imbalance_AllBA_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 0,
    #     'hla_representation_type': 'pseudoseq',
    # },
    # 'pseudoseq_pHLA_imbalance_AllEL_TestAsVal': {
    #     'model_checkpoint': glob.glob(os.path.join(CHECKPOINTS_DIR, 'pseudoseq_pHLA_imbalance_AllEL_TestAsVal', f'version_{v_num}', 'checkpoints','ep*'))[0],
    #     'hla_fp_size': 0,
    #     'hla_representation_type': 'pseudoseq',
    # },
}

In [None]:
# For debugging
test_netmhcpan_data = pd.read_csv(test_netmhcpan_data_file)
# pick 500 random samples
test_netmhcpan_data = test_netmhcpan_data.sample(500, replace=False, random_state=42)
# show num 1 labels
test_netmhcpan_data[test_netmhcpan_data['label'] == 1]

In [None]:
# test_netmhcpan_data = pd.read_csv(test_netmhcpan_data_file)
hla_fp_36_emb = np.load(hla_fp_36_file)
hla_fp_400_emb = np.load(hla_fp_400_file)
hla_fp_36_data = pd.read_csv(hla_fp_36_data_file, index_col=1, names=['index'], header=0).to_dict()['index']
hla_fp_400_data = pd.read_csv(hla_fp_400_data_file, index_col=1, names=['index'], header=0).to_dict()['index']
hla_fp_36_dict = {hla: torch.Tensor(hla_fp_36_emb[idx]) for hla, idx in hla_fp_36_data.items()}
hla_fp_400_dict = {hla: torch.Tensor(hla_fp_400_emb[idx]) for hla, idx in hla_fp_400_data.items()}
hla_pseudoseq_dict = pd.read_csv(hla_pseudoseq_file, sep='\s+', names=['pseudoseq'], header=None).to_dict()['pseudoseq']


In [None]:
test_dataset_pseudoseq_surf = pHLADataset(
    peptide_seq_arr=test_netmhcpan_data['peptide'].values,
    hla_names_arr=test_netmhcpan_data['hla_allele'].values, 
    hla_fp_dict=hla_fp_36_dict,
    labels=test_netmhcpan_data['label'].values
)

In [None]:
test_dataset_patch_surf = pHLADataset(
    peptide_seq_arr=test_netmhcpan_data['peptide'].values,
    hla_names_arr=test_netmhcpan_data['hla_allele'].values, 
    hla_fp_dict=hla_fp_400_dict,
    labels=test_netmhcpan_data['label'].values
)

In [None]:
test_dataset_pseudoseq = pHLAPseudoseqDataset(
    peptide_seq_arr=test_netmhcpan_data['peptide'].values,
    hla_names_arr=test_netmhcpan_data['hla_allele'].values,
    hla_pseudoseq_dict=hla_pseudoseq_dict,
    labels=test_netmhcpan_data['label'].values
)

In [None]:
def load_pretrained_model(checkpoint_file, hla_representation_type='surface_fp'):
    if hla_representation_type == 'surface_fp':
        logger.info(f"Loading pHLABindingPredictor pretrained model {checkpoint_file}")
        model = pHLABindingPredictor.load_from_checkpoint(checkpoint_file)
    elif hla_representation_type == 'pseudoseq':
        logger.info(f"Loading pHLAPseudoseqBindingPredictor pretrained model {checkpoint_file}")
        model = pHLAPseudoseqBindingPredictor.load_from_checkpoint(checkpoint_file)
    else:
        logger.error(f"Unknown hla_representation_type {hla_representation_type}. "
                     f"Expected 'surface_fp' or 'pseudoseq'")
        sys.exit(1)
    return model


In [None]:
#for exp_name in experiments_dict.keys():
exp_name = 'pseudoseq_pHLA_imbalance_ManSplits0123_4'
hla_representation_type = experiments_dict[exp_name]['hla_representation_type']
model = load_pretrained_model(experiments_dict[exp_name]['model_checkpoint'], hla_representation_type=hla_representation_type)

if hla_representation_type == 'surface_fp':
    if 400 == experiments_dict[exp_name]['hla_fp_size']:
        logger.info(f"Using 400 dimensional HLA fingerprints")
        test_dataset = test_dataset_patch_surf
    elif 36 == experiments_dict[exp_name]['hla_fp_size']:
        logger.info(f"Using 36 dimensional HLA fingerprints")
        test_dataset = test_dataset_pseudoseq_surf
elif hla_representation_type == 'pseudoseq':
    logger.info(f"Using HLA pseudosequences")
    test_dataset = test_dataset_pseudoseq
else:
    logger.error(f"Unknown hla_representation_type {hla_representation_type} for {exp_name}")
    sys.exit(1)

model.eval()
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
preds = []
labels = []
for i, batch in enumerate(test_loader):
    p, h, l = batch
    reps = model(p, h) # Filip Representation
    logits = model.linear_to_logits(reps)
    logits = model.to_pred(logits)
    pred = torch.sigmoid(logits).detach().cpu().numpy().tolist()
    preds.extend(pred)
    l = l.detach().cpu().numpy().tolist()
    labels.extend(l)
    
# save results to file in checkpoint folder
results = pd.DataFrame({'labels': labels, 'preds': preds})
results['labels'] = results['labels'].astype(int)
results_file = os.path.join(os.path.dirname(experiments_dict[exp_name]['model_checkpoint']), '..', 'test_results.tsv')
results.to_csv(results_file, index=False, sep='\t')
logger.info(f"Predicted test results saved to {results_file}")

metrics = {}
# compute accuracy
preds = np.array(preds)
labels = np.array(labels)
preds_05 = preds > 0.5
acc = np.mean(preds_05 == labels)
metrics['accuracy'] = acc
# compute precision: TP / (TP + FP)
precision = precision_score(labels, preds_05)
metrics['precision'] = precision
# compute recall: TP / (TP + FN)
recall = recall_score(labels, preds_05)
metrics['recall'] = recall
# compute auroc
fpr, tpr, _ = roc_curve(labels, preds)
auroc_score = auc(fpr, tpr)
metrics['auroc'] = auroc_score
# get pred value for negative class
neg_prob = np.mean(preds[labels == 0])
metrics['neg_prob'] = neg_prob
# get pred value for positive class
pos_prob = np.mean(preds[labels == 1])
metrics['pos_prob'] = pos_prob
# save metrics to file in checkpoint folder
metrics_file = os.path.join(os.path.dirname(experiments_dict[exp_name]['model_checkpoint']), '..', 'test_metrics.tsv')
pd.DataFrame(metrics, index=[0]).to_csv(metrics_file, index=False, sep='\t')
logger.info(f"Metrics saved to {metrics_file}")

# Analysis on precumputed results
## Load baselines results
Each baseline df will have the following columns.
- NetMHCpan4.1:
    - MHC column with format HLA-A*02:01
    - Score_EL column with the predicted score
    - %Rank_EL column with the predicted rank
    - Exp column with the true label
    - BindLevel column with the binding level predicted with NetMHCpan4.1 default thresholds (Stron, Weak and None binders)
- MixMHCpred2.2 and MixMHCpred3 have the same columns:
    - MHC column with format HLA-A02-01
    - Score column with the predicted score
    - %Rank column with the predicted rank
    - is_binder column with 0 or 1. Indicates if the peptide is a binder
    - pred_is_binder column with 0 or 1. Indicates if the peptide is predicted as a binder according to the model default usage

In [None]:
# Load baselines data
cd8_netmhcpan_df = pd.read_csv(merged_netmhcpan_cd8_benchmark_file)
# change MHC colomn format from HLA-A*02:01 to HLA-A02-01
cd8_netmhcpan_df['MHC'] = cd8_netmhcpan_df['MHC'].str.replace('*', '').str.replace(':', '-')

cd8_mixmhcpred22_df = pd.read_csv(merged_mixmhcpred22_cd8_benchmark_file)
cd8_mixmhcpred3_df = pd.read_csv(merged_mixmhcpred3_cd8_benchmark_file)
# change mixmhcpred allele name format in BestAllele from A0201 to HLA-A02-01
cd8_mixmhcpred22_df['BestAllele'] = 'HLA-' + cd8_mixmhcpred22_df['BestAllele'].str.slice(0,3) + '-' + cd8_mixmhcpred22_df['BestAllele'].str.slice(3, 5)
cd8_mixmhcpred3_df['BestAllele'] = 'HLA-' + cd8_mixmhcpred3_df['BestAllele'].str.slice(0,3) + '-' + cd8_mixmhcpred3_df['BestAllele'].str.slice(3, 5)
# rename column Score_bestAllele to Score
cd8_mixmhcpred22_df = cd8_mixmhcpred22_df.rename(columns={'Score_bestAllele': 'Score', 'BestAllele': 'MHC', '%Rank_bestAllele': '%Rank'})
cd8_mixmhcpred3_df = cd8_mixmhcpred3_df.rename(columns={'Score_bestAllele': 'Score', 'BestAllele': 'MHC', '%Rank_bestAllele': '%Rank'})

## ROC-AUC curve
Plot ROC-AUC curve for all the models

In [None]:
for exp_name in experiments_dict.keys():
    results_file = os.path.join(os.path.dirname(experiments_dict[exp_name]['model_checkpoint']), '..', 'test_results.tsv')
    try:
        results = pd.read_csv(results_file, sep='\t')
    except FileNotFoundError:
        logger.error(f"File {results_file} not found")
        continue
    fpr, tpr, _ = roc_curve(results['labels'], results['preds'])
    auroc_score = auc(fpr, tpr)
    plt.plot(fpr, tpr, label=f'{exp_name} (area = {auroc_score:.2f})')
    
# plot baselines
# NetMHCpan4.1
fpr, tpr, _ = roc_curve(cd8_netmhcpan_df['Exp'], cd8_netmhcpan_df['Score_EL'])
auroc_score = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'NetMHCpan4.1_EL_score (area = {auroc_score:.2f})', linestyle='--', alpha=0.5, color='black')
# Notice that the tpr, fpr are swapped in the plot. This is because higher values of Rank_EL indicate lower binding affinity
tpr, fpr, _ = roc_curve(cd8_netmhcpan_df['Exp'], cd8_netmhcpan_df['%Rank_EL'])
auroc_score = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'NetMHCpan4.1_EL_rank (area = {auroc_score:.2f})', linestyle='-', alpha=0.5, color='black')
# MixMHCpred
# Notice that the tpr, fpr are swapped in the plot. This is because higher values of MixMHCpred indicate lower binding affinity
tpr, fpr, _ = roc_curve(cd8_mixmhcpred22_df['is_binder'], cd8_mixmhcpred22_df['%Rank'])
auroc_score = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'MixMHCpred2.2_Rank (area = {auroc_score:.2f})', linestyle='--', alpha=0.5, color='purple')
tpr, fpr, _ = roc_curve(cd8_mixmhcpred3_df['is_binder'], cd8_mixmhcpred3_df['%Rank'])
auroc_score = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'MixMHCpred3_Rank (area = {auroc_score:.2f})', linestyle='-', alpha=0.5, color='purple')

    
plt.plot([0, 1], [0, 1], lw=2, linestyle='dotted', color='gray', alpha=0.5)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.ylim([-0.01, 1.01])
plt.xlim([-0.01, 1.01])
# plot legend outside the plot
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), frameon=False, facecolor='white')
# make plot square
plt.gca().set_aspect('equal', adjustable='box')
plt.gcf().set_size_inches(10, 10)
plt.rcParams.update({'font.size': 18})
plt.tight_layout()
plt.savefig('roc_curve.png')
plt.show()


## Plot Precision-Recall curve

In [None]:
for exp_name in experiments_dict.keys():
    results_file = os.path.join(os.path.dirname(experiments_dict[exp_name]['model_checkpoint']), '..', 'test_results.tsv')
    try:
        results = pd.read_csv(results_file, sep='\t')
    except FileNotFoundError:
        logger.error(f"File {results_file} not found")
        continue
    precision, recall, _ = precision_recall_curve(results['labels'], results['preds'])
    # compute area under the curve
    pr_auc = auc(recall, precision)
    plt.plot(recall, precision, label=f'{exp_name} (area = {pr_auc:.2f})')
    
precision, recall, _ = precision_recall_curve(cd8_netmhcpan_df['Exp'], cd8_netmhcpan_df['Score_EL'])
pr_auc = auc(recall, precision)
plt.plot(recall, precision, label=f'NetMHCpan4.1_EL_score (area = {pr_auc:.2f})')
    
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall curve')
plt.ylim([-0.01, 1.01])
plt.xlim([-0.01, 1.01])
# plot legend outside the plot
plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), frameon=False, facecolor='white')
# make plot square
plt.gca().set_aspect('equal', adjustable='box')
plt.gcf().set_size_inches(10, 10)
plt.rcParams.update({'font.size': 16})
plt.tight_layout()
plt.savefig('pr_curve.png')
plt.show()

## Compare metrics across models

In [None]:
# compute metrics for baselines

metrics_baseline = {}
bl_preds_05 = cd8_netmhcpan_df['Score_EL'] > 0.5
fpr, tpr, _ = roc_curve(cd8_netmhcpan_df['Exp'], cd8_netmhcpan_df['Score_EL'])
metrics_baseline['NetMHCpan4.1_EL'] = {
    'accuracy': np.mean(bl_preds_05 == cd8_netmhcpan_df['Exp']),
    'precision': precision_score(cd8_netmhcpan_df['Exp'], bl_preds_05),
    'recall': recall_score(cd8_netmhcpan_df['Exp'], bl_preds_05),
    'auroc': auc(fpr, tpr),
    'neg_prob': np.mean(bl_preds_05[cd8_netmhcpan_df['Exp'] == 0]),
    'pos_prob': np.mean(bl_preds_05[cd8_netmhcpan_df['Exp'] == 1])
}
bl_preds_swb = cd8_netmhcpan_df['%Rank_EL'] < 2 
metrics_baseline['NetMHCpan4.1_EL_rank'] = {
    'accuracy': np.mean(bl_preds_swb == cd8_netmhcpan_df['Exp']),
    'precision': precision_score(cd8_netmhcpan_df['Exp'], bl_preds_swb),
    'recall': recall_score(cd8_netmhcpan_df['Exp'], bl_preds_swb),
    'auroc': auc(fpr, tpr),
    'neg_prob': np.mean(bl_preds_swb[cd8_netmhcpan_df['Exp'] == 0]),
    'pos_prob': np.mean(bl_preds_swb[cd8_netmhcpan_df['Exp'] == 1])
}

bl_preds_rank2 = cd8_mixmhcpred22_df['%Rank'] < 2
metrics_baseline['MixMHCpred2.2_Rank'] = {
    'accuracy': np.mean(bl_preds_rank2 == cd8_mixmhcpred22_df['is_binder']),
    'precision': precision_score(cd8_mixmhcpred22_df['is_binder'], bl_preds_rank2),
    'recall': recall_score(cd8_mixmhcpred22_df['is_binder'], bl_preds_rank2),
    'auroc': auc(fpr, tpr),
    'neg_prob': np.mean(bl_preds_rank2[cd8_mixmhcpred22_df['is_binder'] == 0]),
    'pos_prob': np.mean(bl_preds_rank2[cd8_mixmhcpred22_df['is_binder'] == 1])
}
bl_preds_rank2 = cd8_mixmhcpred3_df['%Rank'] < 2
metrics_baseline['MixMHCpred3_Rank'] = {
    'accuracy': np.mean(bl_preds_rank2 == cd8_mixmhcpred3_df['is_binder']),
    'precision': precision_score(cd8_mixmhcpred3_df['is_binder'], bl_preds_rank2),
    'recall': recall_score(cd8_mixmhcpred3_df['is_binder'], bl_preds_rank2),
    'auroc': auc(fpr, tpr),
    'neg_prob': np.mean(bl_preds_rank2[cd8_mixmhcpred3_df['is_binder'] == 0]),
    'pos_prob': np.mean(bl_preds_rank2[cd8_mixmhcpred3_df['is_binder'] == 1])
}

metrics_baseline_df = pd.DataFrame.from_dict(metrics_baseline, orient='columns').T
metrics_baseline_df

In [None]:
metrics = {}
for exp_name in experiments_dict.keys():
    metrics_file = os.path.join(os.path.dirname(experiments_dict[exp_name]['model_checkpoint']), '..', 'test_metrics.tsv')
    try:
        metrics[exp_name] = pd.read_csv(metrics_file, sep='\t')
    except FileNotFoundError:
        logger.error(f"File {metrics_file} not found")
        continue
# Add baseline to metrics
for i, r in metrics_baseline_df.iterrows():
    metrics[i] = pd.DataFrame(r).T
    
# Make metrics df keeping exp_name in column
metrics_df = pd.concat(metrics).reset_index(drop=False)
metrics_df = metrics_df.rename(columns={'level_0': 'exp_name'})
metrics_df = metrics_df.drop(columns='level_1')
metrics_df

In [None]:
# Plot metrics
fig, axs = plt.subplots(2, 2, figsize=(20, 20))
sns.barplot(x='exp_name', y='accuracy', data=metrics_df, ax=axs[0, 0])
sns.barplot(x='exp_name', y='precision', data=metrics_df, ax=axs[0, 1])
sns.barplot(x='exp_name', y='recall', data=metrics_df, ax=axs[1, 0])
sns.barplot(x='exp_name', y='auroc', data=metrics_df, ax=axs[1, 1])
plt.rcParams.update({'font.size': 12})
# rotate x labels
for ax in axs.flat:
    for label in ax.get_xticklabels():
        label.set_rotation(45)
        label.set_ha('right')
plt.tight_layout()
plt.show()

## Confusion matrix
Make a plot with all confusion matrices. The confusion matrix is computed with a threshold of 0.5. The title is the experiment name.

In [None]:
from matplotlib.colors import LogNorm

n_baselines = len(metrics_baseline_df)
n_plots = len(experiments_dict.keys()) + n_baselines
n_cols = 3
n_rows = int(np.ceil(n_plots / n_cols))
fig, axs = plt.subplots(n_rows, n_cols, figsize=(8*n_cols, 8*n_rows))
normalize_cm = 'true' # Normalize confusion matrix to get %. Must be str among {'all', 'true', 'pred'} or None
if normalize_cm is None:
    fmt = 'd'
    norm = LogNorm()
else:
    fmt = '.4f'
    norm = None

for i, exp_name in enumerate(experiments_dict.keys()):
    results_file = os.path.join(os.path.dirname(experiments_dict[exp_name]['model_checkpoint']), '..', 'test_results.tsv')
    try:
        results = pd.read_csv(results_file, sep='\t')
    except FileNotFoundError:
        logger.error(f"File {results_file} not found")
        continue
    conf_matrix = confusion_matrix(results['labels'], results['preds'] > 0.5, normalize=normalize_cm)
    sns.heatmap(conf_matrix, annot=True, fmt=fmt, cmap='Blues_r', ax=axs[i//n_cols, i%n_cols], norm=norm)
    axs[i//n_cols, i%n_cols].set_title(exp_name)
    axs[i//n_cols, i%n_cols].set_xlabel('Predicted label')
    axs[i//n_cols, i%n_cols].set_ylabel('True label')
    
# Plot baselines
conf_matrix = confusion_matrix(cd8_netmhcpan_df['Exp'], cd8_netmhcpan_df['Score_EL'] > 0.5, normalize=normalize_cm)
sns.heatmap(conf_matrix, annot=True, fmt=fmt, cmap='Blues_r', ax=axs[(i+1)//n_cols, (i+1)%n_cols], norm=norm)
axs[(i+1)//n_cols, (i+1)%n_cols].set_title('NetMHCpan4.1_EL score')
axs[(i+1)//n_cols, (i+1)%n_cols].set_xlabel('Predicted label')
axs[(i+1)//n_cols, (i+1)%n_cols].set_ylabel('True label')

conf_matrix = confusion_matrix(cd8_netmhcpan_df['Exp'], cd8_netmhcpan_df['%Rank_EL'] < 2, normalize=normalize_cm)
sns.heatmap(conf_matrix, annot=True, fmt=fmt, cmap='Blues_r', ax=axs[(i+2)//n_cols, (i+2)%n_cols], norm=norm)
axs[(i+2)//n_cols, (i+2)%n_cols].set_title('NetMHCpan4.1_EL rank')
axs[(i+2)//n_cols, (i+2)%n_cols].set_xlabel('Predicted label')
axs[(i+2)//n_cols, (i+2)%n_cols].set_ylabel('True label')

conf_matrix = confusion_matrix(cd8_mixmhcpred22_df['is_binder'], cd8_mixmhcpred22_df['%Rank'] < 2, normalize=normalize_cm)
sns.heatmap(conf_matrix, annot=True, fmt=fmt, cmap='Blues_r', ax=axs[(i+3)//n_cols, (i+3)%n_cols], norm=norm)
axs[(i+3)//n_cols, (i+3)%n_cols].set_title('MixMHCpred2.2_Rank')
axs[(i+3)//n_cols, (i+3)%n_cols].set_xlabel('Predicted label')
axs[(i+3)//n_cols, (i+3)%n_cols].set_ylabel('True label')

conf_matrix = confusion_matrix(cd8_mixmhcpred3_df['is_binder'], cd8_mixmhcpred3_df['%Rank'] < 2, normalize=normalize_cm)
sns.heatmap(conf_matrix, annot=True, fmt=fmt, cmap='Blues_r', ax=axs[(i+4)//n_cols, (i+4)%n_cols], norm=norm)
axs[(i+4)//n_cols, (i+4)%n_cols].set_title('MixMHCpred3_Rank')
axs[(i+4)//n_cols, (i+4)%n_cols].set_xlabel('Predicted label')
axs[(i+4)//n_cols, (i+4)%n_cols].set_ylabel('True label')

plt.rcParams.update({'font.size': 20})
plt.tight_layout()
plt.show()

## Check false positives and false negatives

Get information about which hla alleles are more likely to be false positives and false negatives for each model when using the threshold of 0.5.

In [None]:
test_netmhcpan_data = pd.read_csv(test_netmhcpan_data_file)

In [None]:
training_ba_data_file = os.path.join(PROCESSED_DATA, 'pHLA_binding', 'NetMHCpan_dataset', 'train_binding_affinity_peptides_data_MaxLenPep15_hla_ABC_with_BalancedSplits.csv')
training_ba_data = pd.read_csv(training_ba_data_file)
# count num of positive and negative labels for each hla allele
training_ba_data['label'] = training_ba_data['label'].astype(int)
hla_allele_counts = training_ba_data.groupby('hla_allele')['label'].value_counts().unstack().fillna(0)
hla_allele_counts['train_total'] = hla_allele_counts.sum(axis=1)
hla_allele_counts = hla_allele_counts.sort_values(by='train_total', ascending=False)
hla_allele_counts = hla_allele_counts.astype(int)
hla_allele_counts = hla_allele_counts.rename(columns={0: 'train_NB', 1: 'train_B'})
# compute ratio of positive labels
hla_allele_counts['train_ratio_pos_per_total'] = hla_allele_counts['train_B'] / hla_allele_counts['train_total']
hla_allele_counts

In [None]:
for exp_name in experiments_dict.keys():
    results_file = os.path.join(os.path.dirname(experiments_dict[exp_name]['model_checkpoint']), '..', 'test_results.tsv')
    try:
        results = pd.read_csv(results_file, sep='\t')
    except FileNotFoundError:
        logger.error(f"File {results_file} not found")
        continue
    results['preds_05'] = results['preds'] > 0.5
    false_positives = results[(results['labels'] == 0) & (results['preds_05'] == 1)]
    fp_index = false_positives.index
    fp_data = test_netmhcpan_data.iloc[fp_index]
    fp_hla_alleles = fp_data['hla_allele'].value_counts()
    # add columns with training counts per allele
    fp_hla_alleles = pd.concat([fp_hla_alleles, hla_allele_counts], axis=1, join='inner')#.sort_values(by='total', ascending=False)
    false_negatives = results[(results['labels'] == 1) & (results['preds_05'] == 0)]
    fn_index = false_negatives.index
    fn_data = test_netmhcpan_data.iloc[fn_index]
    fn_hla_alleles = fn_data['hla_allele'].value_counts()
    # add columns with training counts per allele
    fn_hla_alleles = pd.concat([fn_hla_alleles, hla_allele_counts], axis=1, join='inner')#.sort_values(by='total', ascending=False)
    
    logger.info(f"False positives (non-binders classified as binders) for {exp_name} were {len(fp_data)}:")
    logger.info(fp_hla_alleles)
    logger.info(f"False negatives (binders classified as non-binders) for {exp_name} were {len(fn_data)}:")
    logger.info(fn_hla_alleles)
    break

## Plot ROC-AUC curve for each hla allele
Num of data is taken from the binding affinity train set loaded in the previous secction.

In [None]:
n_hla_alleles = test_netmhcpan_data['hla_allele'].nunique() # 52
n_cols = 4
n_rows = int(np.ceil(n_hla_alleles / n_cols))
fig, axs = plt.subplots(n_rows, n_cols, figsize=(10*n_cols, 10*n_rows))
uniq_hla_alleles = sorted(test_netmhcpan_data['hla_allele'].unique())

for i, hla_allele in enumerate(uniq_hla_alleles):
    hla_data = test_netmhcpan_data[test_netmhcpan_data['hla_allele'] == hla_allele]
    for exp_name in experiments_dict.keys():
        results_file = os.path.join(os.path.dirname(experiments_dict[exp_name]['model_checkpoint']), '..', 'test_results.tsv')
        try:
            results = pd.read_csv(results_file, sep='\t')
        except FileNotFoundError:
            logger.error(f"File {results_file} not found")
            continue
        hla_results = results.iloc[hla_data.index]
        fpr, tpr, _ = roc_curve(hla_results['labels'], hla_results['preds'])
        auroc_score = auc(fpr, tpr)
        axs[i//n_cols, i%n_cols].plot(fpr, tpr, label=f'{exp_name} (area = {auroc_score:.2f})')
        axs[i//n_cols, i%n_cols].set_title(f'{hla_allele} (N={hla_allele_counts.loc[hla_allele, "train_total"]}, r_pos={hla_allele_counts.loc[hla_allele, "train_ratio_pos_per_total"]:.2f})')
        axs[i//n_cols, i%n_cols].set_xlabel('False Positive Rate')
        axs[i//n_cols, i%n_cols].set_ylabel('True Positive Rate')
        axs[i//n_cols, i%n_cols].set_ylim([-0.01, 1.01])
        axs[i//n_cols, i%n_cols].set_xlim([-0.01, 1.01])
    # Add baselines
    fpr, tpr, _ = roc_curve(cd8_netmhcpan_df[cd8_netmhcpan_df['MHC'] == hla_allele]['Exp'], cd8_netmhcpan_df[cd8_netmhcpan_df['MHC'] == hla_allele]['Score_EL'])
    auroc_score = auc(fpr, tpr)
    axs[i//n_cols, i%n_cols].plot(fpr, tpr, label=f'NetMHCpan4.1_EL_score (area = {auroc_score:.2f})', linestyle='--', alpha=0.5, color='black')
    # Notice that the tpr, fpr are swapped in the plot. This is because higher values of Rank_EL indicate lower binding affinity
    tpr, fpr, _ = roc_curve(cd8_netmhcpan_df[cd8_netmhcpan_df['MHC'] == hla_allele]['Exp'], cd8_netmhcpan_df[cd8_netmhcpan_df['MHC'] == hla_allele]['%Rank_EL'])
    auroc_score = auc(fpr, tpr)
    axs[i//n_cols, i%n_cols].plot(fpr, tpr, label=f'NetMHCpan4.1_EL_rank (area = {auroc_score:.2f})', linestyle='-', alpha=0.5, color='black')
    # MixMHCpred
    # Notice that the tpr, fpr are swapped in the plot. This is because higher values of MixMHCpred indicate lower binding affinity
    tpr, fpr, _ = roc_curve(cd8_mixmhcpred22_df[cd8_mixmhcpred22_df['MHC'] == hla_allele]['is_binder'], cd8_mixmhcpred22_df[cd8_mixmhcpred22_df['MHC'] == hla_allele]['%Rank'])
    auroc_score = auc(fpr, tpr)
    axs[i//n_cols, i%n_cols].plot(fpr, tpr, label=f'MixMHCpred2.2_Rank (area = {auroc_score:.2f})', linestyle='--', alpha=0.5, color='purple')
    tpr, fpr, _ = roc_curve(cd8_mixmhcpred3_df[cd8_mixmhcpred3_df['MHC'] == hla_allele]['is_binder'], cd8_mixmhcpred3_df[cd8_mixmhcpred3_df['MHC'] == hla_allele]['%Rank'])
    auroc_score = auc(fpr, tpr)
    axs[i//n_cols, i%n_cols].plot(fpr, tpr, label=f'MixMHCpred3_Rank (area = {auroc_score:.2f})', linestyle='-', alpha=0.5, color='purple')
    axs[i//n_cols, i%n_cols].plot([0, 1], [0, 1], lw=2, linestyle='dotted', color='gray', alpha=0.5)
    axs[i//n_cols, i%n_cols].legend(loc='upper center', bbox_to_anchor=(0.5, -0.2), frameon=False, facecolor='white')
    axs[i//n_cols, i%n_cols].set_aspect('equal', adjustable='box')
    axs[i//n_cols, i%n_cols].set_ylim([-0.01, 1.01])
    
    # break
    
# Only one legend
# handles, labels = axs[0,0].get_legend_handles_labels()
# fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, -0.05), frameon=False, facecolor='white')
plt.rcParams.update({'font.size': 20})
plt.tight_layout()
plt.show()

In [None]:
from sklearn.metrics import classification_report
print(classification_report(results['labels'], results['preds']>0.5))