In [1]:
import os
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
from plotly.subplots import make_subplots

In [2]:
exp_id = 'exp10'
epoch = 10
###################
exp_path = os.path.join('experiments', exp_id)
metrics_path = os.path.join(exp_path, 'metrics')
img_path = os.path.join(exp_path, 'images')
if not os.path.isdir(img_path):
    os.makedirs(img_path)

df_summ = pd.read_hdf(os.path.join(metrics_path, 'summary_data.hdf5'), 'summary')
metrics = pd.read_hdf(os.path.join(metrics_path, 'metrics.hdf5'), 'metrics')
print(df_summ.shape)
print(metrics.shape)

df = pd.merge(metrics, df_summ[['cluster', 'entry']], on="entry", how="outer")
df.sort_values(by=['epoch'], inplace = True)
print(df.shape)
print(df.cluster.notna().all())
df.head()

(50, 4)
(5599, 6)
(5599, 7)
True


Unnamed: 0,phase,epoch,entry,output,target,loss,cluster
0,training,0.0,residue-ppi-BA_55272.BL00080001:M-P,"[0.5181684494018555, 0.4818316102027893]",0.0,4.341669,4.0
2169,validation,0.0,residue-ppi-BA_55308.BL00010001:M-P,"[0.5181684494018555, 0.4818316102027893]",0.0,4.250819,4.0
2170,validation,0.0,residue-ppi-BA_55308.BL00010001:M-P,"[0.5181684494018555, 0.4818316102027893]",0.0,4.250819,4.0
2171,validation,0.0,residue-ppi-BA_55308.BL00010001:M-P,"[0.5181684494018555, 0.4818316102027893]",0.0,4.250819,4.0
2292,testing,0.0,residue-ppi-BA_55308.BL00010001:M-P,"[0.8814681768417358, 0.11853183805942535]",0.0,0.354167,4.0


## Loss vs epochs

In [3]:
fig = px.line(
    df[(df.phase =='training') | (df.phase =='validation')],
    x='epoch',
    y='loss',
    color='phase',
    markers=True)

fig.update_layout(
    xaxis_title='Epoch #',
    yaxis_title='Loss',
    width=800, height=500,
    title='Loss vs epochs',
    title_x=0.5,
    margin=go.layout.Margin(
        l=50,
        r=50,
        b=50,
        t=50,
        pad=4),
    legend=dict(
        yanchor="top",
        #y=0.99
        xanchor="right",
        x=0.99
        )
)

fig.add_vline(x=epoch, line_width=3, line_dash="dash", line_color="green")
fig.show()
fig.write_html(os.path.join(img_path, 'loss_epoch.html'))

## Binary classification metrics

### Target and score distributions

In [4]:
df_plot = df[(df.epoch == epoch) | ((df.epoch == 0) & (df.phase == 'testing'))]
print(df_plot.shape)
y_true = df_plot.target
y_score = np.array(df_plot.output.values.tolist())[:, 1]

(629, 7)


In [5]:
# The histogram of scores compared to true labels
fig = px.histogram(
    df_plot,
    x=y_score,
    color=y_true,
    nbins=20,
    facet_row='phase', 
    labels=dict(color='True Labels', x='Score'),
    category_orders={'phase': [
        'training',
        "validation",
        "testing"]}
)
fig.update_layout(
    width=800, height=500,
    showlegend=True,
    title='Target and scores',
    title_x=0.5,
    margin=go.layout.Margin(
        l=50,
        r=50,
        b=50,
        t=50,
        pad=4
    ),
    legend=dict(
        yanchor="top",
        y = 1.05,
        xanchor="left",
        x=0.87
        )
)
fig.update_xaxes(range=[0, 1], constrain='domain')
fig.update_yaxes(range=[0, int(df_plot.shape[0]/3)], constrain='domain')
fig.show()
fig.write_html(os.path.join(img_path, 'target_scores.html'))

### AUC and PR curves

In [8]:
fig = make_subplots(rows=1, cols=2, subplot_titles = ['ROC Curves (AUC)', 'PR Curves (AUCPR)'], horizontal_spacing = 0.05)
colors = ["darkcyan", "coral", "cornflowerblue"]

for idx, set in enumerate(['training', 'validation', 'testing']):
    df_plot_phase = df_plot[(df_plot.phase == set)]
    y_true = df_plot_phase.target
    y_score = np.array(df_plot_phase.output.values.tolist())[:, 1]

    fpr_roc, tpr_roc, thr_roc = roc_curve(y_true, y_score)
    pr_pr, rec_pr, thr_pr = precision_recall_curve(y_true, y_score)

    name_roc = f'AUC={auc(fpr_roc, tpr_roc):.4f}'
    name_pr = f'AUCPR={average_precision_score(y_true, y_score):.4f}'

    fig.add_trace(go.Scatter(
        x=fpr_roc,
        y=tpr_roc,
        name=name_roc,
        mode='markers+lines',
        legendgroup=set,
        legendgrouptitle_text=f"{set}",
        marker_color = colors[idx]),
        row=1,
        col=1)
    fig.add_trace(go.Scatter(
        x=rec_pr,
        y=pr_pr,
        name=name_pr,
        mode='markers+lines',
        legendgroup=set,
        marker_color = colors[idx]),
        row=1,
        col=2)

fig.add_shape(
    type='line', line=dict(dash='dash'),
    x0=0, x1=1, y0=0, y1=1,
    row = 1, col = 1
)
fig.add_shape(
    type='line', line=dict(dash='dash'),
    x0=0, x1=1, y0=1, y1=0,
    row = 1, col = 2 
)
fig.update_layout(
    width=900, height=400,
    margin=go.layout.Margin(
        l=50,
        r=50,
        b=50,
        t=50,
        pad=4
    ),
    legend=dict(
        yanchor="top",
        y=1.05,
        xanchor="left",
        x=0.95
        ))
fig.update_xaxes(title_text="FPR", constrain='domain', scaleratio = 1, row=1, col=1)
fig.update_yaxes(title_text="TPR (Recall)", constrain='domain', scaleanchor = "x", scaleratio = 1, row=1, col=1)
fig.update_xaxes(title_text="Recall", constrain='domain', scaleanchor = "y", scaleratio = 1, row=1, col=2)
fig.update_yaxes(title_text="Precision", constrain='domain', scaleratio = 1, row=1, col=2)
fig.write_html(os.path.join(img_path, 'auc_aucpr.html'))
fig.show()