In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%matplotlib notebook


In [2]:
import os

import graph_tool.all as gt
import matplotlib
from matplotlib.colors import LogNorm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from functions import *


# Graph-Tool compatibility
plt.switch_backend('cairo')
# Style
sns.set_theme(context='talk', style='white', palette='Set2')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


# Meta and Parameters

In [3]:
# Load metadata
meta = get_meta()

# Subject preview
filtered = []
for i, row in meta.iterrows():
    try:
        load_graph_by_id(row['SubID'])
        assert not np.isnan(row['nps_MoodDysCurValue'])  # Has NPS information available
        # assert row['BRAAK_AD'] in (3, 4, 5)
    except:
        continue
    filtered.append(f'{row["SubID"]} {row["Ethnicity"]} {row["Sex"]}, {row["Age"]}, BRAAK {row["BRAAK_AD"]}')
filtered = np.sort(filtered)
for i in range(len(filtered)):
    print(filtered[i])

# Parameters
print(f'\nAvailable attention columns: {get_attention_columns()}')
column_ad = get_attention_columns()[0]
column_scz = get_attention_columns()[2]
column_data = get_attention_columns()[4]
synthetic_nodes_of_interest = ['OPC', 'Micro', 'Oligo', 'EN']


  return pd.read_csv(META)


M10031 AMR Male, 71.0, BRAAK 0.0
M10282 EUR Male, 73.0, BRAAK 6.0
M10730 EUR Female, 90.0, BRAAK 5.0
M107983 AMR Female, 65.0, BRAAK 1.0
M10874 EUR Female, 88.0, BRAAK 3.0
M10886 AMR Male, 46.0, BRAAK 0.0
M1119 EUR Female, 91.0, BRAAK 6.0
M11371 EUR Male, 95.0, BRAAK 2.0
M1140 EUR Female, 96.0, BRAAK 1.0
M11588 EUR Female, 62.0, BRAAK 6.0
M11589 AMR Female, 63.0, BRAAK 2.0
M11716 EUR Female, 89.0, BRAAK 5.0
M1176 EUR Female, 85.0, BRAAK 0.0
M118449 EUR Male, 80.0, BRAAK 6.0
M11938 EUR Female, 90.0, BRAAK 3.0
M1198 EUR Female, 87.0, BRAAK 6.0
M12047 EUR Female, 80.0, BRAAK 5.0
M12249 EUR Male, 51.0, BRAAK 0.0
M12326 EUR Female, 93.0, BRAAK 6.0
M12479 AMR Female, 83.0, BRAAK 6.0
M12514 EUR Female, 88.0, BRAAK 4.0
M12614 AFR Female, 75.0, BRAAK 0.0
M12792 EUR Female, 91.0, BRAAK 5.0
M12876 AMR Female, 76.0, BRAAK 2.0
M13326 EUR Male, 84.0, BRAAK 4.0
M133696 EUR Female, 52.0, BRAAK nan
M13458 EUR Male, 91.0, BRAAK 6.0
M13640 EUR Female, 82.0, BRAAK 3.0
M13670 AFR Male, 83.0, BRAAK 6.0
M139

# Plots

## Individual Comparisons (Figure 3)

In [4]:
individual_comparisons = [
    # M19050 Hispanic Female, 74.0, BRAAK 5.0
    # M59593 Hispanic Female, 76.0, BRAAK 5.0
    # M72079 Black Female, 64.0, BRAAK 6.0
    # M41496 Black Female, 76.0, BRAAK 4.0
    # M11589 Black Female, 63.0, BRAAK 2.0
    # M73342 Black Female, 62.0, BRAAK 0.0
    # (subject_id_1, subject_id_2, column)
    # for subject_id_1, subject_id_2, column in individual_comparisons:
    ('M19050', 'M59593', column_ad),  # AD - AD
    # ('M72079', 'M41496', column_ad),  # AD - High BRAAK
    # ('M72079', 'M11589', column_ad),  # AD - Low BRAAK
    # ('M72079', 'M73342', column_ad),  # AD - CTRL
]
palette = plt.rcParams['axes.prop_cycle'].by_key()['color']
individual_colors = {
    sid: rgba_to_hex(palette[i]) for i, sid in enumerate(
        sum([list(comparison[:2]) for comparison in individual_comparisons], []))
}

