In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import graph_tool.all as gt
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from functions import *


# Graph-Tool compatibility
plt.switch_backend('cairo')
# Style
sns.set_theme(context='talk', style='white', palette='Set2')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# Computation

In [3]:
# Load metadata
meta = get_meta()

# Subject preview
filtered = []
for i, row in meta.iterrows():
    try:
        load_graph_by_id(row['SubID'])
        assert not np.isnan(row['nps_MoodDysCurValue'])  # Has NPS
        assert row['Sex'] == 'Female'
        assert row['Ethnicity'] != 'White'
    except:
        continue
    filtered.append(f'{row["SubID"]} {row["Ethnicity"]} {row["Sex"]}, {row["Age"]}, BRAAK {row["BRAAK_AD"]}')
for i in (2, 4, 5, 6):
    print(filtered[i])

# Parameters
column_diff = 'att_D_AD_0_1'
column_data = 'att_D_no_prior_0'
coex_diff_compare_phenotype = 'nps_PsychoAgiCurValue'
diff_data_compare_phenotype = 'nps_WtLossCurValue'
individual_subject_ids = ['M19050', 'M59593', 'M83214', 'M36634']

M19050 Hispanic Female, 74.0, BRAAK 5.0
M59593 Hispanic Female, 76.0, BRAAK 5.0
M83214 Hispanic Female, 83.0, BRAAK 6.0
M36634 Hispanic Female, 87.0, BRAAK 6.0


In [4]:
# Compute individual graphs
coex_g_individuals = [cull_isolated_leaves(compute_graph(scale_edge_coefs_list(load_graph_by_id(individual_subject_ids[i], source='coexpression'), 1./60), filter=.9)) for i in range(len(individual_subject_ids))]
diff_g_individuals = [compute_graph(load_graph_by_id(individual_subject_ids[i], column=column_diff)) for i in range(len(individual_subject_ids))]
data_g_individuals = [compute_graph(load_graph_by_id(individual_subject_ids[i], column=column_data)) for i in range(len(individual_subject_ids))]

# Plots

### 3A Individual Plots

In [5]:
for column, graphs in zip((column_diff, column_data), (diff_g_individuals, data_g_individuals)):
    # Limit number
    num_graphs = 2
    sids = individual_subject_ids[:num_graphs]
    gs = graphs[:num_graphs]

    # Filter
    gs = [
        filter_graph_by_synthetic_vertices(g.copy(), vertex_ids=['OPC', 'Micro', 'Oligo'])
        for g in gs
    ]
    # Plot
    fig, axs = get_mosaic([list(range(num_graphs))], scale=9)
    plot_graph_comparison(gs, axs=axs, subject_ids=sids)

    fig.savefig(f'../plots/mini_{"-".join(sids)}_{column}.pdf', format='pdf', transparent=True, backend='cairo')

Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████████| 140/140 [00:00<00:00, 212831.66it/s]


Calculating positions...
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 218916.02it/s]


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 219613.26it/s]


Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████████| 143/143 [00:00<00:00, 220428.32it/s]


Calculating positions...
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 197547.42it/s]


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 74/74 [00:00<00:00, 223454.64it/s]


### 3B Individual Comparisons

In [6]:
for column, graphs in zip((column_diff, column_data), (diff_g_individuals[:2], data_g_individuals[:2])):
        # Get graph
        sample_ids = individual_subject_ids[:2]
        g = concatenate_graphs(*graphs, threshold=False)
        g = get_intersection(g)
        g = cull_isolated_leaves(g)

        fig, axs = get_mosaic([list(range(1))], scale=6)
        df = plot_individual_edge_comparison(g, sample_ids, ax=axs[0])
        plt.tight_layout()
        fig.savefig(f'../plots/edge_comparison_{column}.pdf', format='pdf', transparent=True, backend='cairo')

Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 1940/1940 [00:00<00:00, 106157.20it/s]
  plt.tight_layout()


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 1940/1940 [00:00<00:00, 107594.61it/s]
  plt.tight_layout()


