# initialTestingAndScratchPaper.ipynb
## Marcus Viscardi,    December 21, 2022

This is a script to get initial ideas down on how to analyse the four libraries I produced on 12/16/2022
These libraries were:
1. **sMV025:** xrn-1 knockdown + tagged xrn-1 + 5TERA
2. **sMV026:** xrn-1 knockdown + tagged xrn-1 + smg-5 allele + 5TERA
3. **sMV026:** xrn-1 knockdown + tagged xrn-1 + smg-6 allele + 5TERA
4. **sMV026:** xrn-1 knockdown + tagged xrn-1 + smg-7 allele + 5TERA

Josh thinks I should initially try to compare general metrics between my new libraries and my two old libaries from ~ this time last year:
1. **sMV002:** xrn-1 knockdown + tagged xrn-1 + 5TERA (9/18/2021)
2. **sMV003:** xrn-1 knockdown + tagged xrn-1 + smg-6 allele + 5TERA (12/10/2021)

An initial assessment I want to run is plotting the "fraction adapted" for each gene agaisnt eachother, comparing the various permutations of libraries.

In [None]:
import sys
sys.path.insert(0, '/data16/marcus/scripts/nanoporePipelineScripts')
import nanoporePipelineCommon as npCommon

import seaborn as sea
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

import numpy as np
import pandas as pd
pd.set_option('display.width', 200)
pd.set_option('display.max_columns', None)

CONVERSION_DICT = {"xrn-1-5tera": "oldN2",
                   "xrn-1-5tera-smg-6": "oldS6",
                   "5tera_xrn-1-KD_wt": "newN2",
                   "5tera_xrn-1-KD_wt_rerun": "newerN2",
                   "5tera_xrn-1-KD_smg-6_rerun": "newerS6",
                   "5tera_xrn-1-KD_smg-5": "newS5",
                   "5tera_xrn-1-KD_smg-6": "newS6",
                   "5tera_xrn-1-KD_smg-7": "newS7",
                   "sPM57": "sPM57",
                   "sPM58": "sPM58",
                   }
REV_CONVERSION_DICT = {val: key for key, val in CONVERSION_DICT.items()}

print(f"Imports done at {npCommon.get_dt(for_print=True)}")

In [None]:
regenerate = False
libs_to_load = sorted({
    'oldN2',
    'newN2',
    'newerN2',
    'oldS6',
    'newS6',
    'newerS6',
    # 'newS5',
    # 'newerS5',
    # 'newS7',
})

try:
    if regenerate:
        raise ValueError
    
    reads_df_raw_path = npCommon.find_newest_matching_file(f"./output_files/mega_merge_parquets/*_{'-'.join(libs_to_load)}_merged5TERA.reads_df.parquet")
    compressed_df_genes_raw_path = npCommon.find_newest_matching_file(f"./output_files/mega_merge_parquets/*_{'-'.join(libs_to_load)}_merged5TERA.compressed_df.parquet")
    print(f"Found preprocessed files at:\n\t{reads_df_raw_path}\nand:\n\t{compressed_df_genes_raw_path}")

    reads_df_genes_raw = pd.read_parquet(reads_df_raw_path)
    compressed_df_genes_raw = pd.read_parquet(compressed_df_genes_raw_path)
except ValueError:
    print(f"Could not find preprocessed files matching these libs: {'/'.join(libs_to_load)}\nGoing to create new ones from scratch! This will take longer.")
    reads_df_genes_raw, compressed_df_genes_raw = npCommon.load_and_merge_lib_parquets([REV_CONVERSION_DICT[lib] for lib in libs_to_load],
                                                                                       drop_sub_n=1,
                                                                                       add_tail_groupings=False,
                                                                                       drop_failed_polya=False,
                                                                                       group_by_t5=True,
                                                                                       use_josh_assignment=False)
    print(f"Saving new parquets to speed up future runs.")
    reads_df_genes_raw.to_parquet(f"./output_files/mega_merge_parquets/{npCommon.get_dt()}_{'-'.join(libs_to_load)}_merged5TERA.reads_df.parquet")
    compressed_df_genes_raw.to_parquet(f"./output_files/mega_merge_parquets/{npCommon.get_dt()}_{'-'.join(libs_to_load)}_merged5TERA.compressed_df.parquet")