# Verify all are available
for subject_id_1, subject_id_2, column in individual_comparisons:
    for sid in [subject_id_1, subject_id_2]:
        load_graph_by_id(sid)


### 3A Mini Plots

In [10]:
for i, (subject_id_1, subject_id_2, column) in enumerate(individual_comparisons):
    print(' - '.join((subject_id_1, subject_id_2, column)))

    # Assemble
    sids = [subject_id_1, subject_id_2]
    gs = [compute_graph(load_graph_by_id(sid, column=column)) for sid in sids]

    # Filter
    gs = [
        filter_to_synthetic_vertices(g.copy(), vertex_ids=synthetic_nodes_of_interest)
        for g in gs
    ]

    # Recalculate
    gs = [assign_vertex_properties(g) for g in gs]

    # Plot
    fig, axs = get_mosaic([list(range(2))], scale=9)
    plot_graph_comparison(gs, axs=axs, subject_ids=sids)
    fig.savefig(f'../plots/individual_mini_{"-".join(sids)}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
    print()


M19050 - M59593 - att_D_AD_1
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████████| 167/167 [00:00<00:00, 202852.24it/s]


Calculating positions...
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 84/84 [00:00<00:00, 154121.41it/s]


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 83/83 [00:00<00:00, 160723.56it/s]





### 3B Attention Comparisons

In [24]:
for subject_id_1, subject_id_2, _ in individual_comparisons:
    for column in get_attention_columns():
        print(' - '.join((subject_id_1, subject_id_2, column)))

        # Assemble
        sample_ids = [subject_id_1, subject_id_2]
        graphs = [compute_graph(load_graph_by_id(sid, column=column)) for sid in sample_ids]

        # Get graph
        g = concatenate_graphs(*graphs, threshold=False)
        g = get_intersection(g)
        g = cull_isolated_leaves(g)

        fig, axs = get_mosaic([list(range(1))], scale=6)
        df = plot_individual_edge_comparison(g, sample_ids, ax=axs[0])
        plt.tight_layout()
        fig.savefig(f'../plots/individual_edge_comparison_{"-".join((subject_id_1, subject_id_2))}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
        print()


M19050 - M59593 - att_D_AD_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 240090.57it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_AD_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 228120.96it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_SCZ_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 252093.59it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_SCZ_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 296480.16it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_0
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 250483.28it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 216393.86it/s]
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 259844.07it/s]
  fig = plt.figure(figsize=(scale*len(mosaic[0]), scale*len(mosaic)), constrained_layout=True)
  plt.tight_layout()



M19050 - M59593 - att_D_no_prior_3
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3347/3347 [00:00<00:00, 251917.16it/s]
  plt.tight_layout()



M72079 - M73342 - att_D_AD_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3797/3797 [00:00<00:00, 276979.59it/s]
  plt.tight_layout()



M72079 - M73342 - att_D_AD_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3797/3797 [00:00<00:00, 231321.22it/s]
  plt.tight_layout()



M72079 - M73342 - att_D_SCZ_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3797/3797 [00:00<00:00, 238449.03it/s]
  plt.tight_layout()



M72079 - M73342 - att_D_SCZ_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3797/3797 [00:00<00:00, 248731.37it/s]
  plt.tight_layout()



M72079 - M73342 - att_D_no_prior_0
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3797/3797 [00:00<00:00, 233153.34it/s]
  plt.tight_layout()



