In [1]:
%%capture
%cd "Compound GRN ENC Analysis/scripts"
%matplotlib agg

In [2]:
from collections import defaultdict
from itertools import product
import os

import matplotlib
from matplotlib.collections import PatchCollection
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from statannotations.Annotator import Annotator

# Params
DATA_FOLDER = os.path.join(os.path.abspath(''), '../../data')
RESULTS_FOLDER = os.path.join(os.path.abspath(''), '../results')
PLOTS_FOLDER = os.path.join(os.path.abspath(''), '../plots')

# Style
sns.set_theme(context='talk', style='white', palette='Accent')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


# Parameters

In [3]:
"""
Cohort : Disease : Delimiter
CMC: SCZ : tsv
UCLA_ASD: ASD : csv
Urban_DLPFC: BPD, SCZ : tsv
Subclass: ASD, BPD, SCZ : csv
"""
data_sources = (
    ('CMC', 'SCZ', '\t'),
    ('UCLA_ASD', 'ASD', ','),
    ('Urban_DLPFC', 'BPD', '\t'),
    ('Urban_DLPFC', 'SCZ', '\t'),
    ('Subclass', 'ASD', ','),
    ('Subclass', 'BPD', ','),
    ('Subclass', 'SCZ', ','),
)
# group, disease, delimiter = data_sources[6]
use_ctl = True

In [4]:
# Get AD, BPD, and SCZ labels
gene_dir = os.path.join(DATA_FOLDER, 'new_labels')
gene_fnames = [fname for fname in os.listdir(gene_dir) if fname.endswith('.txt')]
gene_lists = {'.'.join(fname.split('.')[:-1]): np.loadtxt(os.path.join(gene_dir, fname), dtype=str) for fname in gene_fnames}
gene_lists['BPD'] = gene_lists.pop('BD')

# Get ASD labels
sfari = pd.read_csv(os.path.join(DATA_FOLDER, 'sfari/SFARI-Gene_genes_01-16-2024release_03-21-2024export.csv'))
gene_score_threshold = -1
sfari = sfari.loc[sfari['gene-score'] > gene_score_threshold]  # Threshold by score
gene_lists['ASD'] = sfari['gene-symbol'].to_numpy()

# Set positive genes
# positive_genes = gene_lists[disease]

# Get cell-type based on fname
get_cell_type = lambda fname: '_'.join(fname.split('_')[:-1])
# Convert to result file name
# get_result_name = lambda fname: f'{group}_{disease}_{get_cell_type(fname)}_prioritized_genes.csv'

# Get files for contrast
def get_grn_fnames(group, disease):
    # Calculate directories
    base_dir = os.path.join(DATA_FOLDER, 'merged_GRNs_v2', group)
    disease_folder = os.path.join(base_dir, disease)
    if use_ctl: control_folder = os.path.join(base_dir, 'ctrl')
    grn_fnames = np.sort(list(set(os.listdir(disease_folder)).intersection(set(os.listdir(control_folder)))))

    # Return
    ret = ()
    ret += (base_dir, disease_folder)
    if use_ctl: ret += (control_folder,)
    ret += (grn_fnames,)
    # base_dir, disease_folder, control_folder, grn_fnames
    return ret

# Performance

In [5]:
# Params
horizontal = False

# Get performance df
performance = pd.DataFrame()
for source in data_sources:
    group, disease, delimiter = source
    # base_dir, disease_folder, control_folder, grn_fnames = get_grn_fnames(group, disease)

    # Load performance
    fname = os.path.join(RESULTS_FOLDER, f'{group}_{disease}_performance.csv')
    df = pd.read_csv(fname, index_col=0)
    df['Group'] = group
    df['Disease'] = disease

    # Get baselines
    def get_baseline(ct):
        df_temp = pd.read_csv(os.path.join(RESULTS_FOLDER, f'{group}_{disease}_{ct}_prioritized_genes.csv'), index_col=0)
        num_neg = (df_temp['label'] == 0).sum()
        num_pos = (df_temp['label'] == 1).sum()
        baseline = num_pos / (num_pos + num_neg)

        return baseline
    df['Baseline AUPRC'] = df['Cell Type'].map(get_baseline)

    # Concatenate
    performance = pd.concat((performance, df), axis=0)

