# Imports

In [None]:
from IPython.core.display import display, HTML 
display(HTML(
    "<style>.container {  !important;\
    } div.output_wrapper .output { padding-left: 14px; }</style>"
))

In [None]:
import os
import numpy as np
import pandas as pd
import h5py

from matplotlib import pyplot as plt
from matplotlib.ticker import AutoMinorLocator
import seaborn as sns

from sklearn.metrics import auc, roc_auc_score, adjusted_mutual_info_score
from scipy.stats import ks_2samp, chi2_contingency, ttest_rel, ttest_1samp, ttest_ind
from cds.utilities import read_hdf5, write_hdf5
from statsmodels.stats.multitest import fdrcorrection
from sklearn.cluster import KMeans

# Utility and Metric Functions

In [None]:
def write_hdf5(df, filename):
    if os.path.exists(filename):
        os.remove(filename)
    dest = h5py.File(filename)

    try:
        dim_0 = [x.encode('utf8') for x in df.index]
        dim_1 = [x.encode('utf8') for x in df.columns]

        dest_dim_0 = dest.create_dataset('dim_0', track_times=False, data=dim_0)
        dest_dim_1 = dest.create_dataset('dim_1', track_times=False, data=dim_1)
        dest.create_dataset("data", track_times=False, data=df.values)
    finally:
        dest.close()

def read_hdf5(filename):
    src = h5py.File(filename, 'r')
    try:
        dim_0 = [x.decode('utf8') for x in src['dim_0']]
        dim_1 = [x.decode('utf8') for x in src['dim_1']]
        data = np.array(src['data'])

        return pd.DataFrame(index=dim_0, columns=dim_1, data=data)
    finally:
        src.close()


In [None]:
def index_melt(df, **kwargs):
    name = df.index.name
    if not name:
        name = 'index'
    return pd.melt(df.reset_index(level=(df.index.nlevels-1)), id_vars=name, **kwargs)

In [None]:
def roc_auc(positive_controls, negative_controls):
    return roc_auc_score(
        [0]*len(positive_controls) + [1] * len(negative_controls), 
        list(positive_controls) + list(negative_controls)
    )

In [None]:
def empirical_pval(observed, null):
    '''
    generates left-tailed pvalues
    '''
    null = null.dropna().sort_values()
    observed = observed.dropna()
    ind = observed.index.copy()
    observed.sort_values(inplace=True)
    return pd.Series(
        (np.searchsorted(null, observed)+1)/(len(null)+1),
        index=observed.index
    ).reindex(ind)

In [None]:
def empirical_fdr(observed, null):
    pval =  empirical_pval(observed, null)
    return pd.Series(
        fdrcorrection( pval)[1], 
        index=pval.index
    )

In [None]:
def nnmd(x):
    return (x[ess].mean() - x[ness].mean())/x[ness].std()

In [None]:
def get_ind(pairs, keep):
    '''
    pairs: list of tuple of strings indicating column pairs to check
    keep: list of strings indicating columns to keep
    '''
    keep_mapping = pd.Series(
        np.arange(len(keep)),
        index=keep
    ).sort_index()
    return (keep_mapping.loc[[v[0] for v in pairs]].values*len(keep_mapping) +
            keep_mapping.loc[[v[1] for v in pairs]].values).astype(np.int)

In [None]:
def np_cor_no_missing(x, y):
    """Full column-wise Pearson correlations of two matrices with no missing values."""
    xv = (x - x.mean(axis=0))/x.std(axis=0)
    yv = (y - y.mean(axis=0))/x.std(axis=0)
    result = np.dot(xv.T, yv)/len(xv)
    return result

In [None]:
def corr_dist(df, pairs):
    '''
    df: matrix
    pairs: dict of pair types
    '''
    out = {}
    print('calculating correlations')
    covs = np.ravel(np_cor_no_missing(
        df.values, 
        df.values
    ))
    print('subsetting')
    for key, pairset in pairs.items():
        ind = get_ind(pairset, df.columns)
        out[key] = pd.Series(covs[ind], index=pairset)
    return out

# Load

## Choose preprocessings

In [None]:
# point to a local directory where you've stored the data files from Figshare
source = './Data_23.9.20/'

