In [1]:
%%capture
%cd scripts

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%matplotlib notebook

# Imports

In [3]:
from itertools import product
import os

import graph_tool.all as gt
import matplotlib
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats
from scipy.stats import hypergeom, pearsonr
import seaborn as sns
from sklearn.cluster import KMeans

from functions import *


# Graph-Tool compatibility
plt.switch_backend('cairo')

# Style
sns.set_theme(context='talk', style='white', palette='Set2')
plt.rcParams.update({
    'pdf.fonttype': 42,
    'ps.fonttype': 42,
    'font.size': 22,
    'axes.titlesize': 'medium',
    'axes.labelsize': 'large',
    'xtick.labelsize': 'medium',
    'ytick.labelsize': 'medium',
    'legend.fontsize': 'medium',
    'legend.title_fontsize': 'medium',
    'figure.titlesize': 'x-large',
})

# Figure transparency
# matplotlib.rcParams['figure.facecolor'] = (1., 0., 0., 0.3)  # Debugging
matplotlib.rcParams['figure.facecolor'] = (1., 0., 0., 0.)


In [4]:
# Integrity check
check_ct_edge_specificity()  # Check for duplicate edges with different attentions

100%|███████████████████████████████████████████████████████████████████| 1022/1022 [00:21<00:00, 46.60it/s]


# Metadata

In [5]:
# Load metadata
meta = get_meta()

# Subject preview
filtered = []
for i, row in meta.iterrows():
    try:
        load_graph_by_id(row['SubID'])
        assert not np.isnan(row['nps_MoodDysCurValue'])  # Has NPS information available
        assert row['BRAAK_AD'] in (6,) and row['CERAD'] in (4,) and row['CDRScore'] in (3,)
    except:
        continue
    filtered.append(f'{row["SubID"]} {row["Ethnicity"]} {row["Sex"]}, {row["Age"]}, BRAAK {row["BRAAK_AD"]}, CERAD {row["CERAD"]}, CDR {row["CDRScore"]}, {row["Dx"]}')
filtered = np.sort(filtered)
for i in range(len(filtered)):
    # print(filtered[i])
    pass


# Attention Stack

In [6]:
fname = './attentions.pkl'
if os.path.isfile(fname):
    # Load data
    with open('./attentions.pkl', 'rb') as f:
        all_data = pickle.load(f)
    attention_stack, all_edges, columns, subject_ids = all_data['data'], all_data['edges'], all_data['heads'], all_data['subject_ids']

else:
    # Parameters
    # Scaled probably shouldn't be used, but better for visualization
    # until results are more even
    columns = get_attention_columns(scaled=False)
    subject_ids = meta['SubID'].to_numpy()

    # Load graphs
    graphs, subject_ids = load_many_graphs(subject_ids, column=columns)
    # graphs = [compute_graph(g) for g in graphs]

    # # Get attentions
    # df = {}
    # for column in get_attention_columns():
    #     attention, _ = compute_edge_summary(graphs, subject_ids=subject_ids)
    #     attention = attention.set_index('Edge')
    #     df[column] = attention.var(axis=1)


    # Set indices to edges and clean
    print('Fixing indices...')
    for i in tqdm(range(len(graphs))):
        graphs[i].index = graphs[i].apply(lambda r: get_edge_string([r['TF'], r['TG']]), axis=1)
        graphs[i] = graphs[i].drop(columns=['TF', 'TG'])
        # Remove duplicates
        graphs[i] = graphs[i][~graphs[i].index.duplicated(keep='first')]

    # Get all unique edges
    print('Getting unique edges...')
    all_edges = np.unique(sum([list(g.index) for g in graphs], []))


    # Standardize index order
    print('Standardizing indices...')
    for i in tqdm(range(len(graphs))):
        # Add missing indices and order based on `all_edges`
        # to_add = [edge for edge in all_edges if edge not in list(graphs[i].index)]  # SLOW
        to_add = list(set(all_edges) - set(graphs[i].index))

        # Empty rows
        new_rows = pd.DataFrame(
            [[np.nan]*len(graphs[i].columns)]*len(to_add),
            columns=graphs[i].columns,
        ).set_index(pd.Series(to_add))
        # Native concat
        graphs[i] = pd.concat([graphs[i], new_rows]).loc[all_edges]

    # Convert to numpy
    graphs = [g.to_numpy() for g in graphs]
    attention_stack = np.stack(graphs, axis=-1)
    # attention_stack.shape = (Edge, Head, Subject)
    # attention_stack.shape = (all_edges, columns, subject_ids)

    # Save all data
    all_data = {'data': attention_stack, 'edges': all_edges, 'heads': columns, 'subject_ids': subject_ids}
    # np.savez('attentions.npz', **all_data)
    with open(fname, 'wb') as f:
        pickle.dump(
            all_data,
            f,
            protocol=pickle.HIGHEST_PROTOCOL,
        )