# Take means
performance = performance.drop(columns='Fold').groupby(['Group', 'Disease', 'Cell Type']).mean().reset_index()

# Apply formatting
performance['Dataset'] = performance.apply(lambda r: f'{r["Group"]} ({r["Disease"]})', axis=1)
# Compute fold changes
for stat in ('AUPRC', 'Validation AUPRC'):
    performance[f'{stat} Fold Change'] = performance[stat] / performance['Baseline AUPRC']

# Plot
fnames = [
    'Performance_Main.pdf',
    'Performance_Subclass.pdf',
]
plot_cell_types_lists = [
    performance.loc[performance['Group'] != 'Subclass', 'Cell Type'].unique().tolist(),
    performance.loc[performance['Group'] == 'Subclass', 'Cell Type'].unique().tolist(),
]
scale = [
    1,
    2
]
for fname, plot_cell_types, sc in zip(fnames, plot_cell_types_lists, scale):
    # Filter to cell types
    performance_filtered = performance.loc[performance['Cell Type'].isin(plot_cell_types)]

    # Format as heatmap
    order_dataset = performance_filtered[['Group', 'Disease', 'Dataset']].groupby(['Group', 'Disease', 'Dataset']).count().reset_index().sort_values(['Disease', 'Group'])['Dataset'].to_list()
    order_cell = performance_filtered.sort_values('Cell Type')['Cell Type'].unique()
    if horizontal: order_dataset = order_dataset[::-1]
    else: order_cell = order_cell[::-1]
    heatmap_data = performance_filtered.pivot(index='Dataset', columns='Cell Type', values='AUPRC Fold Change').loc[order_dataset, order_cell]
    heatmap_data_val = performance_filtered.pivot(index='Dataset', columns='Cell Type', values='Validation AUPRC Fold Change').loc[order_dataset, order_cell]
    if not horizontal:
        heatmap_data = heatmap_data.transpose()
        heatmap_data_val = heatmap_data_val.transpose()

    # Plot
    figsize = heatmap_data.shape[::-1] if horizontal else [sc*s for s in heatmap_data.shape]
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # Define vars
    xlabels = heatmap_data.columns.to_list()
    ylabels = heatmap_data.index.to_list()
    m, n = heatmap_data.shape[1], heatmap_data.shape[0]
    X, Y = np.meshgrid(np.arange(m), np.arange(n))
    S = heatmap_data.to_numpy()
    T = heatmap_data_val.to_numpy()

    # Squares are training, circles are validation
    maxval = 3
    norm = matplotlib.colors.TwoSlopeNorm(vmin=0, vcenter=1, vmax=maxval)
    patches = [plt.Rectangle((j-.5, i-.5), 1, 1) for j, i in zip(X.flat, Y.flat)]
    col = PatchCollection(patches, array=S.flatten(), linewidth=0, clim=[0, maxval], norm=norm, cmap='RdBu')
    ax.add_collection(col)
    patches = [plt.Circle((j, i), .4) for j, i in zip(X.flat, Y.flat)]
    col = PatchCollection(patches, array=T.flatten(), linewidth=0, clim=[0, maxval], norm=norm, cmap='RdBu')
    ax.add_collection(col)
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    divider = make_axes_locatable(ax)
    cax = divider.append_axes('top', size="5%", pad=.5)
    cbar = fig.colorbar(col, orientation='horizontal', cax=cax)
    cbar.ax.set_xlabel('AUPRC Fold Change', fontsize='small')
    cbar.ax.xaxis.set_label_position('top') 
    cbar.ax.set_xticks([0, 1, maxval])  # Add 1 as an xtick

    ax.set(
        xticks=np.arange(m),
        yticks=np.arange(n),
        xticklabels=xlabels,
        yticklabels=ylabels,
    )
    ax.set_xticks(np.arange(m+1)-.5, minor=True)
    ax.set_yticks(np.arange(n+1)-.5, minor=True)
    ax.tick_params(axis='x', labelrotation=90)
    ax.grid(which='minor')
    ax.set_aspect('equal', adjustable='box')

    fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

# Disease Distribution