In [None]:
try:
    metadata = pd.read_csv(source + 'MetaSamples.csv')
except FileNotFoundError:
    raise FileNotFoundError("You will need to download the file MetaSamples.csv from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory")
    
files = {}
for ind, row in metadata.iterrows():
    if not row.postcorrection in files:
        files[row.postcorrection] = {}
    files[row.postcorrection][row.preprocessing] = os.path.join(source, row.Names)

In [None]:
output = 'AnalysisOutput'
if not os.path.isdir(output):
    os.mkdir(output)

In [None]:
preprocessings = ['CRISPRCleanR', 'CCR-JACKS', 'CERES']#metadata['preprocessing'].unique()

In [None]:
postcorrections = ['ComBat', 'ComBat+QN', 'ComBat+QN+PC1', 'ComBat+QN+PC1-2']

In [None]:
npreprocessings = len(preprocessings)

## Load Gene Effects

In [None]:
gene_effects = {key: {} for key in files.keys()}
for postcorrection, f in files.items():
    for dataset, v in f.items():
        try:
            gene_effects[postcorrection][dataset] = read_hdf5(v).T
        except FileNotFoundError:
            raise FileNotFoundError("You will need to download the file %s from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory" % v)

In [None]:
try:
    gene_map = pd.read_csv(source + "DepMap_gene_map.csv")
except FileNotFoundError:
    raise FileNotFoundError("You will need to download the file DepMap_gene_map.csv from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory")

In [None]:
shared_genes = set.intersection(*[
    set(v.columns) 
    for val in gene_effects.values() 
for v in val.values() ])
#shared_genes = set(guide_gene_map.gene) & shared_genes
shared_genes = shared_genes & set(gene_map['symbol'])


In [None]:
for v in gene_effects.values():
    for v2 in v.values():
        v2.drop(sorted(set(v2.columns) - set(gene_map.symbol)), axis=1, inplace=True)
        v2.columns = gene_map.set_index('symbol').loc[v2.columns, 'gene'].values

In [None]:
try:
    si = pd.read_csv(source + "sample_info.csv")
except FileNotFoundError:
    raise FileNotFoundError("You will need to download the file sample_info.csv from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory")

In [None]:
lines = si[['DepMap_ID', 'Sanger_Model_ID']]

In [None]:
cell_line_map = pd.concat([
    pd.Series(lines['DepMap_ID'].values, index=lines['DepMap_ID'].values),
    pd.Series(lines.dropna()['DepMap_ID'].values, index=lines.dropna()['Sanger_Model_ID'].values),
], axis=0)

In [None]:
all_lines = set([])
for v in gene_effects.values():
    for v2 in v.values():
        v2.index = cell_line_map.loc[v2.index].values
        all_lines.update(set(v2.index))

## Summarize

In [None]:
npostcorrections = len(gene_effects)

In [None]:
shared_genes = set.intersection(*[
    set(v.columns) 
    for val in gene_effects.values() 
for v in val.values() ])

In [None]:
try:
    ess = pd.read_csv(source + "depmap_common_essentials.csv").iloc[:, 0]
except FileNotFoundError:
    raise FileNotFoundError("You will need to download the file depmap_common_essentials.csv from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory")
ess = [s for s in ess if s in shared_genes]

try:
    ness = pd.read_csv(source + "depmap_nonessentials.csv").iloc[:, 0]
except FileNotFoundError:
    raise FileNotFoundError("You will need to download the file depmap_nonessentials.csv from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory")
    
ness = [s for s in ness if s in shared_genes]

In [None]:
si.set_index('DepMap_ID', inplace=True)

## Scaling

In [None]:
for v in gene_effects.values():
    for v2 in v.values():
        v2 -= v2[ness].median(axis=1).median()
        v2 /= v2[ess].median(axis=1).abs().median()

## Load Omics

In [None]:
try:
    related_genes = pd.read_csv(source + "related_genes.csv")[['target', 'partner']]
except FileNotFoundError:
    raise FileNotFoundError("You will need to download the file related_genes.csv from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory")

related_genes = related_genes[related_genes.target.isin(shared_genes) & related_genes.partner.isin(shared_genes)]

