In [1]:
%%capture
%cd "Compound GRN ENC Analysis/scripts"
%matplotlib agg

In [2]:
from collections import defaultdict
from itertools import product
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from statannotations.Annotator import Annotator

# Params
DATA_FOLDER = os.path.join(os.path.abspath(''), '../../data')
RESULTS_FOLDER = os.path.join(os.path.abspath(''), '../results')
PLOTS_FOLDER = os.path.join(os.path.abspath(''), '../plots')

# Style
sns.set_theme(context='talk', style='white', palette='Accent')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


# Data

In [3]:
"""
Cohort : Disease : Delimiter
CMC: SCZ : tsv
UCLA_ASD: ASD : csv
Urban_DLPFC: BPD, SCZ : tsv
"""
data_sources = (
    ('CMC', 'SCZ', '\t'),
    ('UCLA_ASD', 'ASD', ','),
    ('Urban_DLPFC', 'BPD', '\t'),
    ('Urban_DLPFC', 'SCZ', '\t'),
)
group, disease, delimiter = data_sources[1]

In [4]:
# Get AD, BPD, and SCZ labels
gene_dir = os.path.join(DATA_FOLDER, 'new_labels')
gene_fnames = [fname for fname in os.listdir(gene_dir) if fname.endswith('.txt')]
gene_lists = {'.'.join(fname.split('.')[:-1]): np.loadtxt(os.path.join(gene_dir, fname), dtype=str) for fname in gene_fnames}
gene_lists['BPD'] = gene_lists.pop('BD')

# Get ASD labels
sfari = pd.read_csv(os.path.join(DATA_FOLDER, 'sfari/SFARI-Gene_genes_01-16-2024release_03-21-2024export.csv'))
gene_score_threshold = -1
sfari = sfari.loc[sfari['gene-score'] > gene_score_threshold]  # Threshold by score
gene_lists['ASD'] = sfari['gene-symbol'].to_numpy()

# Set positive genes
positive_genes = gene_lists[disease]

# Get files for contrast
base_dir = os.path.join(DATA_FOLDER, 'merged_GRNs_v2', group)
disease_folder = os.path.join(base_dir, disease)
control_folder = os.path.join(base_dir, 'ctrl')
grn_fnames = os.listdir(control_folder)  # Should be the same names in either folder
num_panels = len(grn_fnames)

# Get groups and cell-type based on fname
get_cell_type = lambda fname: '_'.join(fname.split('_')[:-1])
# Convert to result file name
get_result_name = lambda fname: f'{group}_{disease}_{get_cell_type(fname)}_prioritized_genes.csv'

# Network Analyses

## Performance

In [6]:
fname = os.path.join(RESULTS_FOLDER, f'{group}_{disease}_performance.csv')
df = pd.read_csv(fname, index_col=0)  # Load stats

# Get baselines
baselines = {ct: None for ct in df['Cell Type'].unique()}
for ct in list(baselines):
    df_temp = pd.read_csv(os.path.join(RESULTS_FOLDER, f'{group}_{disease}_{ct}_prioritized_genes.csv'), index_col=0)
    num_neg = (df_temp['label'] == 0).sum()
    num_pos = (df_temp['label'] == 1).sum()
    baseline = num_pos / (num_pos + num_neg)
    baselines[ct] = baseline

# Create df
for ct, base in baselines.items():
    df.loc[df['Cell Type'] == ct, 'AUPRC'] /= float(base)
    df.loc[df['Cell Type'] == ct, 'Validation AUPRC'] /= float(base)
df = df[['Cell Type', 'Fold'] + ['AUPRC', 'Validation AUPRC']]
df = df.melt(id_vars=['Cell Type', 'Fold'], value_name='Fold Change', var_name='Statistic')

