# Imports

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%matplotlib notebook


In [2]:
from itertools import product
import os

import graph_tool.all as gt
import matplotlib
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.stats import hypergeom, pearsonr
import seaborn as sns
from sklearn.cluster import KMeans

from functions import *


# Graph-Tool compatibility
plt.switch_backend('cairo')
# Style
sns.set_theme(context='talk', style='white', palette='Set2')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


# Meta and Parameters

In [3]:
# Load metadata
meta = get_meta()

# Subject preview
filtered = []
for i, row in meta.iterrows():
    try:
        load_graph_by_id(row['SubID'])
        assert not np.isnan(row['nps_MoodDysCurValue'])  # Has NPS information available
        # assert row['BRAAK_AD'] in (3, 4, 5)
    except:
        continue
    filtered.append(f'{row["SubID"]} {row["Ethnicity"]} {row["Sex"]}, {row["Age"]}, BRAAK {row["BRAAK_AD"]}')
filtered = np.sort(filtered)
for i in range(len(filtered)):
    print(filtered[i])

# Parameters
print(f'\nAvailable attention columns: {get_attention_columns()}')
column_ad = get_attention_columns()[0]
column_scz = get_attention_columns()[2]
column_data = get_attention_columns()[4]
synthetic_nodes_of_interest = ['OPC', 'Micro', 'Oligo']


  return pd.read_csv(META)


M10031 AMR Male, 71.0, BRAAK 0.0
M10282 EUR Male, 73.0, BRAAK 6.0
M10730 EUR Female, 90.0, BRAAK 5.0
M107983 AMR Female, 65.0, BRAAK 1.0
M10874 EUR Female, 88.0, BRAAK 3.0
M10886 AMR Male, 46.0, BRAAK 0.0
M1119 EUR Female, 91.0, BRAAK 6.0
M11371 EUR Male, 95.0, BRAAK 2.0
M1140 EUR Female, 96.0, BRAAK 1.0
M11588 EUR Female, 62.0, BRAAK 6.0
M11589 AMR Female, 63.0, BRAAK 2.0
M11716 EUR Female, 89.0, BRAAK 5.0
M1176 EUR Female, 85.0, BRAAK 0.0
M118449 EUR Male, 80.0, BRAAK 6.0
M11938 EUR Female, 90.0, BRAAK 3.0
M1198 EUR Female, 87.0, BRAAK 6.0
M12047 EUR Female, 80.0, BRAAK 5.0
M12249 EUR Male, 51.0, BRAAK 0.0
M12326 EUR Female, 93.0, BRAAK 6.0
M12479 AMR Female, 83.0, BRAAK 6.0
M12514 EUR Female, 88.0, BRAAK 4.0
M12614 AFR Female, 75.0, BRAAK 0.0
M12792 EUR Female, 91.0, BRAAK 5.0
M12876 AMR Female, 76.0, BRAAK 2.0
M13326 EUR Male, 84.0, BRAAK 4.0
M133696 EUR Female, 52.0, BRAAK nan
M13458 EUR Male, 91.0, BRAAK 6.0
M13640 EUR Female, 82.0, BRAAK 3.0
M13670 AFR Male, 83.0, BRAAK 6.0
M139

# Plots

## Attention Stack

In [4]:
# # Parameters
# # Scaled probably shouldn't be used, but better for visualization
# # until results are more even
# columns = get_attention_columns(scaled=False)
# subject_ids = meta['SubID'].to_numpy()

# # Load graphs
# graphs, subject_ids = load_many_graphs(subject_ids, column=columns)
# # graphs = [compute_graph(g) for g in graphs]

# # # Get attentions
# # df = {}
# # for column in get_attention_columns():
# #     attention, _ = compute_edge_summary(graphs, subject_ids=subject_ids)
# #     attention = attention.set_index('Edge')
# #     df[column] = attention.var(axis=1)


# # Set indices to edges and clean
# print('Fixing indices...')
# for i in tqdm(range(len(graphs))):
#     graphs[i].index = graphs[i].apply(lambda r: get_edge_string([r['TF'], r['TG']]), axis=1)
#     graphs[i] = graphs[i].drop(columns=['TF', 'TG'])
#     # Remove duplicates
#     graphs[i] = graphs[i][~graphs[i].index.duplicated(keep='first')]

# # Get all unique edges
# print('Getting unique edges...')
# all_edges = np.unique(sum([list(g.index) for g in graphs], []))


# # Standardize index order
# print('Standardizing indices...')
# for i in tqdm(range(len(graphs))):
#     # Add missing indices and order based on `all_edges`
#     # to_add = [edge for edge in all_edges if edge not in list(graphs[i].index)]  # SLOW
#     to_add = list(set(all_edges) - set(graphs[i].index))

#     # Empty rows
#     new_rows = pd.DataFrame(
#         [[np.nan]*len(graphs[i].columns)]*len(to_add),
#         columns=graphs[i].columns,
#     ).set_index(pd.Series(to_add))
#     # Native concat
#     graphs[i] = pd.concat([graphs[i], new_rows]).loc[all_edges]

# # Convert to numpy
# graphs = [g.to_numpy() for g in graphs]
# attention_stack = np.stack(graphs, axis=-1)
# # attention_stack.shape = (Edge, Head, Subject)
# # attention_stack.shape = (all_edges, columns, subject_ids)


