In [None]:
# %load ../snippets/basic_settings.py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from pathlib import Path
import seaborn as sns
import sys
import plotly.express as px

sns.set_context("notebook", font_scale=1.1)
pd.set_option("display.max_columns", 100)
pd.set_option("display.max_rows", 100)
plt.rcParams["figure.figsize"] = (16, 12)
plt.rcParams['savefig.dpi'] = 200
plt.rcParams['figure.autolayout'] = False
plt.rcParams['axes.labelsize'] = 18
plt.rcParams['axes.titlesize'] = 20
plt.rcParams['font.size'] = 16
plt.rcParams['lines.linewidth'] = 2.0
plt.rcParams['lines.markersize'] = 8
plt.rcParams['legend.fontsize'] = 14
plt.rcParams['text.usetex'] = False  # True activates latex output in fonts!
#pd.set_option('display.float_format', lambda x: '{:,.2f}'.format(x))

# Load Ground Truth data

In [None]:
%ls /nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/02_22_result_benchmarks/

In [None]:
outDir = Path("/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/02_22_result_benchmarks")
gt_file = outDir/"15-02-2022-ground_truth.csv"
gtDf = pd.read_csv(gt_file).iloc[:,:6]
gtDf['log_gt_CI'] = np.log2(gtDf.gt_CI)

# Load current results

In [None]:
resDir = Path("/nfs/nas22/fs2202/biol_micro_bioinf_nccr/hardt/nguyenb/tnseq/scratch/08_21/results/nguyenb")
res_file = resDir/'24-11-2021-all-libraries-zscores.csv'

In [None]:
resDf = pd.read_csv(res_file, index_col=0).rename({'ci': 'CI'}, axis=1)
resDf['log_CI'] = np.log2(resDf.CI)

In [None]:
resDf.sample(5)

In [None]:
# Try for one library


In [None]:
def compare_to_gt(gtDf, results):
    
    compDf = results.merge(gtDf, on=['gene', 'day'])
    if sum([c in compDf.columns for c in ['gene', 'day', 'log_CI', 
                                      'log_gt_CI', 'gt_padj', 'padj']]) < 6:
        print('Some of the columns are missing')
    else:
        compDf['gt_hits'] = compDf.gt_padj<0.05
        compDf['screen_hits'] = compDf.padj<0.05
        compDf['TP'] = (compDf.gt_hits & compDf.screen_hits) == True
        compDf['TN'] = (compDf.gt_hits == False) & (compDf.screen_hits == False)
        compDf['FP'] = (compDf.gt_hits == False) & (compDf.screen_hits == True)
        compDf['FN'] = (compDf.gt_hits == True) & (compDf.screen_hits == False)
        confMat = (pd.DataFrame(compDf[['TP', 'FN',  'FP', 'TN', ]].sum()
                                .values
                                .reshape((2,2)), index=['Real Pos', 'Real Neg'],
                          columns=['Pred Pos', 'Pred Neg']))
        prec = confMat.loc['Real Pos', 'Pred Pos']/confMat.sum()['Pred Pos']
        recall = confMat.loc['Real Pos', 'Pred Pos']/confMat.sum(axis=1)['Real Pos']
        return compDf, confMat, prec, recall

In [None]:
rdict = {}
for i, g  in resDf.groupby('library'):
    print(i)
    rdict[i] = compare_to_gt(gtDf, g)

In [None]:
for k, v in rdict.items():
    print(k)
    print(f"Precision: {v[2]}")
    print(f"Recall: {v[3]}")

In [None]:
comp = res14_2.merge(gtDf, on=['gene', 'day'])
comp['gt_CI_log'] = np.log2(comp.gt_CI)
comp['ci_log'] = np.log2(comp.ci)
comp['gt_hits'] = comp.gt_padj<0.05
comp['screen_hits'] = comp.padj<0.05
comp['TP'] = (comp.gt_hits & comp.screen_hits) == True
comp['TN'] = (comp.gt_hits == False) & (comp.screen_hits == False)
comp['FP'] = (comp.gt_hits == False) & (comp.screen_hits == True)
comp['FN'] = (comp.gt_hits == True) & (comp.screen_hits == False)
confMat = pd.DataFrame(comp[['TP', 'FN',  'FP', 'TN', ]].sum().values.reshape((2,2)), index=['Real Pos', 'Real Neg'],
                      columns=['Pred Pos', 'Pred Neg'])
prec = confMat.loc['Real Pos', 'Pred Pos']/confMat.sum()['Pred Pos']
recall = confMat.loc['Real Pos', 'Pred Pos']/confMat.sum(axis=1)['Real Pos']

In [None]:
comp[['gt_hits', 'screen_hits', 'TP', 'TN', 'FP', 'FN']].sum()

In [None]:
confMat

In [None]:
prec

In [None]:
recall

In [None]:
px.scatter(comp, x='gt_CI_log', y='ci_log', trendline="ols", color='day', height=800)

# Load MAGeCK resuts

In [None]:
mres = outDir/'test8.gene_summary.txt'
maDf = pd.read_table(mres)[['id', 'neg|fdr', 'neg|lfc', 'pos|fdr']]
maDf['fdr'] = maDf[['neg|fdr', 'pos|fdr']].min(axis=1)
maDf = maDf.rename({'id':'gene'}, axis=1).assign(day='d1')

In [None]:
comp2 = maDf.merge(gtDf, on=['gene', 'day'])
comp2['gt_CI_log'] = np.log2(comp2.gt_CI)

In [None]:
px.scatter(comp2, x='gt_CI_log', y='neg|lfc', trendline="ols", color='day', height=800)