In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import pandas as pd
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import matplotlib.pyplot as plt
import logomaker

import sys
sys.path.append('../')
import plotting

# Plot performance with motif substitution

## AUPRC

In [None]:
color_baseline = {'GCall': 'gray', 'GCfix': 'gray'}
color_points = {'GCall': '#de2d26', 'GCfix': '#3182bd'}
color_line = {'GCall': '#fb6a4a', 'GCfix': '#6baed6'}
cut_interval = [[0, 10.3], [10, 30]]


# internal validation 
fig = make_subplots(
    rows=1,
    cols=2,
    horizontal_spacing=0.075,
    column_widths=[0.55, 0.45],
    shared_yaxes=True,
)

for filename in ["GCall", "GCfix"]:
    auprc_replaced, baseline_auprc_replaced, auroc_replaced, baseline_auroc_replaced = pd.read_pickle('../data/machine_learning_results/{}_motif_substitution_performance.pkl'.format(filename))

    model_scatter_points = go.Scatter(
        x=auprc_replaced.index, 
        y=auprc_replaced.values, 
        mode='markers',
        marker=dict(color=color_points[filename], size=5),
    )
    model_scatter_line_left = go.Scatter(
        x=auprc_replaced.index[:11], 
        y=auprc_replaced.values[:11], 
        mode='lines',
        marker=dict(color=color_line[filename]),
    )
    model_scatter_line_right = go.Scatter(
        x=auprc_replaced.index[10:], 
        y=auprc_replaced.values[10:], 
        mode='lines',
        marker=dict(color=color_line[filename]),
    )
    baseline_scatter = go.Scatter(
        x=baseline_auprc_replaced.index, 
        y=baseline_auprc_replaced.values, 
        mode='lines',
        marker=dict(color=color_baseline[filename]),
    )

    fig.add_traces([baseline_scatter, model_scatter_line_left, model_scatter_points], rows=1, cols=1)
    fig.add_traces([baseline_scatter, model_scatter_line_right], rows=1, cols=2)
    fig.update_xaxes(range=cut_interval[0], dtick=2, minor_dtick=1, row=1, col=1)
    fig.update_xaxes(range=cut_interval[1], dtick=10, minor_dtick=5, row=1, col=2)
    fig.update_yaxes(visible=False, row=1, col=2)

fig.update_layout(
    xaxis_title="Number of substituted motifs",
    yaxis_title="AUPRC",
    showlegend=False,
    margin=dict(l=0, r=10, t=10, b=0),
    width=180,
    height=180,
)
fig.update_yaxes(range=[0, 0.5], dtick=0.25)
fig = plotting.standardize_plot(fig)
fig.show()
fig.write_image("./figure_4_motif_replacement/substitution_internal_auprc.svg")




# external validation 
fig = make_subplots(
    rows=1,
    cols=2,
    horizontal_spacing=0.075,
    column_widths=[0.55, 0.45],
    shared_yaxes=True,
)

for source, target in [["GCall", "GCfix"], ["GCfix", "GCall"]]:
    auprc_replaced, baseline_auprc_replaced, _, _ = pd.read_pickle('../data/machine_learning_results/{}to{}_motif_substitution_performance.pkl'.format(source, target))
    
    model_scatter_points = go.Scatter(
        x=auprc_replaced.index, 
        y=auprc_replaced.values, 
        mode='markers',
        marker=dict(color=color_points[source], size=5),
    )
    model_scatter_line_left = go.Scatter(
        x=auprc_replaced.index[:11], 
        y=auprc_replaced.values[:11], 
        mode='lines',
        marker=dict(color=color_line[source]),
    )
    model_scatter_line_right = go.Scatter(
        x=auprc_replaced.index[10:], 
        y=auprc_replaced.values[10:], 
        mode='lines',
        marker=dict(color=color_line[source]),
    )
    baseline_scatter = go.Scatter(
        x=baseline_auprc_replaced.index, 
        y=baseline_auprc_replaced.values, 
        mode='lines',
        marker=dict(color=color_baseline[source]),
    )

    fig.add_traces([baseline_scatter, model_scatter_line_left, model_scatter_points], rows=1, cols=1)
    fig.add_traces([baseline_scatter, model_scatter_line_right], rows=1, cols=2)
    fig.update_xaxes(range=cut_interval[0], dtick=2, minor_dtick=1, row=1, col=1)
    fig.update_xaxes(range=cut_interval[1], dtick=10, minor_dtick=5, row=1, col=2)
    fig.update_yaxes(visible=False, row=1, col=2)

