In [1]:
%%capture
%cd "Compound GRN ENC Analysis/scripts"
%matplotlib agg

In [2]:
from collections import defaultdict
from itertools import product
import os
import warnings

import matplotlib
from matplotlib.collections import PatchCollection
import matplotlib.pyplot as plt
import matplotlib.transforms
import numpy as np
import pandas as pd
import seaborn as sns
from statannotations.Annotator import Annotator

# Params
DATA_FOLDER = os.path.join(os.path.abspath(''), '../../data')
RESULTS_FOLDER = os.path.join(os.path.abspath(''), '../results')
PLOTS_FOLDER = os.path.join(os.path.abspath(''), '../plots')

# Style
sns.set_theme(context='talk', style='white', palette='Accent')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
matplotlib.rcParams['font.family'] = 'Helvetica'  # NOTE: Make sure to download Helvetica


In [3]:
# TODO
# Pull all results at beginning, then just make copies
# Fix processing for subclass

# Parameters

In [4]:
"""
Cohort : Disease : Delimiter
CMC: SCZ : tsv
UCLA_ASD: ASD : csv
Urban_DLPFC: BPD, SCZ : tsv
Subclass: ASD, BPD, SCZ : csv
"""
data_sources = (
    ('CMC', 'SCZ', '\t'),
    ('UCLA_ASD', 'ASD', ','),
    ('Urban_DLPFC', 'BPD', '\t'),
    # ('Urban_DLPFC', 'SCZ', '\t'),  # Removed for low sample size
    ('Subclass', 'ASD', ','),
    ('Subclass', 'BPD', ','),
    ('Subclass', 'SCZ', ','),
    # ('Coregulation', 'ASD', ','),
    # ('Coregulation', 'BPD', ','),
    # ('Coregulation', 'SCZ', ','),
)
modules = [None, 1, 2]
use_ctl = True  # Not really needed here

In [5]:
# Get AD, BPD, and SCZ labels
gene_dir = os.path.join(DATA_FOLDER, 'new_labels')
gene_fnames = [fname for fname in os.listdir(gene_dir) if fname.endswith('.txt')]
gene_lists = {'.'.join(fname.split('.')[:-1]): np.loadtxt(os.path.join(gene_dir, fname), dtype=str) for fname in gene_fnames}
gene_lists['BPD'] = gene_lists.pop('BD')

# Get ASD labels
sfari = pd.read_csv(os.path.join(DATA_FOLDER, 'sfari/SFARI-Gene_genes_01-16-2024release_03-21-2024export.csv'))
gene_score_threshold = -1
sfari = sfari.loc[sfari['gene-score'] > gene_score_threshold]  # Threshold by score
gene_lists['ASD'] = sfari['gene-symbol'].to_numpy()

# Get module genes
def get_module_genes(group, disease, ct, use_modules=None):
    gene_annotations = pd.read_csv(os.path.join(DATA_FOLDER, 'modules', get_modules_fname(use_modules=use_modules), f'{ct}_{group}_{disease}.txt'), index_col=False, delimiter=',')
    positive_genes = gene_annotations.loc[gene_annotations['label']=='positive', 'gene'].to_list()
    negative_genes = gene_annotations.loc[gene_annotations['label']=='negative', 'gene'].to_list()
    return positive_genes, negative_genes

# Get files for contrast
def get_grn_fnames(group, disease, use_ctl=True):
    # Calculate directories
    base_dir = os.path.join(DATA_FOLDER, 'merged_GRNs_v2', group)
    disease_folder = os.path.join(base_dir, disease)
    if use_ctl:
        control_folder = os.path.join(base_dir, 'ctrl')
        grn_fnames = np.sort(list(set(os.listdir(disease_folder)).intersection(set(os.listdir(control_folder)))))
    else: grn_fnames = np.sort(os.listdir(disease_folder)) 

    # Return
    ret = ()
    ret += (base_dir, disease_folder)
    if use_ctl: ret += (control_folder,)
    ret += (grn_fnames,)
    # base_dir, disease_folder, control_folder, grn_fnames
    return ret

# Get fname suffix
def get_modules_fname(use_modules, **kwargs): return f'model{use_modules}' if use_modules is not None else ''
def get_fname_suffix(**kwargs):
    suffixes = [get_modules_fname(**kwargs)]
    suffixes = [s for s in suffixes if len(s) > 0]
    if len(suffixes) == 0: return ''
    return f'_{"_".join(suffixes)}'

# Get cell-type based on fname
get_cell_type = lambda fname: '_'.join(fname.split('_')[:-1])