M72079 - M73342 - att_D_no_prior_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3797/3797 [00:00<00:00, 251155.53it/s]
  plt.tight_layout()



M72079 - M73342 - att_D_no_prior_2
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3797/3797 [00:00<00:00, 302122.29it/s]
  plt.tight_layout()



M72079 - M73342 - att_D_no_prior_3
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 3797/3797 [00:00<00:00, 201237.98it/s]





  plt.tight_layout()


### 3C Pathway Enrichment (MANUAL)

In [8]:
# Enrichment
for subject_id_1, subject_id_2, column in individual_comparisons:
    # Assemble
    sample_ids = [subject_id_1, subject_id_2]
    graphs = [load_graph_by_id(sid, column=column) for sid in sample_ids]
    for i in range(len(graphs)):
        # Remove self-loops
        graphs[i] = graphs[i].loc[graphs[i].apply(lambda x: x['TF'] != x['TG'], axis=1)]

        # Set index, combine tf and tg, rename coef column
        graphs[i].index = graphs[i].apply(lambda x: get_edge_string([x['TF'], x['TG']]), axis=1)
        graphs[i] = graphs[i].drop(columns=['TF', 'TG'])
        graphs[i] = graphs[i].rename(columns={'coef': sample_ids[i]})

        # Drop duplicates
        # TODO: Why are there duplicates?
        graphs[i] = graphs[i].drop_duplicates()

    # Filter to common graphs and join
    graphs = graphs[0].join(graphs[1], how='inner')

    # Get differentially expressed genes
    graphs['Difference'] = np.abs(graphs[subject_id_1] - graphs[subject_id_2])
    graphs = graphs.sort_values('Difference', ascending=False)

    # Write to file
    fname_prefix = f'../plots/individual_pathway_enrichment_{"-".join((subject_id_1, subject_id_2))}_{column}'
    f_tfs = open(fname_prefix + '.tfs.txt', 'w')
    f_tgs = open(fname_prefix + '.tgs.txt', 'w')
    f_edges = open(fname_prefix + '.edges.txt', 'w')
    for edge in graphs.index:
        # TODO: Only show unique tf, tg, maybe?
        f_edges.write(edge + '\t' + f'{graphs.loc[edge, "Difference"]:.5f}' + '\n')
        tf, tg = edge.split(get_edge_string())
        # Don't show synthetic nodes
        if not string_is_synthetic(tf): f_tfs.write(tf + '\n')
        if not string_is_synthetic(tg): f_tgs.write(tg + '\n')
    f_tfs.close()
    f_tgs.close()
    f_edges.close()

    # MANUAL PROCESSING
    # Run the output from '<fname_prefix>.xxs.txt' in DisGeNet, save file
    # from Enrichment_QC/GO_DisGeNET as '<fname_prefix>.xxs.csv'

    # Plot enrichments
    get_enrichment_file = lambda x: f'{fname_prefix}.{x}'
    for ftype in ['tfs', 'tgs']:
        # Get file name
        fname = get_enrichment_file(ftype)

        # Get enrichment
        enrichment = get_enrichment(fname)
        if enrichment is None: continue

        # Plot
        fig, axs = get_mosaic([list(range(1))], scale=9)
        plot_enrichment(enrichment, ax=axs[0])
        fig.savefig(fname + '.pdf', format='pdf', transparent=True, backend='cairo')

        # Plot
        fig, axs = get_mosaic([list(range(1))], scale=9)
        plot_enrichment(enrichment, ax=axs[0])
        fig.savefig(fname + '.pdf', format='pdf', transparent=True, backend='cairo')


### 3D Attention Heatmap

In [5]:
# Parameters
# Scaled probably shouldn't be used, but better for visualization
# until results are more even
columns = get_attention_columns(scaled=False)
subject_ids = meta['SubID'].to_numpy()

# Load graphs
graphs, subject_ids = load_many_graphs(subject_ids, column=columns)
# graphs = [compute_graph(g) for g in graphs]