fig.update_layout(
    xaxis_title="Number of substituted motifs",
    yaxis_title="AUPRC",
    showlegend=False,
    margin=dict(l=0, r=10, t=10, b=0),
    width=180,
    height=180,
)
fig.update_yaxes(range=[0, 0.5], dtick=0.25)
fig = plotting.standardize_plot(fig)
fig.show()
fig.write_image("./figure_4_motif_replacement/substitution_external_auprc.svg")

## AUROC

In [None]:
color_baseline = {'GCall': '#b0887a', 'GCfix': '#909da5'}
color_points = {'GCall': '#de2d26', 'GCfix': '#3182bd'}
color_line = {'GCall': '#fb6a4a', 'GCfix': '#6baed6'}
cut_interval = [[0, 10.3], [10, 30]]


# internal validation 
fig = make_subplots(
    rows=1,
    cols=2,
    horizontal_spacing=0.075,
    column_widths=[0.55, 0.45],
    shared_yaxes=True,
)

for filename in ["GCall", "GCfix"]:
    _, _, auroc_replaced, baseline_auroc_replaced = pd.read_pickle('../data/machine_learning_results/{}_motif_substitution_performance.pkl'.format(filename))

    model_scatter_points = go.Scatter(
        x=auroc_replaced.index, 
        y=auroc_replaced.values, 
        mode='markers',
        marker=dict(color=color_points[filename], size=5),
    )
    model_scatter_line_left = go.Scatter(
        x=auroc_replaced.index[:11], 
        y=auroc_replaced.values[:11], 
        mode='lines',
        marker=dict(color=color_line[filename]),
    )
    model_scatter_line_right = go.Scatter(
        x=auroc_replaced.index[10:], 
        y=auroc_replaced.values[10:], 
        mode='lines',
        marker=dict(color=color_line[filename]),
    )
    baseline_scatter = go.Scatter(
        x=baseline_auroc_replaced.index, 
        y=baseline_auroc_replaced.values, 
        mode='lines',
        marker=dict(color=color_baseline[filename]),
    )

    fig.add_traces([baseline_scatter, model_scatter_line_left, model_scatter_points], rows=1, cols=1)
    fig.add_traces([baseline_scatter, model_scatter_line_right], rows=1, cols=2)
    fig.update_xaxes(range=cut_interval[0], dtick=2, minor_dtick=1, row=1, col=1)
    fig.update_xaxes(range=cut_interval[1], dtick=10, minor_dtick=5, row=1, col=2)
    fig.update_yaxes(visible=False, row=1, col=2)

fig.update_layout(
    xaxis_title="Number of substituted motifs",
    yaxis_title="AUROC",
    showlegend=False,
    margin=dict(l=0, r=10, t=10, b=0),
    width=180,
    height=180,
)
fig.update_yaxes(range=[0.5, 1.0], dtick=0.25)
fig = plotting.standardize_plot(fig)
fig.show()
fig.write_image("./figure_4_motif_replacement/substitution_internal_auroc.svg")




# external validation 
fig = make_subplots(
    rows=1,
    cols=2,
    horizontal_spacing=0.075,
    column_widths=[0.55, 0.45],
    shared_yaxes=True,
)

for source, target in [["GCall", "GCfix"], ["GCfix", "GCall"]]:
    _, _, auroc_replaced, baseline_auroc_replaced = pd.read_pickle('../data/machine_learning_results/{}to{}_motif_substitution_performance.pkl'.format(source, target))
    
    model_scatter_points = go.Scatter(
        x=auroc_replaced.index, 
        y=auroc_replaced.values, 
        mode='markers',
        marker=dict(color=color_points[source], size=5),
    )
    model_scatter_line_left = go.Scatter(
        x=auroc_replaced.index[:11], 
        y=auroc_replaced.values[:11], 
        mode='lines',
        marker=dict(color=color_line[source]),
    )
    model_scatter_line_right = go.Scatter(
        x=auroc_replaced.index[10:], 
        y=auroc_replaced.values[10:], 
        mode='lines',
        marker=dict(color=color_line[source]),
    )
    baseline_scatter = go.Scatter(
        x=baseline_auroc_replaced.index, 
        y=baseline_auroc_replaced.values, 
        mode='lines',
        marker=dict(color=color_baseline[source]),
    )

    fig.add_traces([baseline_scatter, model_scatter_line_left, model_scatter_points], rows=1, cols=1)
    fig.add_traces([baseline_scatter, model_scatter_line_right], rows=1, cols=2)
    fig.update_xaxes(range=cut_interval[0], dtick=2, minor_dtick=1, row=1, col=1)
    fig.update_xaxes(range=cut_interval[1], dtick=10, minor_dtick=5, row=1, col=2)
    fig.update_yaxes(visible=False, row=1, col=2)