In [6]:
# Subclass major pairing
inhibitory = ['Lamp5', 'Pvalb', 'Sncg', 'Sst','Sst.Chodl', 'Lamp5.Lhx6', 'Vip','Pax6','Chandelier']
excitatory = ['L2.3.IT', 'L4.IT', 'L5.IT', 'L5.ET', 'L5.6.NP','L6b','L6.IT','L6.CT','L6.IT.Car3']
ct_conversion = {**{ct: 'inhibitory' for ct in inhibitory}, **{ct: 'excitatory' for ct in excitatory}}
# Disease pairings
group_conversion = {'ASD': 'UCLA_ASD', 'BPD': 'Urban_DLPFC', 'SCZ': 'CMC'}

# Preliminary

In [7]:
# Determine cell types
# group, disease, delimiter = data_sources[0]
# fname = os.path.join(RESULTS_FOLDER, f'{group}_{disease}_performance{get_fname_suffix(use_modules=None)}.csv')
# df = pd.read_csv(fname, index_col=0)
# major_ct = df['Cell Type'].unique()
major_ct = np.array(['astro', 'endo', 'excitatory', 'inhibitory', 'micro', 'oligo', 'opc', 'vlmc'])
# group, disease, delimiter = data_sources[-1]
# fname = os.path.join(RESULTS_FOLDER, f'{group}_{disease}_performance{get_fname_suffix(use_modules=None)}.csv')
# df = pd.read_csv(fname, index_col=0)
# minor_ct = df['Cell Type'].unique()
minor_ct = np.array(['Chandelier', 'L2.3.IT', 'L4.IT', 'L5.6.NP', 'L5.ET', 'L5.IT', 'L6.CT', 'L6.IT.Car3', 'L6.IT', 'L6b', 'Lamp5.Lhx6', 'Lamp5', 'Pax6', 'Pvalb', 'Sncg', 'Sst', 'Vip'])

# Colors
major_palette = sns.color_palette('Dark2', as_cmap=True)
minor_palette = sns.color_palette('magma', as_cmap=True)
major_colors = {ct: major_palette((i+1) / (len(major_ct) + 1)) for i, ct in enumerate(major_ct)}
minor_colors = {ct: minor_palette((i+1) / (len(minor_ct) + 1)) for i, ct in enumerate(minor_ct)}
merged_colors = major_colors | minor_colors

In [8]:
# Major and legends
for name, colors in zip(('major', 'minor'), (major_colors, minor_colors)):
    fig, ax = plt.subplots(1, 1)
    ax.spines.top.set_visible(False)  # Avoid line through legend if too long
    handles = [ax.plot([], [], color=c, marker='s', ls='none')[0] for c in colors.values()]
    labels = list(colors.keys())
    legend = plt.legend(handles, labels, loc=3, frameon=False)
    ax.axis('off')
    bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
    fname = f'legend_{name}.pdf'
    fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches=bbox)

# Performance

## Individual Datasets

In [9]:
# # Params
# horizontal = False

# for use_modules in modules:
#     # Get performance df
#     performance = pd.DataFrame()
#     for source in data_sources:
#         group, disease, delimiter = source

#         # Load performance
#         fname = os.path.join(RESULTS_FOLDER, f'{group}_{disease}_performance{get_fname_suffix(use_modules=use_modules)}.csv')
#         try: df = pd.read_csv(fname, index_col=0)
#         except: continue
#         df['Group'] = group
#         df['Disease'] = disease

#         # Get baselines
#         def get_baseline(ct):
#             df_temp = pd.read_csv(os.path.join(RESULTS_FOLDER, f'{group}_{disease}_{ct}_prioritized_genes{get_fname_suffix(use_modules=use_modules)}.csv'), index_col=0)
#             num_neg = (df_temp['label'] == 0).sum()
#             num_pos = (df_temp['label'] == 1).sum()
#             baseline = num_pos / (num_pos + num_neg)

#             return baseline
#         df['Baseline AUPRC'] = df['Cell Type'].map(get_baseline)

#         # Concatenate
#         performance = pd.concat((performance, df), axis=0)

#     # Take means
#     performance = performance.drop(columns='Fold').groupby(['Group', 'Disease', 'Cell Type']).mean().reset_index()

#     # Apply formatting
#     # performance['Dataset'] = performance.apply(lambda r: f'{r["Group"]} ({r["Disease"]})', axis=1)  # Include group name
#     performance['Dataset'] = performance.apply(lambda r: f'{r["Disease"]}', axis=1)  # Exclude group name

#     # Compute fold changes
#     for stat in ('AUPRC', 'Validation AUPRC'):
#         performance[f'{stat} Fold Change'] = performance[stat] / performance['Baseline AUPRC']

#     # Plot
#     fnames = [
#         f'Performance_Main{get_fname_suffix(use_modules=use_modules)}.pdf',
#         f'Performance_Subclass{get_fname_suffix(use_modules=use_modules)}.pdf',
#     ]
#     plot_cell_types_lists = [
#         performance.loc[performance['Group'] != 'Subclass', 'Cell Type'].unique().tolist(),
#         performance.loc[performance['Group'] == 'Subclass', 'Cell Type'].unique().tolist(),
#     ]
#     scale = [
#         1,
#         2
#     ]
#     for fname, plot_cell_types, sc in zip(fnames, plot_cell_types_lists, scale):
#         # Check for data
#         if len(plot_cell_types) == 0: break