# Plot
fig, ax = plt.subplots(1, 1, figsize=(10, 4))
sns.boxplot(data=df, x='Cell Type', y='Fold Change', hue='Statistic', ax=ax)
sns.despine()
plt.axhline(y=1, color='black', ls='--')
# plt.ylim(0, 1)
plt.xticks(rotation=90)
fig.savefig(os.path.join(PLOTS_FOLDER, f'GCN_Performance_{group}_{disease}.pdf'), bbox_inches='tight')

## In and Out Degrees

In [7]:
# Create figures
fig, ax = {}, {}
key = 'out'; fig[key], ax[key] = plt.subplots(1, num_panels+1, figsize=(3*(num_panels+1), 3), sharex=True, sharey=True)  # Out
key = 'in'; fig[key], ax[key] = plt.subplots(1, num_panels+1, figsize=(3*(num_panels+1), 3), sharex=True, sharey=True)  # In
for k in ax: ax[k] = ax[k].flatten()

# Construct consistent objects
cmap = 'Blues'
norm = plt.Normalize(0, .2)  # Change top as needed
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)

# Plot network analyses
for i, fname in enumerate(grn_fnames):
    # Load graph
    disease_graph = pd.read_csv(os.path.join(disease_folder, fname), index_col=False, delimiter=delimiter)
    disease_graph['disease'] = 'Disease'
    control_graph = pd.read_csv(os.path.join(control_folder, fname), index_col=False, delimiter=delimiter)
    control_graph['disease'] = 'Control'
    combined_graph = pd.concat((disease_graph, control_graph), axis=0)

    # Load scores
    scores = pd.read_csv(os.path.join(RESULTS_FOLDER, get_result_name(fname)), index_col=0)[['label', 'mean', 'std']]

    # Out degree plot
    plt.sca(ax['out'][i])
    out_degree_df = combined_graph.copy()
    out_degree_df['Out Degree'] = 1
    out_degree_df = out_degree_df[['TF', 'disease', 'Out Degree']].groupby(['TF', 'disease']).sum().reset_index()
    out_degree_df = out_degree_df.pivot(index='TF', columns='disease', values='Out Degree').fillna(0)
    out_degree_df = out_degree_df.join(scores, on='TF')
    out_degree_df = out_degree_df[['Control', 'Disease', 'mean']].groupby(['Control', 'Disease']).mean().reset_index()
    sns.scatterplot(data=out_degree_df, x='Control', y='Disease', c=out_degree_df['mean'], cmap=cmap, norm=norm)  # , hue='mean'
    sns.despine()
    plt.xlabel('Control'); plt.ylabel('Disease')
    # plt.gca().get_legend().remove()

    # In degree plot
    plt.sca(ax['in'][i])
    in_degree_df = combined_graph.copy()
    in_degree_df['In Degree'] = 1
    in_degree_df = in_degree_df[['target', 'disease', 'In Degree']].groupby(['target', 'disease']).sum().reset_index()
    in_degree_df = in_degree_df.pivot(index='target', columns='disease', values='In Degree').fillna(0)
    in_degree_df = in_degree_df.join(scores, on='target')
    in_degree_df = in_degree_df[['Control', 'Disease', 'mean']].groupby(['Control', 'Disease']).mean().reset_index()
    sns.scatterplot(data=in_degree_df, x='Control', y='Disease', c=in_degree_df['mean'], cmap=cmap, norm=norm)
    sns.despine()
    plt.xlabel('Control'); plt.ylabel('Disease')
    # plt.gca().get_legend().remove()

# Insert colorbars
key = 'out'; ax[key][-1].axis('off'); fig[key].colorbar(sm, ax=ax[key][-1])
key = 'in'; ax[key][-1].axis('off'); fig[key].colorbar(sm, ax=ax[key][-1])

# Formatting
fig['out'].suptitle('Out Degree')
fig['in'].suptitle('In Degree')

# Save figures
fig['out'].savefig(os.path.join(PLOTS_FOLDER, f'DegreeOut_{group}_{disease}.pdf'), bbox_inches='tight')
fig['in'].savefig(os.path.join(PLOTS_FOLDER, f'DegreeIn_{group}_{disease}.pdf'), bbox_inches='tight')