### 3C Characteristic Curves

In [7]:
# Parameters
contrast_name = 'c06x'
contrast = get_contrast(contrast_name)

# Get plots for each column
for column in (column_diff, column_data):
    # Compute
    df_subgroup = compute_contrast_summary(contrast, column=column)

    # Plot
    fig, axs = get_mosaic([list(range(3))], scale=9)
    
    plot_subgroup_heatmap(df_subgroup, ax=axs[0])
    plot_contrast_curve(df_subgroup, ax=axs[1], legend=False)
    plot_contrast_curve(df_subgroup, sorting_subgroup='Population', ax=axs[2])
    
    fig.savefig(f'../plots/characteristic_curve_{contrast_name}_{column}.pdf', format='pdf', transparent=True, backend='cairo')

Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 187003/187003 [00:28<00:00, 6533.74it/s]


Filtered from 3703 vertices and 50922 edges to 823 vertices and 5699 edges via common edge filtering.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 63192/63192 [00:04<00:00, 12732.36it/s]


Filtered from 2498 vertices and 23940 edges to 473 vertices and 2955 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 5699/5699 [00:00<00:00, 14719.38it/s]


Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 2955/2955 [00:00<00:00, 28836.66it/s]


Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 250195/250195 [00:46<00:00, 5405.14it/s]


Filtered from 4002 vertices and 62619 edges to 801 vertices and 5752 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 5752/5752 [00:00<00:00, 12542.53it/s]


Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 187003/187003 [00:27<00:00, 6712.58it/s]


Filtered from 3703 vertices and 50922 edges to 823 vertices and 5699 edges via common edge filtering.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 63192/63192 [00:04<00:00, 13244.01it/s]


Filtered from 2498 vertices and 23940 edges to 473 vertices and 2955 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 5699/5699 [00:00<00:00, 15268.58it/s]


Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 2955/2955 [00:00<00:00, 28173.75it/s]


Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 250195/250195 [00:45<00:00, 5472.98it/s]


Filtered from 4002 vertices and 62619 edges to 801 vertices and 5752 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 5752/5752 [00:00<00:00, 12812.80it/s]


### 4A Group Visualizations

In [8]:
# TODO: Refine this, put into function
# Parameters
contrast_name = 'c06x'
contrast = get_contrast(contrast_name)

# Iteration
group = 'AD'
column = column_diff

# Calculate aggregate graph
all_graphs = [
    compute_graph(load_graph_by_id(sid, column=column_diff))
    for sid in contrast['AD'][:10]
]
aggregate_graph = concatenate_graphs(
    *all_graphs,
    threshold=False,
    color_by_source=False,
    remove_duplicate_edge=False,
)

# Plot
fig, axs = get_mosaic([list(range(1))], scale=9)

ax = axs[0]
visualize_graph_base(aggregate_graph, pos=get_graph_pos(aggregate_graph), mplfig=ax)
ax.axis('off')

fig.savefig(f'../plots/aggregate_{contrast_name}_{group}_{column}.pdf', format='pdf', transparent=True, backend='cairo')

Calculating positions...


### 4B BRAAK Comparison

In [9]:
# TODO: Put into function
# Parameters
contrast_name = 'c06x'
contrast = get_contrast(contrast_name)
min_edge_entries = 100
num_edges = 5

# Iteration
column = column_diff

# Calculate
sids = sum([sids for _, sids in contrast.items()], [])
all_graphs, sids = load_many_graphs(sids, column=column)
all_graphs = [compute_graph(graph) for graph in all_graphs]
df, concatenated_graph = compute_edge_summary(graphs=all_graphs, subject_ids=sids)

# Process
df = df.drop(columns=['Variance', 'Mean'])
df = pd.melt(df, id_vars=['Edge'], var_name='Subject ID', value_name='Attention')
df.index = df['Subject ID']
df_meta = meta.copy()[['BRAAK_AD']]
df_meta.index = meta['SubID']
df = df.join(df_meta, how='left').reset_index(drop=True)