In [7]:
# Additional useful parameters
self_loops = [split_edge_string(s)[0] == split_edge_string(s)[1] for s in all_edges]
self_loops = np.array(self_loops)
# Remove self loops
all_edges = all_edges[~self_loops]
attention_stack = attention_stack[~self_loops]


# Global Parameters

In [8]:
# Parameters
print(f'\nAvailable attention columns: {get_attention_columns()}')
column_ad = get_attention_columns()[0]
column_scz = get_attention_columns()[2]
column_data = get_attention_columns()[4]
synthetic_nodes_of_interest = ['OPC', 'Micro', 'Oligo']



Available attention columns: ['AD_imp_1', 'AD_imp_2', 'SCZ_imp_1', 'SCZ_imp_2', 'data_imp_1', 'data_imp_2', 'data_imp_3', 'data_imp_4']


# Intra-Contrast Comparisons

In [9]:
# Figure parameters
param = {
    'subjects': ['M31969', 'M20337'],
    'columns': [column_data, column_ad, column_scz],
    'column_names': ['Data-Driven', 'AD-Prior', 'SCZ-Prior'],
    'column_groups': [get_attention_columns()[4:8], get_attention_columns()[:2], get_attention_columns()[2:4]],
    'column_group_names': ['Data Prioritization', 'AD Prioritization', 'SCZ Prioritization'],
    'ancestries': meta.groupby('Ethnicity').count()['SubID'].sort_values().index[::-1].to_list()[:3] + ['all'],
    'contrast': 'c15x',
}

# Generate palette
palette = plt.rcParams['axes.prop_cycle'].by_key()['color']
param['palette'] = {sid: rgba_to_hex(palette[i]) for i, sid in enumerate(param['subjects'])}

# Preview subjects
for sid in param['subjects']:
    row = meta.loc[meta['SubID']==sid].iloc[0]
    # print(f'{row["SubID"]} {row["Ethnicity"]} {row["Sex"]}, {row["Age"]}, BRAAK {row["BRAAK_AD"]}, CERAD {row["CERAD"]}, CDR {row["CDRScore"]}, {row["Dx"]}')


In [10]:
# Subplot layout (doesn't work well with constrained layout)
# NOTE: This cannot be used, as constrained layout has glitches
# (see https://github.com/matplotlib/matplotlib/issues/23290)
# with uneven mosaics
# fig, axs = get_mosaic(shape, figsize=(int((3/2) * shape_array.shape[1]), int((3/2) * shape_array.shape[0])), constrained_layout=False)

# Subfigure layout (longer)
# NOTE: Constrained layout will fail for all
# subplots if a single one is not able to scale.
# Also, sometimes leaving a subfigure blank will
# cause it to fail, especially if on an edge.
# It is VERY finnicky.
# SOLUTION: Save again using `fig.savefig(...)`
# and it will run without warning.  Then, you
# can visually inspect for scaling issues.
# fig, axs = create_subfigure_mosaic(shape_array)
# fig.set_constrained_layout_pads(w_pad=0, h_pad=0, wspace=.4, hspace=.4)  # *_pad is pad for figs (including subfigs), *_space is pad between subplots