In [5]:
### Save all data
# all_data = {'data': attention_stack, 'edges': all_edges, 'heads': columns, 'subject_ids': subject_ids}
# # np.savez('attentions.npz', **all_data)
# with open('attentions.pkl', 'wb') as f:
#     pickle.dump(
#         all_data,
#         f,
#         protocol=pickle.HIGHEST_PROTOCOL,
#     )

### Load data
with open('attentions.pkl', 'rb') as f:
    all_data = pickle.load(f)
attention_stack, all_edges, columns, subject_ids = all_data['data'], all_data['edges'], all_data['heads'], all_data['subject_ids']

### Save variance filtered by contrast
# contrast = 'c15x'
# for group in list(get_contrast(contrast).keys()) + [None]:
#     # group = 'Control'  # Either None or group name
#     if group is None:
#         # Population
#         contrast_subjects = sum([v for k, v in get_contrast(contrast).items()], [])
#     else:
#         # Group
#         contrast_subjects = get_contrast(contrast)[group]

#     # Modify stack to include only contrast
#     df = np.var(np.nan_to_num(attention_stack[:, :, [s in contrast_subjects for s in subject_ids]]), axis=2)
#     df = pd.DataFrame(df, index=all_edges, columns=columns)

#     # Save
#     df.to_csv(
#         f'../plots/{contrast}_variation.csv'
#         if group is None else
#         f'../plots/{contrast}_{group}_variation.csv')


In [6]:
# Additional useful parameters
self_loops = [split_edge_string(s)[0] == split_edge_string(s)[1] for s in all_edges]
self_loops = np.array(self_loops)
# Remove self loops
all_edges = all_edges[~self_loops]
attention_stack = attention_stack[~self_loops]


## Individual Comparisons (Figure 3)

In [7]:
individual_comparisons = [
    # M19050 Hispanic Female, 74.0, BRAAK 5.0
    # M59593 Hispanic Female, 76.0, BRAAK 5.0
    # M72079 Black Female, 64.0, BRAAK 6.0
    # M41496 Black Female, 76.0, BRAAK 4.0
    # M11589 Black Female, 63.0, BRAAK 2.0
    # M73342 Black Female, 62.0, BRAAK 0.0
    # (subject_id_1, subject_id_2, column)
    # for subject_id_1, subject_id_2, column in individual_comparisons:
    ('M19050', 'M59593', column_ad),  # AD - AD
    ('M19050', 'M59593', column_data),  # AD - AD
    # ('M72079', 'M41496', column_ad),  # AD - High BRAAK
    # ('M72079', 'M11589', column_ad),  # AD - Low BRAAK
    # ('M72079', 'M73342', column_ad),  # AD - CTRL
]
palette = plt.rcParams['axes.prop_cycle'].by_key()['color']
individual_colors = {
    sid: rgba_to_hex(palette[i]) for i, sid in enumerate(
        sum([list(comparison[:2]) for comparison in individual_comparisons], []))
}

# Verify all are available
for subject_id_1, subject_id_2, column in individual_comparisons:
    for sid in [subject_id_1, subject_id_2]:
        load_graph_by_id(sid)


### 3A Mini Plots

In [8]:
for i, (subject_id_1, subject_id_2, column) in enumerate(individual_comparisons):
    print(' - '.join((subject_id_1, subject_id_2, column)))

    # Assemble
    sids = [subject_id_1, subject_id_2]
    gs = [compute_graph(load_graph_by_id(sid, column=column)) for sid in sids]

    # Filter
    gs = [
        filter_to_synthetic_vertices(g.copy(), vertex_ids=synthetic_nodes_of_interest)
        for g in gs
    ]

    # Recalculate
    gs = [assign_vertex_properties(g) for g in gs]

    # Plot
    fig, axs = get_mosaic([list(range(2))], scale=9)
    plot_graph_comparison(gs, axs=axs, subject_ids=sids)
    fig.savefig(f'../plots/individual_mini_{"-".join(sids)}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
    print()


M19050 - M59593 - att_D_AD_1
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████████| 167/167 [00:00<00:00, 287895.10it/s]


Calculating positions...
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 251299.24it/s]


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 83/83 [00:00<00:00, 231313.78it/s]


M19050 - M59593 - att_D_no_prior_0





Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████████| 167/167 [00:00<00:00, 275117.35it/s]


Calculating positions...
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 237894.35it/s]


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 83/83 [00:00<00:00, 232549.92it/s]







### 3B Attention Comparisons

In [9]:
for subject_id_1, subject_id_2, _ in individual_comparisons:
    for column in get_attention_columns():
        print(' - '.join((subject_id_1, subject_id_2, column)))

        # Assemble
        sample_ids = [subject_id_1, subject_id_2]
        graphs = [compute_graph(load_graph_by_id(sid, column=column)) for sid in sample_ids]

        # Get graph
        g = concatenate_graphs(*graphs, threshold=False)
        g = get_intersection(g)
        g = cull_isolated_leaves(g)

        fig, axs = get_mosaic([list(range(1))], scale=6)
        df = plot_individual_edge_comparison(g, sample_ids, ax=axs[0])
        plt.tight_layout()
        fig.savefig(f'../plots/individual_edge_comparison_{"-".join((subject_id_1, subject_id_2))}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
        print()


M19050 - M59593 - att_D_AD_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 195180.19it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_AD_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 201798.80it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_SCZ_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 195098.82it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_SCZ_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 220465.10it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_0
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 210028.96it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 201338.62it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 212277.50it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_3
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 211264.81it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_AD_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 241582.09it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_AD_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 210785.82it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_SCZ_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 237230.22it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_SCZ_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 206491.66it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_0
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 246606.74it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 249890.27it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 232765.76it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_3
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 207474.33it/s]





  plt.tight_layout()


