In [1]:
import h5py
import glob
import os
import sys
from pathlib import Path
import torch
import pandas as pd
import numpy as np
import logging
from deeprank2.trainer import Trainer
from deeprank2.dataset import GraphDataset
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from sklearn.metrics import (
    roc_curve,
    precision_recall_curve,
    auc,
    average_precision_score,
    precision_score,
    recall_score,
    accuracy_score,
    f1_score,
    matthews_corrcoef)

In [3]:
#################### To fill
exp_id = 'exp_100k_std_transf_bs64_naivegnn1_wloss_0'
cluster_dataset =  None # fill in only if the experiment has clusters
project_folder = '/projects/0/einf2380'
protein_class = 'I'
exp_basepath = f'{project_folder}/data/pMHC{protein_class}/trained_models/deeprank2/experiments/'
validate = True
test = True
####################
exp_df = pd.read_excel(os.path.join(exp_basepath, '_experiments_log.xlsx'), index_col='exp_id')
exp_df.head()

Unnamed: 0_level_0,exp_fullname,exp_path,start_time,end_time,input_data_path,protein_class,target_data,resolution,task,node_features,...,training_accuracy,validation_accuracy,testing_accuracy,training_precision,validation_precision,testing_precision,training_recall,validation_recall,testing_recall,features
exp_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
exp_100k_std_transf_bs64_naivegnn1_wloss_all_data_0,exp_100k_std_transf_bs64_naivegnn1_wloss_all_d...,/projects/0/einf2380/data/pMHCI/trained_models...,06/Oct/2023_17:19:14,06/Oct/2023_22:28:17,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,,,,,,,,,,
exp_100k_std_transf_bs64_naivegnn1_wloss_cl_peptide_0,exp_100k_std_transf_bs64_naivegnn1_wloss_cl_pe...,/projects/0/einf2380/data/pMHCI/trained_models...,23/Aug/2023_12:24:43,23/Aug/2023_18:42:33,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,0.797,0.779,0.755,0.729,0.721,0.69,0.856,0.812,0.786,
exp_100k_std_transf_bs64_naivegnn1_wloss_cl_peptide2_10set_0,exp_100k_std_transf_bs64_naivegnn1_wloss_cl_pe...,/projects/0/einf2380/data/pMHCI/trained_models...,23/Aug/2023_12:22:00,23/Aug/2023_18:40:56,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,0.804,0.776,0.758,0.826,0.804,0.749,0.693,0.636,0.813,
exp_100k_std_transf_bs64_naivegnn1_wloss_wdecay_cl_peptide_0,exp_100k_std_transf_bs64_naivegnn1_wloss_wdeca...,/projects/0/einf2380/data/pMHCI/trained_models...,09/Aug/2023_10:27:46,09/Aug/2023_18:31:10,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,0.778,0.776,0.766,0.745,0.748,0.731,0.756,0.741,0.727,
exp_100k_std_transf_bs64_naivegnn1_wloss_wdecay_cl_peptide2_10set_0,exp_100k_std_transf_bs64_naivegnn1_wloss_wdeca...,/projects/0/einf2380/data/pMHCI/trained_models...,08/Aug/2023_12:22:21,09/Aug/2023_00:48:04,['/projects/0/einf2380/data/pMHCI/features_out...,I,BA,residue,classif,all,...,0.777,0.77,0.749,0.764,0.781,0.761,0.7,0.651,0.763,


In [4]:
exp_fullname = exp_df.loc[exp_id].exp_fullname
exp_path = os.path.join(exp_basepath, exp_fullname)
output_path = os.path.join(exp_path, 'output')
img_path = os.path.join(exp_path, 'images')

df_summ = pd.read_hdf(os.path.join(output_path, 'summary_data.hdf5'), key='summary')
output_train = pd.read_hdf(os.path.join(output_path, 'output_exporter.hdf5'), key='training')
if test:
    output_test = pd.read_hdf(os.path.join(output_path, 'output_exporter.hdf5'), key='testing')
    df = pd.concat([output_train, output_test])
else:
    df = output_train.copy()
df.sort_values(by=['epoch'], inplace = True)
print(df.shape)

(4332983, 6)


In [5]:
if cluster_dataset:
    df = df.merge(df_summ[['entry', 'cluster']], how='left', on='entry')
    print(df.shape)
    df.head()

## Loss vs epochs

In [6]:
# loss vs epoch skipping epoch 0
fig = px.line(
    df[((df.phase =='training') | (df.phase =='validation')) & (df.epoch > 0)],
    x='epoch',
    y='loss',
    color='phase',
    markers=True)

fig.update_layout(
    xaxis_title='Epoch #',
    yaxis_title='Loss',
    width=800, height=500,
    title='Loss vs epochs',
    title_x=0.5,
    margin=go.layout.Margin(
        l=50,
        r=50,
        b=50,
        t=50,
        pad=4),
    legend=dict(
        yanchor="top",
        #y=0.99
        xanchor="right",
        x=0.99
        )
)
epoch = exp_df.loc[exp_id].saved_epoch
fig.add_vline(x=epoch, line_width=3, line_dash="dash", line_color="green")
fig.write_html(os.path.join(img_path, 'loss_epoch_1.html'))

## Binary classification metrics (for best/saved epoch)

In [7]:
df_plot = df[(df.epoch == epoch) | ((df.epoch == exp_df.loc[exp_id].saved_epoch) & (df.phase == 'testing'))]
print(df_plot.shape)
y_true = df_plot.target
y_score = np.array(df_plot.output.values.tolist())[:, 1]

(100069, 6)


### Only target distributions per target and cluster (data exploration)

In [8]:
if cluster_dataset: 
    clusters = df_plot['cluster'].unique()
    clusters.sort()
    clusters = list(clusters)
    df_plot['target_str'] = df_plot['target'].astype(str)
    # The histogram of scores compared to true labels
    fig = px.histogram(
        df_plot,
        x='target_str',
        color=df_plot['cluster'],
        facet_row='phase',
        category_orders={'phase': [
            'training',
            "validation",
            "testing"],
            'cluster': clusters}
    )
    fig.update_layout(
        width=600, height=600,
        showlegend=True,
        title='Target',
        title_x=0.5,
        margin=go.layout.Margin(
            l=50,
            r=50,
            b=50,
            t=50,
            pad=4
        )
    )
    # fig.update_yaxes(range=[0, 2700], constrain='domain')
    fig.update_layout(bargap=0.30,bargroupgap=0.0)
    fig.write_html(os.path.join(img_path, 'target_only.html'))

### Target and score distributions per target and cluster

In [9]:
if cluster_dataset:
    clusters = df_plot['cluster'].unique()
    clusters.sort()
    clusters = list(clusters)
    # The histogram of scores compared to true labels
    fig = px.histogram(
        df_plot,
        x=y_score,
        color=df_plot['cluster'],
        nbins=20,
        facet_col='target',
        facet_row='phase',
        labels=dict(color='True Labels', x='Score'),
        category_orders={'phase': [
            'training',
            "validation",
            "testing"],
            'cluster': clusters}
    )
    fig.update_layout(
        width=900, height=600,
        showlegend=True,
        title='Target and scores',
        title_x=0.5,
        margin=go.layout.Margin(
            l=50,
            r=50,
            b=50,
            t=50,
            pad=4
        ),
        legend=dict(
            yanchor="top",
            y = 1.05,
            xanchor="left",
            x=0.87
            )
    )
    fig.update_xaxes(range=[0, 1], constrain='domain')
    fig.update_yaxes(range=[0, 7000], constrain='domain')
    fig.write_html(os.path.join(img_path, 'target_scores.html'))

### AUC and PR curves

In [10]:
if validate and test:
    phases = ['training', 'validation', 'testing']
elif validate and not test:
    phases = ['training', 'validation']
else:
    phases = ['training']

In [11]:
fig = make_subplots(rows=1, cols=2, subplot_titles = ['ROC Curves (AUC)', 'PR Curves (AUCPR)'], horizontal_spacing = 0.05)
colors = ["darkcyan", "coral", "cornflowerblue"]

for idx, set in enumerate(phases):
    df_plot_phase = df_plot[(df_plot.phase == set)]
    y_true = df_plot_phase.target
    y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]

    fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)
    pr_pr, rec_pr, thr_pr = precision_recall_curve(y_true, y_score)

    name_roc = f'AUC={auc(fpr_roc, tpr_roc):.4f}'
    name_pr = f'AUCPR={average_precision_score(y_true, y_score):.4f}'

    fig.add_trace(go.Scatter(
        x=fpr_roc,
        y=tpr_roc,
        name=name_roc,
        mode='markers+lines',
        legendgroup=set,
        legendgrouptitle_text=f"{set}",
        marker_color = colors[idx]),
        row=1,
        col=1)
    fig.add_trace(go.Scatter(
        x=rec_pr,
        y=pr_pr,
        name=name_pr,
        mode='markers+lines',
        legendgroup=set,
        marker_color = colors[idx]),
        row=1,
        col=2)