fig.update_layout(
    xaxis_title="Number of substituted motifs",
    yaxis_title="AUROC",
    showlegend=False,
    margin=dict(l=0, r=10, t=10, b=0),
    width=180,
    height=180,
)
fig.update_yaxes(range=[0.5, 1.0], dtick=0.25)
fig = plotting.standardize_plot(fig)
fig.show()
fig.write_image("./figure_4_motif_replacement/substitution_external_auroc.svg")

# Visualize the replaced motifs

In [None]:
def get_top_k_pwm_motifs(filename):
    motif_by_pvalue = pd.read_pickle(
        "../data/machine_learning_results/{}_2perc_enriched_motifs_whole_sequence.pkl".format(filename)
    )
    significant_motifs = {}
    corrected_alpha = 0.05 / len(motif_by_pvalue)
    for (
            pwm,
            chi2,
            p_value,
            cluster_idx,
            counts_positive,
            counts_negative,
        ) in motif_by_pvalue:
             if p_value < corrected_alpha and counts_positive[0] > (counts_positive.sum() / 5):
                significant_motifs[p_value] = pwm

    significant_motifs = sorted(significant_motifs.items())
    return significant_motifs

color_scheme = {
    'A' : '#31a354',
    'C' : '#3182bd',
    'G' : '#fd8d3c',
    'T': '#de2d26'
}

### top 3 of GCall and GCfix

In [None]:
for filename in ['GCall', 'GCfix',]:
    motif_list = get_top_k_pwm_motifs(filename)

    for k in range(3):
        p_value, pwm = motif_list[k]
        pwm = pd.DataFrame(pwm.T, columns=['A', 'C', 'G', 'T'])
        ic_df = pwm.apply(
            lambda row: 2 + np.sum(row * np.log2(row + 1e-9)), axis=1
        )  # apply entropy-based representation
        pwm = pwm.multiply(ic_df, axis=0)
        plt.figure(figsize=(1.5, 0.25*1.5))
        logo = logomaker.Logo(
            pwm,
            ax=plt.gca(),
            font_name='Inter',
            color_scheme=color_scheme,
        )
        logo.style_spines(visible=False)
        plt.gca().get_yaxis().set_visible(False)
        plt.gca().get_xaxis().set_visible(False)
        plt.savefig('./figure_4_motif_replacement/{}_motif_{}.svg'.format(filename, k), bbox_inches='tight')
        
        print(filename, k+1, p_value)
        plt.show()
        print("\n")

### Top-scoring motif in all datasets

In [None]:
for filename in ['GCall', 'GCfix', 'Koch_et_al', 'Erlich_et_al', 'Song_et_al', 'Choi_et_al', 'Gao_et_al']:
    motif_list = get_top_k_pwm_motifs(filename)
    n_motifs = min(len(motif_list), 3)
    print(filename, len(motif_list))
    
    for k in range(n_motifs):
        p_value, pwm = motif_list[k]
        pwm = pd.DataFrame(pwm.T, columns=['A', 'C', 'G', 'T'])
        ic_df = pwm.apply(
            lambda row: 2 + np.sum(row * np.log2(row + 1e-9)), axis=1
        )  # apply entropy-based representation
        pwm = pwm.multiply(ic_df, axis=0)
        plt.figure(figsize=(1, 0.25*1))
        logo = logomaker.Logo(
            pwm,
            ax=plt.gca(),
            font_name='Inter',
            color_scheme=color_scheme,
        )
        logo.style_spines(visible=False)
        plt.gca().get_yaxis().set_visible(False)
        plt.gca().get_xaxis().set_visible(False)
        plt.savefig('./figure_5_generalization/{}_motif_table_{}.svg'.format(filename, k), bbox_inches='tight', transparent=True)
        
        print(filename, k+1, len(motif_list), p_value)
        plt.show()
        print("\n")

# Illustrative PWM logos for workflow part of figure