# Format
df = df.loc[df['Attention'] != 0]  # Remove 0 attention
all_possible_edges, counts = np.unique(df['Edge'], return_counts=True)
all_possible_edges = all_possible_edges[counts > min_edge_entries]
edges_include = np.random.choice(all_possible_edges, num_edges, replace=False)
df = df.loc[[e in edges_include for e in df['Edge']]]

# Plot
fig, axs = get_mosaic([[0, 0]], scale=9)

ax = axs[0]
sns.violinplot(data=df, hue='BRAAK_AD', y='Attention', x='Edge')
# sns.lineplot(data=df, x='BRAAK_AD', y='Attention', hue='Edge')
sns.despine(offset=10, trim=True)
plt.yscale('log')

fig.savefig(f'../plots/BRAAK_{contrast_name}_{column}.pdf', format='pdf', transparent=True, backend='cairo')

Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 250195/250195 [00:45<00:00, 5447.46it/s]


Filtered from 4002 vertices and 62619 edges to 801 vertices and 5752 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 5752/5752 [00:00<00:00, 12830.87it/s]


### 4C Cross-Validation Accuracies

In [10]:
# Compute prioritized edges
sids = sum([sids for _, sids in contrast.items()], [])
df_subgroup = compute_contrast_summary(contrast, column=column)
df = join_df_subgroup(df_subgroup)
prioritized_edges = list(df.index)

Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 187003/187003 [00:28<00:00, 6667.24it/s]


Filtered from 3703 vertices and 50922 edges to 823 vertices and 5699 edges via common edge filtering.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 63192/63192 [00:04<00:00, 13344.91it/s]


Filtered from 2498 vertices and 23940 edges to 473 vertices and 2955 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 5699/5699 [00:00<00:00, 16090.34it/s]


Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 2955/2955 [00:00<00:00, 31075.15it/s]


Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 250195/250195 [00:45<00:00, 5499.14it/s]


Filtered from 4002 vertices and 62619 edges to 801 vertices and 5752 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 5752/5752 [00:00<00:00, 12744.36it/s]


In [11]:
# TODO: Choose more objectives, put into function
# Parameters
contrast_name = 'c06x'
contrast = get_contrast(contrast_name)

# Iteration
column = column_diff

# Calculate
all_graphs, sids = load_many_graphs(sids, column=column)
all_graphs = [compute_graph(graph) for graph in all_graphs]
df, concatenated_graph = compute_edge_summary(graphs=all_graphs, subject_ids=sids)

# Filter
df = df.drop(columns=['Variance', 'Mean'])
df = df.loc[[e in prioritized_edges for e in df['Edge']]]

# Format
X = np.array(df)[:, 1:].T
df_meta = meta.copy()
df_meta.index = df_meta['SubID']
df_meta = df_meta.loc[list(df.columns)[1:]].reset_index(drop=True)
y = np.array(df_meta['BRAAK_AD'])

# Remove nan
is_nan = pd.isna(y)
X = X[~is_nan]
y = y[~is_nan]

# Predict
from sklearn.linear_model import SGDClassifier
from sklearn.neural_network import MLPClassifier
from sklearn import metrics
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, random_state=42)
classifier = SGDClassifier().fit(X_train, y_train)  # MLPClassifier
accuracy = classifier.score(X_test, y_test)

# Evaluate
confusion_matrix = metrics.confusion_matrix(y_test, classifier.predict(X_test))
df = pd.DataFrame(
    confusion_matrix,
    columns=[f'BRAAK {i}' for i in range(7)], 
    index=[f'BRAAK {i}' for i in range(7)])

