In [1]:
%%capture
%cd scripts

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%matplotlib notebook

# Imports

In [3]:
from itertools import product
import os

import graph_tool.all as gt
import matplotlib
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
from scipy.stats import hypergeom, pearsonr
import seaborn as sns
from sklearn.cluster import KMeans

from functions import *


# Graph-Tool compatibility
plt.switch_backend('cairo')

# Style
sns.set_theme(context='talk', style='white', palette='Set2')
plt.rcParams.update({
    'pdf.fonttype': 42,
    'ps.fonttype': 42,
    'font.size': 22,
    'axes.titlesize': 'medium',
    'axes.labelsize': 'large',
    'xtick.labelsize': 'medium',
    'ytick.labelsize': 'medium',
    'legend.fontsize': 'medium',
    'legend.title_fontsize': 'medium',
    'figure.titlesize': 'x-large',
})

# Figure transparency
# matplotlib.rcParams['figure.facecolor'] = (1., 0., 0., 0.3)  # Debugging
matplotlib.rcParams['figure.facecolor'] = (1., 0., 0., 0.)


In [4]:
# Integrity check
check_ct_edge_specificity()  # Check for duplicate edges with different attentions

100%|███████████████████████████████████████████████████████████████████| 1022/1022 [00:15<00:00, 67.39it/s]


# Metadata

In [5]:
# Load metadata
meta = get_meta()

# Subject preview
filtered = []
for i, row in meta.iterrows():
    try:
        load_graph_by_id(row['SubID'])
        assert not np.isnan(row['nps_MoodDysCurValue'])  # Has NPS information available
        assert row['BRAAK_AD'] in (6,) and row['CERAD'] in (4,) and row['CDRScore'] in (3,)
    except:
        continue
    filtered.append(f'{row["SubID"]} {row["Ethnicity"]} {row["Sex"]}, {row["Age"]}, BRAAK {row["BRAAK_AD"]}, CERAD {row["CERAD"]}, CDR {row["CDRScore"]}, {row["Dx"]}')
filtered = np.sort(filtered)
for i in range(len(filtered)):
    # print(filtered[i])
    pass


# Attention Stack

In [6]:
fname = './attentions.pkl'
if os.path.isfile(fname):
    # Load data
    with open('./attentions.pkl', 'rb') as f:
        all_data = pickle.load(f)
    attention_stack, all_edges, columns, subject_ids = all_data['data'], all_data['edges'], all_data['heads'], all_data['subject_ids']

else:
    # Parameters
    # Scaled probably shouldn't be used, but better for visualization
    # until results are more even
    columns = get_attention_columns(scaled=False)
    subject_ids = meta['SubID'].to_numpy()

    # Load graphs
    graphs, subject_ids = load_many_graphs(subject_ids, column=columns)
    # graphs = [compute_graph(g) for g in graphs]

    # # Get attentions
    # df = {}
    # for column in get_attention_columns():
    #     attention, _ = compute_edge_summary(graphs, subject_ids=subject_ids)
    #     attention = attention.set_index('Edge')
    #     df[column] = attention.var(axis=1)


    # Set indices to edges and clean
    print('Fixing indices...')
    for i in tqdm(range(len(graphs))):
        graphs[i].index = graphs[i].apply(lambda r: get_edge_string([r['TF'], r['TG']]), axis=1)
        graphs[i] = graphs[i].drop(columns=['TF', 'TG'])
        # Remove duplicates
        graphs[i] = graphs[i][~graphs[i].index.duplicated(keep='first')]

    # Get all unique edges
    print('Getting unique edges...')
    all_edges = np.unique(sum([list(g.index) for g in graphs], []))


    # Standardize index order
    print('Standardizing indices...')
    for i in tqdm(range(len(graphs))):
        # Add missing indices and order based on `all_edges`
        # to_add = [edge for edge in all_edges if edge not in list(graphs[i].index)]  # SLOW
        to_add = list(set(all_edges) - set(graphs[i].index))

        # Empty rows
        new_rows = pd.DataFrame(
            [[np.nan]*len(graphs[i].columns)]*len(to_add),
            columns=graphs[i].columns,
        ).set_index(pd.Series(to_add))
        # Native concat
        graphs[i] = pd.concat([graphs[i], new_rows]).loc[all_edges]

    # Convert to numpy
    graphs = [g.to_numpy() for g in graphs]
    attention_stack = np.stack(graphs, axis=-1)
    # attention_stack.shape = (Edge, Head, Subject)
    # attention_stack.shape = (all_edges, columns, subject_ids)

    # Save all data
    all_data = {'data': attention_stack, 'edges': all_edges, 'heads': columns, 'subject_ids': subject_ids}
    # np.savez('attentions.npz', **all_data)
    with open(fname, 'wb') as f:
        pickle.dump(
            all_data,
            f,
            protocol=pickle.HIGHEST_PROTOCOL,
        )


In [7]:
# Additional useful parameters
self_loops = [split_edge_string(s)[0] == split_edge_string(s)[1] for s in all_edges]
self_loops = np.array(self_loops)
# Remove self loops
all_edges = all_edges[~self_loops]
attention_stack = attention_stack[~self_loops]


# Global Parameters

In [8]:
# Parameters
print(f'\nAvailable attention columns: {get_attention_columns()}')
column_ad = get_attention_columns()[0]
column_scz = get_attention_columns()[2]
column_data = get_attention_columns()[4]
synthetic_nodes_of_interest = ['OPC', 'Micro', 'Oligo']



