In [None]:
import pandas as pd
import numpy as np
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import statsmodels.formula.api as smf
import svgutils.transform as sg
import datashader as ds

import sys
sys.path.append('..')
import plotting

# collect the seqprop data and the parameters of a dataset
def collect_data(folder):
    # read in the data
    df_props = pd.read_csv(f'{folder}seqprops.csv', dtype={'id': str}).rename(columns={'id': 'seq_id'})
    print(f'Loaded properties of {df_props.shape[0]} sequences')
    df_params = pd.read_csv(f'{folder}params.csv', dtype={'seq_id': str})
    print(f'Loaded parameters of {df_params.shape[0]} sequences')
    df = pd.merge(df_props, df_params, on='seq_id')

    # detect and correct errors in dG values
    count = df[(df['dg'].isna()) | (df['dg'] == np.inf) | (df['dg'] > 10)].shape[0]
    print(f'{count}, {100*count / df.shape[0]:.2f}% of dG values are missing or invalid')
    df.loc[(df['dg'].isna()) | (df['dg'] == np.inf) | (df['dg'] > 10), 'dg'] = None
    return df


# compose the plots into a single SVG
def compose_plots(folder, property):
    fig = sg.SVGFigure("680px", "515px")
    fig.set_size(("680px", "515px"))
    total_width = 680
    panel_width = 200
    crop_panel_width = 150
    total_whitespace = total_width - panel_width - 3*crop_panel_width
    panel_spacing = total_whitespace/3
    second_row_offset = 260

    # FIRST/LAST BASE PANEL
    panel = sg.fromfile(f'{folder}/{property}_firstlast.svg').getroot()
    panel.moveto(3*crop_panel_width + 3*panel_spacing + 50, 0)
    fig.append(panel)

    # SEQ PROPERTIES PANELS
    for i, var in [(2, "hp"), (1, "dg"), (0, "GC")]:
        # place the panel
        panel = sg.fromfile(f'{folder}/{property}_{var}.svg').getroot()
        if i == 0:
            x_pos = 0
        else:
            x_pos = i*crop_panel_width + i*panel_spacing
        panel.moveto(x_pos, 0)
        fig.append(panel)
        # place the cover
        if i > 0:
            cover = sg.fromfile('./SI_figure_parameter_vs_seqprops/cover.svg').getroot()
            cover.moveto(x_pos+22, 20)
            fig.append(cover)

    # NUCLEOTIDE PANELS
    for i, base in [(3, "T"), (2, "G"), (1, "C"), (0, "A")]:
        # place the panel
        panel = sg.fromfile(f'{folder}/{property}_{base}.svg').getroot()
        if i == 0:
            x_pos = 0
        else:
            x_pos = i*crop_panel_width + i*panel_spacing
        panel.moveto(x_pos, second_row_offset)
        fig.append(panel)
        # place the cover
        if i > 0:
            cover = sg.fromfile('./SI_figure_parameter_vs_seqprops/cover.svg').getroot()
            cover.moveto(x_pos+22, second_row_offset+20)
            fig.append(cover)
        
    for i, letter in enumerate(["a", "b", "c", "d"]):
        txt = sg.TextElement(40 + i*crop_panel_width + i*panel_spacing, 10, letter, size=7, weight="bold", font="Inter")
        fig.append([txt])
    for i, letter in enumerate(["e", "f", "g", "h"]):
        txt = sg.TextElement(40 + i*crop_panel_width + i*panel_spacing, second_row_offset+10, letter, size=7, weight="bold", font="Inter")
        fig.append([txt])

    fig.save(f'{folder}/{property}_composed.svg')


