In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
from Bio import SeqIO
import sklearn.metrics

import sys
sys.path.append('..')
import model
import plotting

# Define function for data reading

In [None]:
exp_data = {
    'PCR1': 15,
    'PCR2': 30,
    'PCR3': 45,
    'PCR4': 60,
    'PCR5': 75,
    'PCR6': 90,
}

def compile_dataframe(folder, merge_seqid=True, keep_all_sequences=False):

    # read experimental data
    data_df = model.read_csvs(folder, keep_all_sequences=keep_all_sequences)

    # read reference counts and efficiencies
    df_ref_eff = pd.read_csv(f"{folder}eff_reference.csv", index_col=False, header=None, names=['seq_id', 'efficiency'], dtype={'seq_id': str}).set_index('seq_id')
    df_ref_x0 = pd.read_csv(f"{folder}initial_pool.csv", index_col=False, header=None, names=['seq_id', 'initial_count'], dtype={'seq_id': str}).set_index('seq_id')
    df_ref = pd.merge(df_ref_eff, df_ref_x0, left_index=True, right_index=True)

    if merge_seqid:
        # read reference sequences
        seq_ref = {}
        with open(f"{folder}../design_files.fasta") as handle:
            for record in SeqIO.parse(handle, "fasta"):
                seq_ref[str(record.seq)] = str(record.id)
        seq_ref = pd.DataFrame.from_dict(seq_ref, orient='index', columns=['seq_id'])

        # merge on sequence to get seq_id into df
        df_ref = pd.merge(df_ref, seq_ref, left_index=True, right_index=True)
        df_ref.set_index("seq_id", inplace=True)

    # normalize reference counts and efficiency for comparison
    df_ref['eff_ref'] = df_ref.efficiency/(1+np.mean(df_ref.efficiency-1))
    df_ref['x0_ref'] = df_ref.initial_count/df_ref.initial_count.mean()

    # fit modelling parameters
    df_param = model.fit_parameters(data_df, exp_data)
    df_param.to_csv(f'{folder}params.csv')

    # merge experimental data, reference data, and modelling parameters
    merged_df = df_param.join(df_ref)
    merged_df = merged_df.join(data_df)

    # calculate error metrics
    merged_df['x_error'] = np.abs(merged_df.x0_ref - merged_df.x0)
    merged_df['eff_error'] = np.abs(merged_df.eff_ref - merged_df.eff)
    merged_df['PCR_x_mean'] = merged_df[['PCR1', 'PCR2', 'PCR3', 'PCR4', 'PCR5', 'PCR6']].mean(axis=1)
    merged_df['PCR_x_nzeros'] = (merged_df[['PCR1', 'PCR2', 'PCR3', 'PCR4', 'PCR5', 'PCR6']] == 0).astype(int).sum(axis=1).astype(str)
    
    return merged_df

# Plotting functions

In [None]:
def groundtruth_plot(df, x, y, hover_data, color=None, colormap=None):

    value_min_x = df[x].min() - 0.3*df[x].std()
    value_min_y = df[y].min() - 0.3*df[y].std()

    value_max_x = df[x].max() + 0.3*df[x].std()
    value_max_y = df[y].max() + 0.3*df[y].std()

    if x == "x0_ref":
        value_min_x = 0

    min_val = min(value_min_x, value_min_y)
    max_val = max(value_max_x, value_max_y)

    fig = px.scatter(
        df, 
        x=x, 
        y=y,
        hover_data=hover_data,
        color=color,
        color_discrete_map=colormap,
    )
    fig.add_scatter(x=[0, 20], y=[0, 20], mode='lines', line_color="black")
    fig.update_layout(
        width=150,
        height=150,
        margin=dict(l=0, r=10, t=5, b=0),
        yaxis_title="Model",
        xaxis_title="Ground truth",
        xaxis_range=[min_val, max_val],
        yaxis_range=[min_val, max_val],
        showlegend=False,
    )
    fig.add_annotation(x=max_val, y=min_val,
        text=f"R<sup>2</sup> = {sklearn.metrics.r2_score(df[x], df[y]):0.2f}",
        showarrow=False,
        xshift=-25,
        yshift=10,
    )

    if not color:
        fig.update_traces(
            marker=dict(color='#3182bd'),
            selector=dict(mode='markers')
        )
    fig.update_traces(
        marker=dict(size=3),
        selector=dict(mode='markers')
    )
    fig = plotting.standardize_plot(fig)
    return fig


def histogram_plot(df, x):

    fig = px.histogram(
        df,
        x=x,               
    )
    fig.update_layout(
        width=150,
        height=150,
        margin=dict(l=0, r=10, t=5, b=0),
        yaxis_title="Count",
        xaxis_title="Ground truth",
        showlegend=False,
    )
    if x == "x0_ref":
        bindict = dict(start=0, end=10, size=0.05)
        fig.update_layout(xaxis_range=[0, 4])
    else:
        bindict = dict(start=0, end=2, size=0.0025)
        fig.update_layout(xaxis_range=[0.90, 1.05])

    fig.update_traces(
        marker_color="#acacac",
        marker_line_width=0,
        marker_line_color="white",
        xbins=bindict,
    )
    fig = plotting.standardize_plot(fig)
    return fig