# Close figs
plt.close()

## Score by Label

In [8]:
# Create figure
fig, ax = plt.subplots(1, 1, figsize=(2*num_panels, 3), sharex=True, sharey=True)

# Create df
df_all = pd.DataFrame()
for i, fname in enumerate(grn_fnames):
    # Load scores
    scores = pd.read_csv(os.path.join(RESULTS_FOLDER, get_result_name(fname)), index_col=0)[['label', 'mean', 'std']]

    # Format df
    df = scores.copy()
    df.loc[df['label'] == 0, 'label'] = 'Control'; df.loc[df['label'] == 1, 'label'] = 'Disease'
    df = df.rename(columns={'label': 'Label', 'mean': 'Score'})
    df['Cell Type'] = get_cell_type(fname)
    df_all = pd.concat((df_all, df), axis=0)

# Params
hue_order = ['Control', 'Disease']

# Plot
plt.sca(ax)
sns.violinplot(data=df_all, x='Cell Type', y='Score', hue='Label', hue_order=hue_order, split=True, inner='quart', density_norm='count')
sns.despine()
# plt.title('Score Distribution by Label')

# Annotate significance
pairs = [((ct, hue_order[0]), (ct, hue_order[1])) for ct in df_all['Cell Type'].unique()]
annotator = Annotator(ax, pairs, data=df_all, x='Cell Type', y='Score', hue='Label', hue_order=hue_order)
annotator.configure(test='Mann-Whitney', text_format='star', loc='outside')
results = annotator.apply_test().annotate()

# Save figure
fig.savefig(os.path.join(PLOTS_FOLDER, f'DistributionScore_{group}_{disease}.pdf'), bbox_inches='tight')
plt.close()

p-value annotation legend:
      ns: 5.00e-02 < p <= 1.00e+00
       *: 1.00e-02 < p <= 5.00e-02
      **: 1.00e-03 < p <= 1.00e-02
     ***: 1.00e-04 < p <= 1.00e-03
    ****: p <= 1.00e-04

endo_Control vs. endo_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:8.917e-04 U_stat=1.693e+05
astro_Control vs. astro_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:2.836e-05 U_stat=1.134e+05
excitatory_Control vs. excitatory_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:1.433e-04 U_stat=3.342e+04
inhibitory_Control vs. inhibitory_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:2.304e-05 U_stat=2.482e+04
micro_Control vs. micro_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:1.369e-03 U_stat=1.302e+05
oligo_Control vs. oligo_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:1.587e-06 U_stat=7.017e+04
opc_Control vs. opc_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:9.954e-03 U_stat=9.762e+04
vlmc_Control vs. vlmc_Disease: Mann-Whitney-Wilcoxon test two-side

## Score by Graph

In [9]:
# Create figure
fig, ax = plt.subplots(1, 1, figsize=(2*num_panels, 3), sharex=True, sharey=True)

# Create df
df_all = pd.DataFrame()
for i, fname in enumerate(grn_fnames):
    # Load graph
    disease_graph = pd.read_csv(os.path.join(disease_folder, fname), index_col=False, delimiter=delimiter)
    disease_graph['disease'] = 'Disease'
    control_graph = pd.read_csv(os.path.join(control_folder, fname), index_col=False, delimiter=delimiter)
    control_graph['disease'] = 'Control'
    combined_graph = pd.concat((disease_graph, control_graph), axis=0)

    # Load scores
    scores = pd.read_csv(os.path.join(RESULTS_FOLDER, get_result_name(fname)), index_col=0)[['label', 'mean', 'std']]

    # Format df
    control_genes = np.unique(control_graph['TF'].to_list() + control_graph['target'].to_list())
    control_genes = list(set(control_genes).intersection(set(scores.index)))
    control_df = pd.DataFrame({'genes': control_genes})
    control_df['Graph'] = 'Control'
    disease_genes = np.unique(disease_graph['TF'].to_list() + disease_graph['target'].to_list())
    disease_genes = list(set(disease_genes).intersection(set(scores.index)))
    disease_df = pd.DataFrame({'genes': disease_genes})
    disease_df['Graph'] = 'Disease'
    df = pd.concat((control_df, disease_df), axis=0).reset_index(drop=True)
    df = df.join(scores, on='genes')
    df = df.rename(columns={'mean': 'Score'})
    df['Cell Type'] = get_cell_type(fname)
    df_all = pd.concat((df_all, df), axis=0)

