In [None]:
import plotly.express as px
import plotly.graph_objects as pg
import numpy as np
import sklearn.metrics
import math
import svgutils.transform as sg
import datashader as ds

import sys
sys.path.append('..')
import model
import plotting


def compile_data(folder, exp_data):

    # read experimental data
    data_df = model.read_csvs(folder)
    
    # calculate parameters
    params_df = model.fit_parameters(data_df, exp_data)

    # assess model fit
    all_df, assessment = model.assess_fit(data_df, params_df, exp_data)

    return params_df, all_df, assessment




# plots the correlation between the PCR efficiency and the initial abundance
def plot_eff_vs_x0(params_df, xrange, yrange, folder="./"):
    cvs = ds.Canvas(plot_width=100, plot_height=100, x_range=xrange, y_range=yrange)
    agg = cvs.points(params_df, x='eff', y='x0')
    zero_mask = agg.values == 0
    agg.values = np.array(agg.values, dtype=float)
    agg.values[zero_mask] = np.nan

    fig = px.imshow(
        agg, 
        origin='lower',
        color_continuous_scale=["#d9d9d9", "#000000"], 
    )
    fig.update_traces(hoverongaps=False)
    fig.add_annotation(
        x=0.925,
        y=0.95*yrange[1],
        text=f"ρ = {params_df.eff.corr(params_df['x0'], method="spearman"):.2f}",
        font_color="black",
        align='right',
        showarrow=False,
    )

    fig.update_layout(
        width=225,
        height=200,
        margin=dict(l=0, r=10, t=10, b=0),
        showlegend=False,
        coloraxis_showscale=False,
    )
    fig.update_xaxes(range=xrange, title="PCR efficiency")
    fig.update_yaxes(range=yrange, autorange=False, title="Initial abundance")
    fig = plotting.standardize_plot(fig)
    fig.write_image(f"{folder}/eff_vs_x0.svg")


# plots a histogram of the parameters
def plot_histogram(params_df, x, title, range, folder="./"):
    fig = px.histogram(
        params_df,
        x=x,
    )
    fig.update_traces(marker=dict(color='#969696'))
    fig.update_layout(
        width=225,
        height=200,
        margin=dict(l=0, r=10, t=10, b=0),
        yaxis_title="Count",
        xaxis_title=title,
        showlegend=False,
    )
    fig.update_traces(
        marker_line_width=0,
        marker_line_color="white"
    )
    fig.update_xaxes(range=range)
    fig = plotting.standardize_plot(fig)
    fig.write_image(f"{folder}/{x}_histogram.svg")


# plots the comparison between the model and the experimental data
def plot_comparison(all_df, exp_data, limit=3, folder="./"):
    for j, exp in enumerate(exp_data.keys()):
        idata = all_df.loc[all_df.exp == exp]

        cvs = ds.Canvas(plot_width=50, plot_height=50, x_range=[0, limit], y_range=[0, limit])
        agg = cvs.points(idata, x='model', y='true')
        zero_mask = agg.values == 0
        agg.values = np.array(agg.values, dtype=float)
        agg.values[zero_mask] = np.nan

        fig = px.imshow(
            agg, 
            origin='lower',
            color_continuous_scale=["#d9d9d9", "#000000"], 
        )
        fig.for_each_annotation(lambda a: a.update(text=""))
        fig.add_trace(
            pg.Scatter(
                x=np.linspace(0, 1000), 
                y=np.linspace(0, 1000),
                line_color="#ff0000",
                line_width=1,
                line_dash='dash',
            ), 
        )
        fig.add_annotation(
            x=0.22*limit,
            y=0.85*limit,
            text=f"<b>{exp}</b><br>ρ = {idata.true.corr(idata.model, method="spearman"):.2f}<br>MAE = {np.mean(np.abs(idata.true-idata.model)):.2f}<br>R<sup>2</sup> = {sklearn.metrics.r2_score(idata.true, idata.model):.2f}",
            font_color="black",
            align='right',
            showarrow=False,
        )
        fig.update_layout(
            showlegend=False,
            margin=dict(l=0, r=10, t=10, b=0),
            width=(680/4)+20,
            height=(680/4)+20,
            coloraxis_showscale=False,
        )
        fig.update_xaxes(range=[0, limit], title="Model")
        fig.update_yaxes(range=[0, limit], autorange=False, title="Ground truth")
        fig = plotting.standardize_plot(fig)
        fig.write_image(f"{folder}/comparison_{j}.svg")