fig.add_shape(
    type='line', line=dict(dash='dash'),
    x0=0, x1=1, y0=0, y1=1,
    row = 1, col = 1
)
fig.add_shape(
    type='line', line=dict(dash='dash'),
    x0=0, x1=1, y0=1, y1=0,
    row = 1, col = 2 
)
fig.update_layout(
    width=900, height=400,
    margin=go.layout.Margin(
        l=50,
        r=50,
        b=50,
        t=50,
        pad=4
    ),
    legend=dict(
        yanchor="top",
        y=1.05,
        xanchor="left",
        x=0.95
        ))
fig.update_xaxes(title_text="FPR", constrain='domain', scaleratio = 1, row=1, col=1)
fig.update_yaxes(title_text="TPR (Recall)", constrain='domain', scaleanchor = "x", scaleratio = 1, row=1, col=1)
fig.update_xaxes(title_text="Recall", constrain='domain', scaleanchor = "y", scaleratio = 1, row=1, col=2)
fig.update_yaxes(title_text="Precision", constrain='domain', scaleratio = 1, row=1, col=2)
fig.write_html(os.path.join(img_path, 'auc_aucpr.html'))

### Metrics vs thresholds curves

In [12]:
d = {'thr': [], 'precision': [], 'recall': [], 'accuracy': [], 'f1': [], 'mcc': [], 'auc': [], 'aucpr': [], 'phase': []}
thr_df = pd.DataFrame(data=d)