In [None]:
def plot_pwm(pos_base, filename, len=10):
    pwm = pd.DataFrame(np.zeros((len, 4)), columns=['A', 'C', 'G', 'T'])
    for pos, (base, height) in pos_base.items():
        pwm.loc[pos, base] = height
    plt.figure(figsize=(1.5, 0.25*1.5))
    logo = logomaker.Logo(
        pwm,
        ax=plt.gca(),
        font_name='Inter',
        color_scheme=color_scheme,
        shade_below=.5,
        fade_below=.5,
    )
    logo.style_spines(visible=False)
    plt.gca().get_yaxis().set_visible(False)
    plt.gca().get_xaxis().set_visible(False)
    plt.savefig('./figure_4_motif_replacement/illustration_motif_{}.svg'.format(filename), bbox_inches='tight')
    plt.show()

pos_base = {
    0: ('C', -0.2),
    1: ('T', 0.8),
    2: ('C', 1.0),
    3: ('G', 0.9),
    4: ('T', 0.8),
    5: ('G', 0.8),
    6: ('T', 0.7),
    7: ('A', 0.4),
    8: ('C', -0.2),
    9: ('T', -0.3),
    10: ('T', -0.2),
    11: ('T', -0.1),
    12: ('T', 0),
    13: ('A', -0.3),
    14: ('C', -0.1),
    15: ('T', 0.3),
}
plot_pwm(pos_base, 'window1', len=16)

pos_base = {
    0: ('C', -0.1),
    1: ('A', 0.1),
    2: ('C', -0.3),
    3: ('A', 0.1),
    4: ('C', -0.2),
    5: ('G', -0.3),
    6: ('C', 0.1),
    7: ('C', -0.2),
    8: ('T', 0.3),
    9: ('C', 0.8),
    10: ('G', 1.0),
    11: ('T', 0.5),
    12: ('G', 0.7),
    13: ('T', 0.3),
    14: ('C', 0),
    15: ('C', -0.2),
}
plot_pwm(pos_base, 'window2', len=16)

pos_base = {
    0: ('C', 0.6),
    1: ('G', 0.8),
    2: ('T', 1.0),
    3: ('G', 0.7),
    4: ('T', 0.4),
    5: ('A', 0.3),
}
plot_pwm(pos_base, 'cluster1', len=6)

pos_base = {
    0: ('G', 0.6),
    1: ('T', 0.2),
    2: ('C', 0.5),
    3: ('G', 1.0),
    4: ('T', 0.8),
    5: ('G', 0.8),
}
plot_pwm(pos_base, 'cluster2', len=6)

pos_base = {
    0: ('T', 0.6),
    1: ('C', 1.0),
    2: ('G', 0.8),
    3: ('T', 0.8),
    4: ('G', 1.0),
    5: ('T', 0.7),
}
plot_pwm(pos_base, 'cluster3', len=6)

# Compose motif and positional histogram

In [None]:
import svgutils.transform as sg
from IPython.display import SVG, display

fig = sg.SVGFigure("680px", "780px")
fig.set_size(("680px", "780px"))
total_width = 680
panel_width = 210
total_whitespace = total_width - 3*panel_width
panel_spacing = total_whitespace/3
row_offset = 130

# for each dataset, show the top 3 motifs
for i, dataset in enumerate(['GCall', 'GCfix', 'Koch_et_al', 'Erlich_et_al', 'Song_et_al', 'Gao_et_al']):

    for j in range(3):
        # try to get the histogram panel, some datasets have less than 3 motifs
        try:
            panel = sg.fromfile(f'./SI_figure_motifs/{dataset}_{j}_motif_positions_histogram.svg').getroot()
        except FileNotFoundError:
            print(f"Did not find {dataset}_{j}_motif_positions_histogram.svg")
            continue
        panel.moveto(j*panel_width + j*panel_spacing, i*row_offset+10)
        fig.append(panel)

        # get the motif panel
        panel = sg.fromfile(f'./figure_5_generalization/{dataset}_motif_table_{j}.svg').getroot()
        panel.moveto((j+0.6)*panel_width + j*panel_spacing, i*row_offset-5)
        fig.append(panel)

        # show title
        title = sg.TextElement(10, 10, f"{dataset.replace("_et_al", " et al.")} #{j+1}", size=28/3, font="Inter", weight="normal")
        title.moveto((j+0.25)*panel_width + j*panel_spacing, i*row_offset)
        fig.append(title)


fig.save('./SI_figure_motifs/composed.svg')
display(SVG(filename='./SI_figure_motifs/composed.svg'))