In [7]:
import pandas as pd
from datetime import datetime
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import os
from itertools import cycle
import numpy as np
from sklearn.metrics import (
    roc_curve,
    precision_recall_curve,
    auc,
    average_precision_score,
    accuracy_score,
    f1_score,
    matthews_corrcoef)
from scipy import interpolate

In [8]:
baseline_results_path = '/projects/0/einf2380/data/results/best_models_metrics.csv'
baseline_df = pd.read_csv(baseline_results_path)
models = {
    'SHUFF_CNN': 'CNN_shuffle_class',
    'PEPT_CNN': 'CNN_peptide_class',
    'ALLELE_CNN': 'CNN_allele_class',
    'SHUFF_Group_Reg4': 'CNN_shuffle_reg',
    'PEPT_Group_Reg4': 'CNN_peptide_reg',
    'ALLELE_Group_Reg4': 'CNN_allele_reg',
    'mlp_classification_blosum_with_allele_encoder_500_neurons_50_epochs_shuffled_64_batch_size.pt': 'MLP_shuffle_class',
    'mlp_classification_blosum_with_allele_encoder_500_neurons_50_epochs_LOGO_anch_rep_64_batch_size.pt': 'MLP_peptide_class',
    'mlp_classification_blosum_with_allele_encoder_500_neurons_50_epochs_pseudoseq_cluster_64_batch_size.pt': 'MLP_allele_class',
    'mhcflurry_held_out_trained': 'MHCFlurry_shuffle',
    'mhcflurry_peptide_clustered_trained': 'MHCFlurry_peptide',
    'mhcflurry_allele_clustered_trained': 'MHCFlurry_allele',
}
baseline_df.rename(columns={"Model":"model"}, inplace=True)
baseline_df.replace(models, inplace=True)
baseline_df.head()

Unnamed: 0,model,acc,auc,f1,mcc,tnr,tpr
0,SHUFF_Cnn_SumFeat_ChannExpand,0.787383,0.855928,0.751575,0.566675,0.831581,0.73111
1,PEPT_Cnn_SumFeat_ChannExpand,0.717805,0.788124,0.70636,0.43834,0.695017,0.745038
2,ALLELE_Cnn_SumFeat_ChannExpand,0.649774,0.697055,0.431115,0.245109,0.878164,0.322732
3,SHUFF_Group_Class4,0.777401,0.848964,0.740999,0.546493,0.819462,0.723848
4,PEPT_Group_Class4,0.710789,0.750552,0.692116,0.42058,0.708477,0.713552


In [9]:
baseline_df

Unnamed: 0,model,acc,auc,f1,mcc,tnr,tpr
0,SHUFF_Cnn_SumFeat_ChannExpand,0.787383,0.855928,0.751575,0.566675,0.831581,0.73111
1,PEPT_Cnn_SumFeat_ChannExpand,0.717805,0.788124,0.70636,0.43834,0.695017,0.745038
2,ALLELE_Cnn_SumFeat_ChannExpand,0.649774,0.697055,0.431115,0.245109,0.878164,0.322732
3,SHUFF_Group_Class4,0.777401,0.848964,0.740999,0.546493,0.819462,0.723848
4,PEPT_Group_Class4,0.710789,0.750552,0.692116,0.42058,0.708477,0.713552
5,ALLELE_Group_Class4,0.646601,0.674479,0.49623,0.244818,0.802548,0.423293
6,CNN_shuffle_class,0.780295,0.871268,0.712325,0.557802,0.907503,0.618334
7,CNN_peptide_class,0.713439,0.807924,0.624438,0.42716,0.872852,0.52293
8,CNN_allele_class,0.639004,0.725724,0.332978,0.221437,0.932223,0.21913
9,MLP_shuffle_class,0.818141,0.892664,0.793863,0.631333,0.79577,0.79577


In [10]:
######## Modify here
project_folder = '/projects/0/einf2380'
protein_class = 'I'
exp_path = f'{project_folder}/data/pMHC{protein_class}/trained_models/deeprankcore/experiments/cyulin/'
# exps we want to compare with sequence-based model baseline and best CNN
exp_ids = [
    'exp_100k_final_feattrans_Increase1_seed44_rmpssm_0',
    'exp_100k_final_feattrans_Increase1_cl_peptide2_seed55_rmpssm_0',
    'exp_100k_final_feattrans_Increase1_cl_allele_seed55_rmpssm_0']