related_genes.drop_duplicates(inplace=True)

In [None]:
try:
    expression = read_hdf5(source + "DepMap_20Q2_expression.hdf5")
except FileNotFoundError:
    raise FileNotFoundError("You will need to download the file DepMap_20Q2_expression.hdf5 from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory")

In [None]:
try:
    mut = pd.read_csv(source + "DepMap_20Q2_mutations.csv")
except FileNotFoundError:
    raise FileNotFoundError("You will need to download the file DepMap_20Q2_mutations.csv from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory")

mut = mut[mut.DepMap_ID.isin(all_lines)]

mut['Gene'] = mut.apply(lambda x: '%s (%i)' % (x['Hugo_Symbol'], x['Entrez_Gene_Id']), axis=1)

mut['hotspot'] = mut['isTCGAhotspot'] | mut['isCOSMIChotspot']

mut = mut[mut.hotspot].drop_duplicates(subset=['Gene', 'DepMap_ID'])

In [None]:
try:
    fusions = pd.read_csv(source + "DepMap_20Q2_fusions.csv")
except FileNotFoundError:
    raise FileNotFoundError("You will need to download the file DepMap_20Q2_fusions.csv from the \
figshare at https://figshare.com/projects/Integrated_CRISPR/78252 and save it in the source \
directory")

fusions = fusions.set_index('DepMap_ID').drop(
    ['Unnamed: 0'], errors='ignore', axis=1)

fusions['LeftGene'] = fusions.LeftGene.apply(lambda s: s.split(' ')[0])
fusions['RightGene'] = fusions.RightGene.apply(lambda s: s.split(' ')[0])

fusions['Sorted'] = fusions.apply(lambda x: tuple(sorted(x[['LeftGene','RightGene']])), axis=1)

# Control Separation

## Unexpressed False Positives

### Calculation

In [None]:
ux = (expression < .01).astype(np.bool)

In [None]:
top_ux = {}
total_ux = {}
threshold = 0.15
for postcorrection, coll in gene_effects.items():
    
    top_ux[postcorrection] = {}
    total_ux[postcorrection] = {}
    print(postcorrection)
    for key in coll.keys():
        ux_shared = sorted(set(ux.columns) & set(coll[key].columns))
        print('\t', key)
        val = coll[key][ux_shared]
        rank = val.rank(pct=True, axis=1)
        mask = ux.loc[rank.index, rank.columns]
        selected = rank.mask(~mask.fillna(False))
        top_ux[postcorrection][key] = (selected < threshold).sum().sum()
        total_ux[postcorrection][key] = mask.notnull().sum().sum()
        print('\t', top_ux[postcorrection][key], total_ux[postcorrection][key])

### Plot

In [None]:
df = index_melt( pd.DataFrame(top_ux)/pd.DataFrame(total_ux))
df.columns = ['preprocessing', 'postcorrection', 'UXFP']

plt.close('all')
plt.figure(figsize=(6, 7))
sns.barplot(data=df, x='preprocessing', hue='postcorrection', y='UXFP', 
            order=preprocessings, hue_order=postcorrections)
sns.despine(top=True, right=True)
plt.xlabel('')
plt.ylabel("Unexpressed False Positives")
plt.xticks(rotation=90)
plt.grid(axis='y', which='major')
plt.tight_layout()
plt.savefig(output + "unexpressed_bar.pdf", dpi=600)

### Significance

In [None]:
print(df.groupby(['preprocessing', 'postcorrection']).UXFP.median())

In [None]:
compare_between_preprocessing = []
for postcorrection in postcorrections:
    for i, d1 in enumerate(preprocessings[:-1]):
        for d2 in preprocessings[i+1:]:
            table = [
                [top_ux[postcorrection][d1], total_ux[postcorrection][d1] - top_ux[postcorrection][d1]],
                [top_ux[postcorrection][d2], total_ux[postcorrection][d2] - top_ux[postcorrection][d2]],
            ]
            g, p, dof, expected = chi2_contingency(table)
            compare_between_preprocessing.append({"Correction": postcorrection, "Preprocessing1": d1, 
                                           "Preprocessing2": d2, 'p': p})