# plot the discrete first and last base properties
def plot_discrete(df, property, yrange, colors, fillcolors, title):
    fig = make_subplots(rows=1, cols=2, horizontal_spacing=0.1, shared_yaxes=True)
    for i, base in enumerate(reversed(['A', 'C', 'G', 'T'])):
        
        # plot first nucleotide
        idata = df.loc[df['first'] == base, property].dropna()
        fig.add_trace(go.Violin(x=idata, line_color=colors[base], name=base, fillcolor=fillcolors[base]), row=1, col=1)
        
        # add R2 and correlation
        rsquared = (df['first'] == base).corr(df[property])**2
        corr = (df['first'] == base).corr(df[property], method='spearman')
        fig.add_annotation(
                x=yrange[0] + 0.3*(yrange[1]-yrange[0]),
                y=4-i-1.4,
                text=f"<b>{base}</b><br>R<sup>2</sup> = {rsquared:.2f}<br>ρ = {corr:.2f}",
                col=1,
                row=1,
                font_color="black",
                align='right',
                showarrow=False,
            )
        
        # plot last nucleotide
        idata = df.loc[df['last'] == base, property].dropna()
        fig.add_trace(go.Violin(x=idata, line_color=colors[base], name=base, fillcolor=fillcolors[base]), row=1, col=2)
        
        # add R2 and correlation
        rsquared = (df['last'] == base).corr(df[property])**2
        corr = (df['last'] == base).corr(df[property], method='spearman')
        fig.add_annotation(
                x=yrange[0] + 0.3*(yrange[1]-yrange[0]),
                y=4-i-1.4,
                text=f"<b>{base}</b><br>R<sup>2</sup> = {rsquared:.2f}<br>ρ = {corr:.2f}",
                col=2,
                row=1,
                font_color="black",
                align='right',
                showarrow=False,
            )
        
    # layout changes
    fig.update_layout(xaxis_showgrid=False, xaxis_zeroline=False)
    fig.update_xaxes(range=yrange, col=1, row=1, title=title)
    fig.update_xaxes(range=yrange, col=2, row=1, title=title)

    # add the line at 1
    fig.add_vline(x=1, line_width=2, line_dash="dash", line_color="black")

    # use negative
    fig.update_traces(orientation='h', side='negative', points=False, width=3, opacity=1)
    # reverse the (z)-order of the traces
    fig.data = fig.data[::-1]
    # flip the y axis (negative violin is now positive and traces on the top are now on the bottom)
    fig.update_layout(legend_traceorder='reversed')
    fig.update_yaxes(title="", range=[3, -2], autorange=False)
    fig.update_yaxes(col=1, row=1, title="Probability density", autorange=False)
    fig.update_yaxes(minor_dtick=20, dtick=20, showgrid=False, showticklabels=False)
    fig.update_layout(
        height=250,
        width=150,
        showlegend=False,
        margin=dict(l=0, r=10, t=10, b=0),
    )
    fig = plotting.standardize_plot(fig)
    return fig


