In [None]:
import pandas as pd
import numpy as np
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
import plotly.express as px
import plotly.graph_objects as go

import sys
sys.path.append('../')
import plotting

In [None]:
prediction_per_data, label_per_data = pd.read_pickle('../data/machine_learning_results/GCdata_internal_validation_baseline.pkl')

prediction_per_data

# Internal validation on GC data

In [None]:
# load prediction data
prediction_per_data, label_per_data = pd.read_pickle('../data/machine_learning_results/GCdata_internal_validation.pkl')
baseline_prediction_per_data, baseline_label_per_data = pd.read_pickle('../data/machine_learning_results/GCdata_internal_validation_baseline.pkl')
for key in ['GCall', 'GCfix']:
    prediction_per_data[f'{key}_baseline'] = baseline_prediction_per_data[key]
    label_per_data[f'{key}_baseline'] = baseline_label_per_data[key]

data_names = ['GCall', 'GCfix', 'GCall_baseline', 'GCfix_baseline']
linecolors = {'GCall': '#de2d26', 'GCfix': '#3182bd', 'GCall_baseline': '#636363', 'GCfix_baseline': '#636363'}
linedashes = {'GCall': 'solid', 'GCfix': 'solid', 'GCall_baseline': 'solid', 'GCfix_baseline': 'dash'}
areacolors = {'GCall': 'rgba(222,45,38,0.2)', 'GCfix': 'rgba(49,130,189,0.2)', 'GCall_baseline': 'rgba(100,100,100,0.2)', 'GCfix_baseline': 'rgba(100,100,100,0.2)'}
positions_roc = {'GCall': [0.2, 0.5], 'GCfix': [0.6, 0.8], 'GCall_baseline': [0.8, 0.2], 'GCfix_baseline': [0.8, 0.2]}
positions_pr = {'GCall': [0.125, 0.45], 'GCfix': [0.6, 0.8], 'GCall_baseline': [0.8, 0.2], 'GCfix_baseline': [0.8, 0.2]}

# create empty figures for ROC and PR curves
fig_roc = go.Figure()
fig_pr = go.Figure()

# empty dicts to store the metrics
data_roc = {}
data_pr = {}

# go through all datasets
for data_index, filename in enumerate(data_names):
    prediction_per_fold = prediction_per_data[filename]
    label_per_fold = label_per_data[filename]
    
    tprs = []
    aucs = []
    precisions = []
    average_precisions = []
    base_recall = np.linspace(0, 1, 100)
    
    # calculate metrics per fold
    for i in range(len(prediction_per_fold)):
        fold_predictions = np.array(prediction_per_fold[i])
        fold_labels = np.array(label_per_fold[i])
        fpr, tpr, _ = roc_curve(fold_labels, fold_predictions)
        roc_auc = auc(fpr, tpr)
        tprs.append(np.interp(base_recall, fpr, tpr))
        tprs[-1][0] = 0.0
        aucs.append(roc_auc)
        
        precision, recall, _ = precision_recall_curve(fold_labels, fold_predictions)
        precisions.append(np.interp(base_recall, recall[::-1], precision[::-1]))
        average_precisions.append(average_precision_score(fold_labels, fold_predictions))
    
    # Calculate mean and standard deviation for the metrics
    mean_tpr = np.mean(tprs, axis=0)
    std_tpr = np.std(tprs, axis=0)
    mean_auc = np.mean(aucs)
    mean_precision = np.mean(precisions, axis=0)
    std_precision = np.std(precisions, axis=0)
    mean_average_precision = np.mean(average_precisions)
    
    # save metrics
    data_roc[filename] = {'fpr': base_recall, 'mean_tpr': mean_tpr, 'std_tpr': std_tpr}
    data_pr[filename] = {'recall': base_recall, 'mean_precision': mean_precision, 'std_precision': std_precision}
    
    # ROC curve with uncertainty band   
    fig_roc.add_traces([
        go.Scatter(
            x=base_recall, 
            y=mean_tpr - std_tpr, 
            line=dict(color='rgba(0,0,0,0)'),
        ),
        go.Scatter(
            x=base_recall, 
            y=mean_tpr + std_tpr,
            line=dict(color='rgba(0,0,0,0)'),
            fill='tonexty', 
            fillcolor=areacolors[filename],
        ),
        go.Scatter(
            x=base_recall, 
            y=mean_tpr, 
            line=dict(color=linecolors[filename], dash=linedashes[filename]),
        ),
    ])
    fig_roc.add_annotation(
        x=positions_roc[filename][0], 
        y=positions_roc[filename][1], 
        text=f'{filename}<br>({mean_auc:.2f})', 
        showarrow=False, 
        font_color=linecolors[filename]
    )

    # Precision-recall curve with uncertainty band
    fig_pr.add_traces([
        go.Scatter(
            x=base_recall, 
            y=mean_precision - std_precision, 
            line=dict(color='rgba(0,0,0,0)'),
        ),
        go.Scatter(
            x=base_recall, 
            y=mean_precision + std_precision,
            line=dict(color='rgba(0,0,0,0)'),
            fill='tonexty', 
            fillcolor=areacolors[filename],
        ),
        go.Scatter(
            x=base_recall, 
            y=mean_precision, 
            line=dict(color=linecolors[filename], dash=linedashes[filename]),
        ),
    ])
    fig_pr.add_annotation(
        x=positions_pr[filename][0], 
        y=positions_pr[filename][1], 
        text=f'{filename}<br>({mean_average_precision:.2f})', 
        showarrow=False, 
        font_color=linecolors[filename]
    )