## Edge Prioritization and Cross-Ancestry Enrichment - Figure 5cd

In [13]:
shape = """
    NNNNNNNNNNN
    NNNNNNNNNNN
    NNNNNNNNNNN
    NNNNNNNNNNN
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
    RRRRRRRRRRR
"""
fig, axs = create_subfigure_mosaic(shape_array_from_shape(shape))

axs_lab = (len(param['ancestries']) - 1) * ['None'] + ['N']
# axs_lab = ['K', 'L', 'M', 'N']
print(f'\nEdge Discovery Enrichment ({", ".join(axs_lab)})')
for ancestry, ax in zip(param['ancestries'], [axs[lab] if lab in axs else None for lab in axs_lab]):
    # Filter to ancestry
    anc_data = all_data.copy()
    if ancestry != 'all':
        sub_ids = meta.loc[meta['Ethnicity'] == ancestry, 'SubID'].to_list()
        mask = [sid in sub_ids for sid in anc_data['subject_ids']]
        anc_data['data'] = anc_data['data'][:, :, mask]
        anc_data['subject_ids'] = np.array(anc_data['subject_ids'])[mask]

    # Run
    temp = plot_edge_discovery_enrichment(
        **anc_data,
        column=param['columns'][0],
        range_colors=[rgb_to_float(hex_to_rgb('#7aa457')), rgb_to_float(hex_to_rgb('#a46cb7')), rgb_to_float(hex_to_rgb('#cb6a49'))],
        ax=ax,
        postfix=f'{ancestry}_{param["columns"][0]}',
        gene_max_num=300,
        threshold=95,
        clamp_min=4,
        skip_plot=(ax is None),
        verbose=True)
    if ax is not None:
        ax.set_xlabel(f'High-Scoring Edges ({param["column_names"][0]})')
        ylabel = 'Frequency'
        if ancestry != 'all': ylabel += f' ({ancestry})'
        ax.set_ylabel(ylabel)
    # MANUAL PROCESSING
    # Run the output '../plots/genes_<column>.csv' from above on Metascape as multiple gene list and perform
    # enrichment.  From the all-in-one ZIP file, save the file from Enrichment_GO/GO_membership.csv as '../plots/go_<column>.csv'
    # and rerun.

axs_lab = ['R']
print(f'\nAncestry Enrichment Comparison ({", ".join(axs_lab)})')
postfixes = [f'{ancestry}_{param["columns"][0]}' for ancestry in param['ancestries']]
enrichments = plot_cross_enrichment(postfixes, names=param['ancestries'], ax=axs[axs_lab[0]], excluded_subgroups=['all'])

# Place labels
offset = plot_labels(axs, shape=shape)