new_exp_ids = [
    'GNN_shuffle_class',
    'GNN_peptide_class',
    'GNN_allele_class']
comparison_id = 'cluster_exp'
########
exp_log = pd.read_excel(exp_path + 'cyulin_experiments_log.xlsx', index_col='exp_id')
exp_log.head()

Unnamed: 0_level_0,exp_fullname,exp_path,start_time,end_time,input_data_path,protein_class,target_data,resolution,task,node_features,...,testing_f1,training_accuracy,validation_accuracy,testing_accuracy,training_precision,validation_precision,testing_precision,training_recall,validation_recall,testing_recall
exp_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
exp_100k_final_feattrans_Increase1_cl_allele_seed55_rmpssm_0,exp_100k_final_feattrans_Increase1_cl_allele_s...,/projects/0/einf2380/data/pMHCI/trained_models...,04/Jun/2023_00:48:35,04/Jun/2023_05:37:54,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,0.579,0.696,0.685,0.629,0.598,0.589,0.543,0.96,0.957,0.62
exp_100k_final_feattrans_Increase1_seed55_rmpssm_0,exp_100k_final_feattrans_Increase1_seed55_rmps...,/projects/0/einf2380/data/pMHCI/trained_models...,04/Jun/2023_00:34:14,04/Jun/2023_05:27:42,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,0.757,0.792,0.765,0.771,0.725,0.69,0.711,0.849,0.847,0.809
exp_100k_final_feattrans_Increase1_cl_peptide2_seed55_rmpssm_0,exp_100k_final_feattrans_Increase1_cl_peptide2...,/projects/0/einf2380/data/pMHCI/trained_models...,04/Jun/2023_00:37:48,04/Jun/2023_05:23:14,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,0.692,0.798,0.769,0.773,0.746,0.719,0.65,0.85,0.818,0.741
exp_100k_final_feattrans_Increase1_seed44_rmpssm_0,exp_100k_final_feattrans_Increase1_seed44_rmps...,/projects/0/einf2380/data/pMHCI/trained_models...,04/Jun/2023_00:28:18,04/Jun/2023_05:22:43,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,0.763,0.805,0.782,0.783,0.774,0.757,0.735,0.787,0.743,0.795
exp_100k_final_Increase2_seed55_rmpssm_0,exp_100k_final_Increase2_seed55_rmpssm_0_230603,/projects/0/einf2380/data/pMHCI/trained_models...,03/Jun/2023_13:34:06,03/Jun/2023_20:29:05,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,0.756,0.778,0.757,0.771,0.695,0.677,0.711,0.885,0.856,0.808


In [11]:
######## Definitions used in the plotting
comparisons_path = os.path.join(exp_path, 'comparisons', 'baseline')
comparison_path = os.path.join(comparisons_path, comparison_id)

if not os.path.exists(comparisons_path):
    os.makedirs(comparisons_path)

if not os.path.exists(comparison_path):
    os.makedirs(comparison_path)
else:
    print(f'Folder comparisons/{comparison_id}/ already exists! \
          \nChange comparison_id if you want to save plots for a different comparison.')

def get_single_exp_df(exp_id, exp_log, exp_path):
    exp_fullname = exp_log.loc[exp_id].exp_fullname
    exp_path = os.path.join(exp_path, exp_fullname)
    output_path = os.path.join(exp_path, 'output')
    output_train = pd.read_hdf(os.path.join(output_path, 'output_exporter.hdf5'), key='training')
    output_test = pd.read_hdf(os.path.join(output_path, 'output_exporter.hdf5'), key='testing')
    df = pd.concat([output_train, output_test])
    df.sort_values(by=['epoch'], inplace = True)
    return df