# add random classifier
fig_roc.add_trace(
    go.Scatter(
        x=[-2, 2], 
        y=[-2, 2], 
        line=dict(color='gray', width=1, dash='dot'),
    )
)
fig_roc.add_annotation(
    x=0.55, 
    y=0.4, 
    text=f'Random classifier<br>(0.50)', 
    showarrow=False, 
    font_color="gray",
    yanchor="middle",
    xanchor="center",
    textangle=-45
)   
fig_roc.update_layout(
    xaxis_title='False positive rate',
    yaxis_title='True positive rate',
    showlegend=False,
    margin=dict(l=0, r=5, t=5, b=0),
    width=160,
    height=160
)
fig_roc.update_yaxes(range=[0, 1.01])   
fig_roc.update_xaxes(range=[0, 1])  
fig_roc = plotting.standardize_plot(fig_roc)
fig_roc.write_image("./figure_3_performance/roc_curve_internal.svg")
fig_roc.show()

# save data
for name, dat in data_roc.items():
    pd.DataFrame(dat).to_csv(f'./figure_3_performance/roc_curve_internal_data_{name}.csv', index=False)


# fig_pr.add_hline(
#     0.02, 
#     line_width=1,
#     opacity=1,
#     line_dash='dash', 
#     line_color='gray', 
#     annotation_text='Prevalence<br>(0.02)', 
#     annotation_position='top left',
#     annotation_font_color="gray",
# )
fig_pr.update_layout(
    xaxis_title='Recall',
    yaxis_title='Precision',
    showlegend=False,
    margin=dict(l=0, r=5, t=5, b=0),
    width=160,
    height=160
)
fig_pr.update_yaxes(range=[0, 1.01])   
fig_pr.update_xaxes(range=[0, 1])  
fig_pr = plotting.standardize_plot(fig_pr)
fig_pr.write_image("./figure_3_performance/pr_curve_internal.svg")
fig_pr.show()

# save data
for name, dat in data_pr.items():
    pd.DataFrame(dat).to_csv(f'./figure_3_performance/pr_curve_internal_data_{name}.csv', index=False)

# External validation on GC data

In [None]:
# load prediction data
prediction_per_data, label_per_data = pd.read_pickle('../data/machine_learning_results/GCdata_external_validation.pkl')
baseline_prediction_per_data, baseline_label_per_data = pd.read_pickle('../data/machine_learning_results/GCdata_external_validation_baseline.pkl')
for key in ['GCall -> GCfix', 'GCfix -> GCall']:
    prediction_per_data[f'{key}_baseline'] = baseline_prediction_per_data[key]
    label_per_data[f'{key}_baseline'] = baseline_label_per_data[key]