# plot pseudo-continuous properties
def plot_density(df, var, property, yrange, xtitle, ytitle, histogram_step, histogram_color, density_colors, line_color):
    
    fig = make_subplots(rows=2, cols=1, shared_xaxes=True, vertical_spacing=0.003, row_heights=[0.25, 0.75])
    
    # data aggregation with datashader
    cvs = ds.Canvas(
        plot_width=int((df[var].max()-df[var].min())/histogram_step)+1, 
        plot_height=100, 
        y_range=yrange,
        x_range=[df[var].min()-histogram_step/2, df[var].max()+histogram_step/2]
    )
    agg = cvs.points(df, x=var, y=property)
    zero_mask = agg.values == 0
    agg.values = np.array(agg.values, dtype=float)
    agg.values[zero_mask] = np.nan

    # plot the density
    fig.add_trace(
        px.imshow(
        agg, 
        origin='lower',
    ).data[0], row=2, col=1)
    fig.update_xaxes(title=xtitle, row=2, col=1)
    fig.update_yaxes(range=yrange, autorange=False, title=ytitle, row=2, col=1)
    fig.update_layout(coloraxis_colorscale=density_colors)

    # plot the histogram
    fig.add_trace(
        px.histogram(
        df[var],
    ).data[0], row=1, col=1)
    fig.update_yaxes(showgrid=False, visible=False, row=1, col=1)
    fig.update_traces(
        marker_color=histogram_color, 
        marker_line_width=0, 
        marker_line_color="white", 
        xbins=dict(start=df[var].min()-histogram_step/2, end=df[var].max()+histogram_step/2, size=histogram_step),
        selector=dict(type="histogram"),
        row=1, col=1
    )

    # add the line at 1
    fig.add_hline(y=1, line_width=2, line_dash="dash", line_color="black", row=2, col=1)

    # add regression line
    model = smf.ols(f'{property} ~ {var}', data=df).fit()
    x = np.linspace(df[var].min(), df[var].max(), 100)
    y = model.params.iloc[1]*x + model.params.iloc[0]
    fig.add_trace(go.Scatter(
        x=x, 
        y=y, 
        mode='lines', 
        line_color=line_color, 
        line_width=2),
    row=2, col=1)

    # add the R2 and correlation
    rsquared = df[var].corr(df[property])**2
    corr = df[var].corr(df[property], method='spearman')
    fig.add_annotation(
        x=df[var].min()+0.85*(df[var].max()-df[var].min()),
        y=yrange[0]+0.15*(yrange[1]-yrange[0]),
        text=f"R<sup>2</sup> = {rsquared:.2f}<br>ρ = {corr:.2f}",
        font_color="black",
        align='right',
        showarrow=False,
        row=2, col=1
    )

    # layout changes
    l_margin = 0 if property == "eff" else 50
    fig.update_layout(
        height=250,
        width=200,
        showlegend=False,
        coloraxis_showscale=False,
        margin=dict(l=l_margin, r=10, t=0, b=0),
    )
    fig = plotting.standardize_plot(fig)
    return fig


# caller script for all plotting functions
def plot(df, property, yrange, folder=".", clip=None):
    seq_len = df['length'].mean()
    colors = {
        'A': ['#fee0d2', '#fb6a4a'], 
        'C': ['#deebf7', '#08519c'], 
        'G': ['#fee6ce', '#a63603'], 
        'T': ['#e5f5e0', '#006d2c'],
        'GC': ['#bdbdbd', '#000000'],
        'dg': ['#bdbdbd', '#000000'],
        'hp': ['#bdbdbd', '#000000'],
    }
    linecolors = {
        'A': '#cb181d', 
        'C': '#2171b5', 
        'G': '#d94701', 
        'T': '#238b45',
        'GC': 'black',
        'dg': 'black',
        'hp': 'black',
    }
    fillcolors = {
        'A': '#fb6a4a', 
        'C': '#6baed6', 
        'G': '#fd8d3c', 
        'T': '#74c476',
        'GC': '#969696',
        'dg': '#969696',
        'hp': '#969696',
    }
    title = {'eff': 'Relative efficiency', 'x0': 'Initial coverage'}[property]
    label = {
        'A': 'Frequency of A', 
        'C': 'Frequency of C', 
        'G': 'Frequency of G', 
        'T': 'Frequency of T', 
        'GC': 'GC content',
        'dg': 'ΔG / kcal mol<sup>-1</sup>',
        'hp': 'Longest homopolymer / nt',
    }
    binsize = {
        'A': 1/seq_len, 
        'C': 1/seq_len, 
        'G': 1/seq_len, 
        'T': 1/seq_len, 
        'GC': 1/seq_len,
        'dg': 1,
        'hp': 1,
    }

    # clip the range of the property
    df = df.copy()
    if clip: 
        print(f'Clipping {property} to {clip}')
        df[property] = df[property].clip(*clip)

    for var in ('GC', 'A', 'C', 'G', 'T', 'hp', 'dg'):
        fig = plot_density(df, var, property, yrange, label[var], title, binsize[var], fillcolors[var], colors[var], linecolors[var])
        fig.write_image(f'{folder}/{property}_{var}.svg')

    # discrete variables
    fig = plot_discrete(df, property, yrange, linecolors, fillcolors, title)
    fig.write_image(f'{folder}/{property}_firstlast.svg')

    # compose all subpanels
    compose_plots(folder, property)

