In [None]:
# %load ../snippets/basic_settings.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import seaborn as sns
import sys
import plotly.express as px

sns.set_context("notebook", font_scale=1.1)
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)
plt.rcParams["figure.figsize"] = (16, 12)
plt.rcParams['savefig.dpi'] = 200
plt.rcParams['figure.autolayout'] = False
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 2.0
plt.rcParams['lines.markersize'] = 8
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['text.usetex'] = False  # True activates latex output in fonts!
#pd.set_option('display.float_format', lambda x: '{:,.2f}'.format(x))

In [None]:
clrs = px.colors.qualitative.Safe
libraries = resDf.library.unique()
library_clrs = {lib:col for lib, col in zip(libraries, clrs)}


def get_ci_corr(comp, gt_CI, exp_CI, method):
    corr_df = comp.groupby(['library', 'day'])[[gt_CI, exp_CI]].corr().unstack().iloc[:,1].reset_index()
    corr_df.columns = ['library', 'day', 'R2']
    r2mean = corr_df.groupby('library').R2.mean()
    corr_df = corr_df.set_index('library').assign(method=method)
    corr_df['r2lib'] = r2mean
    corr_df = corr_df.reset_index()
    return corr_df



def compare_to_gt(compDf, gt_padj = "gt_padj", exp_padj = 'padj'):
    compDf['gt_hits'] = compDf[gt_padj]<0.05
    compDf['screen_hits'] = compDf[exp_padj]<0.05
    compDf['TP'] = (compDf.gt_hits & compDf.screen_hits) == True
    compDf['TN'] = (compDf.gt_hits == False) & (compDf.screen_hits == False)
    compDf['FP'] = (compDf.gt_hits == False) & (compDf.screen_hits == True)
    compDf['FN'] = (compDf.gt_hits == True) & (compDf.screen_hits == False)
    confMat = (pd.DataFrame(compDf[['TP', 'FN',  'FP', 'TN', ]].sum()
                            .values
                            .reshape((2,2)), index=['Real Pos', 'Real Neg'],
                      columns=['Pred Pos', 'Pred Neg']))
    prec = confMat.loc['Real Pos', 'Pred Pos']/confMat.sum()['Pred Pos']
    recall = confMat.loc['Real Pos', 'Pred Pos']/confMat.sum(axis=1)['Real Pos']
    return compDf, confMat, prec, recall


def get_stats(comp, method, gt_padj, exp_padj):
    pr = []
    for lib, g in comp.groupby('library'):
        print(lib)
        df, mat, prec, recall = compare_to_gt(g, gt_padj, exp_padj)
        pr.append([lib, prec, recall])
        print(mat)
    pr_df = pd.DataFrame(pr, columns = ['library', 'precision', 'recall']).assign(method=method)
    return pr_df


def get_numHits(res, pval_col, method, pval_cutoff=0.05):
    return (res[res[pval_col] < pval_cutoff].copy()
           .groupby(['library', 'day']).gene.nunique()
           .reset_index()
            .assign(method=method))

def plot_correlations(corr_df):
    per_lib_corr = corr_df[['library', 'r2lib', 'method']].drop_duplicates()
    fig = px.bar(per_lib_corr, template='simple_white',
       color_discrete_map= library_clrs,
       category_orders={'library': per_lib_corr.sort_values('r2lib').library.values},
       x='library', y='r2lib', color='library', 
      labels={'library':'Library', 'r2lib': 'R2'},
      title="R2 between RBSeq CIs and experimental CIs", hover_data=['method'])

    fig.update_layout(

        font_size=14,
        title={
            'y':0.9,
            'x':0.5,
            'xanchor': 'center',
            'yanchor': 'top'},
        yaxis_range=[0,1]
    )
    fig.update_xaxes(showticklabels=False)
    return fig


# Load Ground Truth data

In [None]:
%ls /nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/02_22_result_benchmarks/

In [None]:
outDir = Path("/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/02_22_result_benchmarks")
gt_file = outDir/"15-02-2022-ground_truth.csv"
gtDf = pd.read_csv(gt_file).iloc[:,:6]
gtDf['log_gt_CI'] = np.log2(gtDf.gt_CI)

# Load June results

In [None]:
resDirJune = Path("/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/08_21/results/nguyenb")
res_file_june = resDir/"27_07/26-07-final-results.csv"
resJun = pd.read_csv(res_file_june, index_col=0)
resJun['log_CI'] = np.log2(resJun.CI)
comp0 = resJun.merge(gtDf, on=['gene', 'day'])

In [None]:
zscoreJun_pr

In [None]:
gt_CI = "log_gt_CI"
gt_padj = "gt_padj"
exp_CI = "log_CI"
exp_padj = "zscore_padj"

zscoreJun_corr = get_ci_corr(comp0, gt_CI, exp_CI, "zscore-original")
zscoreJune_stats = get_stats(comp0, "zscore-original", gt_padj, exp_padj)
fig = plot_correlations(zscoreJun_corr)
fig

# Load current results

In [None]:
resDir = Path("/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/08_21/results/nguyenb")
res_file = resDir/'24-11-2021-all-libraries-zscores.csv'
resDf = pd.read_csv(res_file, index_col=0).rename({'ci': 'CI'}, axis=1)
resDf['log_CI'] = np.log2(resDf.CI)

comp = resDf.merge(gtDf, on=['gene', 'day'])
comp.head()

In [None]:
gt_CI = "log_gt_CI"
gt_padj = "gt_padj"
exp_CI = "log_CI"
exp_padj = "padj"

zscoreNov_corr = get_ci_corr(comp, gt_CI, exp_CI, "zscore-current")
zscoreNov_stats = get_stats(comp, "zscore-current", gt_padj, exp_padj)
fig = plot_correlations(zscoreNov_corr)
fig

