In [12]:
%%capture
%cd "Compound GRN ENC Analysis/scripts"

In [13]:
import os

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Params
DATA_FOLDER = os.path.join(os.path.abspath(''), '../../data')
RESULTS_FOLDER = os.path.join(os.path.abspath(''), '../results')
PLOTS_FOLDER = os.path.join(os.path.abspath(''), '../plots')

# Style
sns.set_theme(context='talk', style='white', palette='Set2')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


# Data

In [14]:
"""
Cohort : Disease : Delimiter
CMC: SCZ : tsv
UCLA_ASD: ASD : csv
Urban_DLPFC: BPD, SCZ : tsv
"""
group = ['CMC', 'UCLA_ASD', 'Urban_DLPFC'][2]
disease = ['SCZ', 'ASD', 'BPD'][0]
delimiter = [',', '\t'][1]

In [15]:
# Get AD, BPD, and SCZ labels
gene_dir = os.path.join(DATA_FOLDER, 'new_labels')
gene_fnames = [fname for fname in os.listdir(gene_dir) if fname.endswith('.txt')]
gene_lists = {'.'.join(fname.split('.')[:-1]): np.loadtxt(os.path.join(gene_dir, fname), dtype=str) for fname in gene_fnames}
gene_lists['BPD'] = gene_lists.pop('BD')

# Get ASD labels
sfari = pd.read_csv(os.path.join(DATA_FOLDER, 'sfari/SFARI-Gene_genes_01-16-2024release_03-21-2024export.csv'))
gene_score_threshold = -1
sfari = sfari.loc[sfari['gene-score'] > gene_score_threshold]  # Threshold by score
gene_lists['ASD'] = sfari['gene-symbol'].to_numpy()

# Set positive genes
positive_genes = gene_lists[disease]

# Get files for contrast
base_dir = os.path.join(DATA_FOLDER, 'merged_GRNs_v2', group)
disease_folder = os.path.join(base_dir, disease)
control_folder = os.path.join(base_dir, 'ctrl')
grn_fnames = os.listdir(control_folder)  # Should be the same names in either folder

# Get groups and cell-type based on fname
get_cell_type = lambda fname: '_'.join(fname.split('_')[:-1])
# Convert to result file name
get_result_name = lambda fname: f'{group}_{disease}_{get_cell_type(fname)}_prioritized_genes.csv'

# Network Analyses

In [16]:
%matplotlib agg

# Plot network analyses
for fname in grn_fnames:
    # Load graph
    disease_graph = pd.read_csv(os.path.join(disease_folder, fname), index_col=False, delimiter=delimiter)
    disease_graph['disease'] = 'Disease'
    control_graph = pd.read_csv(os.path.join(control_folder, fname), index_col=False, delimiter=delimiter)
    control_graph['disease'] = 'Control'
    combined_graph = pd.concat((disease_graph, control_graph), axis=0)

    # Load scores
    scores = pd.read_csv(os.path.join(RESULTS_FOLDER, get_result_name(fname)), index_col=0)[['label', 'mean', 'std']]

    # Out degree plot
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    out_degree_df = combined_graph.copy()
    out_degree_df['Out Degree'] = 1
    out_degree_df = out_degree_df[['TF', 'disease', 'Out Degree']].groupby(['TF', 'disease']).sum().reset_index()
    out_degree_df = out_degree_df.pivot(index='TF', columns='disease', values='Out Degree').fillna(0)
    out_degree_df = out_degree_df.join(scores, on='TF')
    out_degree_df = out_degree_df[['Control', 'Disease', 'mean']].groupby(['Control', 'Disease']).mean().reset_index()
    # out_degree_df = out_degree_df.pivot(index='Disease', columns='Control', values='mean').iloc[::-1]
    sns.scatterplot(data=out_degree_df, x='Control', y='Disease', hue='mean', palette='Reds')
    sns.despine()
    plt.xlabel('Control'); plt.ylabel('Disease')
    plt.title('Out Degree')
    sm = plt.cm.ScalarMappable(cmap='Reds', norm=plt.Normalize(out_degree_df['mean'].min(), out_degree_df['mean'].max()))
    ax.get_legend().remove()
    ax.figure.colorbar(sm, ax=ax)
    fig.savefig(os.path.join(PLOTS_FOLDER, f'DegreeOut_{group}_{disease}_{get_cell_type(fname)}.pdf'), bbox_inches='tight')
    plt.close()

    # In degree plot
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    in_degree_df = combined_graph.copy()
    in_degree_df['In Degree'] = 1
    in_degree_df = in_degree_df[['target', 'disease', 'In Degree']].groupby(['target', 'disease']).sum().reset_index()
    in_degree_df = in_degree_df.pivot(index='target', columns='disease', values='In Degree').fillna(0)
    in_degree_df = in_degree_df.join(scores, on='target')
    in_degree_df = in_degree_df[['Control', 'Disease', 'mean']].groupby(['Control', 'Disease']).mean().reset_index()
    # in_degree_df = in_degree_df.pivot(index='Disease', columns='Control', values='mean').iloc[::-1]
    sns.scatterplot(data=in_degree_df, x='Control', y='Disease', hue='mean', palette='Reds')
    sns.despine()
    plt.xlabel('Control'); plt.ylabel('Disease')
    plt.title('In Degree')
    sm = plt.cm.ScalarMappable(cmap='Reds', norm=plt.Normalize(in_degree_df['mean'].min(), in_degree_df['mean'].max()))
    ax.get_legend().remove()
    ax.figure.colorbar(sm, ax=ax)
    fig.savefig(os.path.join(PLOTS_FOLDER, f'DegreeIn_{group}_{disease}_{get_cell_type(fname)}.pdf'), bbox_inches='tight')
    plt.close()

    # Score distribution
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    df = scores.copy()
    df.loc[df['label'] == 0, 'label'] = 'Control'; df.loc[df['label'] == 1, 'label'] = 'Disease'
    df = df.rename(columns={'label': 'Label'})
    sns.histplot(data=df.sort_values('Label'), x='mean', hue='Label', element='poly')
    sns.despine()
    plt.xlabel('Score')
    plt.title('Score Distribution by Label')
    fig.savefig(os.path.join(PLOTS_FOLDER, f'DistributionScore_{group}_{disease}_{get_cell_type(fname)}.pdf'), bbox_inches='tight')
    plt.close()

    # Disease score distribution
    fig, ax = plt.subplots(1, 1, figsize=(6, 6))
    control_genes = np.unique(control_graph['TF'].to_list() + control_graph['target'].to_list())
    control_genes = list(set(control_genes).intersection(set(scores.index)))
    control_df = pd.DataFrame({'genes': control_genes})
    control_df['Graph'] = 'Control'
    disease_genes = np.unique(disease_graph['TF'].to_list() + disease_graph['target'].to_list())
    disease_genes = list(set(disease_genes).intersection(set(scores.index)))
    disease_df = pd.DataFrame({'genes': disease_genes})
    disease_df['Graph'] = 'Disease'
    df = pd.concat((control_df, disease_df), axis=0).reset_index(drop=True)
    df = df.join(scores, on='genes')
    sns.histplot(data=df.sort_values('Graph'), x='mean', hue='Graph', element='poly')
    sns.despine()
    plt.xlabel('Score')
    plt.title('Score Distribution by Graph')
    fig.savefig(os.path.join(PLOTS_FOLDER, f'DistributionDisease_{group}_{disease}_{get_cell_type(fname)}.pdf'), bbox_inches='tight')
    plt.close()