compare_between_preprocessing = pd.DataFrame(compare_between_preprocessing)

compare_between_correction = []
for dataset in preprocessings:
    for i, postcorrection1 in enumerate(postcorrections[:-1]):
        for postcorrection2 in postcorrections[i+1:]:
            table = [
                [top_ux[postcorrection1][dataset], total_ux[postcorrection1][dataset] - top_ux[postcorrection1][dataset]],
                [top_ux[postcorrection2][dataset], total_ux[postcorrection2][dataset] - top_ux[postcorrection2][dataset]],
            ]
            g, p, dof, expected = chi2_contingency(table)
            compare_between_correction.append({"Preprocessing": dataset, "Correction1": postcorrection1, 
                                           "Correction2": postcorrection2, 'p': p})
compare_between_correction = pd.DataFrame(compare_between_correction)

In [None]:
compare_between_preprocessing.groupby(["Preprocessing1", "Preprocessing2"]).p.max()

In [None]:
compare_between_correction.groupby(["Correction1", "Correction2"]).p.max()

## FDR

### Calculate

In [None]:
recall_at_10fdr = {}
for postcorrection in gene_effects.keys():
    print(postcorrection)
    recall_at_10fdr[postcorrection] = {}
    for dataset in gene_effects[postcorrection].keys():
        ux_shared = sorted(
            set(gene_effects[postcorrection][dataset].columns) 
            & set(ux.columns)
        )
        print('\t', dataset)
        val = gene_effects[postcorrection][dataset]
        overlap = sorted(set(ux.index) & set(val.index))
        recall_at_10fdr[postcorrection][dataset] = val.loc[overlap, ux_shared].apply(lambda x: (
            empirical_fdr(
                x,
                x.iloc[ux.loc[x.name, x.index].fillna(False).values])[ess] < .1
            ).mean(),
        axis=1)

In [None]:
recall10_flat = pd.concat([
    pd.DataFrame({'postcorrection': postcorrection, 'preprocessing': dataset, 'Recall': recall_at_10fdr[postcorrection][dataset]})
    for postcorrection, v in recall_at_10fdr.items()
    for dataset in v.keys()
], ignore_index=True)

### Plot

In [None]:
plt.close('all')
plt.figure(figsize=(6, 6))
sns.boxplot(data=recall10_flat, hue='postcorrection', x='preprocessing', y='Recall', 
            order=preprocessings, hue_order=postcorrections)
plt.xlabel("")
plt.xticks(rotation=90)
plt.grid(axis='y')
plt.ylabel("Recall of Essentials at 10% FDR")
plt.tight_layout()
plt.savefig(output + "recall_10fdr.pdf", dpi=600)

### Significance

In [None]:
for postcorrection, r in recall_at_10fdr.items():
    print(postcorrection)
    for dataset in recall_at_10fdr[postcorrection].keys():
        print('\t', dataset, r[dataset].median(), r[dataset].mean())
    print()

In [None]:
compare_between_preprocessing = []
for postcorrection in postcorrections:
    for i, d1 in enumerate(preprocessings[:-1]):
        for d2 in preprocessings[i+1:]:
            t, p = ttest_rel(recall_at_10fdr[postcorrection][d1], recall_at_10fdr[postcorrection][d2])
            compare_between_preprocessing.append({"Correction": postcorrection, "Preprocessing1": d1,
                                              "Preprocessing2": d2, 'p': p})
compare_between_preprocessing = pd.DataFrame(compare_between_preprocessing)

compare_between_correction = []
for dataset in preprocessings:
    for i, postcorrection1 in enumerate(postcorrections[:-1]):
        for postcorrection2 in postcorrections[i+1:]:
            t, p = ttest_rel(recall_at_10fdr[postcorrection1][dataset], recall_at_10fdr[postcorrection2][dataset])
            compare_between_correction.append({"Correction1": postcorrection1, 
                                               "Correction2": postcorrection2,
                                              "Preprocessing2": dataset, 
                                               'p': p})
compare_between_correction = pd.DataFrame(compare_between_correction)

In [None]:
compare_between_preprocessing.groupby(["Preprocessing1", "Preprocessing2"]).p.max()

