# quickRocketPlot.ipynb
## Marcus Viscardi,    July 03, 2024

I want to make a quicker way to run the rocket plots from `initialPlanningAndTests.ipynb`



In [None]:
import pandas as pd
from pathlib import Path

import seaborn as sea
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

import nanoporePipelineCommon as npCommon

from icecream import ic
from datetime import datetime

def __time_formatter__():
    now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    return f"ic: {now} | > "
ic.configureOutput(prefix=__time_formatter__)
_ = ic("Imports done.")

working_dir = Path.cwd() / f"{npCommon.get_dt(for_file=True)}_rocketPlots"
try:
    working_dir.mkdir()
    _ = ic("Working dir created!", working_dir)
except FileExistsError:
    _ = ic("Working dir already exists!", working_dir)

In [None]:
gene_id_gene_name_df = npCommon.gene_names_to_gene_ids()
gene_id_gene_name_df.head()
_ = ic("Gene ID to Gene Name conversion table loaded.")

In [None]:
obj_dict = {}
pretty_name_map = {
    # The classics:
    "oldN2": "Wildtype (rep1)",
    # "oldS6": "<i>smg-6</i> (rep1)",
    # The terrible second replicates:
    # "newN2": "Wildtype (bad rep)",
    # "newS5": "<i>smg-5</i> (bad rep)",
    # "newS6": "<i>smg-6</i> (bad rep)",
    # The "better" second replicates:
    "newerN2": "Wildtype (rep2)",
    "newerS6": "<i>smg-6</i> (rep2)",
    "newerS5": "<i>smg-5</i> (rep2)",
    # The triplicates!
    "thirdN2": "Wildtype (rep3)",
    "thirdS5": "<i>smg-5</i> (rep3)",
    "thirdS6": "<i>smg-6</i> (rep3)",
    # At 25C for smg-7 (and fourth replicates in a way):
    "temp25cN2": "Wildtype (25°C)",
    "temp25cS5": "<i>smg-5</i> (25°C)",
    "temp25cS6": "<i>smg-6</i> (25°C)",
    "temp25cS7": "<i>smg-7</i> (25°C)",
}
libs_to_run = list(pretty_name_map.keys())
scatter_combos = [
    # ("oldN2", "newerN2"),
    # ("oldN2", "thirdN2"),
    ("newerN2", "thirdN2"),
    ("newerS6", "thirdS6"),
    ("newerS5", "thirdS5"),
    ("newerN2", "temp25cN2"),
    ("newerS6", "temp25cS6"),
    ("newerS5", "temp25cS5"),
    ("thirdN2", "temp25cN2"),
    ("thirdS6", "temp25cS6"),
    ("thirdS5", "temp25cS5"),
    ("temp25cN2", "temp25cS7"),
    ("temp25cS5", "temp25cS7"),
    ("temp25cS6", "temp25cS7"),
]

for lib in libs_to_run:
    print(f"\nLoading {lib}...", end="")
    obj = npCommon.NanoporeRun(run_nickname=lib)
    obj_dict[lib] = obj
    print(" Done!")
    # obj.load_mergedOnReads()

In [None]:
read_df_dict = {}
gene_df_dict = {}
drop_standards = True

for lib, obj in obj_dict.items():
    print(f"Processing {lib}...")
    # read_df = obj.mergedOnReads_df.copy()
    # if drop_standards and lib not in  ['oldN2', 'newN2', 'newS5', 'newS6']:
    #     read_df = read_df.query("assignment == 'NotAStandard'")
    # read_df.qc_pass_featc = read_df.qc_pass_featc.fillna(False)
    # read_df.qc_pass_featc = read_df.qc_pass_featc.astype(bool)
    # cols_to_keep = ['read_id', 'chr_id', 'chr_pos', 'qc_pass_featc', 'gene_id', 'gene_name', 'sequence', 'cigar', 'strand', 'read_length', 'polya_length', 'qc_tag_polya']
    # if lib not in  ['oldN2', 'newN2', 'newS5', 'newS6']:
    #     cols_to_keep += ['assignment']
    # read_df_dict[lib] = read_df[cols_to_keep]
    # print(read_df.value_counts('qc_pass_featc', normalize=True))
    gene_df = obj.load_compressedOnGenes()  # Looks like the old N2 library had a read cutoff of 5 while everything else had no cutoff!!
    if drop_standards:
        gene_df = gene_df.query("gene_id != 'cerENO2'")
    gene_df_dict[lib] = gene_df