def compare_against_groundtruth(df, folder):
    all_data = ["x0_ref", "x0", "eff_ref", "eff", "PCR1", "PCR2", "PCR3", "PCR4", "PCR5", "PCR6"]

    fig = histogram_plot(
        df,
        "x0_ref",
    )
    fig = plotting.standardize_plot(fig)
    fig.write_image(f"{folder}initial_abundance_histogram.svg")
    # fig.show()

    fig = histogram_plot(
        df,
        "eff_ref",
    )
    fig = plotting.standardize_plot(fig)
    fig.write_image(f"{folder}initial_efficiency_histogram.svg")
    # fig.show()

    fig = groundtruth_plot(
        df, 
        "x0_ref", 
        "x0", 
        all_data, 
        color="PCR_x_nzeros",
        colormap={"0": "#31a354", "1": "#3182bd", "2": "#e6550d", "3": "#de2d26", "4": "#de2d26", "5": "#de2d26"},
    )
    fig = plotting.standardize_plot(fig)
    fig.write_image(f"{folder}xy_initial_abundance_by_zerocount.svg")
    # fig.show()
    
    fig = groundtruth_plot(
        df, 
        "eff_ref", 
        "eff", 
        all_data, 
        color="PCR_x_nzeros",
        colormap={"0": "#31a354", "1": "#3182bd", "2": "#e6550d", "3": "#de2d26", "4": "#de2d26", "5": "#de2d26"},
    )
    fig = plotting.standardize_plot(fig)
    fig.write_image(f"{folder}xy_efficiency_by_zerocount.svg")
    # fig.show()

# Basic verification - Gaussian with outliers

In [None]:
folder = "../data/model_verification/basic/gaussian_outlier/"
df = compile_dataframe(folder, merge_seqid=False)
df.to_csv("./SI_figure_model_verification/basic/gaussian_outlier/data.csv")
compare_against_groundtruth(df, "./SI_figure_model_verification/basic/gaussian_outlier/")

# Basic verification - Lognormal

In [None]:
folder = "../data/model_verification/basic/lognormal/"
df = compile_dataframe(folder, merge_seqid=False)
df.to_csv("./SI_figure_model_verification/basic/lognormal/data.csv")
compare_against_groundtruth(df, "./SI_figure_model_verification/basic/lognormal/")

# Basic verification - Uniform

In [None]:
folder = "../data/model_verification/basic/uniform/"
df = compile_dataframe(folder, merge_seqid=False)
df.to_csv("./SI_figure_model_verification/basic/uniform/data.csv")
compare_against_groundtruth(df, "./SI_figure_model_verification/basic/uniform/")

# DT4DDS verification - Gaussian with outliers

In [None]:
folder = "../data/model_verification/dt4dds/gaussian_outlier/"
df = compile_dataframe(folder)
df.to_csv("./SI_figure_model_verification/dt4dds/gaussian_outlier/data.csv")
compare_against_groundtruth(df, "./SI_figure_model_verification/dt4dds/gaussian_outlier/")

# DT4DDS verification - Lognormal

In [None]:
folder = "../data/model_verification/dt4dds/lognormal/"
df = compile_dataframe(folder)
df.to_csv("./SI_figure_model_verification/dt4dds/lognormal/data.csv")
compare_against_groundtruth(df, "./SI_figure_model_verification/dt4dds/lognormal/")

# DT4DDS verification - Uniform

In [None]:
folder = "../data/model_verification/dt4dds/uniform/"
df = compile_dataframe(folder)
df.to_csv("./SI_figure_model_verification/dt4dds/uniform/data.csv")
compare_against_groundtruth(df, "./SI_figure_model_verification/dt4dds/uniform/")

# Composing

In [None]:
import svgutils.transform as sg

total_width = 680
panel_width = 150
total_whitespace = total_width - 4*panel_width
x0_offset = 25
y0_offset = 35
panel_spacing = (total_whitespace-x0_offset)/3
row_offset = 135




fig = sg.SVGFigure("680px", "875px")

