# lib_v_lib_scatterPlots.ipynb
### Marcus Viscardi,    August 31, 2023

Simple script with goal to look at read count differences between libs

In [None]:
import sys
import warnings

from tqdm.notebook import tqdm

import seaborn as sea
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

import numpy as np
import pandas as pd
import statistics as stats
from pathlib import Path

sys.path.insert(0, '/data16/marcus/scripts/nanoporePipelineScripts')
import nanoporePipelineCommon as npCommon

pio.renderers.default = "browser"

pd.set_option('display.width', 200)
pd.set_option('display.max_columns', None)

CONVERSION_DICT = npCommon.CONVERSION_DICT
REV_CONVERSION_DICT = npCommon.REV_CONVERSION_DICT

print(f"Imports done at {npCommon.get_dt(for_print=True)}")

In [None]:
regenerate = False
libs_to_load = sorted({
    'oldN2',
    'oldS6',
    'newerN2',
    'newerS6',
    'newerS5',
    'thirdN2',
    'thirdS5',
    'thirdS6',
})

try:
    if regenerate:
        raise FileNotFoundError
    
    reads_df_raw_path = npCommon.find_newest_matching_file(f"./output_files/mega_merge_parquets/*_{'-'.join(libs_to_load)}_merged5TERA.reads_df.parquet")
    compressed_df_genes_raw_path = npCommon.find_newest_matching_file(f"./output_files/mega_merge_parquets/*_{'-'.join(libs_to_load)}_merged5TERA.compressed_df.parquet")
    print(f"Found preprocessed files at:\n\t{reads_df_raw_path}\nand:\n\t{compressed_df_genes_raw_path}")

    reads_df_genes_raw = pd.read_parquet(reads_df_raw_path)
    compressed_df_genes_raw = pd.read_parquet(compressed_df_genes_raw_path)
except FileNotFoundError:
    print(f"Could not find preprocessed files matching these libs: {'/'.join(libs_to_load)}\nGoing to create new ones from scratch! This will take longer.")
    reads_df_genes_raw, compressed_df_genes_raw = npCommon.load_and_merge_lib_parquets([REV_CONVERSION_DICT[lib] for lib in libs_to_load],
                                                                                       drop_sub_n=1,
                                                                                       add_tail_groupings=False,
                                                                                       drop_failed_polya=False,
                                                                                       group_by_t5=True,
                                                                                       use_josh_assignment=False)
    print(f"Saving new parquets to speed up future runs.")
    reads_df_genes_raw.to_parquet(f"./output_files/mega_merge_parquets/{npCommon.get_dt()}_{'-'.join(libs_to_load)}_merged5TERA.reads_df.parquet")
    compressed_df_genes_raw.to_parquet(f"./output_files/mega_merge_parquets/{npCommon.get_dt()}_{'-'.join(libs_to_load)}_merged5TERA.compressed_df.parquet")
print(f"Lib load done @ {npCommon.get_dt(for_print=True)}")

compressed_df_genes_short = compressed_df_genes_raw.copy()[["lib", "chr_id", "gene_id", "gene_name", "t5", "gene_hits", "gene_rpm"]]
compressed_df_genes_short.query("gene_name == 'rpl-12'")

In [None]:
conversion_dict = CONVERSION_DICT
ans = [y for x, y in compressed_df_genes_short.groupby(['lib', 't5'], as_index=False)]
df_dict = {}
for i, df in enumerate(ans):
    lib = df.lib.unique()[0]
    t5 = df.t5.unique()[0]
    df = df[["chr_id", "gene_id", "gene_name", "gene_hits", "gene_rpm"]]
    df = df.rename(columns={col: f'{col}_{conversion_dict[lib]}_t5{t5}' for col in df.columns if col not in ["chr_id", "gene_id", "gene_name"]})
    df_dict[(conversion_dict[lib], t5)] = df.set_index(["chr_id", "gene_id", "gene_name"])
    # print((conversion_dict[lib], t5))
    # print(df_dict[(conversion_dict[lib], t5)].query("gene_name == 'rpl-12'"))

super_df = pd.concat(df_dict.values(), axis=1, join='outer').fillna(0)
super_df

In [None]:
# Regenerate just the total RPM for each gene:
for lib in libs_to_load:
    super_df[f"gene_rpm_{lib}"] = super_df[[f"gene_rpm_{lib}_t5+", f"gene_rpm_{lib}_t5-"]].sum(axis=1)

In [None]:
from plotly.subplots import make_subplots
from plotly import graph_objects as go
plot_df = super_df.copy()
genes_to_exclude = ['xrn-1', 'rrn-2.1', 'F23A7.4', 'F23A7.8', 'unNamed']
plot_df = plot_df[~plot_df.index.get_level_values('gene_name').isin(genes_to_exclude)]
plot_df = plot_df.sort_index()