# Plot
fig, axs = get_mosaic([[0, 0]], scale=9)
ax = axs[0]
sns.heatmap(data=df, ax=ax)
ax.set_xlabel('Predicted')
ax.set_ylabel('True')
fig.savefig(f'../plots/prioritized_{contrast_name}_{column}.pdf', format='pdf', transparent=True, backend='cairo')

Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 250195/250195 [00:45<00:00, 5445.80it/s]


Filtered from 4002 vertices and 62619 edges to 801 vertices and 5752 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 5752/5752 [00:00<00:00, 13269.84it/s]


### Enrichment

In [12]:
# Generate fake enrichment data
import itertools
# Get columns
cell_type = ['OPC', 'Micro', 'Oligo', 'Endo']
disease = ['Alzheimer\'s Disease', 'Schizophrenia', 'Bipolar Disorder', 'A', 'B', 'C', 'D', 'E', 'F', 'G']
combined = [val for val in itertools.product(cell_type, disease)]
cell_type = [val[0] for val in combined]
disease = [val[1] for val in combined]
# Get significance
np.random.seed(42)
significance = np.exp(-8 * np.random.rand(len(combined)))
# Combine
df = pd.DataFrame({'cell_type': cell_type, 'disease': disease, 'significance': significance})
df = df.loc[df['significance'] < 5e-2]

# Rename
df = df.rename(columns={'cell_type': 'Cell Type', 'disease': 'Disease'})
# Add significance scale
df['-log10(p)'] = -np.log10(df['significance'])
# Plot
fig, axs = get_mosaic([[0, 0, 0], [0, 0, 0]], scale=3)
plot_enrichment(df, ax=axs[0])
plt.savefig(f'../plots/enrichment.pdf', format='pdf', transparent=True, backend='cairo')

### Legend

In [13]:
# Plot legend
plt.clf()
plot_legend()
plt.gca().axis('off')
plt.tight_layout()
plt.savefig(f'../plots/legend.pdf', format='pdf', transparent=True, backend='cairo')

  plt.tight_layout()


# Archive

### Coex Individual Trio Comparison

In [14]:
# # Choose three graphs
# graphs = coex_g_individuals[:3]
# graphs_subject_ids = individual_subject_ids[:3]

# # Create figure
# fig, axs = get_mosaic([list(range(len(graphs)+1))], scale=9)

# # Compute edge summaries
# df, concatenated_graph = compute_edge_summary(graphs=graphs, subject_ids=graphs_subject_ids)

# # Show individual graph comparisons
# plot_graph_comparison(graphs, axs=axs, subject_ids=graphs_subject_ids)

# # Show edge summary
# plot_edge_summary(graphs, df=df, ax=axs[len(graphs)], subject_ids=graphs_subject_ids)

# # Save figure
# plt.tight_layout()
# fig.savefig(f'../plots/CoexIndividualTrioComparison.pdf', format='pdf', transparent=True, backend='cairo')

### Individual Trio Comparison

In [15]:
# # Choose three graphs
# graphs = data_g_individuals[:3]
# graphs_subject_ids = individual_subject_ids[:3]

# # Create figure
# plt.clf()
# fig, axs = get_mosaic([list(range(len(graphs)+1))], scale=9)

# # Compute edge summaries
# df, concatenated_graph = compute_edge_summary(graphs=graphs, subject_ids=graphs_subject_ids)

# # Show individual graph comparisons
# plot_graph_comparison(graphs, axs=axs, subject_ids=graphs_subject_ids)

# # Show edge summary
# plot_edge_summary(graphs, df=df, ax=axs[len(graphs)], subject_ids=graphs_subject_ids)

# # Save figure
# plt.tight_layout()
# fig.savefig(f'../plots/IndividualTrioComparison.pdf', format='pdf', transparent=True, backend='cairo')

### Aggregate Trio Comparison

In [16]:
# # Parameters
# contrast = 'c01x'
# column = column_diff

# # Create figure
# fig, axs = get_mosaic([list(range(len(get_contrast(contrast))+1))], scale=9)

# # Compute aggregate edge summaries
# contrast_group = compute_aggregate_edge_summary(get_contrast(contrast), column=column_diff)