# # Get attentions
# df = {}
# for column in get_attention_columns():
#     attention, _ = compute_edge_summary(graphs, subject_ids=subject_ids)
#     attention = attention.set_index('Edge')
#     df[column] = attention.var(axis=1)


# Set indices to edges and clean
print('Fixing indices...')
for i in tqdm(range(len(graphs))):
    graphs[i].index = graphs[i].apply(lambda r: get_edge_string([r['TF'], r['TG']]), axis=1)
    graphs[i] = graphs[i].drop(columns=['TF', 'TG'])
    # Remove duplicates
    graphs[i] = graphs[i][~graphs[i].index.duplicated(keep='first')]

# Get all unique edges
print('Getting unique edges...')
all_edges = np.unique(sum([list(g.index) for g in graphs], []))


# Standardize index order
print('Standardizing indices...')
for i in tqdm(range(len(graphs))):
    # Add missing indices and order based on `all_edges`
    # to_add = [edge for edge in all_edges if edge not in list(graphs[i].index)]  # SLOW
    to_add = list(set(all_edges) - set(graphs[i].index))

    # Empty rows
    new_rows = pd.DataFrame(
        [[np.nan]*len(graphs[i].columns)]*len(to_add),
        columns=graphs[i].columns,
    ).set_index(pd.Series(to_add))
    # Native concat
    graphs[i] = pd.concat([graphs[i], new_rows]).loc[all_edges]

# Convert to numpy
graphs = [g.to_numpy() for g in graphs]
attention_stack = np.stack(graphs, axis=-1)
# attention_stack.shape = (Edge, Head, Subject)
# attention_stack.shape = (all_edges, columns, subject_ids)


Fixing indices...


100%|████████████████████████████████████████████████████████████████████████████████| 616/616 [00:04<00:00, 150.98it/s]


Getting unique edges...
Standardizing indices...


100%|█████████████████████████████████████████████████████████████████████████████████| 616/616 [05:17<00:00,  1.94it/s]


In [82]:
# Calculate heatmap
df = np.var(np.nan_to_num(attention_stack), axis=2)
# Create df
df = pd.DataFrame(df, index=all_edges, columns=columns)
# Sort
df = df.iloc[df.fillna(0).mean(axis=1).argsort().to_numpy()[::-1]]
# Standardize for visualization
# TODO: Remove once model scale is fixed, only for visualization
df = (df - df.mean()) / df.std()
# Sample or filter
# TODO: Currently, too many edges leads to the whole
# visualization looking like the darkest value.  If
# this can be fixed, then the random sampling can be
# removed.
# df = df.iloc[np.random.choice(df.shape[0], 200, replace=False)]
df = df.iloc[np.nan_to_num(attention_stack).mean(axis=(1,2)) > .05]

### Combined clustermap
# Assign groups by associated cell type
# TODO: Make greater depth, currently 1
clusters = pd.DataFrame(
    np.array([
        [tf, tg] for tf, tg in df.index.map(lambda s: split_edge_string(s))
    ]),
    index=df.index,
    columns=pd.Series(['TF', 'TG']),
)
def assign(row):
    if string_is_synthetic(row['TF']):
        return row['TF']
    if string_is_synthetic(row['TG']):
        return row['TG']
    return 'None'
clusters['Association'] = clusters.apply(assign, axis=1)
# Convert to colors
cluster_colors = {
    a: c for a, c in zip(
        np.unique(clusters['Association']),
        sns.color_palette(palette='husl', n_colors=np.unique(clusters['Association']).shape[0]),
    )
}
clusters['Colors'] = clusters['Association'].apply(lambda a: cluster_colors[a])
# Plot