pair_names = ['GCall -> GCfix', 'GCfix -> GCall', 'GCall -> GCfix_baseline', 'GCfix -> GCall_baseline']
linecolors = {'GCall -> GCfix': '#de2d26', 'GCfix -> GCall': '#3182bd', 'GCall -> GCfix_baseline': '#636363', 'GCfix -> GCall_baseline': '#636363'}
linedashes = {'GCall -> GCfix': 'solid', 'GCfix -> GCall': 'solid', 'GCall -> GCfix_baseline': 'solid', 'GCfix -> GCall_baseline': 'dash'}
areacolors = {'GCall -> GCfix': 'rgba(222,45,38,0.2)', 'GCfix -> GCall': 'rgba(49,130,189,0.2)', 'GCall -> GCfix_baseline': 'rgba(100,100,100,0.2)', 'GCfix -> GCall_baseline': 'rgba(100,100,100,0.2)'}
positions_roc = {'GCall -> GCfix': [0.375, 0.65], 'GCfix -> GCall': [0.35, 0.925], 'GCall -> GCfix_baseline': [0.8, 0.2], 'GCfix -> GCall_baseline': [0.8, 0.2]}
positions_pr = {'GCall -> GCfix': [0.5, 0.8], 'GCfix -> GCall': [0.7, 0.4], 'GCall -> GCfix_baseline': [0.8, 0.2], 'GCfix -> GCall_baseline': [0.8, 0.2]}

# create empty figures for ROC and PR curves
fig_roc = go.Figure()
fig_pr = go.Figure()

# empty dicts to store the metrics
data_roc = {}
data_pr = {}

# go trough all datasets
for pair_index, pair_name in enumerate(pair_names):
    flat_predictions = prediction_per_data[pair_name]
    flat_labels = label_per_data[pair_name]

    # compute metrics
    fpr, tpr, _ = roc_curve(flat_labels, flat_predictions)
    roc_auc = auc(fpr, tpr)
    precision, recall, _ = precision_recall_curve(flat_labels, flat_predictions)
    average_precision = average_precision_score(flat_labels, flat_predictions)

    # save metrics
    data_roc[pair_name] = {'fpr': fpr, 'tpr': tpr}
    data_pr[pair_name] = {'precision': precision, 'recall': recall}

    # ROC curve 
    fig_roc.add_traces([
        go.Scatter(
            x=fpr, 
            y=tpr, 
            line=dict(color=linecolors[pair_name], dash=linedashes[pair_name]),
        ),
    ])
    fig_roc.add_annotation(
        x=positions_roc[pair_name][0], 
        y=positions_roc[pair_name][1], 
        text=f'{pair_name}<br>({roc_auc:.2f})', 
        showarrow=False, 
        font_color=linecolors[pair_name]
    )

    # Precision-recall curve
    fig_pr.add_traces([
        go.Scatter(
            x=recall, 
            y=precision, 
            line=dict(color=linecolors[pair_name], dash=linedashes[pair_name]),
        ),
    ])
    fig_pr.add_annotation(
        x=positions_pr[pair_name][0], 
        y=positions_pr[pair_name][1], 
        text=f'{pair_name}<br>({average_precision:.2f})', 
        showarrow=False, 
        font_color=linecolors[pair_name]
    )

# add random classifier
fig_roc.add_trace(
    go.Scatter(
        x=[-2, 2], 
        y=[-2, 2],  
        line=dict(color='gray', width=1, dash='dot'),
    )
)
fig_roc.add_annotation(
    x=0.55, 
    y=0.4, 
    text=f'Random classifier<br>(0.50)', 
    showarrow=False, 
    font_color="gray",
    yanchor="middle",
    xanchor="center",
    textangle=-45
)   
fig_roc.update_layout(
    xaxis_title='False positive rate',
    yaxis_title='True positive rate',
    showlegend=False,
    margin=dict(l=0, r=5, t=5, b=0),
    width=160,
    height=160
)
fig_roc.update_yaxes(range=[0, 1.01])   
fig_roc.update_xaxes(range=[0, 1])  
fig_roc = plotting.standardize_plot(fig_roc)
fig_roc.write_image("./figure_3_performance/roc_curve_external.svg")
fig_roc.show()

# save data
for name, dat in data_roc.items():
    pd.DataFrame(dat).to_csv(f'./figure_3_performance/roc_curve_external_data_{name.replace(" -> ", "_")}.csv', index=False)