In [12]:
for idx, exp_id in enumerate(exp_ids):
    exp_dict = {}
    exp_dict["model"] = new_exp_ids[idx]
    df = get_single_exp_df(exp_id, exp_log, exp_path)
    df_plot = df[(df.epoch == 0) & (df.phase == 'testing')]

    y_true = df_plot.target
    y_score = np.array(df_plot.output.values.tolist())[:, 1]

    thrs = np.linspace(0,1,100)
    accuracy = []
    f1 = []
    mcc = []
    for thr in thrs:
        y_pred = (y_score > thr)*1
        accuracy.append(accuracy_score(y_true, y_pred))
        f1.append(f1_score(y_true, y_pred))
        mcc.append(matthews_corrcoef(y_true, y_pred))

    thr_df = pd.DataFrame({
        'thr': thrs,
        'accuracy': accuracy,
        'f1': f1,
        'mcc': mcc})
    # maximize mcc
    mcc_idxmax = thr_df.mcc.idxmax()
    sel_thr = thr_df.loc[mcc_idxmax].thr
    exp_dict["model"] = exp_dict["model"] + f"_{sel_thr:.2f}"
    exp_dict["mcc"] = thr_df.loc[mcc_idxmax].mcc
    exp_dict["acc"] = thr_df.loc[mcc_idxmax].accuracy
    exp_dict["f1"] = thr_df.loc[mcc_idxmax].f1

    fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)
    exp_dict["auc"] = auc(fpr_roc, tpr_roc)
    tpr_intrp = interpolate.interp1d(thr_roc, tpr_roc)
    exp_dict["tpr"] = float(tpr_intrp(sel_thr)) # recall

    baseline_df = baseline_df.append(exp_dict, ignore_index=True)

IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

In [None]:
baseline_df

Unnamed: 0,model,acc,auc,f1,mcc,tnr,tpr
0,CNN_shuffle_class,0.780295,0.871268,0.712325,0.557802,0.907503,0.618334
1,CNN_peptide_class,0.713439,0.807924,0.624438,0.42716,0.872852,0.52293
2,CNN_allele_class,0.639004,0.725724,0.332978,0.221437,0.932223,0.21913
3,MLP_shuffle_class,0.818141,0.892664,0.793863,0.631333,0.79577,0.79577
4,MLP_peptide_class,0.784439,0.855406,0.747912,0.553975,0.745867,0.745867
5,MLP_allele_class,0.567818,0.459883,0.255109,0.042437,0.195628,0.195628
6,MHCFlurry_shuffle,0.736301,0.735547,0.723246,0.471687,0.713289,0.757804
7,MHCFlurry_peptide,0.690395,0.688556,0.663286,0.376789,0.66712,0.709991
8,MHCFlurry_allele,0.643538,0.606256,0.477279,0.235476,0.395371,0.817141
9,GNN_shuffle_class_0.58,0.756039,0.829692,0.708284,0.501701,,0.672789


In [None]:
######## Compare
cl_type = 'allele'
metrics = ['auc', 'f1', 'mcc', 'tpr', 'acc']
########
models = [
    f'CNN_{cl_type}_class', 
    # f'CNN_{cl_type}_reg',
    f'MLP_{cl_type}_class',
    f'MHCFlurry_{cl_type}',
    f'GNN_{cl_type}_class_0.43']
fig = go.Figure()

for model in models:
    idx = baseline_df.index[baseline_df['model'] == model].tolist()[0]
    fig.add_trace(go.Bar(
        x = metrics,
        y = [baseline_df.loc[idx][metric] for metric in metrics],
        name = ''.join(model.split(f"_{cl_type}")),
        legendgroup = model
    ))

fig.update_yaxes(title_text="Value")
fig.update_layout(
    barmode='group',
    title=f'Experiments type: {cl_type}',
    title_x=0.5,
    width=1100, height=600)
fig.write_html(os.path.join(comparison_path, f'{cl_type}.html'))

In [None]:
#### Poster
cnn = [0.871268, 0.807924, 0.725724]
gnn = [0.8585, 0.8404, 0.6673]
mhcflurry = [0.735547, 0.688556, 0.606256]
mlp = [0.892664, 0.855406, 0.459883]
x_axis = ['Shuffled', 'Peptide-clustered', 'Allele-clustered']

models = {
    '3D-CNN': [0.871268, 0.807924, 0.725724],
    '3D-GNN': [0.8585, 0.8404, 0.6673],
    'Re-trained MHCFlurry2.0': [0.735547, 0.688556, 0.606256],
    'Seq-based NN': [0.892664, 0.855406, 0.459883]}
fig = go.Figure()

for key, value in models.items():
    fig.add_trace(go.Bar(
        x = x_axis,
        y = value,
        name = key,
        legendgroup = key,
        # text = key
    ))

fig.update_yaxes(title_text="AUC", tickfont_size=15)
fig.update_xaxes(title_text="Dataset", tickfont_size=15)
fig.update_layout(
    barmode='group',
    title_x=0.5,
    width=900, height=500,
    showlegend = True,
    font=dict(
        size=16,
        color="#421A48"
    )
    )
# fig.write_html(os.path.join(comparison_path, f'{cl_type}.html'))
fig.write_image("plot1.svg")
fig.show()