np.random.seed(42)
fig = sns.clustermap(
    data=df,
    row_colors=clusters[['Colors']].rename(columns={'Colors': 'Cell Association'}),
    # row_cluster=False,
    # norm=LogNorm(),
    cmap='mako',
    # dendrogram_ratio=.1,
    figsize=(9, 27),
)
fig.savefig(f'../plots/individual_edge_variance_heatmap.pdf', format='pdf', transparent=True, backend='cairo')
plt.show()
# Plot legend
plt.clf()
ax = plt.gca()
legend_elements = [
    Line2D([0], [0], color='gray', linestyle='None', markersize=10, marker='s', markerfacecolor=color, label=ct)
    for ct, color in cluster_colors.items()
]
ax.legend(handles=legend_elements, loc='best')
plt.gca().axis('off')
plt.tight_layout()
plt.savefig(f'../plots/individual_edge_variance_heatmap_cell_legend.pdf', format='pdf', transparent=True, backend='cairo')

### Combined heatmap
# fig, axs = get_mosaic([[0,0,0]]*9, scale=3)
# sns.heatmap(data=df, cmap='crest', norm=LogNorm(), ax=axs[0])
# plt.xticks(rotation=60)
# fig.savefig(f'../plots/individual_edge_variance_heatmap.pdf', format='pdf', transparent=True, backend='cairo')
# plt.show()

### Separate heatmaps
# # Get data and diffusion idx
# idx = range(len(columns))
# data_idx = [i for i in idx if 'no_prior' in columns[i]]
# diff_idx = list(set(idx) - set(data_idx))

# # Data heatmap
# fig, axs = get_mosaic([[0,0,0]]*9, scale=3)
# sns.heatmap(data=df.iloc[:, data_idx], cmap='crest', norm=LogNorm(), ax=axs[0])
# plt.xticks(rotation=60)
# fig.savefig(f'../plots/individual_edge_variance_heatmap_data.pdf', format='pdf', transparent=True, backend='cairo')
# plt.show()

# # Diffusion heatmap
# fig, axs = get_mosaic([[0,0,0]]*9, scale=3)
# sns.heatmap(data=df.iloc[:, diff_idx], cmap='crest', norm=LogNorm(), ax=axs[0])
# plt.xticks(rotation=60)
# fig.savefig(f'../plots/individual_edge_variance_heatmap_diff.pdf', format='pdf', transparent=True, backend='cairo')
# plt.show()


In [None]:
# # Get variance filtered by contrast
# contrast = 'c15x'
# contrast_subjects = sum([v for k, v in get_contrast(contrast).items()], [])
# attention_stack[:, :, [s in contrast_subjects for s in subject_ids]]
# df = np.var(np.nan_to_num(attention_stack), axis=2)
# df = pd.DataFrame(df, index=all_edges, columns=columns)
# df.to_csv(f'../plots/{contrast}_variation.csv')


In [33]:
# # Parameters
# column = column_data

# # Filter to data column
# data_idx = np.argwhere(np.array(columns) == column)[0][0]
# df = pd.DataFrame(attention_stack[:, data_idx], index=all_edges, columns=subject_ids)
# # Sort (fillna can be excluded, but this also makes more common edges visible)
# df = df.iloc[df.fillna(0).mean(axis=1).argsort().to_numpy()[::-1]]

# # Individual heatmap (Limited to top 5k links)
# fig, axs = get_mosaic([[0]*9]*9, scale=3)
# sns.heatmap(data=df.iloc[:5000], cmap='crest', norm=LogNorm(), ax=axs[0])
# plt.xticks(rotation=60)
# fig.savefig(f'../plots/individual_edge_heatmap_{column}.pdf', format='pdf', transparent=True, backend='cairo')
# plt.show()


### 3E Dosage Analysis

In [22]:
# Parameters
column = column_ad
subject_ids = meta['SubID'].to_numpy()

# Load graphs
graphs, subject_ids = load_many_graphs(subject_ids, column=column)
graphs = [compute_graph(g) for g in graphs]