# # Plot graph comparison
# plot_graph_comparison(
#     graphs=[v for k, v in contrast_group[0].items()],
#     subject_ids=[k for k, v in contrast_group[1].items()],
#     axs=[axs[i] for i in range(len(get_contrast(contrast)))])

# # Plot edge summary for subgroups
# plot_aggregate_edge_summary(ax=axs[len(get_contrast(contrast))], contrast=contrast_group)

# # Save figure
# plt.tight_layout()
# fig.savefig(f'../plots/AggregateTrioComparison.pdf', format='pdf', transparent=True, backend='cairo')

##### Linkage Analysis

In [17]:
# # Record edge instances
# # df = pd.DataFrame(columns=['Edge', 'Subgroup', 'Count'])
# df = {k: [] for k in ['Edge', 'Subgroup', 'Count']}
# for subgroup in contrast_group[0]:
#     g = contrast_group[0][subgroup]
#     for e in tqdm(g.edges(), total=g.num_edges()):
#         coefs = g.ep.coefs[e]
#         row = [get_edge_string(g, e), subgroup, sum([c!=0 for c in coefs])]
#         # df.loc[df.shape[0]] = row  # Slow
#         for k, v in zip(df, row):
#             df[k].append(v)
# df = pd.DataFrame(df)

# # Get edge counts
# count_table = df.pivot(index='Edge', columns='Subgroup', values='Count')
# count_table = count_table.fillna(0)
# # Max scale for fairness
# for subgroup in contrast_group[0]:
#     count_table[subgroup] /= count_table[subgroup].max()
# # Compute differences
# # TODO: REVISE DIFFERENCE METRIC
# count_table['Difference'] = count_table['AD'] - count_table['Control']
# count_table['Range'] = count_table.max(axis=1) - count_table.min(axis=1)

# # Get list of linkages by significance
# open(f'../plots/AggregateTrioComparisonList.txt', 'w').close()
# for i in np.unique(count_table['Difference'])[::-1]:
#     condition = (count_table['Difference'] == i)
#     significant_edges = list(count_table.loc[condition].index)
#     synthetic_genes = np.concatenate([detect_synthetic_vertices_graph(contrast_group[0][subgroup]) for subgroup in contrast_group[0]])
#     try: significant_genes = np.concatenate([e.split('--') for e in significant_edges])
#     except: significant_genes = []
#     significant_genes = np.unique([g for g in significant_genes if g not in synthetic_genes])

#     # Print significant genes
#     if len(significant_genes) > 0:
#         with open(f'../plots/AggregateTrioComparisonList.txt', 'a') as f:
#             print(f'--- {i} ---', file=f)
#             for g in significant_genes:
#                 print(g, file=f)
#             print(file=f)

### Differentially Expressed Edges

In [18]:
# # TODO: Fix nodes cutting off
# # Plot total and subplots for aggregate differences
# for prefix, individuals in zip(('diff', 'data'), (diff_g_individuals, data_g_individuals)):
#     plt.clf()
#     concat = concatenate_graphs(*individuals)
#     concat = get_intersection(concat)
#     concat = cull_isolated_leaves(concat)
#     concat = remove_text_by_centrality(concat)
#     concat = color_by_significance(concat)
#     visualize_graph(concat)
#     plt.gca().axis('off')
#     plt.tight_layout()
#     plt.savefig(f'../plots/{prefix}_concat.pdf', format='pdf', transparent=True, backend='cairo')

#     # Show all subsets of graph by cell type
#     for v_name in detect_synthetic_vertices_graph(concat):
#         plt.clf()
#         subset = subset_by_hub(concat, [v_name])
#         visualize_graph(subset)
#         plt.gca().axis('off')
#         plt.tight_layout()
#         plt.savefig(f'../plots/{prefix}_concat_{v_name}.pdf', format='pdf', transparent=True, backend='cairo')