#         # Filter to cell types
#         performance_filtered = performance.loc[performance['Cell Type'].isin(plot_cell_types)]

#         # Format as heatmap
#         order_dataset = performance_filtered[['Group', 'Disease', 'Dataset']].groupby(['Group', 'Disease', 'Dataset']).count().reset_index().sort_values(['Disease', 'Group'])['Dataset'].to_list()
#         order_cell = performance_filtered.sort_values('Cell Type')['Cell Type'].unique()
#         if horizontal: order_dataset = order_dataset[::-1]
#         else: order_cell = order_cell[::-1]
#         heatmap_data = performance_filtered.pivot(index='Dataset', columns='Cell Type', values='AUPRC Fold Change').loc[order_dataset, order_cell]
#         heatmap_data_val = performance_filtered.pivot(index='Dataset', columns='Cell Type', values='Validation AUPRC Fold Change').loc[order_dataset, order_cell]
#         if not horizontal:
#             heatmap_data = heatmap_data.transpose()
#             heatmap_data_val = heatmap_data_val.transpose()

#         # Plot
#         figsize = heatmap_data.shape[::-1] if horizontal else [sc*s for s in heatmap_data.shape]
#         fig, ax = plt.subplots(1, 1, figsize=figsize)

#         # Define vars
#         xlabels = heatmap_data.columns.to_list()
#         ylabels = heatmap_data.index.to_list()
#         m, n = heatmap_data.shape[1], heatmap_data.shape[0]
#         X, Y = np.meshgrid(np.arange(m), np.arange(n))
#         S = heatmap_data.to_numpy()
#         T = heatmap_data_val.to_numpy()

#         # Squares are training, circles are validation
#         maxval = 2 if not use_modules else 20
#         norm = matplotlib.colors.TwoSlopeNorm(vmin=0, vcenter=1, vmax=maxval)
#         patches = [plt.Rectangle((j-.5, i-.5), 1, 1) for j, i in zip(X.flat, Y.flat)]
#         col = PatchCollection(patches, array=S.flatten(), linewidth=0, clim=[0, maxval], norm=norm, cmap='RdBu')
#         ax.add_collection(col)
#         patches = [plt.Circle((j, i), .4) for j, i in zip(X.flat, Y.flat)]
#         col = PatchCollection(patches, array=T.flatten(), linewidth=0, clim=[0, maxval], norm=norm, cmap='RdBu')
#         ax.add_collection(col)
#         from mpl_toolkits.axes_grid1 import make_axes_locatable
#         divider = make_axes_locatable(ax)
#         cax = divider.append_axes('top', size="5%", pad=.5)
#         cbar = fig.colorbar(col, orientation='horizontal', cax=cax)
#         cbar.ax.set_xlabel('AUPRC Fold Change', fontsize='small')
#         cbar.ax.xaxis.set_label_position('top') 
#         cbar.ax.set_xticks([0, 1, maxval])  # Add 1 as an xtick

#         ax.set(
#             xticks=np.arange(m),
#             yticks=np.arange(n),
#             xticklabels=xlabels,
#             yticklabels=ylabels,
#         )
#         ax.set_xticks(np.arange(m+1)-.5, minor=True)
#         ax.set_yticks(np.arange(n+1)-.5, minor=True)
#         ax.tick_params(axis='x', labelrotation=90)
#         ax.grid(which='minor')
#         ax.set_aspect('equal', adjustable='box')

#         fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

## Model Comparison

In [10]:
performance = pd.DataFrame()
for source, use_modules in product(data_sources, modules):
    group, disease, delimiter = source

    # Skip non-models
    if not use_modules: continue

    # Load performance
    fname = os.path.join(RESULTS_FOLDER, f'{group}_{disease}_performance{get_fname_suffix(use_modules=use_modules)}.csv')
    try: df = pd.read_csv(fname, index_col=0)
    except: continue
    df['Group'] = group
    df['Disease'] = disease
    df['Model'] = f'Model {use_modules}'

    # Get baselines
    def get_baseline(ct):
        df_temp = pd.read_csv(os.path.join(RESULTS_FOLDER, f'{group}_{disease}_{ct}_prioritized_genes{get_fname_suffix(use_modules=use_modules)}.csv'), index_col=0)
        num_neg = (df_temp['label'] == 0).sum()
        num_pos = (df_temp['label'] == 1).sum()
        baseline = num_pos / (num_pos + num_neg)

        return baseline
    df['Baseline AUPRC'] = df['Cell Type'].map(get_baseline)

    # Concatenate
    performance = pd.concat((performance, df), axis=0)