# Get dosage information
dosage = get_dosage()
# Why do some SNPs go missing with the new meta?
dosage = convert_dosage_ids_to_subject_ids(dosage, meta=meta)

# Get attention
attention, _ = compute_edge_summary(graphs, subject_ids=subject_ids)
attention = attention.set_index('Edge')


  dosage = pd.read_csv(DOSAGE)
' in metadata


No threshold provided, using threshold of 0.01381273008154222.
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████| 1005000/1005000 [00:03<00:00, 302670.25it/s]


Filtered from 5047 vertices and 362559 edges to 1629 vertices and 12719 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 12719/12719 [00:02<00:00, 5782.52it/s]


In [23]:
# Select target SNP
target_snp = dosage.index[42]  # Random for now

# Make df
data_dosage = dosage.loc[[target_snp]].T
data_attention = attention.T
df = data_dosage.join(data_attention, how='inner')

# Select target edge
p_min = 1
for edge in attention.index:
    corr, pval = scipy.stats.pearsonr(
        df[[edge]].to_numpy().squeeze(),
        df[[target_snp]].to_numpy().squeeze())
    if pval < p_min:
        p_min = pval
        best_corr = corr
        target_edge = edge
print(f'Found minimal p-value of {p_min:.6f} (Correlation: {best_corr:.6f}).')

# Format df
axis_snp = f'{target_snp} Dosage'
axis_edge = f'{target_edge} Attention'
df = df.rename(columns={target_snp: axis_snp, target_edge: axis_edge})

# Scatter
fig, axs = get_mosaic([list(range(1))], scale=9)
sns.scatterplot(data=df, x=axis_snp, y=axis_edge, ax=axs[0])
fig.savefig(f'../plots/individual_dosage_correlation_{column}.pdf', format='pdf', transparent=True, backend='cairo')


Found minimal p-value of 0.000171 (Correlation: 0.163327).


## Group Comparisons (Figure 4)

In [40]:
# Combinations
# TODO: Potentially move each entry to dictionary, so changes in order
#   are easier to propagate
contrast_groupings = [
    # (contrast name, contrast group, attention column, comparison column, target meta column, other target meta column)
    # for contrast_name, contrast_group, column, comparison, target, target_comparison in contrast_groupings:
    # TODO: Revise ethnicity prediction
    ('c15x', 'AD', column_data, column_ad, 'Ethnicity', 'BRAAK_AD'),
    # ('c06x', 'AD', column_ad, column_data, 'BRAAK_AD'),  # Eventually SCZ, BP and such
    # ('c71x', 'MoodDys', column_data, column_ad, 'nps_MoodDysCurValue'),  # Dysphoria
    # ('c72x', 'DecInt', column_data, column_ad, 'nps_DecIntCurValue'),  # Anhedonia
]


### 4A Variance Heatmap

In [35]:
# Get plots for each column
for contrast_name, _, column, comparison, _, _ in contrast_groupings:
    for col in (column, comparison):
        print(' - '.join((contrast_name, col)))

        # Get contrast
        contrast = get_contrast(contrast_name)

        # Compute
        df_subgroup = compute_contrast_summary(contrast, column=col)

        # Plot mean-sorted
        fig, axs = get_mosaic([list(range(1))], scale=9)
        plot_subgroup_heatmap(df_subgroup, ax=axs[0])
        plt.tight_layout()
        fig.savefig(f'../plots/group_variance_heatmap_{contrast_name}_{col}.pdf', format='pdf', transparent=True, backend='cairo')


c15x - att_D_no_prior_0
No threshold provided, using threshold of 0.017553121661133885.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 698279/698279 [00:02<00:00, 309781.98it/s]


Filtered from 5030 vertices and 270687 edges to 1463 vertices and 10173 edges via common edge filtering.
No threshold provided, using threshold of 0.04979338607285737.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 146728/146728 [00:00<00:00, 312538.86it/s]