read_hits_series_dict = {}
for lib, df in gene_df_dict.items():
    print(f"Pre-cutdown:  {lib} - {df.shape[0]:,} Genes", end=" ")
    # # TODO: Eventually, I should rerun the compressing for oldN2 without the cutoff!!!
    df = df.query("read_hits >= 2")
    print(f"Post-cutdown: {lib} - {df.shape[0]:,} Genes")
    hits_series = df[['gene_id', 'read_hits']].set_index('gene_id')
    hits_series.rename(columns={'read_hits': lib}, inplace=True)
    read_hits_series_dict[lib] = hits_series

In [None]:
plot_libs = libs_to_run

plot_read_hits_table = pd.concat({lib: read_hits_series_dict[lib] for lib in plot_libs}.values(), axis=1).fillna(0)
plot_read_hits_table['avg'] = plot_read_hits_table.mean(axis=1)
plot_read_hits_table['avg_rounded'] = plot_read_hits_table['avg'].round(1)
plot_read_hits_table['std'] = plot_read_hits_table.std(axis=1)
plot_read_hits_table['std/avg'] = plot_read_hits_table['std'] / plot_read_hits_table['avg']
if 'gene_name' not in plot_read_hits_table.columns:
    gene_id_gene_name_df = npCommon.gene_names_to_gene_ids()
    plot_read_hits_table.reset_index(names='gene_id', inplace=True)
    plot_read_hits_table = plot_read_hits_table.merge(gene_id_gene_name_df, on='gene_id', how='left')
plot_read_hits_table

In [None]:
rpm_series_dict = {}
for lib, hits_series in read_hits_series_dict.items():
    print(f"{lib}: {hits_series.shape[0]:,} Genes")
    rpm_series = hits_series / hits_series.sum() * 1_000_000
    rpm_series_dict[lib] = rpm_series
plot_rpm_table = pd.concat({lib: rpm_series_dict[lib] for lib in plot_libs}.values(), axis=1).fillna(0)
if 'gene_name' not in plot_rpm_table.columns:
    gene_id_gene_name_df = npCommon.gene_names_to_gene_ids()
    plot_rpm_table.reset_index(names='gene_id', inplace=True)
    plot_rpm_table = plot_rpm_table.merge(gene_id_gene_name_df, on='gene_id', how='left')
plot_rpm_table

In [None]:
scatter_combos = [
    # ("oldN2", "newerN2"),
    # ("oldN2", "thirdN2"),
    # ("newerN2", "thirdN2"),
    # ("newerS6", "thirdS6"),
    # ("newerS5", "thirdS5"),
    # ("newerN2", "temp25cN2"),
    # ("newerS6", "temp25cS6"),
    # ("newerS5", "temp25cS5"),
    # ("thirdN2", "temp25cN2"),
    # ("thirdS6", "temp25cS6"),
    # ("thirdS5", "temp25cS5"),
    ("temp25cN2", "temp25cS5"),
    ("temp25cN2", "temp25cS6"),
    ("temp25cN2", "temp25cS7"),
    ("temp25cS5", "temp25cS7"),
    ("temp25cS6", "temp25cS7"),
    
]

name_map = {}
for lib, pretty_name in pretty_name_map.items():
    # I just want to remove the <i> and </i>
    simple_name = pretty_name.replace("<i>", "").replace("</i>", "")
    name_map[lib] = simple_name