print(f"Lib load done @ {npCommon.get_dt(for_print=True)}")

In [None]:
compressed_df_genes_raw.query("lib == 'sPM58'").query("gene_id == 'unc-54'")

In [None]:
compressed_df_genes = compressed_df_genes_raw.copy(deep=True).set_index(["lib", "chr_id", "gene_id", "gene_name"])
reads_df_genes = reads_df_genes_raw.copy(deep=True).set_index(["lib", "chr_id", "gene_id", "gene_name", "read_id"])
compressed_df_genes.groupby(["lib", "t5"])['gene_hits'].sum()

In [None]:
min_gene_hits = 1

df = pd.concat([compressed_df_genes.query("t5 == '-'")[["gene_hits", "gene_rpm"]], compressed_df_genes.query("t5 == '+'")[["gene_hits", "gene_rpm"]]], axis=1)
df.columns = ['unadapted_hits', 'unadapted_rpm', 'adapted_hits', 'adapted_rpm']
df.fillna(0, inplace=True)
df['total_hits'] = df['adapted_hits'] + df['unadapted_hits']
df['total_rpm'] = df['adapted_rpm'] + df['unadapted_rpm']
df['fraction_adapted'] = df['adapted_hits'] / df['total_hits']
df = df.sort_values('fraction_adapted', ascending=False).query(f"total_hits > {min_gene_hits}")

df_list = []
for lib_name in df.index.unique(level="lib"):
    if lib_name not in ["xrn-1-5tera", "xrn-1-5tera-smg-6"]:
        df_list.append(df.query(f"lib == '{lib_name}'").add_suffix(f"_{lib_name.lstrip('5tera_xrn-1-KD_')}").droplevel("lib"))
remerge_df = pd.concat(df_list, axis=1).sort_values("gene_id")
remerge_df.fillna(0, inplace=True)
remerge_df

In [None]:
def plot_scatter(input_df, x_lib_suffix, y_lib_suffix,
                 shared_read_count_minimum=20, drop_MtDNA=True,
                 compare_prefix="total_rpm", log_axes=True,
                 color_by=None,
                 ):
    # shared_read_count_minimum = 50
    # drop_MtDNA = True
    # x_lib_suffix = "smg-6"
    # y_lib_suffix = "smg-5"
    # compare_prefix = "total_rpm"  # "fraction_adapted" or "total_hits" or "total_rpm"
    # color_by = "fraction_adapted_smg-6"
    # log_axes = True
    
    plot_df = input_df[[col for col in input_df.columns
                        if col.endswith(x_lib_suffix) or col.endswith(y_lib_suffix)]].reset_index()
    plot_df['shared_hits'] = plot_df[f"total_hits_{x_lib_suffix}"] + plot_df[f"total_hits_{y_lib_suffix}"]
    plot_df = plot_df[plot_df['shared_hits'] >= shared_read_count_minimum]
    if drop_MtDNA:
        plot_df = plot_df.query("chr_id != 'MtDNA'")
    fig = px.scatter(plot_df,
                     x=f"{compare_prefix}_{x_lib_suffix}",
                     y=f"{compare_prefix}_{y_lib_suffix}",
                     # size="shared_hits",
                     log_x=log_axes, log_y=log_axes,
                     hover_name="gene_name",
                     color=color_by,
                     hover_data=[f"total_hits_{x_lib_suffix}",
                                 f"total_hits_{y_lib_suffix}", ],
                     height=800, width=800,
                     template='plotly_white',
                     )
    fig.update_layout(shapes=[{'type': 'line', 'yref': 'paper', 'xref': 'paper', 'y0': 0, 'y1': 1, 'x0': 0, 'x1': 1}])
    fig.show()
    return fig, plot_df


figure, plotted_df = plot_scatter(remerge_df, "wt", "smg-5",
                                  compare_prefix="adapted_rpm",
                                  log_axes=False,
                                  shared_read_count_minimum=100)
figure, plotted_df = plot_scatter(remerge_df, "wt", "smg-7",
                                  compare_prefix="adapted_rpm",
                                  log_axes=False,
                                  shared_read_count_minimum=100)
plotted_df