In [6]:
for source in data_sources:
    group, disease, _ = source
    base_dir, disease_folder, control_folder, grn_fnames = get_grn_fnames(group, disease)

    # Plot
    fig, axs = plt.subplots(len(grn_fnames), 1, figsize=(10, len(grn_fnames) // 2), sharex=True, sharey=True, gridspec_kw={'hspace': -.5})
    fig.suptitle(f'{group} ({disease})')

    for i, fname in enumerate(grn_fnames):
        cell_type = get_cell_type(fname)
        result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes.csv'
        result = pd.read_csv(os.path.join(RESULTS_FOLDER, result_fname), index_col=0)
        ax = axs[i]

        # KDS
        sns.kdeplot(result['mean'], lw=1, color='white', ax=ax)
        x, y = ax.get_lines()[0].get_xydata().T

        # Stylizing and coloration
        ax.set_xlabel('Score')
        # ax.set_ylabel(cell_type, rotation='horizontal', ha='right', va='center')
        ax.set_ylabel(None)
        # ax.set_xticklabels([])
        ax.set_yticklabels([])
        color_val = (i+2) / (len(grn_fnames)+3)  # Padded
        # color_val = i / len(grn_fnames)  # Unpadded
        color = matplotlib.colormaps['Blues'](color_val)
        ax.fill_between(x, 0, y, color=color)

        patches = []
        width = 3e-3  # if group != 'Subclass' else 3e-6
        alpha = .02
        # alpha = 2 / (result.loc[result['label'] != -1, 'label']).sum()
        # alpha = (result.loc[result['label'] != -1, 'label']).sum() / result['label'].count()
        for r in result.loc[result['label'] == 1].iterrows():
            score = (r[1]['mean'])

            # Fill
            # NOTE: Could also compute KDE of positive labels and use a gradient
            xrange = np.linspace(max(0, score-width), min(1, score+width), 10)
            ytop = np.interp(xrange, x, y)
            ax.fill_between(xrange, 0, ytop, color='red', alpha=alpha)
            
            # Patches
            # height = np.interp(score, x, y)
            # width = .1 / x.shape[0]
            # arrow = plt.Rectangle((score, -width/2), width, height)
            # patches.append(arrow)
        # col = PatchCollection(patches, linewidth=0, color='salmon', alpha=.1)
        # ax.add_collection(col)

        # Formatting
        sns.despine(bottom=True, left=True)
        ax.patch.set_alpha(0)  # Make background transparent
        ax.text(0, 0, cell_type, fontsize='small', ha='right', va='bottom', transform=ax.transAxes)  # Add x label

    # Save
    fname = f'Distribution_{group}_{disease}.pdf'
    fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

# Drugs

In [7]:
# Parameters
num_top_drugs = 10

# Load drugs and gene lists
drugs = pd.read_csv(os.path.join(DATA_FOLDER, 'pharmacologically_active.csv'))
drugs = drugs[['Gene Name', 'Drug IDs', 'Species']]
drugs['Drug IDs'] = drugs['Drug IDs'].apply(lambda s: s.split('; '))
drugs = drugs.explode('Drug IDs')

# Add BBB predictions
bbb = pd.read_csv(os.path.join(DATA_FOLDER, 'BBB_plus_dbIDS.csv'), index_col=0)
drugs['BBB'] = drugs['Drug IDs'].map(lambda x: x in list(bbb['ID']))

# Add scores
for source in data_sources:
    group, disease, _ = source
    base_dir, disease_folder, control_folder, grn_fnames = get_grn_fnames(group, disease)

    for i, fname in enumerate(grn_fnames):
        # Load prioritization
        cell_type = get_cell_type(fname)
        result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes.csv'
        result = pd.read_csv(os.path.join(RESULTS_FOLDER, result_fname), index_col=0)

        # Append to drugs
        drug_result = (
            drugs
                .set_index('Gene Name')
                .join(result[['mean', 'std']])
                .fillna(0)
                .reset_index(names='Gene')
        )

        # Compute drug scores
        groupby_list = ['Drug IDs', 'Species', 'BBB']
        df_genelist = (
            drug_result
                .groupby(groupby_list)['Gene']
                .apply(list)
                .reset_index(name='Genes')
                .set_index(groupby_list)
        )
        df_genelist['Num Genes'] = df_genelist['Genes'].apply(lambda l: len(l))
        df_means = drug_result.groupby(groupby_list)[['mean']].sum()
        df_stds = drug_result.groupby(groupby_list)[['std']].mean()
        df = (
            df_genelist
                .join(df_means)
                .join(df_stds)
                .sort_values('mean', ascending=False)
                .reset_index()
                .rename(columns={'mean': 'Cumulative Score', 'std': 'Mean STD'})
        )

        # Save
        fname = f'{group}_{disease}_{cell_type}_prioritized_drugs.csv'
        df.to_csv(os.path.join(RESULTS_FOLDER, fname), index=False)

        # Plot
        # TODO: Refine ordering (currently by sum of scores) and analyze/plot
        fig, ax = plt.subplots(1, 1, figsize=(4, 4))
        df_plot = df.sort_values('Cumulative Score', ascending=False).iloc[:num_top_drugs]
        sns.barplot(data=df_plot, x='Cumulative Score', y='Drug IDs', palette='Blues', ax=ax)
        fname = f'Drug_Relevance_{group}_{disease}_{cell_type}.pdf'
        fig.savefig(os.path.join(PLOTS_FOLDER, fname), bbox_inches='tight')

  fig, ax = plt.subplots(1, 1, figsize=(4, 4))


# Enrichments

## GOrilla

In [8]:
# Params
sort_param = '-log10(q)'
origin_files = [
    # Should be named <{Group}_{Disease}_{Cell Type}_prioritized_genes>_GOrilla.csv
    'UCLA_ASD_ASD_micro_prioritized_genes',
]

# Plot
for fname in origin_files:
    # Enrichment done with rank-order GOrilla
    enrichment = pd.read_csv(os.path.join(RESULTS_FOLDER, f'{fname}_GOrilla.csv'))
    enrichment['-log10(p)'] = -np.log(enrichment['P-value'])
    enrichment['-log10(q)'] = -np.log(enrichment['FDR q-value'])

    # Filter to top
    enrichment = enrichment.sort_values(sort_param, ascending=False)
    enrichment = enrichment.iloc[:25]

    # Create figure
    fig, ax = plt.subplots(1, 1, figsize=(6, 8))
    sns.barplot(data=enrichment, x=sort_param, y='Description', ax=ax)
    ax.set_ylabel(None)
    ax.axvline(-np.log(.05), color='black')
    fig.savefig(os.path.join(PLOTS_FOLDER, f'Enrichment_GOrilla_{fname}.pdf'), bbox_inches='tight')

## Metascape

In [9]:
## Preparing the gene lists
# Params
percentages = (5, 10, 15)

# Convert to gene lists
for source in data_sources:
    group, disease, _ = source
    base_dir, disease_folder, control_folder, grn_fnames = get_grn_fnames(group, disease)

    for i, fname in enumerate(grn_fnames):
        cell_type = get_cell_type(fname)
        result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes.csv'
        result = pd.read_csv(os.path.join(RESULTS_FOLDER, result_fname), index_col=0)

        # Sort and filter
        result = result.sort_values('mean', ascending=False)
        num_genes = result['mean'].shape[0]
        df = pd.DataFrame({'_BACKGROUND': result.index.to_list()})
        for percent in percentages:
            num_genes_to_take = num_genes * percent // 100
            df_concat = pd.DataFrame({f'{percent}p': result.index.to_list()[:num_genes_to_take]})
            df = pd.concat((df, df_concat), axis=1)

        # Save
        fname = f'{group}_{disease}_{cell_type}_prioritized_genes_GO_genes.csv'
        df.to_csv(os.path.join(RESULTS_FOLDER, fname), index=False)

In [10]:
## Visualizing the enrichments
for source in data_sources:
    group, disease, _ = source
    base_dir, disease_folder, control_folder, grn_fnames = get_grn_fnames(group, disease)

    for i, fname in enumerate(grn_fnames):
        cell_type = get_cell_type(fname)
        result_fname = f'{group}_{disease}_{cell_type}_prioritized_genes_GO_enrichment.csv'
        result = pd.read_csv(os.path.join(RESULTS_FOLDER, result_fname), index_col=0)

# TODO

FileNotFoundError: [Errno 2] No such file or directory: '/mnt/c/Users/nck/repos/GNN-Plus/Compound GRN ENC Analysis/scripts/../results/CMC_SCZ_astro_prioritized_genes_GO_enrichment.csv'