### 3C Module Analysis

In [10]:
for subject_id_1, subject_id_2, column in individual_comparisons:
    # Get graphs
    g1 = compute_graph(load_graph_by_id(subject_id_1, column=column))
    g2 = compute_graph(load_graph_by_id(subject_id_2, column=column))

    # Compute module scores
    def get_module_scores(g):
        association = []
        name = []
        score = []
        for v in g.vertices():
            # Escape if not TF
            if 'tf' not in g.vp.node_type[v]: continue
            # Get association
            association_list = None
            for e in v.in_edges():
                v_source = e.source()
                # If synthetic, record
                if 'celltype' == g.vp.node_type[v_source]:
                    if association_list is None: association_list = [g.vp.ids[v_source]]
                    else: association_list += [g.vp.ids[v_source]]

            # Get scores
            for e in v.out_edges():
                v_target = e.target()
                # Escape if not TG
                if 'tg' not in g.vp.node_type[v_target]: continue
                # Record weights
                for assoc in association_list:
                    association.append(assoc)
                    name.append(g.vp.ids[v])
                    score.append(g.ep.coef[e])

        return pd.DataFrame({
            'Cell Type': association,
            'TF': name,
            'Module Score': score,
        }).groupby(['Cell Type', 'TF']).sum().reset_index()
    # Get module scores
    module_scores_1 = get_module_scores(g1)
    module_scores_2 = get_module_scores(g2)
    # Make blanks
    zeros_1 = module_scores_1.copy()
    zeros_1['Module Score'] = 0
    zeros_2 = module_scores_2.copy()
    zeros_2['Module Score'] = 0
    # Append for consistency
    module_scores_1 = pd.concat((module_scores_1, zeros_2)).groupby(['Cell Type', 'TF']).max().reset_index()
    module_scores_2 = pd.concat((module_scores_2, zeros_1)).groupby(['Cell Type', 'TF']).max().reset_index()
    # Concatenate subjects
    # NOTE: Only matters that they're in the order sub_1 -> sub_2
    # and all present for the `.diff()` groupby, no need to label
    # module_scores_1['Subject'] = subject_id_1
    # module_scores_2['Subject'] = subject_id_2
    module_scores = pd.concat((module_scores_1, module_scores_2))
    module_scores['Module Score'] = module_scores.groupby(['Cell Type', 'TF']).diff(periods=-1)  # First minus second
    module_scores = module_scores.loc[~module_scores['Module Score'].isna()]

    # Plot
    fig, axs = get_mosaic([[0]*2], scale=6)

    def plot_module_scores(module_scores, ax=None):
        # Pivot
        df = module_scores.pivot(index='Cell Type', columns='TF', values='Module Score')
        # Roughly sort by cell type
        df = df.T
        for c in df.columns:
            df = df.sort_values(c)
        df = df.T  # .iloc[::-1]
        # Plot
        pl = sns.heatmap(
            data=df,
            vmin=np.abs(df.fillna(0).to_numpy()).max(),
            vmax=-np.abs(df.fillna(0).to_numpy()).max(),
            cmap='icefire_r',
            cbar_kws={'label': f'Module Score ({subject_id_1}-{subject_id_2})'},
            ax=ax)
        return pl
    p1 = plot_module_scores(module_scores, ax=axs[0])

    # Inset axis
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes
    axins = inset_axes(
        axs[0],
        width='25%', height='25%',
        loc=4,
        bbox_to_anchor=(0, .15, 1, 1), bbox_transform=axs[0].transAxes)
    sns.histplot(data=module_scores, x='Module Score', kde=True, ax=axins)
    plt.ylabel(None)

    # Format
    p1.set(title=column)

    # Save
    fig.savefig(f'../plots/individual_module_analysis_{subject_id_1}_{subject_id_2}_{column}.pdf', format='pdf', transparent=True, backend='cairo')


### 3D Module Discovery Barplot