# starts the compositing of plots to save some manual work later
def compose_plots(params, data, exp_data, limit=3, folder="./"):
    params = params.copy()
    params["x0"] = params["x0"].clip(0, limit)
    params["eff"] = params["eff"].clip(0.9, 1.05)
    plot_histogram(params, "x0", "Initial abundance", [0, limit], folder=folder)
    plot_histogram(params, "eff", "PCR efficiency", [0.9, 1.05], folder=folder)
    plot_eff_vs_x0(params, [0.9, 1.05], [0, limit], folder=folder)
    plot_comparison(data, exp_data, limit=limit, folder=folder)

    n_rows = math.ceil(len(exp_data) / 4)
    fig = sg.SVGFigure("680px", f"{240+n_rows*(680/4)}px")

    # x0 histogram
    panel = sg.fromfile(f'{folder}/x0_histogram.svg').getroot()
    panel.moveto(0, 10)
    fig.append(panel)

    # eff histogram
    panel = sg.fromfile(f'{folder}/eff_histogram.svg').getroot()
    panel.moveto(225+2.5, 10)
    fig.append(panel)

    # eff vs x0 plot
    panel = sg.fromfile(f'{folder}/eff_vs_x0.svg').getroot()
    panel.moveto(450+5, 10)
    fig.append(panel)

    # comparison plots
    n_rows = math.ceil(len(exp_data) / 4)
    for n_row in range(n_rows):
        for n_col in reversed(range(4)):
            j = n_row*4 + n_col
            if j >= len(exp_data):
                continue
            panel = sg.fromfile(f'{folder}/comparison_{j}.svg').getroot()
            panel.moveto(n_col*(680/4-7), 230+n_row*(680/4-11))
            fig.append(panel)


    for i, letter in enumerate(["a", "b", "c"]):
        txt = sg.TextElement(2+i*(680/3), 6, letter, size=7, weight="bold", font="Inter")
        fig.append([txt])
    txt = sg.TextElement(2, 220+6, "d", size=7, weight="bold", font="Inter")
    fig.append([txt])

    fig.save(f'{folder}/composed.svg')

# GCall

In [None]:
exp_data = {
    'PCR1': 15,
    'PCR2': 30,
    'PCR3': 45,
    'PCR4': 60,
    'PCR5': 75,
    'PCR6': 90,
}
name = "GCfix"

params, data, assessment = compile_data(f'../data/internal_datasets/{name}/', exp_data)
params.to_csv(f'../data/internal_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=3)

display(assessment)

# GCfix

In [None]:
exp_data = {
    'PCR1': 15,
    'PCR2': 30,
    'PCR3': 45,
    'PCR4': 60,
    'PCR5': 75,
    'PCR6': 90,
}
name = "GCfix"

params, data, assessment = compile_data(f'../data/internal_datasets/{name}/', exp_data)
params.to_csv(f'../data/internal_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=3)

display(assessment)

# Koch et al

In [None]:
exp_data = {
    'Bunny_M': 44,
    'Bunny_P': 59,
    'Bunny_F1': 71,
    'Bunny_F2': 83,
    'Bunny_F3': 95,
    'Bunny_F4': 107,
    'Bunny_F5': 119,
}
name = "Koch_et_al"

params, data, assessment = compile_data(f'../data/external_datasets/{name}', exp_data)
params.to_csv(f'../data/external_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=7)

display(assessment)

# Erlich et al

In [None]:
exp_data = {
    'MasterPool': 10,
    'DeepCopy': 100,
}
name = "Erlich_et_al"

params, data, assessment = compile_data(f'../data/external_datasets/{name}', exp_data)
params.to_csv(f'../data/external_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=4)

display(assessment)

# Song et al

In [None]:
exp_data = {
    'PCR1': 30,
    'PCR2': 60,
    'PCR3': 90,
    'PCR4': 120,
    'PCR5': 150,
    'PCR6': 180,
}
name = "Song_et_al"

params, data, assessment = compile_data(f'../data/external_datasets/{name}/', exp_data)
params.to_csv(f'../data/external_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=4)

display(assessment)

# Choi et al

In [None]:
exp_data = {
    'PCR1': 17,
    'PCR2': 17*20,
}
name = "Choi_et_al"

params, data, assessment = compile_data(f'../data/external_datasets/{name}/', exp_data)
params.to_csv(f'../data/external_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=6)

display(assessment)

# Gao et al

In [None]:
exp_data = {
    'PCR1': 10,
    'PCR2': 50,
    'PCR3': 100,
}
name = "Gao_et_al"

params, data, assessment = compile_data(f'../data/external_datasets/{name}/', exp_data)
params.to_csv(f'../data/external_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=7)

display(assessment)

# Validation GCall/fix

In [None]:
exp_data = {
    'PCR1': 15,
    'PCR2': 30,
    'PCR3': 45,
    'PCR4': 60,
    'PCR5': 75,
    'PCR6': 90,
}
name = "validation_GCall_fix"

params, data, assessment = compile_data(f'../data/internal_datasets/{name}/', exp_data)
params.to_csv(f'../data/internal_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=3)

display(assessment)

# Validation Erlich et al.

In [None]:
exp_data = {
    'PCR1': 15,
    'PCR2': 30,
    'PCR3': 45,
    'PCR4': 60,
    'PCR5': 75,
    'PCR6': 90,
}
name = "validation_Erlich_et_al"

params, data, assessment = compile_data(f'../data/internal_datasets/{name}/', exp_data)
params.to_csv(f'../data/internal_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=3)

display(assessment)

# Validation Erlich et al. (internal repeat)

In [None]:
exp_data = {
    'PCR1': 15,
    'PCR2': 30,
    'PCR3': 45,
    'PCR4': 60,
    'PCR5': 75,
    'PCR6': 90,
}
name = "validation_Erlich_et_al_internalrepeat"

params, data, assessment = compile_data(f'../data/internal_datasets/{name}/', exp_data)
params.to_csv(f'../data/internal_datasets/{name}/params.csv')
compose_plots(params, data, exp_data, folder=f"./SI_dataset_model_plots/{name}", limit=3)

display(assessment)