# Apply formatting
# performance['Dataset'] = performance.apply(lambda r: f'{r["Group"]} ({r["Disease"]})', axis=1)  # Include group name
performance['Dataset'] = performance.apply(lambda r: f'{r["Group"]} - {r["Disease"]} - {r["Cell Type"]}', axis=1)  # Exclude group name

# Aggregate values
performance['AUPRC Fold'] = performance['Validation AUPRC'] / performance['Baseline AUPRC']

# Sort df
performance['Subclass'] = performance['Group'] == 'Subclass'
performance = performance.sort_values(['Disease', 'Subclass', 'Group']).drop(columns='Subclass')

In [11]:
# Parameters
ypos = -.1  # Position of the major disease labels

# Plot
fig, ax = plt.subplots(1, 1, figsize=(30, 4))
sns.barplot(data=performance, x='Dataset', y='AUPRC Fold', hue='Model', ax=ax)

# Formatting
sns.despine()
ax.set_xlabel(None)

# Set xticklabels
original = []
diseases = []
groups = []
cell_types = []
for l in ax.get_xticklabels():
    split = l.get_text().split(' - ')
    original.append(l.get_text())
    groups.append(split[0])
    diseases.append(split[1])
    cell_types.append(split[2])
ax.set_xticklabels([], rotation=90)  # cell_types

# Set limits
ax.set_xlim([-.5, len(cell_types)-.5])

# Recolor
for bars, hatch in zip(ax.containers, [None, '///']):
    for bar, ct in zip(bars, cell_types):
        bar.set_facecolor(merged_colors[ct])
        bar.set_hatch(hatch)

# Remove legend title
handles, labels = ax.get_legend_handles_labels()
ax.legend(handles=handles, labels=labels)

# Draw baselines
## Non-scaled
# for x, dataset in enumerate(original):
#     baseline = performance.loc[performance['Dataset'] == dataset, 'Baseline AUPRC'].mean()
## Scaled
ax.axhline(y=1, linewidth=.5, linestyle='dashed', color='black')

# Draw group labels
trans = matplotlib.transforms.blended_transform_factory(ax.transData, ax.transAxes)
lidx = 0
for ridx in range(len(diseases)):
    if (ridx == len(diseases) - 1) or (diseases[ridx+1] != diseases[ridx]):
        # Draw label
        xpos = (lidx + ridx) / 2
        label = f'{diseases[ridx]}'
        # TODO: Maybe separate subclass and non-subclass?
        # if groups[ridx] == 'Subclass': label += ' Subclass'
        ax.text(xpos, ypos, label, ha='center', va='bottom', fontsize='medium', transform=trans)

        # Draw lines
        it = [ridx+.5]
        if lidx == 0: it += [lidx-.5]
        for xpos in it:
            ax.arrow(
                xpos, 0, 0, ypos,
                width=.01,
                head_width=0,
                transform=trans,
                color='black',
                clip_on=False)
            pass

        # Iterate
        lidx = ridx + 1

# Save
fname = f'Performance.pdf'
fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

# Disease Distribution