In [None]:
compare_between_correction.groupby(["Correction1", "Correction2"]).p.max()

## NNMD

### Calculate

In [None]:
nnmds = pd.concat([pd.DataFrame({'postcorrection': postcorrection, 'preprocessing': dataset, 
                                 'nnmd': val.apply(nnmd, axis=1).values,
                                'line': val.index})
                  for postcorrection, gene_effect in gene_effects.items()
                for dataset, val in gene_effect.items()], ignore_index=True)

### Plot

In [None]:
plt.close('all')
plt.figure(figsize=(6, 6))
sns.violinplot(data=nnmds, x='preprocessing', hue='postcorrection', y='nnmd',
               order=preprocessings, hue_order=postcorrections
              )
sns.despine(top=True, right=True)
plt.xlabel("")
plt.ylabel("NNMD (lower is better)")
#plt.xticks(rotation=45)
plt.gcf().set_size_inches(7.5, 5)
plt.grid(axis='y')
plt.xticks(rotation=90)
plt.tight_layout()
plt.savefig(output + "nnmd_violin.pdf", dpi=400)

In [None]:
nnmds.to_csv(output + 'nnmd.csv')

# Oncogene Biomarkers

## Identify Oncogenes

In [None]:
try:
    oncogenes = pd.read_csv('/Users/dempster/Documents/genes/oncokb_allvariants_20200715.csv',
                       )
except FileNotFoundError:
    raise FileNotFoundError(
"You will need to download the OncoKB all variants annotation from \
 http://oncokb.org/api/v1/utils/allAnnotatedVariants\
 and save locally as a csv"
    )

oncogenes = oncogenes[oncogenes.oncogenicity.isin(['Likely Oncogenic', 'Oncogenic'])]
oncogenes = oncogenes[oncogenes['mutationEffect'].isin([
    'Likely Gain-of-function','Gain-of-function',
    'Likely Switch-of-function', 'Switch-of-function'
])]

In [None]:
already_ess = gene_effects['ComBat']['CRISPRCleanR'].reindex(
    columns=oncogenes.gene.unique()
).median().loc[lambda x: x < -.5].index

oncogenes = oncogenes[~oncogenes.gene.isin(already_ess)]

In [None]:
def p_change_parse(s):
    if pd.isnull(s):
        return s
    if s.startswith('p.'):
        s = s[2:]
    return s

mut['Protein_Change'] = mut.Protein_Change.apply(p_change_parse)

In [None]:
indicated_alterations = {}
for gene in oncogenes.gene.unique():
    lines = set([])
    sub = mut[mut.Hugo_Symbol == gene]
    for alt in oncogenes.variant[oncogenes.gene == gene].unique():
        lines = lines | set(sub[sub.Protein_Change == alt].DepMap_ID)
        if alt.endswith('Fusions'):
            lines = lines | set(fusions[fusions.LeftGene == gene].index)\
                        | set(fusions[fusions.RightGene == gene].index)
        elif alt.endswith('Fusion'):
            alt = alt.split('Fusion')[0]
            pairs = tuple(sorted([s.strip() for s in alt.split('-')]))
            lines = lines | set(fusions[fusions.Sorted == pairs].index)
    indicated_alterations[gene] = lines

dropped_alterations = [key for key, val in indicated_alterations.items()
                        if len(val) == 0]
indicated_alterations = {key: sorted(val) for key, val in indicated_alterations.items()
                         if len(val) > 0}

In [None]:
alteration_matrix = pd.DataFrame(index=set.union(*[
    set(val) for val in indicated_alterations.values()
]))
for gene, lines in indicated_alterations.items():
    alteration_matrix[gene] = False
    alteration_matrix.loc[lines, gene] = True

In [None]:
cannonical = []
for postcorrection in postcorrections:
    for dataset in preprocessings:
        df = gene_effects[postcorrection][dataset]
        gene_map = pd.Series(df.columns, index=[s.split(' ')[0] for s in df.columns])
        for gene, lines in indicated_alterations.items():
            if not gene in gene_map.index:
                continue
            lines_pos = sorted(set(lines) & set(df.index))
            lines_neg = sorted(set(df.index) - set(lines))
            cannonical.append(pd.DataFrame({
                'postcorrection': postcorrection,
                'dataset': dataset,
                'gene': gene_map[gene],
                'line': lines_pos,
                'biomarker': True,
                'gene_effect': df.loc[lines_pos, gene_map[gene]]
            }))
            cannonical.append(pd.DataFrame({
                'postcorrection': postcorrection,
                'dataset': dataset,
                'gene': gene_map[gene],
                'line': lines_neg,
                'biomarker': False,
                'gene_effect': df.loc[lines_neg, gene_map[gene]]
            }))