# Params
hue_order = ['Control', 'Disease']

# Plot
plt.sca(ax)
sns.violinplot(data=df_all, x='Cell Type', y='Score', hue='Graph', hue_order=hue_order, split=True, inner='quart', density_norm='count')
sns.despine()
# plt.title('Score Distribution by Graph')

# Annotate significance
pairs = [((ct, hue_order[0]), (ct, hue_order[1])) for ct in df_all['Cell Type'].unique()]
annotator = Annotator(ax, pairs, data=df_all, x='Cell Type', y='Score', hue='Graph', hue_order=hue_order)
annotator.configure(test='Mann-Whitney', text_format='star', loc='outside')
results = annotator.apply_test().annotate()

# Save figure
fig.savefig(os.path.join(PLOTS_FOLDER, f'DistributionDisease_{group}_{disease}.pdf'), bbox_inches='tight')
plt.close()

p-value annotation legend:
      ns: 5.00e-02 < p <= 1.00e+00
       *: 1.00e-02 < p <= 5.00e-02
      **: 1.00e-03 < p <= 1.00e-02
     ***: 1.00e-04 < p <= 1.00e-03
    ****: p <= 1.00e-04

endo_Control vs. endo_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:2.557e-21 U_stat=4.631e+06
astro_Control vs. astro_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:6.003e-14 U_stat=3.538e+06
excitatory_Control vs. excitatory_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:5.397e-07 U_stat=1.356e+06
inhibitory_Control vs. inhibitory_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:1.113e-01 U_stat=6.533e+05
micro_Control vs. micro_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:3.709e-14 U_stat=3.133e+06
oligo_Control vs. oligo_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:7.730e-01 U_stat=2.529e+06
opc_Control vs. opc_Disease: Mann-Whitney-Wilcoxon test two-sided, P_val:1.542e-08 U_stat=3.123e+06
vlmc_Control vs. vlmc_Disease: Mann-Whitney-Wilcoxon test two-side

## Drug Targets

In [21]:
# Randomly select drugs
df = pd.read_csv(os.path.join(DATA_FOLDER, 'pharmacologically_active.csv'))
df = df[['Gene Name', 'Drug IDs']]
df['Drug IDs'] = df['Drug IDs'].apply(lambda s: s.split('; '))
df = df.explode('Drug IDs')
df = df.groupby('Drug IDs')['Gene Name'].apply(list).reset_index(name='Genes')
df['Num Genes'] = df['Genes'].apply(lambda l: len(l))
df = df.sort_values('Num Genes', ascending=False)  # For visualization
# Filter
# df = df.loc[df['Num Genes'] >= 10]
#### Select drugs at random
# np.random.seed(42)
# drug_idx = np.random.choice(df.shape[0], 20, replace=False)
# df = df.iloc[drug_idx]
### Manually select drugs
# drug_names = ['DB05541', 'DB01189', 'DB09166', 'DB09089', 'DB09078', 'DB09232']  # MANUAL
# df = df.loc[df['Drug IDs'].isin(drug_names)]
### END

# Annotate based on BBB
# TODO: Check that this data is correct, providing opposite results right now
df_bbb = pd.read_csv(os.path.join(DATA_FOLDER, 'BBB_plus_dbIDS.csv'), index_col=0)
df['BBB'] = df['Drug IDs'].map(lambda x: x in list(df_bbb['ID']))