In [12]:
for source, use_modules in product(data_sources, modules):
    group, disease, _ = source
    grn_fnames = get_grn_fnames(group, disease, use_ctl=use_ctl)[-1]

    # Plot
    ranges = [(None, None)]  # [(0, .55), (.85, 1)]
    width_ratios = [
        (
            (rg[1] - rg[0])
            if (rg[0] is not None) and (rg[1] is not None)
            else 1
        )
        for rg in ranges
    ]
    figsize = (10, 10)  # (10, len(grn_fnames) // 2)
    fig, axs = plt.subplots(len(grn_fnames), len(ranges), figsize=figsize, width_ratios=width_ratios, sharex='col', sharey='row', gridspec_kw={'hspace': -.5, 'wspace': .1})
    axs = axs.reshape(-1, 1)
    fig.suptitle(f'{group} ({disease})')

    content = False
    for i, fname in enumerate(grn_fnames):
        cell_type = get_cell_type(fname)
        result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes{get_fname_suffix(use_modules=use_modules)}.csv'
        try: result = pd.read_csv(os.path.join(RESULTS_FOLDER, result_fname), index_col=0)
        except: continue
        content = True

        # Scale score from 0-1 PER CELL TYPE
        result['mean'] /= result['mean'].max()  # Do we include this?

        for j, (ax, rgs) in enumerate(zip(axs[i], ranges)):
            # KDS
            sns.kdeplot(result['mean'], lw=0, color='white', ax=ax)
            x, y = ax.get_lines()[0].get_xydata().T

            # Stylizing and coloration
            ax.set_xlabel(None)
            ax.set_ylabel(None)
            ax.set_yticklabels([])
            # Color
            # color_val = (i+2) / (len(grn_fnames)+3)  # Padded
            # # color_val = i / len(grn_fnames)  # Unpadded
            # color = matplotlib.colormaps['Blues'](color_val)
            color = merged_colors[cell_type]
            ax.fill_between(x, 0, y, color=color)

            patches = []
            width = 3e-3  # if group != 'Subclass' else 3e-6
            alpha = .05
            # alpha = 2 / (result.loc[result['label'] != -1, 'label']).sum()
            # alpha = (result.loc[result['label'] != -1, 'label']).sum() / result['label'].count()
            for r in result.loc[result['label'] == 1].iterrows():
                score = (r[1]['mean'])

                # Fill
                # NOTE: Could also compute KDE of positive labels and use a gradient
                xrange = np.linspace(max(0, score-width), min(1, score+width), 10)
                ytop = np.interp(xrange, x, y)
                ax.fill_between(xrange, 0, ytop, color='red', alpha=alpha)
                
                # Patches
                # height = np.interp(score, x, y)
                # width = .1 / x.shape[0]
                # arrow = plt.Rectangle((score, -width/2), width, height)
                # patches.append(arrow)
            # col = PatchCollection(patches, linewidth=0, color='salmon', alpha=.1)
            # ax.add_collection(col)

            # Formatting
            ax.spines.left.set_visible(False)
            ax.spines.right.set_visible(False)
            ax.spines.top.set_visible(False)
            ax.spines.bottom.set_visible(False)
            # if i != len(grn_fnames) - 1: ax.spines.bottom.set_visible(False)
            left, right = rgs
            if left is not None: ax.set_xlim(left=left)
            if right is not None: ax.set_xlim(right=right)
            ax.patch.set_alpha(0)  # Make background transparent
            if j == 0: ax.text(-.02, 0, cell_type, fontsize='small', ha='right', va='bottom', transform=ax.transAxes)  # Add y label

            # Plot split line
            if i == len(grn_fnames) - 1:
                # Create marker
                kwargs = dict(marker=[(-.5, -1), (.5, 1)], markersize=12, linestyle='none', color='k', mec='k', mew=2, clip_on=False)

                # Plot markers
                if j > 0: ax.plot([0], [0], transform=ax.transAxes, **kwargs)
                if j < len(ranges) - 1: ax.plot([1], [0], transform=ax.transAxes, **kwargs)
                
            # Set ticks to only between 0 and 1, only set in here for broken purposes
            xticks = [x for x in ax.get_xticks() if x >= -.01 and x <= 1.01]
            ax.set_xticks(xticks)

    # Skip if no data
    if not content: continue

    # Final formatting
    fig.text(.5, 0, 'Gene Relevance Score', fontsize='medium', ha='center', va='center')

    # Save
    fname = f'Distribution_{group}_{disease}{get_fname_suffix(use_modules=use_modules)}.pdf'
    fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

  fig, axs = plt.subplots(len(grn_fnames), len(ranges), figsize=figsize, width_ratios=width_ratios, sharex='col', sharey='row', gridspec_kw={'hspace': -.5, 'wspace': .1})


# Drugs

## BBB Drug Target Counts

In [13]:
# Load drugs and gene lists
drugs = pd.read_csv(os.path.join(DATA_FOLDER, 'pharmacologically_active.csv')).rename(columns={'Drug IDs': 'Drug ID'})
drugs = drugs[['Gene Name', 'Drug ID', 'Species']]
drugs['Drug ID'] = drugs['Drug ID'].apply(lambda s: s.split('; '))
drugs = drugs.explode('Drug ID')

# Add BBB predictions
bbb = pd.read_csv(os.path.join(DATA_FOLDER, 'BBB_plus_dbIDS.csv'), index_col=0)
drugs['BBB'] = drugs['Drug ID'].map(lambda x: x in list(bbb['ID']))

# Get BBB drug targets
bbb_drug_targets = drugs.loc[drugs['BBB'], 'Gene Name'].unique()  # Figure out why 'G', 'L', etc. in here

In [14]:
# Load all results (Taken from model-ensemble)
results = pd.DataFrame()
for source in data_sources:
    group, disease, _ = source
    grn_fnames = get_grn_fnames(group, disease, use_ctl=use_ctl)[-1]

    for fname, use_modules in product(grn_fnames, modules):
        # Load prioritized genes
        cell_type = get_cell_type(fname)
        result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes{get_fname_suffix(use_modules=use_modules)}.csv'
        try: result = pd.read_csv(os.path.join(RESULTS_FOLDER, result_fname), index_col=0).reset_index()
        except: continue

        # Tag df
        renames = {'label': 'Label', 'mean': 'Mean', 'std': 'STD', 'gene': 'Gene'}
        result = result.rename(columns=renames)[renames.values()]
        result['Group'] = group
        result['Disease'] = disease
        result['Cell Type'] = cell_type
        result['Model'] = use_modules

        # Append
        results = pd.concat((results, result), axis=0)
    