Available attention columns: ['AD_imp_1', 'AD_imp_2', 'SCZ_imp_1', 'SCZ_imp_2', 'data_imp_1', 'data_imp_2', 'data_imp_3', 'data_imp_4']


# Intra-Contrast Comparisons

In [9]:
# Figure parameters
param = {
    'subjects': ['M31969', 'M20337'],
    'columns': [column_data, column_ad, column_scz],
    'column_names': ['Data-Driven', 'AD-Prior', 'SCZ-Prior'],
    'column_groups': [get_attention_columns()[4:8], get_attention_columns()[:2], get_attention_columns()[2:4]],
    'column_group_names': ['Data Prioritization', 'AD Prioritization', 'SCZ Prioritization'],
    'ancestries': meta.groupby('Ethnicity').count()['SubID'].sort_values().index[::-1].to_list()[:3] + ['all'],
    'contrast': 'c15x',
}

# Generate palette
palette = plt.rcParams['axes.prop_cycle'].by_key()['color']
param['palette'] = {sid: rgba_to_hex(palette[i]) for i, sid in enumerate(param['subjects'])}

# Preview subjects
for sid in param['subjects']:
    row = meta.loc[meta['SubID']==sid].iloc[0]
    # print(f'{row["SubID"]} {row["Ethnicity"]} {row["Sex"]}, {row["Age"]}, BRAAK {row["BRAAK_AD"]}, CERAD {row["CERAD"]}, CDR {row["CDRScore"]}, {row["Dx"]}')


In [10]:
# Subplot layout (doesn't work well with constrained layout)
# NOTE: This cannot be used, as constrained layout has glitches
# (see https://github.com/matplotlib/matplotlib/issues/23290)
# with uneven mosaics
# fig, axs = get_mosaic(shape, figsize=(int((3/2) * shape_array.shape[1]), int((3/2) * shape_array.shape[0])), constrained_layout=False)

# Subfigure layout (longer)
# NOTE: Constrained layout will fail for all
# subplots if a single one is not able to scale.
# Also, sometimes leaving a subfigure blank will
# cause it to fail, especially if on an edge.
# It is VERY finnicky.
# SOLUTION: Save again using `fig.savefig(...)`
# and it will run without warning.  Then, you
# can visually inspect for scaling issues.
# fig, axs = create_subfigure_mosaic(shape_array)
# fig.set_constrained_layout_pads(w_pad=0, h_pad=0, wspace=.4, hspace=.4)  # *_pad is pad for figs (including subfigs), *_space is pad between subplots


## Edge Prioritization and Cross-Ancestry Enrichment - Figure 5cd

In [None]:
shape = """
    NNNNNNNNNNN
    NNNNNNNNNNN
    NNNNNNNNNNN
    NNNNNNNNNNN
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
"""
fig, axs = create_subfigure_mosaic(shape_array_from_shape(shape))

axs_lab = (len(param['ancestries']) - 1) * ['None'] + ['N']
# axs_lab = ['K', 'L', 'M', 'N']
print(f'\nEdge Discovery Enrichment ({", ".join(axs_lab)})')
for ancestry, ax in zip(param['ancestries'], [axs[lab] if lab in axs else None for lab in axs_lab]):
    # Filter to ancestry
    anc_data = all_data.copy()
    if ancestry != 'all':
        sub_ids = meta.loc[meta['Ethnicity'] == ancestry, 'SubID'].to_list()
        mask = [sid in sub_ids for sid in anc_data['subject_ids']]
        anc_data['data'] = anc_data['data'][:, :, mask]
        anc_data['subject_ids'] = np.array(anc_data['subject_ids'])[mask]

    # Run
    temp = plot_edge_discovery_enrichment(
        **anc_data,
        column=param['columns'][0],
        range_colors=[rgb_to_float(hex_to_rgb('#7aa457')), rgb_to_float(hex_to_rgb('#a46cb7')), rgb_to_float(hex_to_rgb('#cb6a49'))],
        ax=ax,
        postfix=f'{ancestry}_{param["columns"][0]}',
        gene_max_num=300,
        threshold=95,
        clamp_min=4,
        skip_plot=(ax is None),
        verbose=True)
    if ax is not None:
        ax.set_xlabel(f'High-Scoring Edges ({param["column_names"][0]})')
        ylabel = 'Frequency'
        if ancestry != 'all': ylabel += f' ({ancestry})'
        ax.set_ylabel(ylabel)
    # MANUAL PROCESSING
    # Run the output '../plots/genes_<column>.csv' from above on Metascape as multiple gene list and perform
    # enrichment.  From the all-in-one ZIP file, save the file from Enrichment_GO/GO_membership.csv as '../plots/go_<column>.csv'
    # and rerun.

axs_lab = ['R']
print(f'\nAncestry Enrichment Comparison ({", ".join(axs_lab)})')
postfixes = [f'{ancestry}_{param["columns"][0]}' for ancestry in param['ancestries']]
enrichments = plot_cross_enrichment(postfixes, names=param['ancestries'], ax=axs[axs_lab[0]], excluded_subgroups=['all'])

# Place labels
offset = plot_labels(axs, shape=shape)

