In [1]:
%%capture
%cd "Compound GRN ENC Analysis/scripts"
%matplotlib agg

In [2]:
from collections import defaultdict
from itertools import product
import os

import matplotlib
from matplotlib.collections import PatchCollection
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from statannotations.Annotator import Annotator

# Params
DATA_FOLDER = os.path.join(os.path.abspath(''), '../../data')
RESULTS_FOLDER = os.path.join(os.path.abspath(''), '../results')
PLOTS_FOLDER = os.path.join(os.path.abspath(''), '../plots')

# Style
sns.set_theme(context='talk', style='white', palette='Accent')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
# matplotlib.rcParams['font.family'] = 'Helvetica'  # NOTE: Make sure to download Helvetica


# Parameters

In [7]:
"""
Cohort : Disease : Delimiter
CMC: SCZ : tsv
UCLA_ASD: ASD : csv
Urban_DLPFC: BPD, SCZ : tsv
Subclass: ASD, BPD, SCZ : csv
"""
data_sources = (
    ('CMC', 'SCZ', '\t'),
    ('UCLA_ASD', 'ASD', ','),
    ('Urban_DLPFC', 'BPD', '\t'),
    # ('Urban_DLPFC', 'SCZ', '\t'),  # Removed for low sample size
    # ('Subclass', 'ASD', ','),  # No CTL, No modules
    # ('Subclass', 'BPD', ','),  # No CTL, No modules
    # ('Subclass', 'SCZ', ','),  # No CTL, No modules
)
modules = [1, 2]
use_ctl = True  # Not really needed here

In [10]:
# Get AD, BPD, and SCZ labels
gene_dir = os.path.join(DATA_FOLDER, 'new_labels')
gene_fnames = [fname for fname in os.listdir(gene_dir) if fname.endswith('.txt')]
gene_lists = {'.'.join(fname.split('.')[:-1]): np.loadtxt(os.path.join(gene_dir, fname), dtype=str) for fname in gene_fnames}
gene_lists['BPD'] = gene_lists.pop('BD')

# Get ASD labels
sfari = pd.read_csv(os.path.join(DATA_FOLDER, 'sfari/SFARI-Gene_genes_01-16-2024release_03-21-2024export.csv'))
gene_score_threshold = -1
sfari = sfari.loc[sfari['gene-score'] > gene_score_threshold]  # Threshold by score
gene_lists['ASD'] = sfari['gene-symbol'].to_numpy()

# Get module genes
def get_module_genes(group, disease, ct, use_modules=None):
    gene_annotations = pd.read_csv(os.path.join(DATA_FOLDER, 'modules', get_modules_fname(use_modules=use_modules), f'{ct}_{group}_{disease}.txt'), index_col=False, delimiter=',')
    positive_genes = gene_annotations.loc[gene_annotations['label']=='positive', 'gene'].to_list()
    negative_genes = gene_annotations.loc[gene_annotations['label']=='negative', 'gene'].to_list()
    return positive_genes, negative_genes

# Get files for contrast
def get_grn_fnames(group, disease):
    # Calculate directories
    base_dir = os.path.join(DATA_FOLDER, 'merged_GRNs_v2', group)
    disease_folder = os.path.join(base_dir, disease)
    if use_ctl: control_folder = os.path.join(base_dir, 'ctrl')
    grn_fnames = np.sort(list(set(os.listdir(disease_folder)).intersection(set(os.listdir(control_folder)))))

    # Return
    ret = ()
    ret += (base_dir, disease_folder)
    if use_ctl: ret += (control_folder,)
    ret += (grn_fnames,)
    # base_dir, disease_folder, control_folder, grn_fnames
    return ret

# Get fname suffix
def get_modules_fname(use_modules, **kwargs): return f'model{use_modules}' if use_modules is not None else ''
def get_fname_suffix(**kwargs):
    suffixes = [get_modules_fname(**kwargs)]
    suffixes = [s for s in suffixes if len(s) > 0]
    if len(suffixes) == 0: return ''
    return f'_{"_".join(suffixes)}'

# Get cell-type based on fname
get_cell_type = lambda fname: '_'.join(fname.split('_')[:-1])

# Combine Predictions

In [101]:
# Load all results
results = pd.DataFrame()
for source in data_sources:
    group, disease, _ = source
    grn_fnames = get_grn_fnames(group, disease)[-1]

    for fname, use_modules in product(grn_fnames, modules):
        # Load prioritized genes
        cell_type = get_cell_type(fname)
        result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes{get_fname_suffix(use_modules=use_modules)}.csv'
        result = pd.read_csv(os.path.join(RESULTS_FOLDER, result_fname), index_col=0).reset_index()

        # Tag df
        renames = {'label': 'Label', 'mean': 'Mean', 'std': 'STD', 'gene': 'Gene'}
        result = result.rename(columns=renames)[renames.values()]
        result['Group'] = group
        result['Disease'] = disease
        result['Cell Type'] = cell_type
        result['Module'] = use_modules

        # Append
        results = pd.concat((results, result), axis=0)

In [102]:
# Stack model results as columns
model_cols = ['Group', 'Disease', 'Cell Type', 'Gene']
assert results.groupby(model_cols + ['Module']).count().max().max() == 1, 'Values not unique'
results.groupby(model_cols).max()
df = None
for module in modules:
    df_concat = (
        results
            .loc[results['Module'] == module]
            .groupby(model_cols)
            .max()
            .drop(columns='Module')
            .rename(columns=lambda s: f'{s}_{module}')
    )
    if df is None: df = df_concat
    else: df = df.join(df_concat, how='outer')

# Check that all values match
assert (~df['Label_1'].isna() * ~df['Label_2'].isna()).sum() == df.shape[0], 'Not all genes match between models'
results = df


In [104]:
results.loc[results.index.get_level_values('Gene')=='APBB2']

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Label_1,Mean_1,STD_1,Label_2,Mean_2,STD_2
Group,Disease,Cell Type,Gene,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
CMC,SCZ,astro,APBB2,-1,0.228024,0.161939,1,0.581221,0.199055
CMC,SCZ,micro,APBB2,-1,0.123615,0.04759,-1,0.201744,0.074349
CMC,SCZ,oligo,APBB2,-1,0.108319,0.095984,-1,0.222423,0.156811
CMC,SCZ,vlmc,APBB2,-1,0.105772,0.082372,-1,0.205615,0.165477
UCLA_ASD,ASD,astro,APBB2,-1,0.639237,0.141159,1,0.79395,0.150655
UCLA_ASD,ASD,excitatory,APBB2,0,0.421414,0.093437,0,0.470572,0.065124
UCLA_ASD,ASD,micro,APBB2,0,0.238604,0.02166,0,0.365486,0.082941
Urban_DLPFC,BPD,astro,APBB2,0,0.008027,0.006119,0,0.008027,0.006119
Urban_DLPFC,BPD,endo,APBB2,-1,0.008668,0.004417,1,0.171199,0.028407
Urban_DLPFC,BPD,micro,APBB2,0,0.1415,0.080968,0,0.425756,0.048299


0