In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score

import sys
sys.path.append('..')
import plotting

# Read data with predictions

In [None]:
pred_df = pd.read_csv('../data/machine_learning_results/external-validation-predictions.csv')
pred_df

# Filter out the sequences with inserted motifs (i.e. use only the randomly generated sequences)

In [None]:
pred_df = pred_df[pred_df['has_insertedmotif'] == False].copy()
pred_df

# Assign labels from efficiency

In [None]:
def ecdf(a):
    x, counts = np.unique(a, return_counts=True)
    cusum = np.cumsum(counts)
    return x, cusum / cusum[-1]

for exp in ['Taq', 'Q5']:
    x, y = ecdf(pred_df[f'eff_{exp}'])
    threshold = x[np.argmax(y > 0.02)]
    print(exp, threshold)

    pred_df[f'eff_{exp}_low'] = 0
    pred_df.loc[pred_df[f'eff_{exp}'] < threshold, f'eff_{exp}_low'] = 1
    pred_df.loc[pred_df[f'eff_{exp}'].isna(), f'eff_{exp}_low'] = np.NAN

    display(pred_df[f'eff_{exp}_low'].value_counts())

# Plotting

In [None]:
# color definitions
linecolors = {
    'GCfix 2perc': '#bdd7e7',
    'GCall 2perc': '#3182bd',
    'Erlich_et_al 2perc': '#de2d26',
}
locations_roc = {
    'GCfix 2perc': {'Taq': [0.45, 0.65], 'Q5': [0.45, 0.85]},
    'GCall 2perc': {'Taq': [0.15, 0.35], 'Q5': [0.15, 0.65]},
    'Erlich_et_al 2perc': {'Taq': [0.75, 0.25], 'Q5': [0.75, 0.25]},
}
locations_pr = {
    'GCfix 2perc': {'Taq': [0.8, 0.4], 'Q5': [0.15, 0.45]},
    'GCall 2perc': {'Taq': [0.5, 0.6], 'Q5': [0.4, 0.25]},
    'Erlich_et_al 2perc': {'Taq': [0.22, 0.15], 'Q5': [0.8, 0.15]},
}

# empty dicts to hold data
roc_data = {}
pr_data = {}

# go through the models and experiments
for exp in ['Taq', 'Q5']:
    # empty figures
    fig_roc = go.Figure()
    fig_pr = go.Figure()
    
    for model in ['GCfix 2perc', 'GCall 2perc', 'Erlich_et_al 2perc',]:
        # filter out NaNs
        idf = pred_df[pred_df[f'eff_{exp}_low'].notna()]

        # get predictions and labels
        flat_predictions = idf[model]
        flat_labels = idf[f"eff_{exp}_low"].astype(bool)

        # compute metrics
        fpr, tpr, _ = roc_curve(flat_labels, flat_predictions)
        roc_auc = auc(fpr, tpr)
        precision, recall, _ = precision_recall_curve(flat_labels, flat_predictions)
        average_precision = average_precision_score(flat_labels, flat_predictions)
        print(model, exp, "auroc", roc_auc)
        print(model, exp, "auprc", average_precision)

        # save data
        roc_data[f"{model}_{exp}"] = {'fpr': fpr, 'tpr': tpr}
        pr_data[f"{model}_{exp}"] = {'precision': precision, 'recall': recall}

        # ROC curve 
        fig_roc.add_traces([
            go.Scatter(
                x=fpr, 
                y=tpr, 
                line=dict(
                    color=linecolors[model],
                ),
            ),
        ])
        fig_roc.add_annotation(
            x=locations_roc[model][exp][0], 
            y=locations_roc[model][exp][1], 
            text=f"{model.split(' ')[0].replace("_", " ")}<br>({roc_auc:.2f})", 
            font_color=linecolors[model],
            showarrow=False,
        )

        # Precision-recall curve
        fig_pr.add_traces([
            go.Scatter(
                x=recall, 
                y=precision, 
                line=dict(
                    color=linecolors[model],
                ),
            ),
        ])
        fig_pr.add_annotation(
            x=locations_pr[model][exp][0], 
            y=locations_pr[model][exp][1], 
            text=f"{model.split(' ')[0].replace("_", " ")}<br>({average_precision:.2f})", 
            font_color=linecolors[model],
            showarrow=False,
        )

    # add random classifier
    fig_roc.add_trace(
        go.Scatter(
            x=[-2, 2], 
            y=[-2, 2], 
            line=dict(color='gray', width=1, dash='dash'),
        )
    )
    fig_roc.update_layout(
        xaxis_title='False positive rate',
        yaxis_title='True positive rate',
        showlegend=False,
        margin=dict(l=0, r=5, t=8, b=0),
        width=160,
        height=160
    )
    fig_roc.update_yaxes(range=[0, 1.01])   
    fig_roc.update_xaxes(range=[0, 1])  
    fig_roc = plotting.standardize_plot(fig_roc)
    fig_roc.write_image(f"./figure_6_model_performance/roc_curve_{exp}.svg")
    fig_roc.show()

    # save data
    for name, dat in roc_data.items():
        pd.DataFrame(dat).to_csv(f'./figure_6_model_performance/roc_curve_{exp}_data_{name}.csv', index=False)



    fig_pr.add_hline(
        0.02, 
        line_width=1,
        opacity=1,
        line_dash='dash', 
        line_color='gray', 
    )
    fig_pr.update_layout(
        xaxis_title='Recall',
        yaxis_title='Precision',
        showlegend=False,
        margin=dict(l=0, r=5, t=8, b=0),
        width=160,
        height=160
    )
    fig_pr.update_yaxes(range=[0, 1.01])   
    fig_pr.update_xaxes(range=[0, 1])  
    fig_pr = plotting.standardize_plot(fig_pr)
    fig_pr.write_image(f"./figure_6_model_performance/pr_curve_{exp}.svg")
    fig_pr.show()

    # save data
    for name, dat in pr_data.items():
        pd.DataFrame(dat).to_csv(f'./figure_6_model_performance/pr_curve_{exp}_data_{name}.csv', index=False)