cannonical = pd.concat(cannonical, ignore_index=True)

In [None]:
cannonical.dropna(inplace=True)

len(cannonical.gene.unique())

## Overall Separation

In [None]:
oncogene_roc_auc = pd.DataFrame(columns=['Preprocessing', 'postcorrection', 'ROC AUC'])
for dataset in preprocessings:
    print(dataset)
    for postcorrection in postcorrections:
        sub = cannonical.query('postcorrection == %r' % postcorrection).query("dataset == %r" % dataset)
        yf = sub[sub.biomarker == False].gene_effect
        yt = sub[sub.biomarker == True].gene_effect
        auroc = roc_auc(yt, yf)#.sort_values()
        oncogene_roc_auc = oncogene_roc_auc.append({
            'Preprocessing': dataset, 
            'postcorrection': postcorrection,
            'ROC AUC': auroc
        }, ignore_index=True)

In [None]:
oncogene_roc_auc

In [None]:
sns.barplot(data=pd.concat([oncogene_roc_auc]*2, ignore_index=True),
            x='Preprocessing', hue='postcorrection', y='ROC AUC',
            order=preprocessings, hue_order=postcorrections)
plt.xlabel("")
plt.ylabel("ROC AUC")
plt.ylim(.5, .82)
plt.gca().get_legend().remove()
plt.legend(loc="lower left")
sns.despine(right=True, top=True)
plt.gcf().set_size_inches((4,4))
plt.tight_layout()
plt.savefig(output + "oncogene_roc_auc.pdf", dpi=600)

In [None]:
oncogene_roc_auc.groupby("Preprocessing")['ROC AUC'].median()

## Per Gene Separation

In [None]:
def nnmd_reduce(group):
    return (group[group.biomarker == True].gene_effect.mean() 
            - group[group.biomarker == False].gene_effect.mean()) \
          / group[group.biomarker == False].gene_effect.std()

In [None]:
positive_counts = cannonical.groupby(["postcorrection", "dataset", "gene"]).biomarker.sum()\
                    .loc[lambda x: x>1].index.get_level_values(2).unique()

In [None]:
oncogene_nnmds = []
for dataset in preprocessings:
    print(dataset)
    for postcorrection in postcorrections:
        sub = cannonical.query('postcorrection == %r' % postcorrection).query("dataset == %r" % dataset)
        oncogene_nnmd = sub.groupby('gene').apply(nnmd_reduce).dropna().reset_index()
        oncogene_nnmd["preprocessing"] = dataset
        oncogene_nnmd['postcorrection'] = postcorrection
        oncogene_nnmds.append(oncogene_nnmd)
oncogene_nnmds = pd.concat(oncogene_nnmds, ignore_index=True)
oncogene_nnmds = oncogene_nnmds[oncogene_nnmds.gene.isin(positive_counts)]

sns.boxplot(data=oncogene_nnmds, hue="postcorrection", x="preprocessing", y=0,
           postcorrections=postcorrections, hue_order=preprocessings)
plt.xlabel("")
plt.ylabel("NNMD")
plt.gca().minorticks_on()
plt.gca().yaxis.set_minor_locator(AutoMinorLocator(2))
plt.grid(axis='y', which="minor")
plt.grid(axis='y', which="major")
sns.despine(left=True, top=True)
plt.gcf().set_size_inches((4, 4))
plt.gca().get_legend().remove()
plt.legend(loc="lower right")
plt.tight_layout()
plt.savefig(output + "oncogene_per_gene_NNMD.pdf", dpi=600)

# Lineage Clustering

## Get Clusters