def plot_rockets(l1, l2, plotting_df, save_dir=None, force_limits=False):
    fig = make_subplots(rows=1, cols=3,
                        subplot_titles=[f"{l1} vs {l2} RPM (total)",
                                        f"{l1} vs {l2} RPM (unadapted)",
                                        f"{l1} vs {l2} RPM (adapted)",
                                        ],
                        # shared_yaxes=True,
                        # shared_xaxes=True,
                        row_heights=[500],
                        column_widths=[500, 500, 500],
                        )
    for i, t5 in enumerate(['', '_t5-', '_t5+']):
        subplot = px.scatter(plotting_df.reset_index(),
                             x=f"gene_rpm_{l1}{t5}",
                             y=f"gene_rpm_{l2}{t5}",
                             hover_name="gene_name",
                             )
        fig.add_trace(subplot.data[0], row=1, col=i + 1)

        fig.update_xaxes(
            title=f"{l1} RPM {t5.strip('_')}",
            ticks="inside", ticklen=5, showgrid=True, gridcolor='lightgrey', type='log',
            minor=dict(ticks="inside", ticklen=5, showgrid=True),
            row=1, col=i + 1,
        )
        fig.update_yaxes(
            title=f"{l2} RPM {t5.strip('_')}",
            ticks="inside", ticklen=5, showgrid=True, gridcolor='lightgrey', type='log',
            minor=dict(ticks="inside", ticklen=5, showgrid=True),
            row=1, col=i + 1,
        )
        if force_limits:
            if i != 2:
                limits = [0.5, 4.5]
            else:
                limits = [0, 3]
            fig.update_xaxes(range=limits, row=1, col=i + 1)
            fig.update_yaxes(range=limits, row=1, col=i + 1)
    fig.update_traces(marker=dict(size=5,
                                  color='black',
                                  ),
                      )
    fig.update_layout(height=500,
                      width=1500,
                      template='none')
    if save_dir:
        if not Path(save_dir).exists():
            warnings.warn(f"Save directory doesn't exist! Making it now at: {save_dir}")
            Path(save_dir).mkdir(parents=True)
        fig.write_html(f"{save_dir}/{l1}_v_{l2}_scatters.html")
        fig.write_image(f"{save_dir}/{l1}_v_{l2}_scatters.png")
        fig.write_image(f"{save_dir}/{l1}_v_{l2}_scatters.svg")
    
    fig.show(renderer='firefox')
    return fig

def plot_rocket_grid(libs, plotting_df, save_dir=None, force_limits=True):
    lib_list_in_order = sorted(libs)
    plotting_df = plotting_df[[f"gene_rpm_{lib}" for lib in lib_list_in_order]]
    fig = go.Figure(data=go.Splom(
        dimensions=[dict(label=f"{lib}", values=plotting_df[f"gene_rpm_{lib}"]) for lib in lib_list_in_order],
        showupperhalf=False,
        text=plotting_df.index.get_level_values('gene_name'),
        marker=dict(color='black',
                    size=5,
                    opacity=0.5),
    ))
    fig.update_layout(
        title=f"Gene RPMs for {', '.join(libs)}",
        width=1000,
        height=1000,
    )
    axes_layout_dir = dict(ticks="inside", ticklen=5, showgrid=True, gridcolor='lightgrey', type='log',
                           minor=dict(ticks="inside", ticklen=5, showgrid=True))
    if force_limits:
        axes_layout_dir['range'] = [0.5, 4.5]
    update_layout_dict_base = {'xaxis': axes_layout_dir, 'yaxis': axes_layout_dir}
    update_layout_dict_xauto = {f'xaxis{i}': axes_layout_dir for i in range(1, len(libs) + 1)}
    update_layout_dict_yauto = {f'yaxis{i}': axes_layout_dir for i in range(1, len(libs) + 1)}
    update_layout_dict = {**update_layout_dict_base, **update_layout_dict_xauto, **update_layout_dict_yauto}
    fig.update_layout(**update_layout_dict)
    if save_dir:
        if not Path(save_dir).exists():
            warnings.warn(f"Save directory doesn't exist! Making it now at: {save_dir}")
            Path(save_dir).mkdir(parents=True)
        fig.write_html(f"{save_dir}/{'-'.join(libs)}_scatters.html")
        fig.write_image(f"{save_dir}/{'-'.join(libs)}_scatters.png")
        fig.write_image(f"{save_dir}/{'-'.join(libs)}_scatters.svg")
    fig.show()

In [None]:
# lib_combinations = [(l1, l2) for l1 in libs_to_load for l2 in libs_to_load if l1 != l2]
lib_combinations = [
    # ('oldS6', 'newerS6'),
    # ('oldN2', 'newerN2'),
    # ('oldN2', 'oldS6'),
    ('newerN2', 'newerS6'),
    # ('newerN2', 'newerS5'),
    ('newerS6', 'thirdS6'),
    ('newerN2', 'thirdS6'),
    
]

output_directory = f"/home/marcus/Insync/mviscard@ucsc.edu/Google Drive/insync_folder/NMD_cleavage_and_deadenylation_paper/raw_figures_from_python/{npCommon.get_dt()}_scatterPlots"

for lib_1, lib_2 in lib_combinations:
    plot_rockets(lib_1, lib_2, plot_df, save_dir=output_directory, force_limits=True)

In [None]:
libs_to_plot = ['newerN2', 'newerS6', 'newerS5', 'thirdN2', 'thirdS5', 'thirdS6']

plot_rocket_grid(libs_to_plot,
                 plot_df,
                 save_dir=output_directory,
                 force_limits=True)