# Save figure
print('\nSaving Figure...')
fig.savefig(f'../plots/figure_5_main.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True, backend='cairo')


## Distributions Across Heads and Ancestries - Supplementary Figure 12

In [None]:
shape = """
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
"""
fig, axs = create_subfigure_mosaic(shape_array_from_shape(shape))
matplotlib.rcParams['font.size'] = 45

# Plot all panels
axs_lab = ['K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V']
print(f'\nEdge Discovery Enrichment ({", ".join(axs_lab)})')
for (ancestry, column_idx), ax in zip(product(param['ancestries'], range(len(param['columns']))), [axs[lab] if lab in axs else None for lab in axs_lab]):
    # Filter to ancestry
    anc_data = all_data.copy()
    if ancestry != 'all':
        sub_ids = meta.loc[meta['Ethnicity'] == ancestry, 'SubID'].to_list()
        mask = [sid in sub_ids for sid in anc_data['subject_ids']]
        anc_data['data'] = anc_data['data'][:, :, mask]
        anc_data['subject_ids'] = np.array(anc_data['subject_ids'])[mask]

    # Run
    temp = plot_edge_discovery_enrichment(
        **anc_data,
        column=param['columns'][column_idx],
        range_colors=[rgb_to_float(hex_to_rgb('#7aa457')), rgb_to_float(hex_to_rgb('#a46cb7')), rgb_to_float(hex_to_rgb('#cb6a49'))],
        ax=ax,
        postfix=f'{ancestry}_{param["columns"][column_idx]}',
        gene_max_num=300,
        threshold=95,
        wrap_chars=15,
        skip_plot=(ax is None))
    if ax is not None:
        ax.set_xlabel(f'High-Scoring Edges ({param["column_names"][column_idx]})')
        ylabel = 'Frequency'
        if ancestry != 'all': ylabel += f' ({ancestry})'
        ax.set_ylabel(ylabel)
    # MANUAL PROCESSING
    # Run the output '../plots/genes_<column>.csv' from above on Metascape as multiple gene list and perform
    # enrichment.  From the all-in-one ZIP file, save the file from Enrichment_GO/GO_membership.csv as '../plots/go_<column>.csv'
    # and rerun.

# Place labels
offset = plot_labels(axs, shape=shape)

# Save figure
print('\nSaving Figure...')
fig.savefig(f'../plots/figure_5_supplement.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True)  # , backend='cairo'
matplotlib.rcParams['font.size'] = 22


## PRS Analyses - Figure 6a

In [None]:
# SCZ
# Data
# AD
# SCZ
shape = """
    SSSSSSSSSSSSSSSUUUUUUUUUUUUUUU
    SSSSSSSSSSSSSSSUUUUUUUUUUUUUUU
    SSSSSSSSSSSSSSSUUUUUUUUUUUUUUU
    SSSSSSSSSSSSSSSUUUUUUUUUUUUUUU
"""
fig, axs = create_subfigure_mosaic(shape_array_from_shape(shape))

# Plot all panels
axs_lab = ['S', 'U']
print(f'\nPRS Analysis ({", ".join(axs_lab)})')
# Takes around an hour for each loop with no subsampling (on first run)
for fname, head_prefix, ylabel, prs_col, ax_idx in zip(
    ('ad_prs_df.csv', 'scz_prs_df.csv'),
    ('_'.join(column_ad.split('_')[:-1]), '_'.join(column_scz.split('_')[:-1])),
    ('AD Importance Score', 'SCZ Importance Score'),
    ('prs_scaled_AD_Bellenguez', 'prs_scaled_SCZ.3.5_MVP'),
    axs_lab
):
    df = pd.read_csv(fname, index_col=0) if os.path.isfile(fname) else None
    covariates = get_genotype_meta()[['SubID', 'imp_sex_score'] + [f'imp_anc_PC{i}' for i in range(1, 7)] + [f'imp_anc_{anc}' for anc in ('AFR', 'AMR', 'EAS', 'EUR')]]
    df, prs_df, axs[ax_idx] = plot_prs_correlation(
        meta, **all_data, ax=axs[ax_idx],
        df=df, num_targets=5, ylabel=ylabel, max_scale=False,
        head_prefix=head_prefix, prs_col=prs_col,
        covariates=covariates, subsample=1)
    if not os.path.isfile(fname): df.to_csv(fname)

# Place labels
offset = plot_labels(axs, shape=shape)

# Save figure
print('\nSaving Figure...')
fig.savefig(f'../plots/figure_6_prs.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True, backend='cairo')

# Revision Panels

## Revision Gene Importance Enrichment

In [None]:
# Load data
importance_scores = get_importance_scores()
importance_scores = importance_scores.rename(columns={'ad_imp_score': 'AD', 'scz_imp_score': 'SCZ', 'data_imp_score': 'Data'})
importance_scores['All'] = (importance_scores['AD'] * 2 + importance_scores['SCZ'] * 2 + importance_scores['Data'] * 4) / 8

# Get meta
importance_scores['BRAAK_AD'] = meta.set_index('SubID').loc[importance_scores['sample'], 'BRAAK_AD'].to_numpy()

# Calculate correlations
heads = ['All']  # ['AD', 'SCZ', 'Data']
df_corr = pd.DataFrame(columns=['Gene', 'Head', 'Correlation', 'Significance', 'Samples'])
unique_genes = importance_scores['node'].unique()
for i, g in tqdm(enumerate(unique_genes), total=unique_genes.shape[0]):
    for j, h in enumerate(heads):
        filtered_df = importance_scores.loc[(importance_scores['node'] == g)]
        col = filtered_df[h].to_numpy()
        geno = filtered_df['BRAAK_AD'].to_numpy()
        mask = ~np.isnan(col) * ~np.isnan(geno)
        if mask.sum() > 2 and col[mask].var() != 0 and geno[mask].var() != 0:
            corr, sig = scipy.stats.spearmanr(col[mask], geno[mask])
            num = mask.sum()
            df_corr.loc[df_corr.shape[0]] = {'Gene': g, 'Head': h, 'Correlation': corr, 'Significance': sig, 'Samples': num}

# FDR correction
df_corr['Adjusted Significance'] = scipy.stats.false_discovery_control(df_corr['Significance'].clip(0, 1), method='bh')


In [None]:
# Parameters
min_samples = 5
significance_threshold = 5e-2
top_genes = int(1e3)

# Histogram of adjusted significance
axs = df_corr.plot.hist(column=['Significance', 'Adjusted Significance'], by='Head', bins=int(2/significance_threshold), figsize=(12, 5*df_corr['Head'].unique().shape[0]))
for ax in axs:
    ax.set_yscale('log')
    ax.axvline(x=significance_threshold, ls='--', color='red')
plt.savefig(f'../plots/rev_1_histogram.pdf', bbox_inches='tight', transparent=True)

In [None]:
# ## Separate head analysis
# # Gene lists
# sig_genes = pd.DataFrame()
# for h in heads:
#     # Filter dataframe
#     df_filtered = df
#     df_filtered = df_filtered.loc[df_filtered['Head'] == h]  # Head
#     df_filtered = df_filtered.loc[df_filtered['Samples'] > min_samples]  # Samples
#     df_filtered = df_filtered.loc[df_filtered['Adjusted Significance'] < significance_threshold]  # Significance
#     genes = np.unique([g.split(':')[1] for g in df_filtered[['Gene']].to_numpy().flatten() if g.split(':')[0] in ('TF', 'TG')])

#     # Save genes
#     print(f'Significant {h} genes: {genes.shape[0]}')
#     sig_genes = pd.concat([sig_genes, pd.DataFrame({h: genes})], axis=1)

# # Add background and save
# genes = np.unique([g.split(':')[1] for g in importance_scores[['node']].to_numpy().flatten() if g.split(':')[0] in ('TF', 'TG')])
# genes = np.array([g for g in genes if not string_is_synthetic(g)])
# sig_genes = pd.concat([sig_genes, pd.DataFrame({'_BACKGROUND': genes})], axis=1)
# pd.DataFrame(sig_genes).to_csv('../plots/rev_1_genes.csv', index=False)
# # np.savetxt(f'../plots/rev_1_genes_{h.lower()}.txt', genes, fmt='%s')


In [None]:
## Aggregate analysis
# Gene lists
sig_genes = {}

# Filter for positive and negative
for name, factor in zip(('Positive', 'Negative'), (1, -1)):
    df_filtered = df_corr
    df_filtered = df_filtered.loc[df_filtered['Samples'] >= min_samples]  # Samples
    df_filtered = df_filtered.loc[df_filtered['Adjusted Significance'] <= significance_threshold]  # Significance
    df_filtered = df_filtered.loc[factor * df_filtered['Correlation'] > 0]  # Correlation parity
    df_filtered = df_filtered.sort_values(by='Correlation', ascending=factor<0)  # Sort genes
    df_filtered['Gene'] = df_filtered['Gene'].apply(lambda g: g.split(':')[1] if g.split(':')[0] in ('TF', 'TG') else pd.NA)  # Remove celltypes
    df_filtered = df_filtered.dropna()
    df_filtered = df_filtered[~df_filtered['Gene'].duplicated(keep='first')]  # Remove TF + TG genes that got in twice, keep higher score
    df_filtered = df_filtered.iloc[:top_genes]
    genes = df_filtered['Gene'].to_numpy()
    corr = df_filtered['Correlation'].to_numpy()
    sig = df_filtered['Significance'].to_numpy()

    # Save genes
    print(f'Post-filter {name} genes: {genes.shape[0]}')  # Post-filter
    sig_genes[name] = df_filtered[['Gene', 'Samples', 'Correlation', 'Significance', 'Adjusted Significance']]

# Add background and save
genes = np.unique([g.split(':')[1] for g in importance_scores[['node']].to_numpy().flatten() if g.split(':')[0] in ('TF', 'TG')])
background = pd.DataFrame({'_BACKGROUND': genes})
# sig_genes.to_csv('../plots/rev_1_genes.csv', index=False)
with pd.ExcelWriter('../plots/rev_1_genes.xlsx') as w:
    for name, df in sig_genes.items():
        df.to_excel(w, sheet_name=name, index=False)
    background.to_excel(w, sheet_name='Background', index=False)


## Revision Ancestry Enrichment

In [None]:
# Get data-driven column and ancestries
dd_col = param['columns'][0]
ancs = param['ancestries'][:-1]

# Load high-percentile gene sets
fnames = [f'../plots/genes_{anc}_{dd_col}.csv' for anc in ancs]
dfs = [pd.read_csv(fname) for fname in fnames]
data = [df.iloc[:, -2].dropna().sort_values().to_list() for df in dfs]
anc_sizes = [len(d) for d in data]

# Find intersections
from collections import defaultdict
def find_intersections(data, names):
    groups = defaultdict(lambda: [])
    while sum([len(d) for d in data]) > 0:
        # Query current values
        low_idx = np.nanargmin([d[0] if len(d) > 0 else np.nan for d in data])

        # Pop
        low_val = data[low_idx].pop(0)

        # Find and pop other equal heads
        other_equal = [i for i, d in enumerate(data) if i != low_idx and len(d) > 0 and d[0] == low_val]
        for i in other_equal: data[i].pop(0)

        # Aggregate intersection and record
        all_equal = np.sort(other_equal + [low_idx])
        groups['-'.join(np.sort([names[i] for i in all_equal]))].append(low_val)
    
    return groups
groups = find_intersections(data, ancs)

# Print group sizes
# groups = dict(groups)
print('Intersection counts')
for k, v in groups.items():
    print(f'{k}: {len(v)}')

# Save unique genes
new_fnames = [f'../plots/genes_{k}_{dd_col}_uniq.csv' for k in groups]
for new_fname, k in zip(new_fnames, groups): pd.concat([
        pd.DataFrame({dfs[0].columns[-2]: groups[k]}),
        pd.DataFrame({'_BACKGROUND': dfs[0]['_BACKGROUND']})
    ], axis=1).to_csv(new_fname, index=False)

In [None]:
# Simulate overlap
np.random.seed(42)
num_gene_sets = [
    ('All Background', dfs[0]['_BACKGROUND'].shape[0]),  # Assumes bachground is the same between all
    ('Highly Variant', np.unique([g for grp in groups.values() for g in grp])),  # Only variant genes from each ancestry
]
group_sizes = anc_sizes
num_iterations = int(1e4)
sim_group_counts = defaultdict(lambda: defaultdict(lambda: []))
for background_name, num_genes in num_gene_sets:
    for _ in tqdm(range(num_iterations), total=num_iterations, desc=f'{background_name}'):
        # Sample gene lists
        sim_data = []
        for size in group_sizes:
            sim_data.append( np.sort(np.random.choice(num_genes, size, replace=False)).tolist() )
        
        # Find intersections
        sim_groups = find_intersections(sim_data, ancs)

        # Record
        for k, v in sim_groups.items():
            sim_group_counts[background_name][k].append(len(v))

In [None]:
# Plot each sim group with the corresponding reading
group_order = ['AFR', 'AMR', 'EUR', 'AFR-AMR', 'AFR-EUR', 'AMR-EUR', 'AFR-AMR-EUR']
fig, axs = plt.subplots(len(num_gene_sets), len(groups), figsize=(len(groups)*8, len(num_gene_sets)*4))
for i, k1 in enumerate(sim_group_counts):
    for j, k2 in enumerate(group_order):
        sim_count = sim_group_counts[k1][k2]
        obs_count = len(groups[k2])
        quantile = (np.array(sim_count) < obs_count).mean()

        # Plot
        ax = axs[i][j]
        data = pd.DataFrame({k: sim_group_counts[k1][k2]})
        sns.histplot(data=data, x=k, kde=True, color='gray', ax=ax)

        # Significance
        ax.axvline(x=obs_count, ls='-', color='red')
        pval = min(quantile, 1-quantile)  # One-tailed p test
        if pval < 1/num_iterations: pval = 1/num_iterations
        sig_thresholds = np.array([5e-2, 1e-2, 1e-3])
        sig_tail = '*' * (pval < sig_thresholds).sum()
        pval_string = f'p<{pval:.2e}{sig_tail}'
        ax.text(.5, .95, pval_string, ha='center', va='top', transform=ax.transAxes)

        # Labels
        ax.set(xlabel=None, ylabel=None)
        if i == 0: ax.set_title(k2)
        if j == 0: ax.set_ylabel(k1)
fig.savefig(f'../plots/rev_1_ancestry_significance.pdf', bbox_inches='tight', transparent=True)
plt.close(fig)

In [None]:
# Plot after MANUAL PROCESSING
# TODO: Maybe also plot intersecting genes?
fig, ax = plt.subplots(1, 1, figsize=(int((3/2)*11), int((3/2)*7)))
file_prefixes = ['_'.join(fname.split('/')[-1].split('_')[1:])[:-4] for fname, k in zip(new_fnames, groups) if len(groups[k]) > 100]
group_names = [fname.split('_')[0] for fname in file_prefixes]
enrichments = plot_cross_enrichment(file_prefixes, names=group_names, num_terms=30, ax=ax)
fig.savefig(f'../plots/rev_1_ancestry.pdf', bbox_inches='tight', transparent=True)
plt.close(fig)

# Secondary Revision Panels

## Histograms of Edge Prioritization Distributions - Supplementary Figure XX

In [None]:
shape = """
    KKKKKKKLLLLLLLMMMMMMM
    KKKKKKKLLLLLLLMMMMMMM
    KKKKKKKLLLLLLLMMMMMMM
    NNNNNNNOOOOOOOPPPPPPP
    NNNNNNNOOOOOOOPPPPPPP
    NNNNNNNOOOOOOOPPPPPPP
    QQQQQQQRRRRRRRSSSSSSS
    QQQQQQQRRRRRRRSSSSSSS
    QQQQQQQRRRRRRRSSSSSSS
    TTTTTTTUUUUUUUVVVVVVV
    TTTTTTTUUUUUUUVVVVVVV
    TTTTTTTUUUUUUUVVVVVVV
"""
fig, axs = create_subfigure_mosaic(shape_array_from_shape(shape))
matplotlib.rcParams['font.size'] = 45

# Plot all panels
axs_lab = ['K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V']
print(f'\nEdge Discovery Enrichment ({", ".join(axs_lab)})')
for (ancestry, column_idx), ax in zip(product(param['ancestries'], range(len(param['columns']))), [axs[lab] if lab in axs else None for lab in axs_lab]):
    # Filter to ancestry
    anc_data = all_data.copy()
    if ancestry != 'all':
        sub_ids = meta.loc[meta['Ethnicity'] == ancestry, 'SubID'].to_list()
        mask = [sid in sub_ids for sid in anc_data['subject_ids']]
        anc_data['data'] = anc_data['data'][:, :, mask]
        anc_data['subject_ids'] = np.array(anc_data['subject_ids'])[mask]

    # Run
    # Compute edge counts
    edge_counts = compute_edge_counts(**anc_data)
    head_filt = edge_counts['Head'] == param['columns'][column_idx]
    zero_filt = edge_counts['Count'] > 0
    edge_counts_filt = edge_counts.loc[head_filt * zero_filt]
    # Plot histogram
    # fig, ax = plt.subplots(1, 1, figsize=(10, 5))
    sns.histplot(
        edge_counts_filt, x='Count', bins=.5+np.concatenate([np.linspace(0, 10, 11), np.floor(np.logspace(1, 3, 21))[1:]]),
        color='gray', edgecolor='.3', lw=.5, ax=ax)
    sns.despine(ax=ax)
    ax.set(xlabel=f'Prioritizations ({param["column_names"][column_idx]})', ylabel=f'Edges ({ancestry})')
    # Axes
    ax.set_yscale('log', subs=list(range(2, 10)))
    ax.set_xscale('symlog', linthresh=10)
    ax.set_xlim(left=0)
    # Ticks
    ax.yaxis.get_major_locator().set_params(numticks=99)
    ax.yaxis.get_minor_locator().set_params(numticks=99)
    ax.set_xticks(np.concatenate([[0, 1]] + [np.linspace(10**i, 10**(i+1), 10)[1:] for i in range(3)]), minor=True)
    ax.tick_params(which='both', left=True, bottom=True)

# Place labels
offset = plot_labels(axs, shape=shape)

# Save figure
print('\nSaving Figure...')
fig.savefig(f'../plots/figure_5_supplement_hist.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True)  # , backend='cairo'
matplotlib.rcParams['font.size'] = 22


## Cluster Edge Enrichments

In [14]:
# Get all present edges
has_value = np.argwhere(~np.isnan(all_data['data']))
present_edges = pd.DataFrame({
    'Edges': all_data['edges'][has_value[:, 0]],
    'Heads': np.array(all_data['heads'])[has_value[:, 1]],
    'Subject IDs': np.array(all_data['subject_ids'])[has_value[:, 2]],
    'Value': all_data['data'][has_value[:, 0], has_value[:, 1], has_value[:, 2]]})
present_edges[['Source', 'Target']] = present_edges['Edges'].str.split(EDGE_SPLIT_STRING, expand=True)

# Apply any pre-filtering
# present_edges = present_edges.loc[present_edges['Value'] > present_edges['Value'].quantile(.9)]  # 90th percentile of attentions BEFORE assessing clusters for uniqueness

# Attach clusters
clusters = pd.read_csv('./figure_2d_clusters.csv', index_col=0)
present_edges = present_edges.join(clusters.set_index('SubID')[['louvain']], on='Subject IDs').dropna()  # Some subjects are not found in the clustering
instances_edges_cluster_series = present_edges.groupby(['Edges', 'louvain']).size() / len(all_data['heads'])  # How many per edge and cluster
instances_edges_series = instances_edges_cluster_series.reset_index().groupby('Edges').size()  # How many clusters per edge
unique_edges = instances_edges_series.index[instances_edges_series == 1]  # Get cluster-unique edges
unique_present_edges = present_edges.loc[present_edges['Edges'].isin(unique_edges)]

In [39]:
# Params
head = 'data'

# Apply any post-filtering
if '_' not in head: unique_present_edges_filt = unique_present_edges.loc[unique_present_edges['Heads'].apply(lambda s: s.startswith(head))]  # Filter to group of heads
else: unique_present_edges_filt = unique_present_edges.loc[unique_present_edges['Heads'] == head]  # Filter to individual head
unique_present_edges_filt = unique_present_edges_filt.drop(columns='Heads').groupby(['Edges', 'Subject IDs', 'Source', 'Target', 'louvain']).mean().reset_index()  # Mean over all remaining heads
unique_present_edges_filt = unique_present_edges_filt.loc[unique_present_edges_filt['Value'] > present_edges['Value'].quantile(.98)]  # 98th percentile of attentions AFTER assessing clusters for uniqueness

# CLI
for cluster, num in unique_present_edges_filt.groupby('louvain').size().items():
    print(f'Cluster {cluster:.0f} has {num:.0f} unique edges')  # NOTE: If multiple heads, might over-count

Cluster 0 has 462 unique edges
Cluster 1 has 358 unique edges
Cluster 2 has 342 unique edges
Cluster 3 has 189 unique edges
Cluster 4 has 47 unique edges


In [40]:
# Get unique genes by cluster
unique_genes = {'_BACKGROUND': pd.Series(pd.concat([present_edges['Source'], present_edges['Target']]).unique())}
for cluster in unique_present_edges_filt['louvain'].unique():
    df = unique_present_edges_filt
    df = df.loc[df['louvain'] == cluster]
    unique_gene_list = pd.Series(pd.concat([df['Source'], df['Target']]).unique())
    unique_gene_list = unique_gene_list[~unique_gene_list.apply(string_is_synthetic)].reset_index(drop=True)
    unique_genes[cluster] = unique_gene_list
    print(f'Cluster {cluster:.0f} has {unique_genes[cluster].shape[0]:.0f} genes within unique edges')
unique_genes = pd.DataFrame(unique_genes)

# Save
unique_genes.to_csv(f'../plots/cluster_unique_genes_{head}.csv', index=None)

Cluster 0 has 361 genes within unique edges
Cluster 1 has 300 genes within unique edges
Cluster 2 has 282 genes within unique edges
Cluster 3 has 167 genes within unique edges
Cluster 4 has 47 genes within unique edges


In [None]:
# RUN IN METASCAPE AND EXTRACT `GO_membership.csv`, rename to `cluster_unique_genes_{head}_GO.csv`

In [None]:
# Params
head = 'data'  # data_imp_1, data, SCZ, AD

# Construct data
enrichment = pd.read_csv(f'../plots/cluster_unique_genes_{head}_GO.csv', index_col=None)
cluster_cols = [f'_LogP_{i:.1f}' for i in range(5)]
filtered_enrichment = enrichment.loc[enrichment[cluster_cols].mean(axis=1).argsort()].groupby('_PATTERN_').head(3)  # 3 highest means for each pattern
df = -filtered_enrichment.set_index('Description')[cluster_cols].rename(columns={col: i for i, col in enumerate(cluster_cols)}).iloc[::-1]

# Create plot
fig, ax = plt.subplots(figsize=(10, 20))

# Create grid
xlabels = df.columns
ylabels = df.index
Y, X = np.meshgrid(np.arange(len(ylabels)), np.arange(len(xlabels)), indexing='ij')
R = df.to_numpy() / df.to_numpy().max() / 2

# Populate axis
circles = [plt.Circle((x, y), radius=r) for x, y, r in zip(X.flat, Y.flat, R.flat)]
col = matplotlib.collections.PatchCollection(circles, edgecolor='none', array=df.to_numpy().flat, cmap='Reds')  # cmap
col.set_clim(vmin=0)
ax.add_collection(col)

# Formatting
ax.set(
    xticks=np.arange(len(xlabels)), xticklabels=xlabels, xlabel='Development Cluster',
    yticks=np.arange(len(ylabels)), yticklabels=ylabels)
ax.set_xticks(np.arange(len(xlabels)+1)-.5, minor=True)
ax.set_yticks(np.arange(len(ylabels)+1)-.5, minor=True)
ax.grid(which='minor')
ax.set_aspect('equal')

fig.colorbar(col, label='-Log10(p)')
fig.savefig(f'../plots/cluster_unique_genes_{head}.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True)


In [323]:
# # Print gene list
# cluster = 0
# for g in unique_genes[cluster].dropna(): print(g)

## Significant Edge List

In [11]:
# Get significant DEX genes
dex_genes = pd.read_csv('PsychAD_SupplementaryTable6.csv.gz')
dex_gene_list = dex_genes.loc[dex_genes['FDR'] < .05, 'ID'].unique()

In [None]:
# Get all present edges
has_value = np.argwhere(~np.isnan(all_data['data']))
present_edges = pd.DataFrame({
    'Edge': all_data['edges'][has_value[:, 0]],
    'Head': np.array(all_data['heads'])[has_value[:, 1]],
    'Subject ID': np.array(all_data['subject_ids'])[has_value[:, 2]],
    'Value': all_data['data'][has_value[:, 0], has_value[:, 1], has_value[:, 2]]})
present_edges[['Source', 'Target']] = present_edges['Edge'].str.split(EDGE_SPLIT_STRING, expand=True)  # Extract genes
present_edges = present_edges.join(meta.set_index('SubID')[['AD']], on='Subject ID')  # Append AD annotation

# Get overlaps (total, not by ct)
# TODO: Add CT
present_edges['DEX Overlap'] = present_edges['Source'].isin(dex_gene_list) + present_edges['Target'].isin(dex_gene_list)

# Get data means
data_edges = present_edges.loc[present_edges['Head'].str.startswith('AD_imp_')].groupby(list(np.setxor1d(present_edges.columns, ['Value']))).mean().reset_index()

# Filter
data_edges['Top'] = data_edges['Value'] >= data_edges['Value'].quantile(.98)
sig_edges = data_edges.loc[data_edges['Top']]

In [38]:
# Get percentage shared with capstone
# shared_w_capstone = data_edges[['Subject ID', 'AD', 'DEX Overlap', 'Top']].groupby(['Subject ID', 'AD', 'Top']).mean().reset_index()
# shared_w_capstone['Hash'] = shared_w_capstone.apply(lambda r: f'{"Filtered " if r["Top"] else ""}{"AD" if r["AD"] else "Non-AD"}', axis=1)
# Hard filter
shared_w_capstone = data_edges[['Subject ID', 'AD', 'DEX Overlap', 'Top']].groupby(['Subject ID', 'AD', 'Top']).mean().reset_index()
shared_w_capstone['Hash'] = shared_w_capstone.apply(lambda r: f'{"Filtered" if r["Top"] else "Non-Filtered"}', axis=1)

# Create plot
fig, ax = plt.subplots(figsize=(3, 5))

# Plot
sns.boxplot(
    shared_w_capstone, x='Hash', y='DEX Overlap',
    # order=['Non-AD', 'Filtered Non-AD', 'AD', 'Filtered AD'], palette={0.: 'blue', 1.: 'red'}, hue='AD',
    order=['Non-Filtered', 'Filtered'], color='black',
    fill=False, legend=False, ax=ax)
ax.set(xlabel=None, ylabel='Capstone Overlap')
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)
# ax.set(xticks=[0, 1], xticklabels=['Non-AD', 'AD'])
ax.set_ylim(0, 1)
sns.despine(ax=ax)

# Save
fig.savefig(f'../plots/capstone_overlap.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True)
shared_w_capstone[['Hash', 'DEX Overlap']].groupby(['Hash']).mean()


  ax.set_xticklabels(ax.get_xticklabels(), rotation=90)


Unnamed: 0_level_0,DEX Overlap
Hash,Unnamed: 1_level_1
Filtered,0.713676
Non-Filtered,0.668606


In [14]:
# Get AD genes, include duplicates
identified_genes = np.concatenate([sig_edges['Source'], sig_edges['Target']])
identified_genes = np.array([g for g in identified_genes if not string_is_synthetic(g)])

# Get AD genes, exclude duplicates
# identified_genes = []
# for sid in tqdm(sig_edges['Subject ID'].unique()):
#     df_filt = sig_edges.loc[sig_edges['Subject ID']==sid]
#     new_genes = np.concatenate([df_filt['Source'], df_filt['Target']])
#     new_genes = np.unique(new_genes)
#     identified_genes.append(new_genes)
# identified_genes = np.concatenate(identified_genes)

# Sort genes by frequency
unique_genes, unique_counts = np.unique(identified_genes, return_counts=True)
highest_counts = np.argsort(unique_counts)[::-1]

In [15]:
# Compute ROC curve
import sklearn.metrics
fpr, tpr, thresholds = sklearn.metrics.roc_curve(np.isin(unique_genes, dex_gene_list), unique_counts)

# Create plot
fig, ax = plt.subplots(figsize=(5, 5))

# Plot
ax.plot(fpr, tpr, color='black')
ax.plot([0, 1], [0, 1], ls='--', color='black')
ax.set(title='Consensus ROC', xlabel='False Positive Rate', ylabel='True Positive Rate')
ax.text(.95, .05, f'AUC: {sklearn.metrics.auc(fpr, tpr):.3f}', ha='right', va='bottom', transform=ax.transAxes)
for i in range(20, len(thresholds)-10, 45):
    ax.text(fpr[i]+.07, tpr[i], f'{thresholds[i]:.0f}', ha='left', va='center', fontsize='xx-small', backgroundcolor='white', transform=ax.transData)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

# Save
fig.savefig(f'../plots/capstone_recovery.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True)


In [24]:
# Get overlap by population threshold
results = []
for thresh in tqdm(np.unique(unique_counts)):
    gene_list = unique_genes[unique_counts >= thresh]
    results.append({
        'Gene Prioritization Count': thresh,
        'Population Percentage': 1 - thresh / sig_edges['Subject ID'].unique().shape[0],  # Doesn't work with duplicates
        'Gene Count': gene_list.shape[0],
        'Overlap Percentage': np.isin(gene_list, dex_gene_list).mean(),
    })
recovery_df = pd.DataFrame(results)

100%|█████████████████████████████████████████████████████████████████████| 236/236 [00:36<00:00,  6.48it/s]


In [37]:
# Create plot
fig, ax = plt.subplots(figsize=(8, 5))

# Plot
sns.lineplot(recovery_df, x='Gene Prioritization Count', y='Overlap Percentage', color='black')
# ax.set_xlim(0, 1)
ax.set_xlim(left=0)
ax.set_ylim(0, 1.02)
ax.yaxis.set_major_formatter(matplotlib.ticker.PercentFormatter(1.))
sns.despine(ax=ax)

# Annotate
for thresh in np.arange(.6, 1.01, .1):
    df_filt = recovery_df.loc[recovery_df['Overlap Percentage'] >= thresh]
    argmin = df_filt['Gene Prioritization Count'].argmin()
    x_pos, y_pos = df_filt.iloc[argmin]['Gene Prioritization Count'], df_filt.iloc[argmin]['Overlap Percentage']
    x_offset, y_offset = x_pos+.02*ax.get_xlim()[1], y_pos-.13
    gene_count = df_filt.iloc[argmin]['Gene Count']
    ax.scatter(x_pos, y_pos, marker='.', s=200, color='black')
    ax.axvline(x=x_pos, ymin=y_offset, ymax=y_pos, color='black', alpha=.7, ls='--')
    # ({np.floor(100*y_pos):.0f}%)
    ax.text(x_offset, y_offset, f'{x_pos:.0f} ({gene_count:.0f} Genes)', ha='left', va='bottom', transform=ax.transData)  # , bbox=dict(facecolor='white', edgecolor='none')

# Save
fig.savefig(f'../plots/capstone_recovery_overlap.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True)