Filtered from 4719 vertices and 81268 edges to 684 vertices and 3350 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10173/10173 [00:01<00:00, 7498.59it/s]


Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 3350/3350 [00:00<00:00, 22346.14it/s]


No threshold provided, using threshold of 0.01546432984072714.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 845007/845007 [00:02<00:00, 310301.53it/s]


Filtered from 5039 vertices and 315575 edges to 1488 vertices and 10734 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10734/10734 [00:01<00:00, 6635.70it/s]
  plt.tight_layout()


c15x - att_D_AD_1
No threshold provided, using threshold of 0.017553121661133885.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 698279/698279 [00:02<00:00, 314327.01it/s]


Filtered from 5030 vertices and 270687 edges to 1463 vertices and 10173 edges via common edge filtering.
No threshold provided, using threshold of 0.04979338607285737.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 146728/146728 [00:00<00:00, 324060.76it/s]


Filtered from 4719 vertices and 81268 edges to 684 vertices and 3350 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10173/10173 [00:01<00:00, 8249.59it/s]


Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 3350/3350 [00:00<00:00, 24357.50it/s]


No threshold provided, using threshold of 0.01546432984072714.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 845007/845007 [00:02<00:00, 318097.24it/s]


Filtered from 5039 vertices and 315575 edges to 1488 vertices and 10734 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10734/10734 [00:01<00:00, 6827.90it/s]
  plt.tight_layout()


### 4B Distribution Comparison

In [39]:
# TODO: Use highest variance rather than random edges
# TODO: Make both columns same scale, maybe?
for contrast_name, _, column, comparison, target, _ in contrast_groupings:
    print(' - '.join((contrast_name, column, target)))
    # Get contrast
    contrast = get_contrast(contrast_name)

    # Plot
    fig, axs = get_mosaic([2*[0], 2*[1]], scale=3)
    sns.despine()

    _, edges_include = plot_BRAAK_comparison(
        contrast,
        # {k: v[:10] for k, v in contrast.items()},
        meta=meta,
        column=column,
        target=target,
        legend=False,
        ax=axs[0])
    plt.xlabel(None)
    plt.ylabel(column)
    plt.xticks([])

    plot_BRAAK_comparison(
        contrast,
        meta=meta,
        column=comparison,
        target=target,
        edges_include=edges_include,
        ax=axs[1])
    plt.ylabel(comparison)

    fig.savefig(f'../plots/group_differential_expression_{contrast_name}_{column}_{comparison}_{target}.pdf', format='pdf', transparent=True, backend='cairo')
    print()


c15x - att_D_no_prior_0 - Ethnicity
No threshold provided, using threshold of 0.01546432984072714.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 845007/845007 [00:02<00:00, 327763.05it/s]


Filtered from 5039 vertices and 315575 edges to 1488 vertices and 10734 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10734/10734 [00:01<00:00, 7225.57it/s]


No threshold provided, using threshold of 0.01546432984072714.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 845007/845007 [00:03<00:00, 248182.11it/s]


Filtered from 5039 vertices and 315575 edges to 1488 vertices and 10734 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10734/10734 [00:01<00:00, 7209.93it/s]





### 4C Cross-Validation Accuracies