# GCall

In [None]:
data = collect_data('../data/internal_datasets/GCall/')
data.to_csv('./SI_figure_parameter_vs_seqprops/GCall/data.csv', index=False)

plot(data, 'eff', [0.96, 1.02], clip=(0.96, 1.02), folder='./SI_figure_parameter_vs_seqprops/GCall')
plot(data, 'x0', [0, 3], clip=(0, 3), folder='./SI_figure_parameter_vs_seqprops/GCall')

# GCfix

In [None]:
data = collect_data('../data/internal_datasets/GCfix/')
data.to_csv('./SI_figure_parameter_vs_seqprops/GCfix/data.csv', index=False)

plot(data, 'eff', [0.96, 1.02], clip=(0.96, 1.02), folder='./SI_figure_parameter_vs_seqprops/GCfix')
plot(data, 'x0', [0, 3], clip=(0, 3), folder='./SI_figure_parameter_vs_seqprops/GCfix')

# Koch et al

In [None]:
data = collect_data('../data/external_datasets/Koch_et_al/')
data.to_csv('./SI_figure_parameter_vs_seqprops/Koch_et_al/data.csv', index=False)

plot(data, 'eff', [0.96, 1.04], clip=(0.96, 1.04), folder='./SI_figure_parameter_vs_seqprops/Koch_et_al')
plot(data, 'x0', [0, 4], clip=(0, 4), folder='./SI_figure_parameter_vs_seqprops/Koch_et_al')

# Choi et al

In [None]:
data = collect_data('../data/external_datasets/Choi_et_al/')
data.to_csv('./SI_figure_parameter_vs_seqprops/Choi_et_al/data.csv', index=False)

plot(data, 'eff', [0.98, 1.02], clip=(0.98, 1.02), folder='./SI_figure_parameter_vs_seqprops/Choi_et_al')
plot(data, 'x0', [0, 4], clip=(0, 4), folder='./SI_figure_parameter_vs_seqprops/Choi_et_al')

# Gao et al

In [None]:
data = collect_data('../data/external_datasets/Gao_et_al/')
data.to_csv('./SI_figure_parameter_vs_seqprops/Gao_et_al/data.csv', index=False)

plot(data, 'eff', [0.95, 1.05], clip=(0.95, 1.05), folder='./SI_figure_parameter_vs_seqprops/Gao_et_al')
plot(data, 'x0', [0, 4], clip=(0, 4), folder='./SI_figure_parameter_vs_seqprops/Gao_et_al')

# Song et al

In [None]:
data = collect_data('../data/external_datasets/Song_et_al/')
data.to_csv('./SI_figure_parameter_vs_seqprops/Song_et_al/data.csv', index=False)

plot(data, 'eff', [0.95, 1.05], clip=(0.95, 1.05), folder='./SI_figure_parameter_vs_seqprops/Song_et_al')
plot(data, 'x0', [0, 4], clip=(0, 4), folder='./SI_figure_parameter_vs_seqprops/Song_et_al')

# Erlich et al

In [None]:
data = collect_data('../data/external_datasets/Erlich_et_al/')
data.to_csv('./SI_figure_parameter_vs_seqprops/Erlich_et_al/data.csv', index=False)

plot(data, 'eff', [0.97, 1.03], clip=(0.97, 1.03), folder='./SI_figure_parameter_vs_seqprops/Erlich_et_al')
plot(data, 'x0', [0, 3], clip=(0, 3), folder='./SI_figure_parameter_vs_seqprops/Erlich_et_al')