# Format and decide targets
group_cols = ['Group', 'Disease', 'Cell Type', 'Model']
results['BBB Count'] = results['Gene'].isin(bbb_drug_targets)
results['Model'] = results['Model'].fillna('None')  # Necessary for proper groupby behavior
results = results.set_index(group_cols)

In [15]:
# Parameters
quantile = .9  # BBB targets in the top 10%

# Filter to desired quantile
thresholds = results.groupby(group_cols).quantile(quantile, numeric_only=True)[['Mean']]
results['_mask'] = None
with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    for idx in thresholds.index:
        results.loc[idx, '_mask'] = results.loc[idx, 'Mean'] > thresholds.loc[idx, 'Mean']
assert results['_mask'].isna().sum() == 0, 'Not all values checked'
results = results.loc[results['_mask']].drop(columns='_mask')

# Compute counts of BBB targets and pivot separate by model
bbb_counts = results.groupby(group_cols)[['BBB Count']].sum()

In [16]:
for (fname_suffix, df) in zip(
    (
        'major',
        'minor',
    ),
    (
        bbb_counts.loc[bbb_counts.index.get_level_values('Group') != 'Subclass'],  # Non-subclass
        bbb_counts.loc[bbb_counts.index.get_level_values('Group') == 'Subclass'],  # Subclass
    ),
):
    # Check if df is empty
    if df.shape[0] == 0:
        print(f'Input DataFrame for {fname_suffix} is empty, skipping.')

    # Generate heatmaps
    heatmap_1 = df.loc[df.index.get_level_values('Model') == 1].reset_index().pivot(index='Cell Type', columns='Disease', values='BBB Count').iloc[::-1]
    heatmap_2 = df.loc[df.index.get_level_values('Model') == 2].reset_index().pivot(index='Cell Type', columns='Disease', values='BBB Count').iloc[::-1]

    # Plot
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))

    # Define vars
    xlabels = heatmap_1.columns.to_list()
    ylabels = heatmap_1.index.to_list()
    m, n = heatmap_1.shape[1], heatmap_1.shape[0]
    X, Y = np.meshgrid(np.arange(m), np.arange(n))
    S = heatmap_1.to_numpy()
    T = heatmap_2.to_numpy()

    # Make monochrome cmap
    from matplotlib.colors import LinearSegmentedColormap
    cdict = {
        'red': [3*[0.], 3*[1.]],
        'green': [3*[0.], 3*[1.]],
        'blue': [3*[0.], 3*[1.]],
    }
    cmap = LinearSegmentedColormap('Monochrome', segmentdata=cdict, N=256).reversed()

    # Top left is model 1, bottom right is model 2
    maxval = 20
    patches = [plt.Polygon([[j-.5, i-.5], [j+.5, i+.5], [j-.5, i+.5]]) for j, i in zip(X.flat, Y.flat)]
    col = PatchCollection(patches, array=S.flatten(), linewidth=0, clim=[0, maxval], cmap=cmap)
    ax.add_collection(col)
    patches = [plt.Polygon([[j-.5, i-.5], [j+.5, i+.5], [j+.5, i-.5]]) for j, i in zip(X.flat, Y.flat)]
    col = PatchCollection(patches, array=T.flatten(), linewidth=0, clim=[0, maxval], cmap=cmap)
    ax.add_collection(col)

    # Modify colors to match cell types
    fig.canvas.draw()  # Needed to set facecolors originally
    for collection in ax.collections:
        facecolors = collection.get_facecolors()
        for i, x, y in zip(range(facecolors.shape[0]), X.flat, Y.flat):
            ct = ylabels[y]
            ct_color = merged_colors[ct]
            color = (1 - facecolors[i]) * ct_color + facecolors[i]
            color[-1] = facecolors[i][-1]  # Don't include missing data
            facecolors[i] = color
        collection.set_array(None)
        collection.set_facecolors(facecolors)

    # Set ticks
    ax.set(
        xticks=np.arange(m),
        yticks=np.arange(n),
        xticklabels=xlabels,
        yticklabels=ylabels,
    )
    ax.set_xticks(np.arange(m+1)-.5, minor=True)
    ax.set_yticks(np.arange(n+1)-.5, minor=True)
    ax.tick_params(axis='x', labelrotation=90)
    ax.grid(which='minor')
    ax.set_aspect('equal', adjustable='box')

    # Formatting
    ax.spines.left.set_visible(False)
    ax.spines.right.set_visible(False)
    ax.spines.top.set_visible(False)
    ax.spines.bottom.set_visible(False)

    # Save
    fname = f'BBB_Prevalence_{fname_suffix}.pdf'
    fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