def rocket_plot(x_lib, y_lib, plot_read_hits_table, ax=None, min_reads=10, name_map=name_map, save_dir: Path = None):
    try:
        print(plot_read_hits_table.columns)
        data = plot_read_hits_table.query(f"{x_lib} >= @min_reads & {y_lib} >= @min_reads")
        x = data[x_lib]
        y = data[y_lib]
    except Exception as e:
        print(f"Error: `{e}`, giving up on the plot for: {x_lib} vs {y_lib}!")
        return
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
        
    ax.scatter(x=x, y=y, alpha=0.5, color='black', marker='.')

    ax.axvline(100, color='red', linestyle='--')
    ax.axhline(100, color='red', linestyle='--')

    ax.set_title(
        f"Reproducibility of gene abundance at various read count cutoffs\n{name_map[x_lib]} vs {name_map[y_lib]}")
    ax.set_xlabel(f"Log10 Reads per gene - {name_map[x_lib]}")
    ax.set_ylabel(f"Log10 Reads per gene - {name_map[y_lib]}")

    ax.set_yscale('log')
    ax.set_xscale('log')

    ax.grid(True, which="both", alpha=0.5)

    plt.tight_layout()
    if save_dir is not None and save_dir.exists():
        save_name = f"{x_lib}_vs_{y_lib}_rocket_plot"
        plt.savefig(save_dir / f"{save_name}.png", dpi=300)
        plt.savefig(save_dir / f"{save_name}.svg")
        print(f"Saved plot to {save_dir / save_name}")
    elif save_dir is not None:
        print(f"Error: Save directory does not exist: {save_dir}")
    
    if ax is None:
        plt.show()
    else:
        return ax

def rocket_plot_rpm(x_lib, y_lib, rpm_table, ax=None, min_rpm=10, name_map=name_map, save_dir: Path = None):
    try:
        print(f"Filtering for {x_lib} >= {min_rpm} & {y_lib} >= {min_rpm}")
        data = rpm_table.query(f"{x_lib} >= @min_rpm & {y_lib} >= @min_rpm")
        x = data[x_lib]
        y = data[y_lib]
    except Exception as e:
        print(f"Error: `{e}`, giving up on the plot for: {x_lib} vs {y_lib}!")
        return
    if ax is None:
        fig, ax = plt.subplots(figsize=(5, 5))
        
    ax.scatter(x=x, y=y, alpha=0.5, color='black', marker='.')
    # Let's add a diagonal line!
    ax.plot([min_rpm, 1_000_000], [min_rpm, 1_000_000], color='red', linestyle='--')
    x_y_lim = max(x.max(), y.max())
    ax.set_xlim(min_rpm, x_y_lim)
    ax.set_ylim(min_rpm, x_y_lim)

    ax.set_title(f"{name_map[x_lib]} vs {name_map[y_lib]}")
    ax.set_xlabel(f"Log10 RPM - {name_map[x_lib]}")
    ax.set_ylabel(f"Log10 RPM - {name_map[y_lib]}")

    ax.set_yscale('log')
    ax.set_xscale('log')

    ax.grid(True, which="both", alpha=0.5)

    plt.tight_layout()
    if save_dir is not None and save_dir.exists():
        save_name = f"{x_lib}_vs_{y_lib}_rocket_plot"
        plt.savefig(save_dir / f"{save_name}.png", dpi=300)
        plt.savefig(save_dir / f"{save_name}.svg")
        print(f"Saved plot to {save_dir / save_name}")
    elif save_dir is not None:
        print(f"Error: Save directory does not exist: {save_dir}")
    
    if ax is None:
        plt.show()
    else:
        return ax

for x_lib, y_lib in scatter_combos:
    rocket_plot_rpm(x_lib, y_lib, plot_rpm_table, save_dir=working_dir)