In [None]:
import pandas as pd
import plotly.express as px
import svgutils.transform as sg

import sys
sys.path.append('../')
import plotting

In [None]:
dataset_order = ["GCall", "GCfix", "Koch_et_al", "Erlich_et_al", "Song_et_al", "Choi_et_al", "Gao_et_al"]

# fix _et_al names
def fix_et_al(df):
    return df.rename(columns={
        "Koch_et_al": "Koch et al.",
        "Erlich_et_al": "Erlich et al.",
        "Song_et_al": "Song et al.",
        "Choi_et_al": "Choi et al.",
        "Gao_et_al": "Gao et al.",
    }, index={
        "Koch_et_al": "Koch et al.",
        "Erlich_et_al": "Erlich et al.",
        "Song_et_al": "Song et al.",
        "Choi_et_al": "Choi et al.",
        "Gao_et_al": "Gao et al.",
    })

def plot(data, title, colorscale, zrange, dtick=0.1):
    data = data.copy()[dataset_order].reindex(dataset_order)
    data = fix_et_al(data)
    fig = px.imshow(
        data, 
        color_continuous_scale=colorscale,
        zmin=zrange[0], 
        zmax=zrange[1],
        text_auto=".2f",
    )
    fig.update_layout(coloraxis_colorbar=dict(
            title=title,
            title_font_size=28/3,
            title_font_family='Inter',
            title_side='top',
            tickfont_size=28/3,
            tickfont_family='Inter',
            lenmode="pixels", 
            len=180,
            thicknessmode="pixels",
            thickness=10,
            orientation="h",
            dtick=dtick,
        ),
        margin=dict(l=0, r=0, t=0, b=0),
        width=350,
        height=350
    )
    fig = plotting.standardize_plot(fig)
    # fig.show()
    return fig

# Baseline Models at threshold 2perc


In [None]:
for modelname in ['1DCNN_without_PE', 'RNN', 'LR']:
    # get data
    performance = pd.read_csv('../data/machine_learning_results/Baseline_{}_performance_2perc.csv'.format(modelname), index_col=[0, 1]).sort_index()

    # extract data
    perf_auroc = performance['AUROC'].unstack(level=1)
    perf_auprc = performance['AUPRC'].unstack(level=1)

    # plot and save
    fig_auroc = plot(perf_auroc, "AUROC", ["white", "#3182bd"], [0.5, 1])
    fig_auroc.write_image(f"./SI_figure_models/{modelname}_auroc_heatmap.svg")
    fig_auprc = plot(perf_auprc, "AUPRC", ["white", "#de2d26"], [0, 0.6])
    fig_auprc.write_image(f"./SI_figure_models/{modelname}_auprc_heatmap.svg")


# create arrangement in a grid
fig = sg.SVGFigure("680px", "900px")
fig.set_size(("680px", "900px"))

for i, modelname in enumerate(reversed(['1DCNN_without_PE', 'RNN', 'LR'])):
    panel = sg.fromfile(f'./SI_figure_models/{modelname}_auroc_heatmap.svg').getroot()
    panel.moveto(50, (2-i)*270+0)
    fig.append(panel)
    panel = sg.fromfile(f'./SI_figure_models/{modelname}_auprc_heatmap.svg').getroot()
    panel.moveto(375, (2-i)*270+0)
    fig.append(panel)

fig.save('./SI_figure_models/baselines_composed.svg')

from IPython.display import SVG, display
display(SVG(filename='./SI_figure_models/baselines_composed.svg'))

# 1D CNN models for all thresholds

In [None]:
performance = pd.read_csv('../data/machine_learning_results/1DCNN_performance_all_threshold.csv', index_col=0)
performance["Normalized AUPRC"] = performance["AUPRC"] / performance["Prevalence"]

agg_df = (
    performance.groupby(["source", "target", "threshold"])
    .mean()
    .reset_index()
)

agg_df

In [None]:
for threshold in performance["threshold"].unique():
    # pivot table for the current threshold, using the aggregated data for AUROC
    pivot_df = agg_df[agg_df["threshold"] == threshold].pivot(
        index="source", columns="target", values="AUROC"
    )
    fig_auroc = plot(pivot_df, "AUROC", ["white", "#3182bd"], [0.5, 1])
    fig_auroc.write_image(f"./SI_figure_models/{threshold}_auroc_heatmap.svg")
    
    # pivot table for the current threshold, using the aggregated data for AUPRC
    pivot_df = agg_df[agg_df["threshold"] == threshold].pivot(
        index="source", columns="target", values="Normalized AUPRC"
    )
    fig_auprc = plot(pivot_df, "Normalized AUPRC", ["white", "#de2d26"], [0, 10], dtick=2)
    fig_auprc.write_image(f"./SI_figure_models/{threshold}_auprc_heatmap.svg")


    # create arrangement in a grid
    fig = sg.SVGFigure("680px", "330px")
    fig.set_size(("680px", "330px"))

    panel = sg.fromfile(f"./SI_figure_models/{threshold}_auroc_heatmap.svg").getroot()
    panel.moveto(50, -20)
    fig.append(panel)
    panel = sg.fromfile(f"./SI_figure_models/{threshold}_auprc_heatmap.svg").getroot()
    panel.moveto(375, -20)
    fig.append(panel)

    fig.save(f'./SI_figure_models/{threshold}_composed.svg')

    from IPython.display import SVG, display
    display(SVG(filename=f'./SI_figure_models/{threshold}_composed.svg'))