for idx, set in enumerate(phases):
    df_plot_phase = df_plot[(df_plot.phase == set)]
    y_true = df_plot_phase.target
    y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]

    thrs = np.linspace(0,1,100)
    precision = []
    recall = []
    accuracy = []
    f1 = []
    mcc = []
    
    for thr in thrs:
        y_pred = (y_score > thr)*1
        precision.append(precision_score(y_true, y_pred))
        recall.append(recall_score(y_true, y_pred))
        accuracy.append(accuracy_score(y_true, y_pred))
        f1.append(f1_score(y_true, y_pred))
        mcc.append(matthews_corrcoef(y_true, y_pred))
    
    fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)
    auc_score = auc(fpr_roc, tpr_roc)
    aucpr = average_precision_score(y_true, y_score)

    phase_df = pd.DataFrame({'thr': thrs, 'precision': precision, 'recall': recall, 'accuracy': accuracy, 'f1': f1, 'mcc': mcc, 'auc': auc_score, 'aucpr': aucpr, 'phase': set})
    thr_df = pd.concat([thr_df, phase_df], ignore_index=True)

# find max mcc of validation set
if validate: 
    valid_df = thr_df.loc[thr_df.phase == 'validation']
    valid_mcc_idxmax = valid_df.mcc.idxmax()
    if thr_df.loc[valid_mcc_idxmax].mcc > 0:
        sel_thr = thr_df.loc[valid_mcc_idxmax].thr
    # use max mcc of all data if max of test set is 0 (usually only on small local test experiments)
    else:
        mcc_idxmax = thr_df.mcc.idxmax()
        sel_thr = thr_df.loc[mcc_idxmax].thr
        print("WARNING: Maximum mcc of test set is 0. Instead, maximum mcc of all data will be used for determining optimal threshold.\n")
else:
    mcc_idxmax = thr_df.mcc.idxmax()
    sel_thr = thr_df.loc[mcc_idxmax].thr

fig_thresh = px.line(
    thr_df,
    x='thr',
    y=[
        'precision',
        'recall',
        'accuracy',
        'f1',
        'mcc'
    ],
    facet_col='phase',
    category_orders={'phase': phases},
    width=1000,
    height=500
)
fig_thresh.add_vline(x=sel_thr, line_width=3, line_dash="dash", line_color="green")
fig_thresh.update_layout(
    title='Metrics vs thresholds',
    title_x=0.5)
fig_thresh.update_yaxes(range=[-0.2, 1.2], scaleanchor="x", scaleratio=1, constrain='domain')
fig_thresh.update_xaxes(range=[0, 1], scaleratio = 1, constrain='domain')
fig_thresh.write_html(os.path.join(img_path, 'thresholds_metrics.html'))



Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.


Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.


Precision is ill-defined and being set to 0.0 due to no predicted samples. Use `zero_division` parameter to control this behavior.



### 0.5 Threshold

In [13]:
print('Setting thr to 0.5 ...\n')
for idx, set in enumerate(phases):
    print(f'{set} set')
    df_plot_phase = df_plot[(df_plot.phase == set)]
    y_true = df_plot_phase.target
    y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]
    thr = 0.5
    y_pred = (y_score > thr)*1
    print(f'F1: {f1_score(y_true, y_pred)}')
    print(f'MCC: {matthews_corrcoef(y_true, y_pred)}\n')

Setting thr to 0.5 ...

training set
F1: 0.7940689927304184
MCC: 0.622454519842347

validation set
F1: 0.7575757575757576
MCC: 0.5593637639536596

testing set
F1: 0.7563417890520693
MCC: 0.5582110529142474



In [14]:
sel_thr

0.5151515151515152