# Format into dicts for speed
drug_targets = {r['Drug IDs']: r['Genes'] for _, r in df.iterrows()}
drug_bbb = {r['Drug IDs']: r['BBB'] for _, r in df.iterrows()}
target_drugs = defaultdict(lambda: [])
for drug, targets in drug_targets.items():
    for target in targets:
        target_drugs[target].append(drug)
target_drugs = dict(target_drugs)

# CLI
print('Selected Drugs: ' + ', '.join(list(drug_targets)))

Selected Drugs: DB00228, DB11148, DB01049, DB13345, DB01189, DB00273, DB04855, DB11273, DB13025, DB01558, DB00628, DB00683, DB12404, DB00546, DB00690, DB13335, DB13437, DB09017, DB01595, DB15489, DB01594, DB00897, DB01589, DB00842, DB00829, DB01489, DB01588, DB01587, DB01511, DB00786, DB01544, DB01545, DB01553, DB01068, DB01559, DB13837, DB00404, DB14719, DB01215, DB14715, DB09166, DB13872, DB14672, DB00349, DB14028, DB09089, DB00186, DB00475, DB00312, DB00371, DB00463, DB00794, DB01159, DB00306, DB00241, DB01239, DB00753, DB01236, DB09283, DB01956, DB01028, DB01205, DB00292, DB08896, DB00818, DB09118, DB00907, DB01437, DB00402, DB01107, DB01567, DB00237, DB00189, DB13643, DB00231, DB06637, DB00801, DB11859, DB11582, DB09167, DB00193, DB05541, DB13269, DB00252, DB00909, DB01363, DB09231, DB01331, DB00562, DB09238, DB00606, DB09060, DB00825, DB09085, DB00028, DB10772, DB11278, DB01577, DB12107, DB00555, DB06603, DB00202, DB05015, DB01388, DB00675, DB09232, DB09088, DB09342, DB09345, DB0

In [25]:
# DB00228, DB11148: Anaesthetic, anaesthetic
print('BBB Positive: ' + ', '.join(df.loc[df['BBB'], 'Drug IDs'].to_list()))

# DB13345, DB00273: AD treatment, migraines
print('BBB Negative: ' + ', '.join(df.loc[~df['BBB'], 'Drug IDs'].to_list()))

BBB Positive: DB00228, DB11148, DB01049, DB01189, DB00463, DB00794, DB01159, DB00753, DB01236, DB01028, DB01437, DB01107, DB00193, DB00252, DB00909, DB01363, DB00028, DB10772, DB00555, DB00333, DB11124, DB11571, DB11300, DB00098, DB13151, DB09228, DB01043, DB00182, DB00191, DB09568, DB01353, DB01351, DB06404, DB00849, DB01483, DB11572, DB11603, DB06607, DB05121, DB01221, DB14738, DB00898, DB15258, DB00893, DB13152, DB00013, DB00100, DB00721, DB11311, DB00075, DB13150, DB13149, DB01032, DB08885, DB09109, DB05679, DB11130, DB00048, DB06738, DB13998, DB00812, DB13961, DB13933, DB13999, DB00020, DB00025, DB00041, DB00042, DB06779, DB16695, DB16220, DB00895, DB00069, DB00068, DB11639, DB08954, DB01440, DB01281, DB00032, DB00033, DB09052, DB01109, DB00034, DB09336, DB00043, DB00031, DB00029, DB00018, DB00022, DB01156, DB00015, DB11166, DB00054, DB00055, DB00060, DB00005, DB11312, DB00004, DB00052, DB01050, DB09033, DB00011, DB01225, DB00009, DB00008, DB13896, DB14004, DB13923, DB13133, DB131

In [20]:
# Create figure
fig, ax = plt.subplots(1, 1, figsize=(12, 6), sharex=True, sharey=True)