# fig_pr.add_hline(
#     0.02, 
#     line_width=1,
#     opacity=1,
#     line_dash='dash', 
#     line_color='gray', 
#     annotation_text='Prevalence<br>(0.02)', 
#     annotation_position='top left',
#     annotation_font_color="gray",
# )
fig_pr.update_layout(
    xaxis_title='Recall',
    yaxis_title='Precision',
    showlegend=False,
    margin=dict(l=0, r=5, t=5, b=0),
    width=160,
    height=160
)
fig_pr.update_yaxes(range=[0, 1.01])
fig_pr.update_xaxes(range=[0, 1])
fig_pr = plotting.standardize_plot(fig_pr)
fig_pr.write_image("./figure_3_performance/pr_curve_external.svg")
fig_pr.show()

# save data
for name, dat in data_pr.items():
    pd.DataFrame(dat).to_csv(f'./figure_3_performance/pr_curve_external_data_{name.replace(" -> ", "_")}.csv', index=False)

# Performance across all datasets

In [None]:
perf_auprc, perf_auroc = pd.read_pickle('../data/machine_learning_results/1DCNN_performance.pkl')
dataset_order = ["GCall", "GCfix", "Koch_et_al", "Erlich_et_al", "Song_et_al", "Choi_et_al", "Gao_et_al"]

# read and sort data
perf_auprc = perf_auprc[dataset_order].reindex(dataset_order)
perf_auroc = perf_auroc[dataset_order].reindex(dataset_order)

# fix _et_al names
def fix_et_al(df):
    return df.rename(columns={
        "Koch_et_al": "Koch et al.",
        "Erlich_et_al": "Erlich et al.",
        "Song_et_al": "Song et al.",
        "Choi_et_al": "Choi et al.",
        "Gao_et_al": "Gao et al.",
    }, index={
        "Koch_et_al": "Koch et al.",
        "Erlich_et_al": "Erlich et al.",
        "Song_et_al": "Song et al.",
        "Choi_et_al": "Choi et al.",
        "Gao_et_al": "Gao et al.",
    })

perf_auprc = fix_et_al(perf_auprc)
perf_auroc = fix_et_al(perf_auroc)


# plot AUPRC
fig_auprc = px.imshow(
    perf_auprc, 
    color_continuous_scale=["white", "#de2d26"],
    zmin=0, 
    zmax=0.6,
    text_auto=".2f",
)
fig_auprc.update_layout(coloraxis_colorbar=dict(
        title='AUPRC',
        title_font_size=28/3,
        title_font_family='Inter',
        title_side='top',
        tickfont_size=20/3,
        tickfont_family='Inter',
        lenmode="pixels", 
        len=180,
        thicknessmode="pixels",
        thickness=10,
        orientation="h",
        dtick=0.1,
    ),
    margin=dict(l=0, r=0, t=0, b=0),
    width=250,
    height=300
)
fig_auprc = plotting.standardize_plot(fig_auprc)
fig_auprc.write_image("./figure_5_generalization/auprc_external.svg")
fig_auprc.show()

# save data
perf_auroc.to_csv("./figure_5_generalization/auroc_external.csv")



# plot AUROC
fig_auroc = px.imshow(
    perf_auroc, 
    color_continuous_scale=["white", "#3182bd"],
    zmin=0.5, 
    zmax=1.0,
    text_auto=".2f",
)
fig_auroc.update_layout(coloraxis_colorbar=dict(
        title='AUROC',
        title_font_size=28/3,
        title_font_family='Inter',
        title_side='top',
        tickfont_size=20/3,
        tickfont_family='Inter',
        lenmode="pixels", 
        len=180,
        thicknessmode="pixels",
        thickness=10,
        orientation="h",
        dtick=0.1,
    ),
    margin=dict(l=0, r=0, t=0, b=0),
    width=250,
    height=300
)
fig_auroc.update_layout(yaxis={'side': 'right'})
fig_auroc = plotting.standardize_plot(fig_auroc)
fig_auroc.write_image("./figure_5_generalization/auroc_external.svg")
fig_auroc.show()

# save data
perf_auprc.to_csv("./figure_5_generalization/auprc_external.csv")