# Load MAGeCK resuts

In [None]:
maDir = Path("/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/02_22_mageck")
mres = maDir/'16-02-2022-batch-corrected-9-libraries.csv'
maDf = pd.read_csv(mres)[['id', 'neg|fdr', 'neg|lfc', 'pos|fdr', 'contrast', 'library']]
maDf['fdr'] = maDf[['neg|fdr', 'pos|fdr']].min(axis=1)
maDf = maDf.rename({'id':'gene', 'contrast':'day'}, axis=1)
comp2 = maDf.merge(gtDf, on=['gene', 'day'])

In [None]:

mageck_corr = get_ci_corr(comp2,"log_gt_CI", "neg|lfc", 'mageck' )


In [None]:
gt_CI = "log_gt_CI"
gt_padj = "gt_padj"
exp_CI = "neg|lfc"
exp_padj = "fdr"

mageck_corr = get_ci_corr(comp2, gt_CI, exp_CI, "mageck")
mageck_stats = get_stats(comp2, "mageck", gt_padj, exp_padj)
fig = plot_correlations(mageck_corr)
fig

In [None]:
comp2.head()

In [None]:
px.scatter(comp2[comp2.library=='library_15_1'],
        x='log_gt_CI', y='neg|lfc', facet_row='day', trendline='ols', width=400,
           height=1000, color='day', template='simple_white'
          )

In [None]:
mresnoBatch = maDir/'16-02-2022-not-batch-corrected-9-libraries.csv'
maDf2 = pd.read_csv(mresnoBatch)[['id', 'neg|fdr', 'neg|lfc', 'pos|fdr', 'contrast', 'library']]
maDf2['fdr'] = maDf2[['neg|fdr', 'pos|fdr']].min(axis=1)
maDf2 = maDf2.rename({'id':'gene', 'contrast':'day'}, axis=1)
comp3 = maDf2.merge(gtDf, on=['gene', 'day'])

In [None]:
gt_CI = "log_gt_CI"
gt_padj = "gt_padj"
exp_CI = "neg|lfc"
exp_padj = "fdr"

mageck_no_batch_corr = get_ci_corr(comp3, gt_CI, exp_CI, "mageck-no-batch")
mageck_no_batch_stats = get_stats(comp3, "mageck-no-batch", gt_padj, exp_padj)
fig = plot_correlations(mageck_no_batch_corr)
fig

In [None]:
#px.line(corr_df_mageck2, x='day', y='R2', color='library', markers=True)

In [None]:
precision_all = pd.concat([zscoreJune_stats, zscoreNov_stats, mageck_stats,mageck_no_batch_stats])

precision_all = precision_all.melt(id_vars=['library', 'method'], var_name='metric', 
                                   value_name='prop',
                                  )

In [None]:

fig = px.box(precision_all, x='method', y='prop', template="simple_white", 
             facet_col='metric', color='method', height=600, width=700)
fig.update_layout(yaxis_range = [0,1] )

In [None]:
30/40

In [None]:
f, axes =plt.subplots(1, 2, figsize=(8,6))
sns.boxplot(data=precision_all[precision_all.metric == 'precision'], 
                  x='method', y='prop', ax=axes[0])
sns.stripplot(data=precision_all[precision_all.metric == 'precision'], 
                  x='method', y='prop', ax=axes[0], color='black')
sns.boxplot(data=precision_all[precision_all.metric == 'recall'], 
                  x='method', y='prop', ax=axes[1])
sns.stripplot(data=precision_all[precision_all.metric == 'recall'], 
                  x='method', y='prop', ax=axes[1], color='black')

#sns.stripplot(data=precision_all, x='metric', y='prop', hue='method', )

In [None]:
all_corr = pd.concat([zscoreJun_corr, zscoreNov_corr, mageck_corr, mageck_no_batch_corr])

In [None]:
mageck_corr

In [None]:
all_corr = all_corr[['library', 'method', 'r2lib']].drop_duplicates()
px.line(all_corr, x='method', y='r2lib', color = 'library', markers=True,
        color_discrete_map=library_clrs,
       hover_data=['library', 'method'])
            

In [None]:
all_corr

In [None]:
px.line(mageck_corr, x='day', y='R2', color='library', markers=True)

In [None]:
resJun = resJun[resJun.gene.str.len() < 15]
numHit1 = get_numHits(resJun, 'zscore_padj', 'zscore-original', pval_cutoff=0.05)
resNov = resDf[resDf.gene.str.len() < 15]
numHit2 = get_numHits(resNov, 'padj', 'zscore-current', pval_cutoff=0.05)
maDf = maDf[maDf.gene.str.len() < 15]
numHit3 = get_numHits(maDf, 'fdr','mageck', pval_cutoff=0.05)
maDf2 = maDf2[maDf2.gene.str.len() < 15]
numHit4 = get_numHits(maDf2, 'fdr','mageck-no-batch', pval_cutoff=0.05)
hits = pd.concat([numHit1, numHit2, numHit3, numHit4])

In [None]:
numHit1.groupby('day').gene.median()

In [None]:
numHit2.groupby('day').gene.median()

In [None]:
numHit3.groupby('day').gene.median()

In [None]:
numHit4.groupby('day').gene.median()

In [None]:
good_libraries = ['library_11_2', 'library_14_2', 'library_15_1', 'library_11_2']

In [None]:
good_hits = hits[hits.library.isin(good_libraries)]

In [None]:
px.scatter(hits, x='method', y='gene', color='day', hover_data=['library'])

In [None]:
px.box(good_hits, x='method', y='gene', facet_col='day', height=400)

In [None]:
px.box(hits[hits.day == 'd2'], x='method', y='gene')