In [41]:
# TODO: Make all y-labels horizontal
for contrast_name, _, column, _, target, target_comparison in contrast_groupings:
    for tar in (target, target_comparison):
        # if contrast_name != 'c71x': continue
        print(' - '.join((contrast_name, column, tar)))
        # Get contrast
        contrast = get_contrast(contrast_name)

        # Compute prioritized edges
        # Get 100 most variant edges
        # TODO: Revise this method, maybe also consider means
        sids = sum([sids for _, sids in contrast.items()], [])
        df_subgroup = compute_contrast_summary(contrast, column=column)
        df = join_df_subgroup(df_subgroup, num_sort=100)
        prioritized_edges = list(df.index)

        # Plot
        # TODO: Maybe return to row-normalization
        fig, axs = get_mosaic([[0]], scale=9)
        df, acc = plot_prediction_confusion(contrast, meta=meta, column=column, target=tar, prioritized_edges=prioritized_edges, classifier_type='SGD', ax=axs[0])

        # Save plot
        fname_prefix = f'../plots/group_prioritized_edge_prediction_{contrast_name}_{column}_{tar}'
        fig.savefig(f'{fname_prefix}.pdf', format='pdf', transparent=True, backend='cairo')

        # Save text
        f_edges = open(f'{fname_prefix}.edges.txt', 'w')
        f_tfs = open(f'{fname_prefix}.tfs.txt', 'w')
        f_tgs = open(f'{fname_prefix}.tgs.txt', 'w')
        for edge in prioritized_edges:
            f_edges.write(edge + '\n')
            tf, tg = edge.split(get_edge_string(['', '']))
            f_tfs.write(tf + '\n')
            f_tgs.write(tg + '\n')
        f_edges.close()
        f_tfs.close()
        f_tgs.close()

        # CLI
        print()


c15x - att_D_no_prior_0 - Ethnicity
No threshold provided, using threshold of 0.017553121661133885.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 698279/698279 [00:02<00:00, 317518.60it/s]


Filtered from 5030 vertices and 270687 edges to 1463 vertices and 10173 edges via common edge filtering.
No threshold provided, using threshold of 0.04979338607285737.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 146728/146728 [00:00<00:00, 307304.66it/s]


Filtered from 4719 vertices and 81268 edges to 684 vertices and 3350 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10173/10173 [00:01<00:00, 8003.11it/s]


Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 3350/3350 [00:00<00:00, 23049.63it/s]


No threshold provided, using threshold of 0.01546432984072714.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 845007/845007 [00:02<00:00, 319710.15it/s]


Filtered from 5039 vertices and 315575 edges to 1488 vertices and 10734 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10734/10734 [00:01<00:00, 6782.52it/s]


No threshold provided, using threshold of 0.01546432984072714.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 845007/845007 [00:02<00:00, 322522.42it/s]


Filtered from 5039 vertices and 315575 edges to 1488 vertices and 10734 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10734/10734 [00:01<00:00, 6991.41it/s]



c15x - att_D_no_prior_0 - BRAAK_AD
No threshold provided, using threshold of 0.017553121661133885.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 698279/698279 [00:02<00:00, 320423.04it/s]


Filtered from 5030 vertices and 270687 edges to 1463 vertices and 10173 edges via common edge filtering.
No threshold provided, using threshold of 0.04979338607285737.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 146728/146728 [00:00<00:00, 314874.63it/s]


Filtered from 4719 vertices and 81268 edges to 684 vertices and 3350 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10173/10173 [00:01<00:00, 7940.68it/s]


Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 3350/3350 [00:00<00:00, 22981.13it/s]


No threshold provided, using threshold of 0.01546432984072714.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 845007/845007 [00:02<00:00, 320074.75it/s]


Filtered from 5039 vertices and 315575 edges to 1488 vertices and 10734 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10734/10734 [00:01<00:00, 6633.76it/s]


No threshold provided, using threshold of 0.01546432984072714.
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████| 845007/845007 [00:02<00:00, 316957.47it/s]


Filtered from 5039 vertices and 315575 edges to 1488 vertices and 10734 edges via common edge filtering.
Collecting edges...


100%|███████████████████████████████████████████████████████████████████████████| 10734/10734 [00:01<00:00, 6880.93it/s]





## General

### Graph Legend

In [38]:
# Plot legend
plt.clf()
plot_legend()
plt.gca().axis('off')
plt.tight_layout()
plt.savefig(f'../plots/graph_legend.pdf', format='pdf', transparent=True, backend='cairo')


  plt.tight_layout()