In [None]:
nsamples=100
disease_ami = {postcorrection: {} for postcorrection in gene_effects.keys()}
var_genes = gene_effects['ComBat']['CERES'][shared_genes].std().sort_values()[-500:].index

for dataset in preprocessings:
    print(dataset)
    
    disease = pd.Series(pd.Categorical(si.loc[gene_effects['ComBat']['CERES'].index, 
                                              'lineage']).codes, 
                        index=gene_effects['ComBat']['CERES'].index)
    disease = disease[~disease.index.duplicated()]
    for key, val in gene_effects.items():
        if not dataset in val.keys():
            continue
        disease_ami[key][dataset] = []
        overlap = sorted(set(disease.index) & set(val[dataset].index))
        v = val[dataset].loc[overlap, var_genes].copy()
        v -= v.mean()
        v /= v.std()
        for gene in var_genes:
            v[gene].loc[v.index[v[gene].isnull()]] = v[gene].mean()
        model = KMeans(disease.max()+1)
        for i in range(nsamples):
            clusters = model.fit_predict(v)
            disease_ami[key][dataset].append(
                adjusted_mutual_info_score(disease.loc[v.index].astype(int).values,
                                                               clusters.astype(int)))

## Check agreement

In [None]:
disease_ami = pd.concat([pd.DataFrame(
    {'postcorrection': postcorrection, 'preprocessing': key, 'AMI': ami}
) for postcorrection, da in disease_ami.items() 
for key, ami in da.items()
          ])

In [None]:
disease_ami.to_csv(output + 'lineage_ami.csv', index=None)

In [None]:
print(disease_ami.groupby(['preprocessing', 'postcorrection']).median())

In [None]:
print(disease_ami.groupby(['postcorrection', 'preprocessing']).AMI.agg(lambda x: ttest_1samp(x, 0)[1]))

In [None]:
for processing1 in disease_ami.preprocessing.unique():
    for processing2 in disease_ami.preprocessing.unique():
        if not processing2 > processing1:
            continue
        pvalue = ttest_ind(disease_ami.query('preprocessing = =%r' % processing1).AMI,
                           disease_ami.query('preprocessing = =%r' % processing2).AMI
                          )[1]
        print(processing1, processing2, pvalue)

In [None]:
plt.close('all')
plt.figure(figsize=(6, 6))
sns.boxplot(data=disease_ami, x='preprocessing', y='AMI', hue='postcorrection',
           order=preprocessings, hue_order=postcorrections)
plt.ylabel("Adjusted Mutual Information")
sns.despine(top=True, right=True)
plt.grid(axis='y')
plt.xticks(rotation=90)
plt.tight_layout()
plt.savefig(output + "lineage_clustering.pdf", dpi=400)

# Related Gene Agreement

## Munge related pairs

In [None]:
non_null_genes = set.intersection(*[
    set(ge.dropna(how='any', axis=1).columns)
    for val in gene_effects.values()
    for ge in val.values()
])

In [None]:
gene_pairs = set(related_genes.itertuples(index=False, name=None))

gene_pairs = sorted([s for s in gene_pairs 
                  if s[0]>s[1] 
                  and s[0] in non_null_genes
                  and s[1] in non_null_genes
                 ])

In [None]:
gene_pairs_reversed = sorted([(s[1], s[0]) for s in gene_pairs])

In [None]:
gene_means = {postcorrection: {key: val[non_null_genes].mean() for key, val in gene_effect.items()}
              for postcorrection, gene_effect in gene_effects.items()}

In [None]:
gene_mean_bins = {}
gene_mean_binned = {}
for postcorrection, means in gene_means.items():
    gene_mean_bins[postcorrection] = {}
    gene_mean_binned[postcorrection] = {}
    for dataset, val in means.items():
        binned_vals, bins = pd.cut(val, bins=20, retbins=True)
        gene_mean_bins[postcorrection][dataset] = bins
        gene_mean_binned[postcorrection][dataset] = binned_vals

In [None]:
related_gene_means = {postcorrection: {dataset: .5*(
    gene_means[postcorrection][dataset][[s[0] for s in gene_pairs]].values
    + gene_means[postcorrection][dataset][[s[1] for s in gene_pairs]].values
) for dataset in preprocessings} for postcorrection in postcorrections}