In [11]:
for subject_id_1, subject_id_2, column in individual_comparisons:
    # NOTE: Column doesn't matter here with the current snipping method
    g1 = compute_graph(load_graph_by_id(subject_id_1, column=column))
    g2 = compute_graph(load_graph_by_id(subject_id_2, column=column))

    # Get unique TFs
    df = compare_graphs_enrichment(
        g1, g2,
        sid_1=subject_id_1, sid_2=subject_id_2,
        nodes=list(set(get_all_synthetic_ids(g1)).union(set(get_all_synthetic_ids(g2)))),  #list(set(get_all_synthetic_ids(g1)).intersection(set(get_all_synthetic_ids(g2)))),
        include_tgs=True)

    # Get counts of unique TFs
    df = df.melt(var_name='String', value_name='Gene')
    df['String'] = df['String'].apply(lambda x: x.split('.'))
    df = pd.concat((pd.DataFrame(df['String'].tolist(), columns=('Subject', 'Cell Type')), df[['Gene']]), axis=1)
    df = df.groupby(['Subject', 'Cell Type']).count().reset_index().rename(columns={'Gene': 'Unique Modules'})

    # Plot
    fig, axs = get_mosaic([[0]*2], scale=6)
    pl = sns.barplot(data=df, x='Cell Type', y='Unique Modules', hue='Subject', ax=axs[0])
    plt.xticks(rotation=90)
    fig.savefig(f'../plots/individual_module_discovery_barplot_{subject_id_1}_{subject_id_2}_{column}.pdf', format='pdf', transparent=True, backend='cairo')


  fig = plt.figure(figsize=(scale*len(mosaic[0]), scale*len(mosaic)), constrained_layout=True)


### 3E Edge Discovery Line Plot

In [12]:
# Parameters
percentage_prioritizations_range = (.1, .15)


In [13]:
# Threshold by max/10 on head
# NOTE: Percentile is still 0 at 99%
head_threshold = np.nan_to_num(attention_stack).max(axis=(0, 2)).reshape((1, -1, 1)) / 10
within_range = attention_stack > head_threshold

# Get counts for edges
counts = within_range.sum(axis=2)
counts = pd.DataFrame(counts, index=all_edges, columns=columns)

# # Sample
# # NOTE: Maybe remove in final version?  Doesn't matter too much
# np.random.seed(42)
# idx = np.random.choice(counts.shape[0], 1_000, replace=False)
# counts = counts.iloc[idx]

# Melt and format
counts = counts.reset_index(names='Edge').melt(id_vars='Edge', var_name='Head', value_name='Count')

# Remove low counts (was zero, but far too many were low)
counts = counts.loc[counts['Count'] > 1]

# # Average plot
# # Sort by highest spike
# counts = counts.sort_values('Count')
# # Plot
# fig, axs = get_mosaic([[0]*2], scale=6)
# pl = sns.lineplot(data=counts, x='Edge', y='Count', hue='Head')
# plt.xticks(rotation=90)
# # plt.yscale('log')
# limit_labels(pl, n=10)
# fig.savefig(f'../plots/individual_edge_discovery_lineplot.pdf', format='pdf', transparent=True, backend='cairo')

for column in columns:
    # Filter to column
    counts_filtered = counts.loc[counts['Head']==column]

    # Sample
    # NOTE: Maybe remove in final version?  Doesn't matter too much
    np.random.seed(42)
    idx = np.random.choice(counts_filtered.shape[0], min(1_000, counts_filtered.shape[0]), replace=False)
    counts_filtered = counts_filtered.iloc[idx]

    # Sort
    counts_filtered = counts_filtered.sort_values('Count')

    # Plot
    fig, axs = get_mosaic([[0]*2], scale=6)
    pl = sns.lineplot(data=counts_filtered, x='Edge', y='Count')

    # Highlight area
    axs[0].axhspan(
        percentage_prioritizations_range[0]*attention_stack.shape[2],
        percentage_prioritizations_range[1]*attention_stack.shape[2],
        color='red', alpha=.2, lw=0)

    # Format
    plt.xticks(rotation=90)
    # plt.yscale('log')
    limit_labels(pl, n=30)

    # Save
    fig.savefig(f'../plots/individual_edge_discovery_lineplot_{column}.pdf', format='pdf', transparent=True, backend='cairo')


In [14]:
# Determine edges that are highly individual for enrichment (between `percentile_prioritizations_range`%s)
individual_genes = counts.loc[(counts['Count'] > (percentage_prioritizations_range[0]*attention_stack.shape[2])) * (counts['Count'] < (percentage_prioritizations_range[1]*attention_stack.shape[2]))]
individual_genes = np.array([split_edge_string(s) for s in individual_genes['Edge']]).flatten()
individual_genes = [s for s in individual_genes if not string_is_synthetic(s)]


### 3F Pathway Enrichment (MANUAL)

In [15]:
for subject_id_1, subject_id_2, column in individual_comparisons:
    # NOTE: Column doesn't matter here with the current snipping method
    g1 = compute_graph(load_graph_by_id(subject_id_1, column=column))
    g2 = compute_graph(load_graph_by_id(subject_id_2, column=column))

    # Get unique modules
    df = compare_graphs_enrichment(g1, g2, sid_1=subject_id_1, sid_2=subject_id_2, nodes=synthetic_nodes_of_interest)

    # Add individually important edges (requires above)
    df_new = pd.DataFrame(individual_genes, columns=('Population.Specific',))
    df = df.join(df_new, how='outer')

    # Save to file
    df.to_csv(f'../plots/genes_{subject_id_1}_{subject_id_2}_{column}.csv', index=False)