In [17]:
# Save colorbar legend
fig, ax = plt.subplots(1, 1, figsize=(4, .5))
cbar = fig.colorbar(col, orientation='horizontal', cax=ax)
cbar.ax.set_xlabel('BBB Targets', fontsize='small')
# cbar.ax.set_ylabel('a')
# cbar.ax.xaxis.set_label_position('top')
cbar.outline.set_visible(False)
cbar.ax.set_xticks([])
cbar.ax.text(-.02, .5, '0', fontsize='small', ha='right', va='center', transform=cbar.ax.transAxes)
cbar.ax.text(1.02, .5, f'{maxval}', fontsize='small', ha='left', va='center', transform=cbar.ax.transAxes)
bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
fname = f'legend_bbb_color.pdf'
fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

In [18]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2))
# Patches
patch = matplotlib.patches.Polygon([[-1, -1], [1, 1], [-1, 1]], color=(.9, .9, .9, 1))
ax.add_patch(patch)
patch = matplotlib.patches.Polygon([[-1, -1], [1, 1], [1, -1]], color=(.8, .8, .8, 1))
ax.add_patch(patch)
# Text
# ax.text(0, 0, 'Center', ha='center', va='center', transform=ax.transData)
ax.text(-.2, .2, 'Model 1', ha='center', va='center', rotation=45, transform=ax.transData)
ax.text(.2, -.2, 'Model 2', ha='center', va='center', rotation=45, transform=ax.transData)
ax.set_xlim([-1, 1])
ax.set_ylim([-1, 1])
# Save
ax.axis('off')
fname = f'legend_bbb_shape.pdf'
fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

## Known Medication Heatmap

In [19]:
# # Load drugs and gene lists
# drugs = pd.read_csv(os.path.join(DATA_FOLDER, 'pharmacologically_active.csv')).rename(columns={'Drug IDs': 'Drug ID'})
# drugs = drugs[['Gene Name', 'Drug ID', 'Species']]
# drugs['Drug ID'] = drugs['Drug ID'].apply(lambda s: s.split('; '))
# drugs = drugs.explode('Drug ID')

# # Add BBB predictions
# bbb = pd.read_csv(os.path.join(DATA_FOLDER, 'BBB_plus_dbIDS.csv'), index_col=0)
# drugs['BBB'] = drugs['Drug ID'].map(lambda x: x in list(bbb['ID']))

# # Add scores
# for source, use_modules in product(data_sources, modules):
#     group, disease, _ = source
#     grn_fnames = get_grn_fnames(group, disease, use_ctl=use_ctl)[-1]

#     df_all = None
#     for i, fname in enumerate(grn_fnames):
#         # Load prioritization
#         cell_type = get_cell_type(fname)
#         result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes{get_fname_suffix(use_modules=use_modules)}.csv'
#         try: result = pd.read_csv(os.path.join(RESULTS_FOLDER, result_fname), index_col=0)
#         except: continue

#         # Append to drugs
#         drug_result = (
#             drugs
#                 .set_index('Gene Name')
#                 .join(result[['mean', 'std']], how='inner')  # Inner join to account for missing genes
#                 .fillna(0)
#                 .reset_index(names='Gene')
#         )

#         # Compute drug scores
#         # NOTE: Currently `sqrt(num_genes) * mean(means)`
#         groupby_list = ['Drug ID', 'Species', 'BBB']
#         df_genelist = (
#             drug_result
#                 .groupby(groupby_list)['Gene']
#                 .apply(list)
#                 .reset_index(name='Genes')
#                 .set_index(groupby_list)
#         )
#         df_genelist['Num Genes'] = df_genelist['Genes'].apply(lambda l: len(l))
#         df_means = drug_result.groupby(groupby_list)[['mean']].mean().rename(columns={'mean': 'Mean Mean'})
#         df_stds = drug_result.groupby(groupby_list)[['std']].mean().rename(columns={'std': 'Mean STD'})
#         df_scores = df_means
#         df = (
#             df_genelist
#                 .join(df_means)
#                 .join(df_stds)
#         )
#         df['Drug Relevance Score'] = df.apply(lambda r: r['Mean Mean'] * np.sqrt(r['Num Genes']), axis=1)
#         df['Drug Relevance Score'] /= df['Drug Relevance Score'].max()  # Scale score PER CELL TYPE
#         df = df.sort_values('Drug Relevance Score', ascending=False).reset_index()

#         # Save
#         fname = f'{group}_{disease}_{cell_type}_prioritized_drugs{get_fname_suffix(use_modules=use_modules)}.csv'
#         df.to_csv(os.path.join(RESULTS_FOLDER, fname), index=False)

#         # Append to df
#         df['Cell Type'] = cell_type
#         if df_all is None: df_all = df
#         else: df_all = pd.concat((df_all, df), axis=0)

#     # Skip if no data
#     if df_all is None: continue