## Bin related genes by mean effect

In [None]:
# Partition related pairs by the mean of the partner
partner_gene_means = {postcorrection: {dataset: pd.Series(
    gene_means[postcorrection][dataset][[s[1] for s in gene_pairs]].values,
    index=gene_pairs)
                            for dataset in preprocessings}
                     for postcorrection in postcorrections}
cuts = {postcorrection: {
    dataset: pd.cut(val, bins=10) 
    for dataset, val in pgm.items()}
        for postcorrection, pgm in partner_gene_means.items()}

## Create null distribution

In [None]:
null_pairs = {}
for postcorrection in postcorrections:
    null_pairs[postcorrection] = {}
    print(postcorrection)
    for key in preprocessings:
        null_pairs[postcorrection][key] = []
        bins = cuts[postcorrection][key].unique()
        for b in bins:
            ind = cuts[postcorrection][key].loc[lambda x: x == b].index
            ind1, ind2 = zip(*list(ind))
            ind1, ind2 = list(ind1), list(ind2)
            np.random.shuffle(ind2)
            null_pairs[postcorrection][key].extend(list(zip(ind1, ind2)))\
            
        null_pairs[postcorrection][key] = sorted(
                set(null_pairs[postcorrection][key]) 
              - set(gene_pairs) 
              - set(gene_pairs_reversed) 
              - set(zip( list(zip(*null_pairs[postcorrection][key]))[0],
                                  list(zip(*null_pairs[postcorrection][key]))[0] 
                       ))
            )
        
        print('\t', key, len(null_pairs[postcorrection][key]))

## Find correlations for related and null pairs

In [None]:
gene_corrs = {postcorrection: {dataset:
                         corr_dist(val[sorted(non_null_genes)],
                                   {'true': gene_pairs, 'null': null_pairs[postcorrection][dataset]} 
                                  )
                         for dataset, val in gene_effect.items()}
                   for postcorrection, gene_effect in gene_effects.items()}

## Empirical P-values

In [None]:
pvalues_gene = pd.DataFrame(columns=['Gene1', 'Gene2', 'postcorrection', 'preprocessing', 'p'])
for postcorrection in gene_effects.keys():
    for key in preprocessings:
        null = gene_corrs[postcorrection][key]['null'].sort_values().dropna()
        true = gene_corrs[postcorrection][key]['true'].sort_values().dropna()
        pvals = 1 - (np.searchsorted(null, true)+1) / (len(null)+1)

        pvalues_gene = pd.concat([pvalues_gene, pd.DataFrame({
            'Gene1': [tup[0] for tup in true.index],
            'Gene2': [tup[1] for tup in true.index],
            'postcorrection': postcorrection,
            'preprocessing': key,
            'p': pvals,
            'FDR': fdrcorrection(pvals, alpha=.05)[1]
        })], ignore_index=True)

In [None]:
pvalues_gene['FDR < .1'] = pvalues_gene.FDR < .1

## Plot

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(8, 8))

axs = [axes[0, 0], axes[0, 1], axes[1, 0]]
for dataset, ax in zip(preprocessings, axs):
    plt.sca(ax)
    print(dataset)
    for postcorrection in postcorrections:
        x = pvalues_gene\
            .query("preprocessing == %r" % dataset)\
            .query("postcorrection == %r" % postcorrection)\
            .FDR.dropna().sort_values()
        y = np.linspace(0, 1, len(x))[x < .5]
        x = x[x<.5]
        plt.plot(x, y, label=postcorrection)
        plt.xlabel("FDR")
        plt.ylabel("Recall of Relationships")
    if ax == axs[0]:
        plt.legend()
    sns.despine(right=True, top=True)
    plt.title(dataset)
        
plt.sca(axes[1, 1])
sns.barplot(data=pvalues_gene, x="preprocessing", hue="postcorrection", y="FDR < .1",
           order=preprocessings, hue_order=postcorrections, n_boot=100)
sns.despine(top=True, right=True)
plt.xlabel("")
plt.ylabel("Recall 10% FDR")
plt.gca().get_legend().remove()
plt.legend(loc="lower left")

plt.tight_layout()
plt.savefig(output + "gene_relationships.pdf", dpi=600)