# Create df
df_all = pd.DataFrame()
for i, fname in enumerate(grn_fnames):
    # Load scores
    scores = pd.read_csv(os.path.join(RESULTS_FOLDER, get_result_name(fname)), index_col=0)[['label', 'mean', 'std']]

    # Format df
    df = scores.copy()
    df.loc[df['label'] == 0, 'label'] = 'Control'; df.loc[df['label'] == 1, 'label'] = 'Disease'
    df = df.rename(columns={'label': 'Label', 'mean': 'Score'})
    df['Cell Type'] = get_cell_type(fname)
    df_all = pd.concat((df_all, df), axis=0)

# Annotate drug targets
df_all = df_all.reset_index()
df_all['Drugs'] = df_all.apply(lambda row: target_drugs[row['gene']] if row['gene'] in target_drugs else [], axis=1)
df_all['Drugs'] = df_all['Drugs'].map(lambda l: l + ['Background'])
df_all = df_all.explode('Drugs').rename(columns={'Drugs': 'Drug'})
df_all['BBB'] = df_all['Drug'].map(lambda s: drug_bbb[s] if s != 'Background' else pd.NA)
df_all['BBB'] = df_all['BBB'].map(lambda e: {True: 1, False: 0, pd.NA: -1}[e])

### Group by gene
# df = df_all.groupby('gene').max()
### Group by drug
df = df_all[['Drug', 'Score', 'std', 'BBB']].groupby(['Drug', 'BBB']).mean().reset_index()
df = df.loc[df['BBB'] != -1]
### END
df['BBB'] = df['BBB'].map(lambda e: {1: 'Can Penetrate BBB', 0: 'Cannot Penetrate BBB', -1: 'Background'}[e])

# Save for exploration
# df.to_csv(os.path.join(PLOTS_FOLDER, 'data.csv'))

# Params
# hue_order = ['Control', 'Disease']

# Plot
plt.sca(ax)
sns.violinplot(data=df, x='Score', y='BBB', orient='h', split=True)  # , inner='quart'
sns.despine()

# Annotate significance
pairs = []
unique_bbb = list(df['BBB'].unique())
for i in range(len(unique_bbb)):
    for j in range(i+1, len(unique_bbb)):
        pairs.append((unique_bbb[i], unique_bbb[j]))
annotator = Annotator(ax, pairs, data=df, x='Score', y='BBB', orient='h')
annotator.configure(test='Mann-Whitney', text_format='star')
results = annotator.apply_test().annotate()

# Save figure
fig.savefig(os.path.join(PLOTS_FOLDER, f'DistributionDrug_{group}_{disease}.pdf'), bbox_inches='tight')
plt.close()

p-value annotation legend:
      ns: 5.00e-02 < p <= 1.00e+00
       *: 1.00e-02 < p <= 5.00e-02
      **: 1.00e-03 < p <= 1.00e-02
     ***: 1.00e-04 < p <= 1.00e-03
    ****: p <= 1.00e-04

Can Penetrate BBB vs. Cannot Penetrate BBB: Mann-Whitney-Wilcoxon test two-sided, P_val:4.532e-05 U_stat=8.371e+04


# Enrichment Visualizations

In [24]:
# Params
sort_param = '-log10(q)'
origin_file = 'UCLA_ASD_ASD_micro_prioritized_genes'

# Enrichment done with rank-order GOrilla
enrichment = pd.read_csv(os.path.join(RESULTS_FOLDER, f'{origin_file}_GO.csv'))
enrichment['-log10(p)'] = -np.log(enrichment['P-value'])
enrichment['-log10(q)'] = -np.log(enrichment['FDR q-value'])

# Filter to top
enrichment = enrichment.sort_values(sort_param, ascending=False)
enrichment = enrichment.iloc[:25]

# Create figure
fig, ax = plt.subplots(1, 1, figsize=(6, 8))
sns.barplot(data=enrichment, x=sort_param, y='Description', ax=ax)
ax.set_ylabel(None)
ax.axvline(-np.log(.05), color='black')
fig.savefig(os.path.join(PLOTS_FOLDER, f'Enrichment_{origin_file}.pdf'), bbox_inches='tight')