#     # Create figure
#     fig, ax = plt.subplots(1, 1, figsize=(4, 4))
#     df_filtered = df_all.loc[df_all['Species'] == 'Humans']
#     sorted_drugs = df_filtered.groupby('Drug ID')[['Drug Relevance Score']].max().sort_values('Drug Relevance Score', ascending=False).index.to_numpy()

#     # Pivot
#     df_filtered = df_filtered.pivot(index='Cell Type', columns='Drug ID', values='Drug Relevance Score')

#     # Sort
#     df_filtered = df_filtered[sorted_drugs]

#     # Plot
#     sns.heatmap(df_filtered, vmin=0, vmax=1, cmap='Blues', ax=ax)
#     ax.set_xlabel(None)
#     ax.set_ylabel(None)

#     # Labels
#     to_label = {
#         'DB00408': 'Loxapine',  # SCZ
#         'DB04842': 'Fluspirilene',  # SCZ
#         'DB06144': 'Sertindole',  # SCZ
#         'DB00502': 'Haloperidol',  # SCZ
#         'DB00734': 'Risperidone',  # AD
#         'DB00472': 'Fluoxetine',  # AD
#         'DB01104': 'Sertraline',  # AD
#     }
#     xticks = [np.argwhere(df_filtered.columns == k)[0][0] + .5 for k in to_label if k in df_filtered.columns]
#     xticklabels = [v for k, v in to_label.items() if k in df_filtered.columns]
#     ax.set_xticks(xticks)
#     ax.set_xticklabels(xticklabels)

#     # Save
#     fname = f'Drug_Relevance_{group}_{disease}{get_fname_suffix(use_modules=use_modules)}.pdf'
#     fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

# Enrichments

## GOrilla

In [20]:
# # Params
# sort_param = '-log10(q)'
# origin_files = [
#     # Should be named <{Group}_{Disease}_{Cell Type}_prioritized_genes>_GOrilla.csv
#     'UCLA_ASD_ASD_micro_prioritized_genes',
# ]

# # Plot
# for fname in origin_files:
#     # Enrichment done with rank-order GOrilla
#     enrichment = pd.read_csv(os.path.join(RESULTS_FOLDER, f'{fname}_GOrilla.csv'))
#     enrichment['-log10(p)'] = -np.log(enrichment['P-value'])
#     enrichment['-log10(q)'] = -np.log(enrichment['FDR q-value'])

#     # Filter to top
#     enrichment = enrichment.sort_values(sort_param, ascending=False)
#     enrichment = enrichment.iloc[:25]

#     # Create figure
#     fig, ax = plt.subplots(1, 1, figsize=(6, 8))
#     sns.barplot(data=enrichment, x=sort_param, y='Description', ax=ax)
#     ax.set_ylabel(None)
#     ax.axvline(-np.log(.05), color='black')
#     fig.savefig(os.path.join(PLOTS_FOLDER, f'Enrichment_GOrilla_{fname}.pdf'), bbox_inches='tight')

## Metascape

In [21]:
## Preparing the gene lists
# Params
percentages = (5, 10, 15)

# Convert to gene lists
for source, use_modules in product(data_sources, modules):
    group, disease, _ = source
    grn_fnames = get_grn_fnames(group, disease, use_ctl=use_ctl)[-1]

    for i, fname in enumerate(grn_fnames):
        cell_type = get_cell_type(fname)
        result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes{get_fname_suffix(use_modules=use_modules)}.csv'
        try: result = pd.read_csv(os.path.join(RESULTS_FOLDER, result_fname), index_col=0)
        except: continue

        # Sort and filter
        result = result.sort_values('mean', ascending=False)
        num_genes = result['mean'].shape[0]
        df = pd.DataFrame({'_BACKGROUND': result.index.to_list()})
        for percent in percentages:
            num_genes_to_take = num_genes * percent // 100
            df_concat = pd.DataFrame({f'{percent}p': result.index.to_list()[:num_genes_to_take]})
            df = pd.concat((df, df_concat), axis=1)

        # Save
        fname = f'{group}_{disease}_{cell_type}_prioritized_genes_GO_genes{get_fname_suffix(use_modules=use_modules)}.csv'
        df.to_csv(os.path.join(RESULTS_FOLDER, fname), index=False)

In [22]:
## Visualizing the enrichments
for source, use_modules in product(data_sources, modules):
    group, disease, _ = source
    grn_fnames = get_grn_fnames(group, disease, use_ctl=use_ctl)[-1]

    for i, fname in enumerate(grn_fnames):
        cell_type = get_cell_type(fname)
        result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes_GO_enrichment{get_fname_suffix(use_modules=use_modules)}.csv'
        result_path = os.path.join(RESULTS_FOLDER, result_fname)
        if not os.path.exists(result_path): continue
        try: result = pd.read_csv(result_path, index_col=0)
        except: continue

# TODO