In [None]:
remerge_df[sorted([col for col in remerge_df.columns
                   if col.startswith(('adapted', 'total', 'fraction'))])]\
    .query("gene_name == 'F23A7.8'")

In [None]:
compressed_df_genes.groupby(["lib", "t5"])[['gene_rpm', 'gene_hits']].sum()

# Jump to here if loading from file:

In [None]:
save_file_suffix = "quad5TERA.counts.parquet"
try:
    remerge_df.to_parquet(f"./output_files/{npCommon.get_dt(for_file=True)}_{save_file_suffix}")
except NameError:
    remerge_df = pd.read_parquet(npCommon.find_newest_matching_file(f"./output_files/*_{save_file_suffix}"))

In [None]:
# Recalculate an adapted and unadapted rpm for each library. This will allow for a DIFFERENT kind of global effect adjustment
# This will oppose looking at the raw read counts adapted fraction and comparing that to overall adapted percentage

# rpmA will represent this adjusted RPM that uses the total count of adapted or unadapted, rather than overall counts

total_counts_dict = compressed_df_genes.groupby(["lib", "t5"])[['gene_rpm', 'gene_hits']].sum().to_dict()['gene_hits']
library_keys = ['wt', 'smg-5', 'smg-6', 'smg-7']
for lib in library_keys:
    remerge_df[f'adapted_rpmA_{lib}'] = remerge_df[f'adapted_hits_{lib}'] / (total_counts_dict[(f'5tera_xrn-1-KD_{lib}', '+')] / 1_000_000)
    remerge_df[f'unadapted_rpmA_{lib}'] = remerge_df[f'unadapted_hits_{lib}'] / (total_counts_dict[(f'5tera_xrn-1-KD_{lib}', '-')] / 1_000_000)
remerge_df[sorted([col for col in remerge_df.columns if "rpmA" in col])].sort_values('adapted_rpmA_wt', ascending=False)

In [None]:
remerge_plot_df = remerge_df[sorted([col for col in remerge_df.columns if "rpmA" in col or col=='gene_name'])].sort_values('adapted_rpmA_wt', ascending=False)

# fig = px.scatter_matrix(remerge_df.reset_index(),
#                         dimensions=[col for col in remerge_df if col.startswith("adapted_rpmA")],
#                         hover_name='gene_name',
#                         )
log_dict = {'type': 'log'}

plot_df = remerge_df.reset_index()
dims = [dict(label=f"adapted rpmA<br>log({col.lstrip('adapted_rpmA_')})", values=plot_df[col]+100) for col in plot_df.columns
        if col.startswith("adapted_rpmA") ]
color_vals = plot_df['chr_id'].astype('category').cat.codes
fig = go.Figure(data=go.Splom(dimensions=dims,
                              text=plot_df['gene_name'],
                              showupperhalf=False,
                              diagonal=dict(visible=False),
                              marker=dict(color=color_vals,
                                          showscale=False),
))
fig.update_layout(template='plotly_white',
                  xaxis=log_dict,
                  xaxis2=log_dict,
                  xaxis3=log_dict,
                  xaxis4=log_dict,
                  yaxis=log_dict,
                  yaxis2=log_dict,
                  yaxis3=log_dict,
                  yaxis4=log_dict,
                  )
fig.write_html(f"./output_files/{npCommon.get_dt(for_file=True)}_readjustedAdaptedRPM.scatterMatrix.html")

In [None]:
# remerge_df.groupby(["gene_id", "gene_name"])[[col for col in remerge_df.columns if col.startswith("unadapted_hits")]].sum()
compressed_df_genes_copy = compressed_df_genes.copy()
compressed_df_genes_copy.reset_index(inplace=True)
compressed_df_genes_copy['adapted_weirdness'] = (compressed_df_genes_copy.chr_id == 'MtDNA') | (compressed_df_genes_copy.gene_id == 'F23A7.4') | (compressed_df_genes_copy.gene_id == 'F23A7.8')
(compressed_df_genes_copy.groupby(["lib",
                                  "t5",
                                  "adapted_weirdness",
                                  ])[['gene_rpm', 'gene_hits']].sum()['gene_rpm'] / 1_000_000).reset_index().query(
    "t5 == '+' & adapted_weirdness == True"
)

# Note that rpm is no longer "rpm"!!!