for i, group in enumerate(["basic", "dt4dds"]):

    for j, dist in enumerate(["gaussian_outlier", "lognormal", "uniform"]):

        panel = sg.fromfile(f'./SI_figure_model_verification/{group}/{dist}/initial_efficiency_histogram.svg').getroot()
        panel.moveto(1*(panel_width + panel_spacing)+x0_offset, (i*3+j)*row_offset+y0_offset)
        fig.append(panel)

        panel = sg.fromfile(f'./SI_figure_model_verification/{group}/{dist}/initial_abundance_histogram.svg').getroot()
        panel.moveto(0*(panel_width + panel_spacing)+x0_offset+30, (i*3+j)*row_offset+y0_offset)
        fig.append(panel)

        panel = sg.fromfile(f'./SI_figure_model_verification/{group}/{dist}/xy_efficiency_by_zerocount.svg').getroot()
        panel.moveto(3*(panel_width + panel_spacing)+x0_offset, (i*3+j)*row_offset+y0_offset)
        fig.append(panel)

        panel = sg.fromfile(f'./SI_figure_model_verification/{group}/{dist}/xy_initial_abundance_by_zerocount.svg').getroot()
        panel.moveto(2*(panel_width + panel_spacing)+x0_offset+30, (i*3+j)*row_offset+y0_offset)
        fig.append(panel)


txt = sg.TextElement(x0_offset +  1*(panel_width + panel_spacing)-50, 10, "True distribution of", size=28/3, weight="bold", font="Inter", anchor="middle")
fig.append([txt])
txt = sg.TextElement(x0_offset +  1*(panel_width + panel_spacing)-50, 22, "initial abundance", size=28/3, weight="bold", font="Inter", anchor="middle")
fig.append([txt])
txt = sg.TextElement(x0_offset +  2*(panel_width + panel_spacing)-70, 10, "True distribution of", size=28/3, weight="bold", font="Inter", anchor="middle")
fig.append([txt])
txt = sg.TextElement(x0_offset +  2*(panel_width + panel_spacing)-70, 22, "amplification efficiency", size=28/3, weight="bold", font="Inter", anchor="middle")
fig.append([txt])

txt = sg.TextElement(x0_offset +  3*(panel_width + panel_spacing)-60, 10, "Model vs. ground truth for", size=28/3, weight="bold", font="Inter", anchor="middle")
fig.append([txt])
txt = sg.TextElement(x0_offset +  3*(panel_width + panel_spacing)-60, 22, "initial abundance", size=28/3, weight="bold", font="Inter", anchor="middle")
fig.append([txt])
txt = sg.TextElement(x0_offset +  4*(panel_width + panel_spacing)-80, 10, "Model vs. ground truth for", size=28/3, weight="bold", font="Inter", anchor="middle")
fig.append([txt])
txt = sg.TextElement(x0_offset +  4*(panel_width + panel_spacing)-80, 22, "amplification efficiency", size=28/3, weight="bold", font="Inter", anchor="middle")
fig.append([txt])

txt = sg.TextElement(10, y0_offset + 1*(row_offset)-80, "Basic simulation", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 10, y0_offset + 1*(row_offset)-80)
fig.append([txt])
txt = sg.TextElement(22, y0_offset + 1*(row_offset)-80, "with gaussian distribution", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 22, y0_offset + 1*(row_offset)-80)
fig.append([txt])

txt = sg.TextElement(10, y0_offset + 2*(row_offset)-80, "Basic simulation", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 10, y0_offset + 2*(row_offset)-80)
fig.append([txt])
txt = sg.TextElement(22, y0_offset + 2*(row_offset)-80, "with lognormal distribution", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 22, y0_offset + 2*(row_offset)-80)
fig.append([txt])

txt = sg.TextElement(10, y0_offset + 3*(row_offset)-80, "Basic simulation", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 10, y0_offset + 3*(row_offset)-80)
fig.append([txt])
txt = sg.TextElement(22, y0_offset + 3*(row_offset)-80, "with uniform distribution", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 22, y0_offset + 3*(row_offset)-80)
fig.append([txt])

txt = sg.TextElement(10, y0_offset + 4*(row_offset)-80, "DT4DDS simulation", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 10, y0_offset + 4*(row_offset)-80)
fig.append([txt])
txt = sg.TextElement(22, y0_offset + 4*(row_offset)-80, "with gaussian distribution", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 22, y0_offset + 4*(row_offset)-80)
fig.append([txt])

txt = sg.TextElement(10, y0_offset + 5*(row_offset)-80, "DT4DDS simulation", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 10, y0_offset + 5*(row_offset)-80)
fig.append([txt])
txt = sg.TextElement(22, y0_offset + 5*(row_offset)-80, "with lognormal distribution", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 22, y0_offset + 5*(row_offset)-80)
fig.append([txt])

txt = sg.TextElement(10, y0_offset + 6*(row_offset)-80, "DT4DDS simulation", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 10, y0_offset + 6*(row_offset)-80)
fig.append([txt])
txt = sg.TextElement(22, y0_offset + 6*(row_offset)-80, "with uniform distribution", size=28/3, weight="bold", font="Inter", anchor="middle")
txt.rotate(-90, 22, y0_offset + 6*(row_offset)-80)
fig.append([txt])

fig.save(f"./SI_figure_model_verification/composed.svg")
# from IPython.display import SVG, display
# display(SVG(filename=f"./figures/composed.svg"))