In [16]:
# Enrichment
for subject_id_1, subject_id_2, column in individual_comparisons:
    # MANUAL PROCESSING
    # Run the output from above on Metascape as multiple gene list and perform
    # enrichment.  From the all-in-one ZIP file, save the file from
    # Enrichment_QC/GO_DisGeNET.csv as '../plot/disgenet_{subject_id_1}_{subject_id_2}_{column}.csv' and
    # Overlap_circos/CircosOverlapByGene.svg as '../plot/overlap_{subject_id_1}_{subject_id_2}_{column}.svg'

    # Get enrichment
    enrichment_file = f'../plots/disgenet_{subject_id_1}_{subject_id_2}_{column}.csv'
    if not os.path.isfile(enrichment_file): continue
    enrichment = pd.read_csv(enrichment_file)

    # Format
    enrichment = format_enrichment(enrichment)

    # Plot
    fig, axs = get_mosaic([[0]*2], scale=9)
    pl = sns.scatterplot(
        enrichment,
        x='Gene Set', y='Description',
        size='-log10(p)',
        color='black',
        ax=axs[0])
    # Formatting
    pl.grid()
    plt.xticks(rotation=90)
    pl.set_aspect('equal', 'box')
    pl.legend(bbox_to_anchor=(1.2, 1.05))
    # Zoom X
    margin = .5
    min_xlim, max_xlim = pl.get_xlim()
    min_xlim -= margin; max_xlim += margin
    pl.set(xlim=(min_xlim, max_xlim))
    fig.savefig(f'../plots/individual_enrichment_{subject_id_1}_{subject_id_2}_{column}.pdf', format='pdf', transparent=True, backend='cairo')


### 3X Head Variation Heatmap

In [17]:
# # Calculate heatmap
# df = np.var(np.nan_to_num(attention_stack), axis=2)
# # Create df
# df = pd.DataFrame(df, index=all_edges, columns=columns)

# # Sort
# # df = df.iloc[df.fillna(0).mean(axis=1).argsort().to_numpy()[::-1]]
# # Standardize for visualization
# # TODO: Remove once model scale is fixed, only for visualization
# df = df / df.max(axis=0)

# ### Combined clustermap
# # Assign groups by associated cell type
# # TODO: Make greater depth, currently 1
# clusters = pd.DataFrame(
#     np.array([
#         [tf, tg] for tf, tg in df.index.map(lambda s: split_edge_string(s))
#     ]),
#     index=df.index,
#     columns=pd.Series(['TF', 'TG']),
# )
# assign_saved = {}
# def assign(row, df=None):
#     # Progress printing
#     # if np.random.rand() < .01:
#     #     print(f'{row["TF"]} - {row["TG"]}')

#     # If directly related
#     tf_synthetic = string_is_synthetic(row['TF'])
#     tg_synthetic = string_is_synthetic(row['TG'])
#     if tf_synthetic and tg_synthetic and (row['TF'] != row['TG']):
#         return 'Multiple'
#     elif tf_synthetic:
#         return row['TF']
#     elif tg_synthetic:
#         return row['TG']

#     # Otherwise, take indirect associations
#     if df is not None and 'Association' in df:
#         # Default to TF association
#         if row['TF'] not in assign_saved:
#             nodes, counts = np.unique(df.loc[df['TF']==row['TF'], 'Association'], return_counts=True)
#             nodes, counts = nodes[nodes!='None'], counts[nodes!='None']
#             if nodes.shape[0] == 0: assign_saved[row['TF']] = 'None'
#             else: assign_saved[row['TF']] = nodes[np.argsort(counts)[::-1]][0]
#         return assign_saved[row['TF']]

#     # If all else fails, return no association
#     return 'None'

# # Propagate cell types
# for _ in range(2):  # Depth 2
#     clusters['Association'] = clusters.apply(lambda x: assign(x, df=clusters), axis=1)

# # Convert to colors
# cluster_colors = {
#     a: c for a, c in zip(
#         np.unique(clusters['Association']),
#         sns.color_palette(palette='husl', n_colors=np.unique(clusters['Association']).shape[0]),
#     )
# }
# clusters['Colors'] = clusters['Association'].apply(lambda a: cluster_colors[a])

# # Filter to top 10 per head
# idx = []
# for column in columns:
#     idx += list(df.sort_values(column).index[-10:])
# idx = np.unique(idx)
# df = df.loc[idx]
# clusters = clusters.loc[idx]

# # Plot
# np.random.seed(42)
# fig = sns.clustermap(
#     data=df,
#     row_colors=clusters[['Colors']].rename(columns={'Colors': 'Cell Association'}),
#     row_cluster=False,
#     # norm=LogNorm(),
#     cmap='mako_r',
#     # dendrogram_ratio=.1,
#     # cbar_kws={'label': 'Variation'}
#     figsize=(9, 27),
# )
# fig.savefig(f'../plots/individual_edge_variance_heatmap.pdf', format='pdf', transparent=True, backend='cairo')
# plt.show()
# # Plot legend
# plt.clf()
# ax = plt.gca()
# legend_elements = [
#     Line2D([0], [0], color='gray', linestyle='None', markersize=10, marker='s', markerfacecolor=color, label=ct)
#     for ct, color in cluster_colors.items()
# ]
# ax.legend(handles=legend_elements, loc='best')
# plt.gca().axis('off')
# plt.tight_layout()
# plt.savefig(f'../plots/individual_edge_variance_heatmap_cell_legend.pdf', format='pdf', transparent=True, backend='cairo')


### 3X Individual Heatmap

