In [None]:
import pandas as pd
import numpy as np
from os.path import basename
from glob import glob
import h5py
import scipy.stats as ss

import matplotlib.pyplot as plt
from plotnine import *
from IPython.display import display, Markdown

import sys
sys.path.insert(0, '../CODE/')
from importlib import reload
import visualization_utils
reload(visualization_utils)
from visualization_utils import *

import warnings
warnings.filterwarnings('ignore')

In [None]:
def resp_ratio(x):
    return sum(x == 1) / len(x)

# Human K562 10-fold-cv performance

In [None]:
out_dir = '../OUTPUT/archive/human_42tfs_k562.10_cv_folds/'

pred_df = pd.read_csv(out_dir + 'preds.csv.gz')
rr_df = pred_df.groupby('tf')['label'].agg(resp_ratio).reset_index()

In [None]:
stats_df = pd.read_csv(out_dir + 'stats.csv.gz')
stats_df = stats_df.merge(rr_df, on='tf')

stats_df

In [None]:
(
    ggplot(stats_df, aes(x='label', y='auprc'))
    + geom_point(color=COLORS['dark_blue'], alpha=.5)
    + geom_abline(slope=1, intercept=0, linetype='dashed')
    + lims(x=[0, .65], y=[0, .65])
    + labs(x='Chance', y='AUPRC')
    + theme_linedraw()
    + theme(figure_size=(3, 3), dpi=150)
)

# SHAP

In [None]:
organism = 'human_k562'
k562_dir = '../OUTPUT/archive/human_42tfs_k562.10_cv_folds/'
k562_tfs = np.loadtxt('../../Pert_Response_Modeling/RESOURCES/TF_list/Human_ENCODE_K562_TFs.txt', dtype=str, usecols=[0])

k562_sss_df = pd.DataFrame()

for i, tf_chunk in enumerate(np.array_split(k562_tfs, 4)):
    print('Working on TF chunk {}, n={}'.format(i, len(tf_chunk)))
    
    k562_sss_subdf = calculate_resp_and_unresp_signed_shap_sum(k562_dir, organism=organism, tfs=tf_chunk)
    k562_sss_df = k562_sss_df.append(k562_sss_subdf)

k562_sss_df.to_csv('../OUTPUT/archive/human_42tfs_k562.10_cv_folds/signed_shap_sum.csv.gz')

In [None]:
# k562_sss_df = pd.read_csv('../OUTPUT/archive/human_42tfs_k562.10_cv_folds/signed_shap_sum.csv.gz')

k562_sss_df = k562_sss_df.merge(stats_df[['tf', 'auprc']], on='tf', how='left')
k562_sss_df = calculate_shap_net_influence(k562_sss_df)

In [None]:
# Plot all models
plot_df = k562_sss_df[(k562_sss_df['label_name'] == 'Responsive')]
print('All TFs (n={})'.format(len(plot_df['tf'].unique())))

ax = (
    ggplot(plot_df, aes(x='feat_type_name', y='shap_diff'))
    + geom_hline(yintercept=0, linetype='dashed')
    + geom_boxplot(outlier_size=.5, size=.7, width=.4, color=COLORS['orange'])
#     + geom_jitter(height=0, width=.2, alpha=.2, color='blue')
    + coord_flip()
    + ylim(-.1, 1.6)
    + labs(x=None, y='Net influence of SHAP values\nfor responsive genes')
    + theme_classic()
    + theme(
        figure_size=(2.5, 3), dpi=150,
        axis_text_x=element_text(color='#000000'),
        axis_text_y=element_text(color='#000000'),
        axis_title=element_text(size=10, lineheight=1.5))
)
display(ax)

In [None]:
# Plot all models
plot_df = k562_sss_df[(k562_sss_df['auprc'] > 0.1) & (k562_sss_df['label_name'] == 'Responsive')]
print('TFs w/ AUPRC > 0.1 (n={})'.format(len(plot_df['tf'].unique())))

ax = (
    ggplot(plot_df, aes(x='feat_type_name', y='shap_diff'))
    + geom_hline(yintercept=0, linetype='dashed')
    + geom_boxplot(outlier_size=.5, size=.7, width=.4, color=COLORS['orange'])
#     + geom_jitter(height=0, width=.2, alpha=.2, color='blue')
    + coord_flip()
    + ylim(-.1, 1.5)
    + labs(x=None, y='Net influence of SHAP values\nfor responsive genes')
    + theme_classic()
    + theme(
        figure_size=(2.5, 3), dpi=150,
        axis_text_x=element_text(color='#000000'),
        axis_text_y=element_text(color='#000000'),
        axis_title=element_text(size=10, lineheight=1.5))
)
display(ax)