# Save figure
print('\nSaving Figure...')
fig.savefig(f'../plots/figure_5_main.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True, backend='cairo')



Edge Discovery Enrichment (None, None, None, N)
Prioritization Ranges [[14.0, 22.0], [22.0, 49.0], [49.0, 680.0]]
Filtered 165669 edges of 204507 total from histogram
Prioritization Ranges [[9.0, 13.0], [13.0, 25.0], [25.0, 104.0]]
Filtered 52826 edges of 61429 total from histogram
Prioritization Ranges [[8.0, 11.0], [11.0, 21.0], [21.0, 88.0]]
Filtered 46979 edges of 53899 total from histogram
Prioritization Ranges [[15.0, 23.0], [23.0, 51.0], [51.0, 875.0]]
Filtered 195859 edges of 242486 total from histogram

Ancestry Enrichment Comparison (R)
Index(['AFR_data_imp_1', 'AMR_data_imp_1', 'EUR_data_imp_1', 'all_data_imp_1'], dtype='object', name='Ancestry')
Index(['EUR', 'AFR', 'AMR', 'all'], dtype='object')
Index(['EUR', 'AFR', 'AMR'], dtype='object')

Saving Figure...


## Distributions Across Heads and Ancestries - Supplementary Figure 12

In [14]:
shape = """
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    KKKKKKKKKKKLLLLLLLLLLLMMMMMMMMMMM
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    NNNNNNNNNNNOOOOOOOOOOOPPPPPPPPPPP
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    QQQQQQQQQQQRRRRRRRRRRRSSSSSSSSSSS
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
    TTTTTTTTTTTUUUUUUUUUUUVVVVVVVVVVV
"""
fig, axs = create_subfigure_mosaic(shape_array_from_shape(shape))

# Plot all panels
axs_lab = ['K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V']
print(f'\nEdge Discovery Enrichment ({", ".join(axs_lab)})')
for (ancestry, column_idx), ax in zip(product(param['ancestries'], range(len(param['columns']))), [axs[lab] if lab in axs else None for lab in axs_lab]):
    # Filter to ancestry
    anc_data = all_data.copy()
    if ancestry != 'all':
        sub_ids = meta.loc[meta['Ethnicity'] == ancestry, 'SubID'].to_list()
        mask = [sid in sub_ids for sid in anc_data['subject_ids']]
        anc_data['data'] = anc_data['data'][:, :, mask]
        anc_data['subject_ids'] = np.array(anc_data['subject_ids'])[mask]

    # Run
    temp = plot_edge_discovery_enrichment(
        **anc_data,
        column=param['columns'][column_idx],
        range_colors=[rgb_to_float(hex_to_rgb('#7aa457')), rgb_to_float(hex_to_rgb('#a46cb7')), rgb_to_float(hex_to_rgb('#cb6a49'))],
        ax=ax,
        postfix=f'{ancestry}_{param["columns"][column_idx]}',
        gene_max_num=300,
        threshold=95,
        skip_plot=(ax is None))
    if ax is not None:
        ax.set_xlabel(f'High-Scoring Edges ({param["column_names"][column_idx]})')
        ylabel = 'Frequency'
        if ancestry != 'all': ylabel += f' ({ancestry})'
        ax.set_ylabel(ylabel)
    # MANUAL PROCESSING
    # Run the output '../plots/genes_<column>.csv' from above on Metascape as multiple gene list and perform
    # enrichment.  From the all-in-one ZIP file, save the file from Enrichment_GO/GO_membership.csv as '../plots/go_<column>.csv'
    # and rerun.

# Place labels
offset = plot_labels(axs, shape=shape)

# Save figure
print('\nSaving Figure...')
fig.savefig(f'../plots/figure_5_supplement.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True, backend='cairo')



Edge Discovery Enrichment (K, L, M, N, O, P, Q, R, S, T, U, V)

Saving Figure...


## PRS Analyses - Figure 6a

In [15]:
# SCZ
# Data
# AD
# SCZ
shape = """
    SSSSSSSSSSSSSSSUUUUUUUUUUUUUUU
    SSSSSSSSSSSSSSSUUUUUUUUUUUUUUU
    SSSSSSSSSSSSSSSUUUUUUUUUUUUUUU
    SSSSSSSSSSSSSSSUUUUUUUUUUUUUUU
"""
fig, axs = create_subfigure_mosaic(shape_array_from_shape(shape))

# Plot all panels
axs_lab = ['S', 'U']
print(f'\nPRS Analysis ({", ".join(axs_lab)})')
# Takes around an hour for each loop with no subsampling (on first run)
for fname, head_prefix, ylabel, prs_col, ax_idx in zip(
    ('ad_prs_df.csv', 'scz_prs_df.csv'),
    ('_'.join(column_ad.split('_')[:-1]), '_'.join(column_scz.split('_')[:-1])),
    ('AD Importance Score', 'SCZ Importance Score'),
    ('prs_scaled_AD_Bellenguez', 'prs_scaled_SCZ.3.5_MVP'),
    axs_lab
):
    df = pd.read_csv(fname, index_col=0) if os.path.isfile(fname) else None
    covariates = get_genotype_meta()[['SubID', 'imp_sex_score'] + [f'imp_anc_PC{i}' for i in range(1, 7)] + [f'imp_anc_{anc}' for anc in ('AFR', 'AMR', 'EAS', 'EUR')]]
    df, prs_df, axs[ax_idx] = plot_prs_correlation(
        meta, **all_data, ax=axs[ax_idx],
        df=df, num_targets=5, ylabel=ylabel, max_scale=False,
        head_prefix=head_prefix, prs_col=prs_col,
        covariates=covariates, subsample=1)
    if not os.path.isfile(fname): df.to_csv(fname)

# Place labels
offset = plot_labels(axs, shape=shape)

# Save figure
print('\nSaving Figure...')
fig.savefig(f'../plots/figure_6_prs.pdf', bbox_inches='tight', pad_inches=1, format='pdf', transparent=True, backend='cairo')


PRS Analysis (S, U)

Saving Figure...


# Revision Panels

## Revision Gene Importance Enrichment

In [18]:
# Load data
importance_scores = get_importance_scores()
importance_scores = importance_scores.rename(columns={'ad_imp_score': 'AD', 'scz_imp_score': 'SCZ', 'data_imp_score': 'Data'})
importance_scores['All'] = (importance_scores['AD'] * 2 + importance_scores['SCZ'] * 2 + importance_scores['Data'] * 4) / 8

# Get meta
importance_scores['BRAAK_AD'] = meta.set_index('SubID').loc[importance_scores['sample'], 'BRAAK_AD'].to_numpy()

# Calculate correlations
heads = ['All']  # ['AD', 'SCZ', 'Data']
df = pd.DataFrame(columns=['Gene', 'Head', 'Correlation', 'Significance', 'Samples'])
unique_genes = importance_scores['node'].unique()
for i, g in tqdm(enumerate(unique_genes), total=unique_genes.shape[0]):
    for j, h in enumerate(heads):
        filtered_df = importance_scores.loc[(importance_scores['node'] == g)]
        col = filtered_df[h].to_numpy()
        geno = filtered_df['BRAAK_AD'].to_numpy()
        mask = ~np.isnan(col) * ~np.isnan(geno)
        if mask.sum() > 2 and col[mask].var() != 0 and geno[mask].var() != 0:
            corr, sig = scipy.stats.spearmanr(col[mask], geno[mask])
            num = mask.sum()
            df.loc[df.shape[0]] = {'Gene': g, 'Head': h, 'Correlation': corr, 'Significance': sig, 'Samples': num}

# FDR correction
df['Adjusted Significance'] = scipy.stats.false_discovery_control(df['Significance'].clip(0, 1), method='bh')


100%|███████████████████████████████████████████████████████████████████| 6534/6534 [22:15<00:00,  4.89it/s]


In [71]:
# Parameters
min_samples = 5
significance_threshold = 5e-2
top_genes = int(1e3)

# Histogram of adjusted significance
axs = df.plot.hist(column=['Significance', 'Adjusted Significance'], by='Head', bins=int(2/significance_threshold), figsize=(12, 5*df['Head'].unique().shape[0]))
for ax in axs:
    ax.set_yscale('log')
    ax.axvline(x=significance_threshold, ls='--', color='red')
plt.savefig(f'../plots/rev_1_histogram.pdf', bbox_inches='tight', transparent=True)

In [72]:
# ## Separate head analysis
# # Gene lists
# sig_genes = pd.DataFrame()
# for h in heads:
#     # Filter dataframe
#     df_filtered = df
#     df_filtered = df_filtered.loc[df_filtered['Head'] == h]  # Head
#     df_filtered = df_filtered.loc[df_filtered['Samples'] > min_samples]  # Samples
#     df_filtered = df_filtered.loc[df_filtered['Adjusted Significance'] < significance_threshold]  # Significance
#     genes = np.unique([g.split(':')[1] for g in df_filtered[['Gene']].to_numpy().flatten() if g.split(':')[0] in ('TF', 'TG')])

#     # Save genes
#     print(f'Significant {h} genes: {genes.shape[0]}')
#     sig_genes = pd.concat([sig_genes, pd.DataFrame({h: genes})], axis=1)

# # Add background and save
# genes = np.unique([g.split(':')[1] for g in importance_scores[['node']].to_numpy().flatten() if g.split(':')[0] in ('TF', 'TG')])
# genes = np.array([g for g in genes if not string_is_synthetic(g)])
# sig_genes = pd.concat([sig_genes, pd.DataFrame({'_BACKGROUND': genes})], axis=1)
# pd.DataFrame(sig_genes).to_csv('../plots/rev_1_genes.csv', index=False)
# # np.savetxt(f'../plots/rev_1_genes_{h.lower()}.txt', genes, fmt='%s')


In [73]:
## Aggregate analysis
# Gene lists
sig_genes = pd.DataFrame()

# Filter for positive and negative
for name, factor in zip(('Positive', 'Negative'), (1, -1)):
    df_filtered = df
    df_filtered = df_filtered.loc[df_filtered['Samples'] >= min_samples]  # Samples
    df_filtered = df_filtered.loc[df_filtered['Adjusted Significance'] <= significance_threshold]  # Significance
    df_filtered = df_filtered.loc[factor * df_filtered['Correlation'] > 0]  # Correlation parity
    df_filtered = df_filtered.sort_values(by='Correlation', ascending=factor<0)  # Sort genes
    df_filtered['Gene'] = df_filtered['Gene'].apply(lambda g: g.split(':')[1] if g.split(':')[0] in ('TF', 'TG') else pd.NA)  # Remove celltypes
    df_filtered = df_filtered.dropna()
    df_filtered = df_filtered[~df_filtered['Gene'].duplicated(keep='first')]  # Remove TF + TG genes that got in twice, keep higher score
    df_filtered = df_filtered.iloc[:top_genes]
    genes = df_filtered['Gene'].to_numpy()

    # Save genes
    print(f'Post-filter {name} genes: {genes.shape[0]}')  # Post-filter
    sig_genes = pd.concat([sig_genes, pd.DataFrame({name: genes})], axis=1)

# Add background and save
genes = np.unique([g.split(':')[1] for g in importance_scores[['node']].to_numpy().flatten() if g.split(':')[0] in ('TF', 'TG')])
sig_genes = pd.concat([sig_genes, pd.DataFrame({'_BACKGROUND': genes})], axis=1)
sig_genes.to_csv('../plots/rev_1_genes.csv', index=False)


Post-filter Positive genes: 607
Post-filter Negative genes: 316


## Revision Ancestry Enrichment

In [11]:
# Get data-driven column and ancestries
dd_col = param['columns'][0]
ancs = param['ancestries'][:-1]

# Load high-percentile gene sets
fnames = [f'../plots/genes_{anc}_{dd_col}.csv' for anc in ancs]
dfs = [pd.read_csv(fname) for fname in fnames]
data = [df.iloc[:, -2].dropna().sort_values().to_list() for df in dfs]
anc_sizes = [len(d) for d in data]

# Find intersections
from collections import defaultdict
def find_intersections(data, names):
    groups = defaultdict(lambda: [])
    while sum([len(d) for d in data]) > 0:
        # Query current values
        low_idx = np.nanargmin([d[0] if len(d) > 0 else np.nan for d in data])

        # Pop
        low_val = data[low_idx].pop(0)

        # Find and pop other equal heads
        other_equal = [i for i, d in enumerate(data) if i != low_idx and len(d) > 0 and d[0] == low_val]
        for i in other_equal: data[i].pop(0)

        # Aggregate intersection and record
        all_equal = np.sort(other_equal + [low_idx])
        groups['-'.join(np.sort([names[i] for i in all_equal]))].append(low_val)
    
    return groups
groups = find_intersections(data, ancs)

# Print group sizes
# groups = dict(groups)
print('Intersection counts')
for k, v in groups.items():
    print(f'{k}: {len(v)}')

# Save unique genes
new_fnames = [f'../plots/genes_{k}_{dd_col}_uniq.csv' for k in groups]
for new_fname, k in zip(new_fnames, groups): pd.concat([
        pd.DataFrame({dfs[0].columns[-2]: groups[k]}),
        pd.DataFrame({'_BACKGROUND': dfs[0]['_BACKGROUND']})
    ], axis=1).to_csv(new_fname, index=False)

Intersection counts
AFR-AMR: 166
AMR: 62
EUR: 203
AFR: 59
AFR-EUR: 25
AFR-AMR-EUR: 50
AMR-EUR: 22


In [12]:
# Simulate overlap
np.random.seed(42)
num_gene_sets = [
    ('All Background', dfs[0]['_BACKGROUND'].shape[0]),  # Assumes bachground is the same between all
    ('Highly Variant', np.unique([g for grp in groups.values() for g in grp])),  # Only variant genes from each ancestry
]
group_sizes = anc_sizes
num_iterations = int(1e4)
sim_group_counts = defaultdict(lambda: defaultdict(lambda: []))
for background_name, num_genes in num_gene_sets:
    for _ in tqdm(range(num_iterations), total=num_iterations, desc=f'{background_name}'):
        # Sample gene lists
        sim_data = []
        for size in group_sizes:
            sim_data.append( np.sort(np.random.choice(num_genes, size, replace=False)).tolist() )
        
        # Find intersections
        sim_groups = find_intersections(sim_data, ancs)

        # Record
        for k, v in sim_groups.items():
            sim_group_counts[background_name][k].append(len(v))

All Background: 100%|████████████████████████████████████████████████| 10000/10000 [01:18<00:00, 128.00it/s]
Highly Variant: 100%|████████████████████████████████████████████████| 10000/10000 [00:56<00:00, 177.05it/s]


In [13]:
# Plot each sim group with the corresponding reading
group_order = ['AFR', 'AMR', 'EUR', 'AFR-AMR', 'AFR-EUR', 'AMR-EUR', 'AFR-AMR-EUR']
fig, axs = plt.subplots(len(num_gene_sets), len(groups), figsize=(len(groups)*8, len(num_gene_sets)*4))
for i, k1 in enumerate(sim_group_counts):
    for j, k2 in enumerate(group_order):
        sim_count = sim_group_counts[k1][k2]
        obs_count = len(groups[k2])
        quantile = (np.array(sim_count) < obs_count).mean()

        # Plot
        ax = axs[i][j]
        data = pd.DataFrame({k: sim_group_counts[k1][k2]})
        sns.histplot(data=data, x=k, kde=True, color='gray', ax=ax)

        # Significance
        ax.axvline(x=obs_count, ls='-', color='red')
        pval = min(quantile, 1-quantile)  # One-tailed p test
        if pval < 1/num_iterations: pval = 1/num_iterations
        sig_thresholds = np.array([5e-2, 1e-2, 1e-3])
        sig_tail = '*' * (pval < sig_thresholds).sum()
        pval_string = f'p<{pval:.2e}{sig_tail}'
        ax.text(.5, .95, pval_string, ha='center', va='top', transform=ax.transAxes)

        # Labels
        ax.set(xlabel=None, ylabel=None)
        if i == 0: ax.set_title(k2)
        if j == 0: ax.set_ylabel(k1)
fig.savefig(f'../plots/rev_1_ancestry_significance.pdf', bbox_inches='tight', transparent=True)
plt.close(fig)

In [14]:
# Plot after MANUAL PROCESSING
# TODO: Maybe also plot intersecting genes?
fig, ax = plt.subplots(1, 1, figsize=(int((3/2)*11), int((3/2)*7)))
file_prefixes = ['_'.join(fname.split('/')[-1].split('_')[1:])[:-4] for fname, k in zip(new_fnames, groups) if len(groups[k]) > 100]
group_names = [fname.split('_')[0] for fname in file_prefixes]
enrichments = plot_cross_enrichment(file_prefixes, names=group_names, num_terms=30, ax=ax)
fig.savefig(f'../plots/rev_1_ancestry.pdf', bbox_inches='tight', transparent=True)
plt.close(fig)