In [18]:
# # Parameters
# column = column_data

# # Filter to data column
# data_idx = np.argwhere(np.array(columns) == column)[0][0]
# df = pd.DataFrame(attention_stack[:, data_idx], index=all_edges, columns=subject_ids)
# # Sort and filter (fillna can be excluded, but this also makes more common edges visible)
# df = df.iloc[df.fillna(0).mean(axis=1).argsort().to_numpy()[::-1]]
# df = df.iloc[:5000]
# # Sort and filter by common edges
# df = df.iloc[:, df.isna().to_numpy().sum(axis=0).argsort()]
# df = df.iloc[:, :100]

# # Individual heatmap (Limited to top 5k links)
# fig, axs = get_mosaic([[0]*9]*9, scale=3)
# sns.heatmap(data=df.iloc[:5000], cmap='mako_r', ax=axs[0])  # , norm=LogNorm()
# plt.xticks(rotation=60)
# fig.savefig(f'../plots/individual_edge_heatmap_{column}.pdf', format='pdf', transparent=True, backend='cairo')
# plt.show()


### 3X Dosage Analysis

In [19]:
# # Parameters
# column = column_ad
# subject_ids = meta['SubID'].to_numpy()

# # Load graphs
# graphs, subject_ids = load_many_graphs(subject_ids, column=column)
# graphs = [compute_graph(g) for g in graphs]

# # Get dosage information
# dosage = get_dosage()
# # Why do some SNPs go missing with the new meta?
# dosage = convert_dosage_ids_to_subject_ids(dosage, meta=meta)

# # Get attention
# attention, _ = compute_edge_summary(graphs, subject_ids=subject_ids)
# attention = attention.set_index('Edge')


In [20]:
# # Select target SNP
# target_snp = dosage.index[42]  # Random for now

# # Make df
# data_dosage = dosage.loc[[target_snp]].T
# data_attention = attention.T
# df = data_dosage.join(data_attention, how='inner')

# # Select target edge
# p_min = 1
# for edge in attention.index:
#     corr, pval = scipy.stats.pearsonr(
#         df[[edge]].to_numpy().squeeze(),
#         df[[target_snp]].to_numpy().squeeze())
#     if pval < p_min:
#         p_min = pval
#         best_corr = corr
#         target_edge = edge
# print(f'Found minimal p-value of {p_min:.6f} (Correlation: {best_corr:.6f}).')

# # Format df
# axis_snp = f'{target_snp} Dosage'
# axis_edge = f'{target_edge} Attention'
# df = df.rename(columns={target_snp: axis_snp, target_edge: axis_edge})

# # Scatter
# fig, axs = get_mosaic([list(range(1))], scale=9)
# sns.scatterplot(data=df, x=axis_snp, y=axis_edge, ax=axs[0])
# fig.savefig(f'../plots/individual_dosage_correlation_{column}.pdf', format='pdf', transparent=True, backend='cairo')


## Group Comparisons (Figure 4)

In [21]:
# Combinations
# TODO: Potentially move each entry to dictionary, so changes in order
#   are easier to propagate
contrast_groupings = [
    # (contrast name, contrast group, attention column, comparison column, target meta column, other target meta column)
    # for contrast_name, contrast_group, column, comparison, target, target_comparison in contrast_groupings:
    # TODO: Revise ethnicity prediction
    ('c15x', 'AD', column_ad, column_data, 'BRAAK_AD', 'Ethnicity'),
    # ('c06x', 'AD', column_ad, column_data, 'BRAAK_AD', 'nps_MoodDysCurValue'),  # Eventually SCZ, BP and such
    # ('c71x', 'MoodDys', column_data, column_ad, 'nps_MoodDysCurValue'),  # Dysphoria
    # ('c72x', 'DecInt', column_data, column_ad, 'nps_DecIntCurValue'),  # Anhedonia
]


### 4X Variance Heatmap

In [22]:
# # Get plots for each column
# for contrast_name, _, column, comparison, _, _ in contrast_groupings:
#     for col in (column, comparison):
#         print(' - '.join((contrast_name, col)))

#         # Get contrast
#         contrast = get_contrast(contrast_name)

#         # Compute
#         df_subgroup = compute_contrast_summary(contrast, column=col)

#         # Plot mean-sorted
#         fig, axs = get_mosaic([list(range(1))], scale=9)
#         plot_subgroup_heatmap(df_subgroup, ax=axs[0])
#         plt.tight_layout()
#         fig.savefig(f'../plots/group_variance_heatmap_{contrast_name}_{col}.pdf', format='pdf', transparent=True, backend='cairo')


### 4B Distribution Comparison

In [23]:
for contrast_name, _, column, comparison, target, target_comparison in contrast_groupings:
    # Filter attention stack to contrast
    contrast = get_contrast(contrast_name)
    contrast_subject_ids = sum([contrast[group] for group in contrast], [])
    contrast_mask = [sid in contrast_subject_ids for sid in subject_ids]
    contrast_subject_ids = np.array(subject_ids)[contrast_mask]
    contrast_stack = attention_stack[:, :, contrast_mask]

    # Filter to 1000 most variant edges
    top_variant_edge_idx = np.nan_to_num(
        contrast_stack[:, np.argwhere(np.array(columns)==column)[0][0]]).var(axis=1).argsort()[::-1][:1000]
    contrast_stack = contrast_stack[top_variant_edge_idx]
    edge_names = all_edges[top_variant_edge_idx]

    # Correlation df
    df = pd.DataFrame(
        contrast_stack[:, np.argwhere(np.array(columns)==column)[0][0]],
        index=pd.Series(all_edges[top_variant_edge_idx]),
        columns=contrast_subject_ids).T
    df = df.join(meta.set_index('SubID')[[target, target_comparison]]).reset_index(drop=True)
    # Select edge which most cleanly separates `target`
    # top_distinct_edge_idx = df.drop(target_comparison, axis=1).groupby(target).mean().var(axis=0).argsort()[-1]
    # Select most correlating edge
    top_distinct_edge_idx = df.drop(target_comparison, axis=1).corr()[target].abs().drop(target).argsort()[-2]
    # Format
    contrast_stack = contrast_stack[top_distinct_edge_idx]
    edge_name = edge_names[top_distinct_edge_idx]

    # Scale attention
    # TODO: Remove once heads are balanced
    contrast_stack = contrast_stack / np.nan_to_num(contrast_stack).max(axis=1).reshape((-1, 1))

    # Format
    df = pd.DataFrame(contrast_stack, index=pd.Series(columns), columns=contrast_subject_ids)
    df = df.reset_index(names='Head').melt(id_vars='Head', var_name='Subject', value_name=edge_name).dropna()  # Melt
    df = df.set_index('Subject').join(meta.set_index('SubID')[[target, target_comparison]]).reset_index()  # Join meta

    # Plot
    fig, axs = get_mosaic([4*[0], 4*[1]], scale=5)
    sns.despine()
    # axs[0].sharex(axs[1])

    # Main target
    p1 = sns.violinplot(data=df, x='Head', y=edge_name, hue=target, ax=axs[0])
    p1.legend(bbox_to_anchor=(1.1, 1.05))
    # Get correlation p-values for main target (which must be numeric)
    for i, c in enumerate(columns):
        pval = pearsonr(df.loc[df['Head']==c, edge_name], df.loc[df['Head']==c, target])[1]
        axs[0].text(i, axs[0].get_ylim()[0] - .15, f'p={pval:.1e}', ha='center', va='center')

    # Comparison target
    p2 = sns.violinplot(data=df, x='Head', y=edge_name, hue=target_comparison, ax=axs[1])
    p2.legend(bbox_to_anchor=(1.1, 1.05))
    p1.set(xlabel=None, xticklabels=[])
    plt.xticks(rotation=60)
    fig.savefig(f'../plots/group_differential_expression_{contrast_name}_{column}_{target}_{target_comparison}.pdf', format='pdf', transparent=True, backend='cairo')


### 4C Linkage Cluster Enrichment

In [24]:
for contrast_name, _, column, _, target, target_comparison in contrast_groupings:
    for tar in (target, target_comparison):
        # Get subject ids
        group = None  # contrast_group
        if group is None:
            # Population
            contrast_subjects = sum([v for k, v in get_contrast(contrast_name).items()], [])
        else:
            # Group
            contrast_subjects = get_contrast(contrast_name)[group]

        # Modify stack to include only contrast
        df = np.nan_to_num(attention_stack[:, np.argwhere(np.array(columns)==column)[0][0], [s in contrast_subjects for s in subject_ids]])
        new_subject_ids = [s for s in subject_ids if s in contrast_subjects]
        df = pd.DataFrame(df, index=all_edges, columns=new_subject_ids)

        # Get 100 most variant edges
        df = df.iloc[df.to_numpy().var(axis=1).argsort()[::-1][:100]]

        # Cluster
        labels = KMeans(n_clusters=10, n_init=10).fit_predict(df.to_numpy().T)
        labels += 1

        # Get phenotypes
        pheno = [meta.iloc[np.argwhere(meta['SubID'] == sid)[0][0]][tar] for sid in new_subject_ids]

        # Format results
        df = pd.DataFrame({'Cluster': labels, tar: pheno}, index=new_subject_ids)
        df['count'] = 1
        df = df.pivot_table(index='Cluster', columns=tar, values='count', aggfunc='sum').fillna(0)

        # Transform to hypergeometric
        df_np = df.to_numpy()
        df_np_new = np.zeros_like(df_np)
        for i, j in product(*[range(k) for k in df.shape]):
            # i - cluster, j - target
            dist = hypergeom(df_np.sum(), df_np[:, j].sum(), df_np[i, :].sum())
            # Calculate probability of overrepresentation
            df_np_new[i, j] = 1 - dist.cdf(df_np[i, j])
        df_np_new = -np.log10(df_np_new)
        df = pd.DataFrame(df_np_new, index=df.index, columns=df.columns)

        # Plot
        fig, axs = get_mosaic([list(range(1))], scale=9)
        sns.heatmap(df, cmap='rocket_r', cbar_kws={'label': '-log10(p)'}, ax=axs[0])
        # plt.tight_layout()
        fig.savefig(f'../plots/group_linkage_cluster_{contrast_name}_{column}_{tar}.pdf', format='pdf', transparent=True, backend='cairo')


  df_np_new = -np.log10(df_np_new)
  df_np_new = -np.log10(df_np_new)


### 4D Aggregate Graph Enrichment (MANUAL)

In [25]:
# NOTE: Only top 100 are taken for aggregate due to memory concerns
for contrast_name, group, column, _, _, _ in contrast_groupings:
      # Load contrast
      np.random.seed(42)
      contrast_subjects = get_contrast(contrast_name)
      gs = {
            gname: concatenate_graphs(*[
                  compute_graph(g)
                  for g in load_many_graphs(np.random.choice(sids, 100, replace=False))[0]
            ])
            for gname, sids in contrast_subjects.items()
      }

      # Split into groups
      # TODO: Make more general, perhaps add comparison group to arguments
      g1_name = group
      g1 = gs[g1_name]
      g2_name = 'Control'
      g2 = gs[g2_name]

      # Get unique TFs
      df = compare_graphs_enrichment(g1, g2, sid_1=g1_name, sid_2=g2_name, nodes=synthetic_nodes_of_interest)

      # Save to file
      df.to_csv(f'../plots/genes_{contrast_name}_{group}_{column}.csv', index=False)


No threshold provided, using threshold of 0.05016567523128692.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 147973/147973 [00:00<00:00, 266609.17it/s]


Filtered from 4709 vertices and 77684 edges to 748 vertices and 3552 edges via common edge filtering.
No threshold provided, using threshold of 0.05386086725079708.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 130009/130009 [00:00<00:00, 270991.79it/s]


Filtered from 4655 vertices and 73778 edges to 622 vertices and 2925 edges via common edge filtering.


In [26]:
# Enrichment
for contrast_name, group, column, _, _, _ in contrast_groupings:
    # MANUAL PROCESSING
    # Run the output from above on Metascape as multiple gene list and perform
    # enrichment.  From the all-in-one ZIP file, save the file from
    # Enrichment_QC/GO_DisGeNET as '../plot/disgenet_{subject_id_1}_{subject_id_2}_{column}.csv'

    # Get enrichment
    enrichment_file = f'../plots/disgenet_{contrast_name}_{group}_{column}.csv'
    if enrichment_file is None: continue
    enrichment = pd.read_csv(enrichment_file)

    # Format
    enrichment = format_enrichment(enrichment)

    # Plot
    fig, axs = get_mosaic([[0]*2], scale=9)
    pl = sns.scatterplot(
        enrichment,
        x='Gene Set', y='Description',
        size='-log10(p)',
        color='black',
        ax=axs[0])
    # Formatting
    pl.grid()
    plt.xticks(rotation=90)
    pl.set_aspect('equal', 'box')
    pl.legend(bbox_to_anchor=(1.2, 1.05))
    # Zoom X
    margin = .5
    min_xlim, max_xlim = pl.get_xlim()
    min_xlim -= margin; max_xlim += margin
    pl.set(xlim=(min_xlim, max_xlim))
    # Save
    fig.savefig(f'../plots/group_enrichment_{contrast_name}_{group}_{column}.pdf', format='pdf', transparent=True, backend='cairo')


### 4X Cross-Validation Accuracies

In [27]:
# # TODO: Make all y-labels horizontal
# for contrast_name, _, column, _, target, target_comparison in contrast_groupings:
#     for tar in (target, target_comparison):
#         # if contrast_name != 'c71x': continue
#         print(' - '.join((contrast_name, column, tar)))
#         # Get contrast
#         contrast = get_contrast(contrast_name)

#         # Compute prioritized edges
#         # Get 100 most variant edges
#         # TODO: Revise this method, maybe also consider means
#         sids = sum([sids for _, sids in contrast.items()], [])
#         df_subgroup = compute_contrast_summary(contrast, column=column)
#         df = join_df_subgroup(df_subgroup, num_sort=100)
#         prioritized_edges = list(df.index)

#         # Plot
#         # TODO: Maybe return to row-normalization
#         fig, axs = get_mosaic([[0]], scale=9)
#         df, acc = plot_prediction_confusion(contrast, meta=meta, column=column, target=tar, prioritized_edges=prioritized_edges, classifier_type='SGD', ax=axs[0])

#         # Save plot
#         fname_prefix = f'../plots/group_prioritized_edge_prediction_{contrast_name}_{column}_{tar}'
#         fig.savefig(f'{fname_prefix}.pdf', format='pdf', transparent=True, backend='cairo')

#         # Save text
#         f_edges = open(f'{fname_prefix}.edges.txt', 'w')
#         f_tfs = open(f'{fname_prefix}.tfs.txt', 'w')
#         f_tgs = open(f'{fname_prefix}.tgs.txt', 'w')
#         for edge in prioritized_edges:
#             f_edges.write(edge + '\n')
#             tf, tg = edge.split(get_edge_string(['', '']))
#             f_tfs.write(tf + '\n')
#             f_tgs.write(tg + '\n')
#         f_edges.close()
#         f_tfs.close()
#         f_tgs.close()

#         # CLI
#         print()


## General

### Graph Legend

In [28]:
# Plot legend
plt.clf()
plot_legend()
plt.gca().axis('off')
plt.tight_layout()
plt.savefig(f'../plots/graph_legend.pdf', format='pdf', transparent=True, backend='cairo')


  plt.tight_layout()
