In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import graph_tool.all as gt
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from functions import *


# Graph-Tool compatibility
plt.switch_backend('cairo')
# Style
sns.set_theme(context='talk', style='white', palette='Set2')
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

# Computation

In [3]:
# Load metadata
meta = get_meta()

# Subject preview
filtered = []
for i, row in meta.iterrows():
    try:
        load_graph_by_id(row['SubID'])
        assert not np.isnan(row['nps_MoodDysCurValue'])  # Has NPS information available
        assert row['Sex'] == 'Female'
        assert row['Ethnicity'] != 'White'
        # assert row['BRAAK_AD'] in (3, 4, 5)
    except:
        continue
    filtered.append(f'{row["SubID"]} {row["Ethnicity"]} {row["Sex"]}, {row["Age"]}, BRAAK {row["BRAAK_AD"]}')
for i in range(10):
    print(filtered[i])

# Parameters
print(f'\nAvailable attention columns: {get_attention_columns()}')
column_ad = get_attention_columns()[0]
column_data = get_attention_columns()[4]

M48247 Black Female, 95.0, BRAAK 6.0
M41496 Black Female, 76.0, BRAAK 4.0
M19050 Hispanic Female, 74.0, BRAAK 5.0
M61862 Black Female, 79.0, BRAAK 6.0
M59593 Hispanic Female, 76.0, BRAAK 5.0
M83214 Hispanic Female, 83.0, BRAAK 6.0
M36634 Hispanic Female, 87.0, BRAAK 6.0
M46196 Black Female, 80.0, BRAAK 6.0
M63213 Black Female, 78.0, BRAAK 6.0
M51374 Black Female, 72.0, BRAAK 2.0

Available attention columns: ['att_D_AD_0_1', 'att_D_AD_0_3', 'att_D_AD_0_5', 'att_D_AD_0_7', 'att_D_no_prior_0', 'att_D_no_prior_1', 'att_D_no_prior_2', 'att_D_no_prior_3']


# Plots

In [5]:
# Broad parameters
synthetic_nodes_of_interest = ['OPC', 'Micro', 'Oligo', 'EN']

## Individual Comparisons (Figure 3)

In [6]:
individual_comparisons = [
    # M19050 Hispanic Female, 74.0, BRAAK 5.0
    # M59593 Hispanic Female, 76.0, BRAAK 5.0
    # M72079 Black Female, 64.0, BRAAK 6.0
    # M41496 Black Female, 76.0, BRAAK 4.0
    # M11589 Black Female, 63.0, BRAAK 2.0
    # M73342 Black Female, 62.0, BRAAK 0.0
    # (subject_id_1, subject_id_2, column)
    # for subject_id_1, subject_id_2, column in individual_comparisons:
    ('M19050', 'M59593', column_ad),  # AD - AD
    ('M72079', 'M41496', column_ad),  # AD - High BRAAK
    ('M72079', 'M11589', column_ad),  # AD - Low BRAAK
    ('M72079', 'M73342', column_ad),  # AD - CTRL
]

# Verify all are available
for subject_id_1, subject_id_2, column in individual_comparisons:
    for sid in [subject_id_1, subject_id_2]:
        load_graph_by_id(sid)

### 3A Mini Plots

In [7]:
for subject_id_1, subject_id_2, column in individual_comparisons:
    print(' - '.join((subject_id_1, subject_id_2, column)))
    
    # Assemble
    sids = [subject_id_1, subject_id_2]
    gs = [compute_graph(load_graph_by_id(sid, column=column)) for sid in sids]

    # Filter
    gs = [
        filter_graph_by_synthetic_vertices(g.copy(), vertex_ids=synthetic_nodes_of_interest)
        for g in gs
    ]
    # Plot
    fig, axs = get_mosaic([list(range(2))], scale=9)
    plot_graph_comparison(gs, axs=axs, subject_ids=sids)
    fig.savefig(f'../plots/individual_mini_{"-".join(sids)}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
    print()

M19050 - M59593 - att_D_AD_0_1
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████████| 140/140 [00:00<00:00, 218372.09it/s]


Calculating positions...
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 69/69 [00:00<00:00, 198495.87it/s]


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 71/71 [00:00<00:00, 225431.93it/s]



M72079 - M41496 - att_D_AD_0_1
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████████| 147/147 [00:00<00:00, 271016.57it/s]


Calculating positions...
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 75/75 [00:00<00:00, 209715.20it/s]


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 199464.92it/s]



M72079 - M11589 - att_D_AD_0_1
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████████| 147/147 [00:00<00:00, 253416.64it/s]


Calculating positions...
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 75/75 [00:00<00:00, 208050.79it/s]


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 72/72 [00:00<00:00, 174359.06it/s]



M72079 - M73342 - att_D_AD_0_1
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████████| 143/143 [00:00<00:00, 214285.63it/s]


Calculating positions...
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 75/75 [00:00<00:00, 206548.13it/s]


Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████████| 68/68 [00:00<00:00, 147396.73it/s]





### 3B Attention Comparisons

In [None]:
for subject_id_1, subject_id_2, column in individual_comparisons:
    print(' - '.join((subject_id_1, subject_id_2, column)))

    # Assemble
    sample_ids = [subject_id_1, subject_id_2]
    graphs = [compute_graph(load_graph_by_id(sid, column=column)) for sid in sample_ids]

    # Get graph
    g = concatenate_graphs(*graphs, threshold=False)
    g = get_intersection(g)
    g = cull_isolated_leaves(g)

    fig, axs = get_mosaic([list(range(1))], scale=6)
    df = plot_individual_edge_comparison(g, sample_ids, ax=axs[0])
    plt.tight_layout()
    fig.savefig(f'../plots/individual_edge_comparison_{"-".join((subject_id_1, subject_id_2))}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
    print()

M19050 - M59593 - att_D_AD_0_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 1940/1940 [00:00<00:00, 156233.43it/s]
  plt.tight_layout()



M72079 - M41496 - att_D_AD_0_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 1737/1737 [00:00<00:00, 147107.64it/s]
  plt.tight_layout()



M72079 - M11589 - att_D_AD_0_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 2080/2080 [00:00<00:00, 165899.41it/s]
  plt.tight_layout()



M72079 - M73342 - att_D_AD_0_1
Removing duplicate edges...


100%|███████████████████████████████████████████████████████████████████████████| 2239/2239 [00:00<00:00, 174389.46it/s]
  plt.tight_layout()





### 3C Head Comparison

In [8]:
# Parameters
num_edges_per_head = 2

# Compute
all_columns = get_attention_columns()
for subject_id_1, subject_id_2, _ in individual_comparisons:
    print(' - '.join((subject_id_1, subject_id_2, column)))

    # Plot
    fig, axs = get_mosaic([list(range(1))], scale=9)
    plot_head_comparison(subject_id_1, subject_id_2, ax=axs[0])
    plt.tight_layout()
    fig.savefig(f'../plots/individual_head_comparison_{"-".join((subject_id_1, subject_id_2))}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
    print()

M19050 - M59593 - att_D_AD_0_1
879 common edges found

M72079 - M41496 - att_D_AD_0_1
695 common edges found

M72079 - M11589 - att_D_AD_0_1
1038 common edges found

M72079 - M73342 - att_D_AD_0_1
1042 common edges found



### 3D Population Variance Analysis

In [81]:
# Get variance for each head over all edges
all_columns = get_attention_columns()
subject_ids = list(meta['SubID'])
value_name = 'Variance'

df_all = pd.DataFrame(columns=['Edge', 'Head', value_name])
for column in all_columns:
    print(column)

    # Compute variance
    df_subgroup = compute_contrast_summary(
        {'Population': subject_ids},
        column=column,
        population=False)

    # Format df
    df = df_subgroup['Population'][['Edge', value_name]].copy()
    df['Head'] = column
    df_all = pd.concat([df_all, df])

    print()

# Format
df_all_pivot = df_all.pivot(index='Edge', columns='Head', values=value_name)

# Get top edges
idx_to_include = get_top_idx(df_all_pivot.abs(), all_columns, num_edges_per_head=2)

# Plot curves and circle heatmap
fig, axs = get_mosaic([list(range(2))], scale=9)
plot_contrast_curve(df_all, subgroup_name='Head', value_name=value_name, sorting_subgroup='Mean', concatenate=False, ax=axs[0])
axs[0].set_ylabel('Attention Variance')
plot_circle_heatmap(df_all_pivot.iloc[idx_to_include], column_name='Head', index_name='Edge', value_name=value_name, ax=axs[1])
axs[1].set_xlabel(None); axs[1].set_ylabel(None)
plt.tight_layout()
fig.savefig(f'../plots/individual_population_curves.pdf', format='pdf', transparent=True, backend='cairo')

  plt.tight_layout()


### 3E SNP Trend Analysis

In [52]:
# Get top variant edges across population, and values for all heads
# TODO

# For all edges
# TODO
    # Get 100 closest SNPs for TG
    # TODO

    # For all close SNPs
    # TODO
        # For all heads
        # TODO
            # Assess mean separation across SNP variants
            # TODO

            # Add mean separation, edge, SNP, head to LIST
            # TODO
    
# Choose best mean separation results from LIST
# TODO

# Plot SNP variants vs attention
# TODO

In [75]:
# Fake data placeholder
for i in range(3):
    np.random.seed(42+i)

    # Fake dataframe for a single SNP
    n = 200
    variants = ['AATG', 'ACTG', 'GATG', 'GGTG', 'AATA', 'AAGG']
    ch = np.random.randint(23) + 1
    loc = np.random.randint(10000, 1000000)
    loc = f'{ch}:{loc}-{loc+len(variants[0])}'
    df = pd.DataFrame({
        'Subject ID': np.random.choice(range(int(n/10)), n, replace=True),
        'SNP Variant': np.random.choice(variants, n, replace=True),
        'Head': np.random.choice(get_attention_columns(), n, replace=True),
        'Attention Weight': np.random.rand(n) / 10,
    })

    # Plot
    fig, axs = get_mosaic([2*[0]], scale=9)

    sns.boxplot(data=df, x='Head', y='Attention Weight', hue='SNP Variant', ax=axs[0])
    axs[0].set_title(f'SNP: {loc}')
    plt.xticks(rotation=90)
    plt.legend(bbox_to_anchor=(1.02, .7), loc='upper left', borderaxespad=0, frameon=False)

    plt.tight_layout()
    fig.savefig(f'../plots/individual_SNP_analysis_{i}.pdf', format='pdf', transparent=True, backend='cairo')

  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


## Group Comparisons (Figure 4)

In [10]:
# Combinations
contrast_groupings = [
    # (contrast name, contrast group, attention column, target meta column)
    # for contrast_name, contrast_group, column, target in contrast_groupings:
    ('c06x', 'AD', column_ad, 'BRAAK_AD'),  # Eventually SCZ, BP and such
    ('c71x', 'MoodDys', column_data, 'nps_MoodDysCurValue'),  # Dysphoria
    ('c72x', 'DecInt', column_data, 'nps_DecIntCurValue'),  # Anhedonia
]

### 4A Group Visualizations

In [11]:
for contrast_name, contrast_group, column, _ in contrast_groupings:
    print(' - '.join((contrast_name, contrast_group, column)))
    # Get contrast
    contrast = get_contrast(contrast_name)

    # Calculate aggregate graph
    all_graphs, _ = load_many_graphs(contrast[contrast_group], column=column_ad)
    all_graphs = [compute_graph(graph) for graph in all_graphs]
    aggregate_graph = concatenate_graphs(
        *all_graphs,
        threshold=True,
        color_by_source=False,
        # remove_duplicate_edge=False,
    )

    # Filter graph
    aggregate_graph = filter_graph_by_synthetic_vertices(aggregate_graph, vertex_ids=synthetic_nodes_of_interest)

    # Plot
    fig, axs = get_mosaic([list(range(1))], scale=9)

    ax = axs[0]
    visualize_graph_base(aggregate_graph, pos=get_graph_pos(aggregate_graph), mplfig=ax)
    ax.axis('off')

    fig.savefig(f'../plots/group_aggregate_{contrast_name}_{contrast_group}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
    print()

c06x - AD - att_D_AD_0_1
No threshold provided, using threshold of 0.028856692729820266.
Removing duplicate edges...


100%|████████████████████████████████████████████████████████████████████████| 187003/187003 [00:02<00:00, 77295.00it/s]


Filtered from 3703 vertices and 50922 edges to 519 vertices and 3444 edges via common edge filtering.
Calculating positions...

c71x - MoodDys - att_D_no_prior_0
No threshold provided, using threshold of 0.19078570709222195.
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 10968/10968 [00:00<00:00, 100579.66it/s]


Filtered from 872 vertices and 6108 edges to 107 vertices and 896 edges via common edge filtering.
Calculating positions...

c72x - DecInt - att_D_no_prior_0
No threshold provided, using threshold of 0.17215301886965925.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 12862/12862 [00:00<00:00, 97944.67it/s]


Filtered from 982 vertices and 6839 edges to 122 vertices and 987 edges via common edge filtering.
Calculating positions...



### 4B BRAAK Comparison

In [12]:
for contrast_name, _, column, target in contrast_groupings:
    print(' - '.join((contrast_name, column, target)))
    # Get contrast
    contrast = get_contrast(contrast_name)

    # Plot
    fig, axs = get_mosaic([[0, 0]], scale=9)

    ax = axs[0]
    df = plot_BRAAK_comparison(contrast, meta=meta, column=column, target=target)

    fig.savefig(f'../plots/group_differential_expression_{contrast_name}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
    print()

c06x - att_D_AD_0_1 - BRAAK_AD
No threshold provided, using threshold of 0.023704617839825772.
Removing duplicate edges...


100%|████████████████████████████████████████████████████████████████████████| 250195/250195 [00:02<00:00, 99699.86it/s]


Filtered from 4002 vertices and 62619 edges to 556 vertices and 3837 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 3837/3837 [00:01<00:00, 3750.65it/s]



c71x - att_D_no_prior_0 - nps_MoodDysCurValue
No threshold provided, using threshold of 0.06673925242461716.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 52989/52989 [00:00<00:00, 93368.20it/s]


Filtered from 2282 vertices and 20744 edges to 282 vertices and 1765 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 1765/1765 [00:00<00:00, 10334.68it/s]



c72x - att_D_no_prior_0 - nps_DecIntCurValue
No threshold provided, using threshold of 0.06673925242461716.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 52989/52989 [00:00<00:00, 94892.25it/s]


Filtered from 2282 vertices and 20744 edges to 282 vertices and 1765 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 1765/1765 [00:00<00:00, 7944.93it/s]





### 4C Cross-Validation Accuracies

In [13]:
# TODO: Make all y-labels horizontal
for contrast_name, _, column, target in contrast_groupings:
    print(' - '.join((contrast_name, column, target)))
    # Get contrast
    contrast = get_contrast(contrast_name)

    # Compute prioritized edges
    # TODO: Revise this method, maybe also consider means
    sids = sum([sids for _, sids in contrast.items()], [])
    df_subgroup = compute_contrast_summary(contrast, column=column)
    df = join_df_subgroup(df_subgroup, num_sort=100)
    prioritized_edges = list(df.index)

    # Plot
    # TODO: Maybe return to row-normalization
    fig, axs = get_mosaic([[0, 0]], scale=9)
    df, acc = plot_prediction_confusion(contrast, meta=meta, column=column, target=target, prioritized_edges=prioritized_edges, classifier_type='SGD', ax=axs[0])
    fig.savefig(f'../plots/group_prioritized_edge_prediction_{contrast_name}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
    print()

c06x - att_D_AD_0_1 - BRAAK_AD
No threshold provided, using threshold of 0.028856692729820266.
Removing duplicate edges...


100%|████████████████████████████████████████████████████████████████████████| 187003/187003 [00:01<00:00, 96353.09it/s]


Filtered from 3703 vertices and 50922 edges to 519 vertices and 3444 edges via common edge filtering.
No threshold provided, using threshold of 0.058875504282979364.
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 63192/63192 [00:00<00:00, 100404.37it/s]


Filtered from 2498 vertices and 23940 edges to 255 vertices and 1640 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 3444/3444 [00:00<00:00, 4293.14it/s]


Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 1640/1640 [00:00<00:00, 8580.04it/s]


No threshold provided, using threshold of 0.023704617839825772.
Removing duplicate edges...


100%|████████████████████████████████████████████████████████████████████████| 250195/250195 [00:02<00:00, 97648.00it/s]


Filtered from 4002 vertices and 62619 edges to 556 vertices and 3837 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 3837/3837 [00:01<00:00, 3139.73it/s]


No threshold provided, using threshold of 0.023704617839825772.
Removing duplicate edges...


100%|████████████████████████████████████████████████████████████████████████| 250195/250195 [00:02<00:00, 98858.19it/s]


Filtered from 4002 vertices and 62619 edges to 556 vertices and 3837 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 3837/3837 [00:01<00:00, 3487.47it/s]



c71x - att_D_no_prior_0 - nps_MoodDysCurValue
No threshold provided, using threshold of 0.07789234502208577.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 42021/42021 [00:00<00:00, 97184.07it/s]


Filtered from 2059 vertices and 17327 edges to 233 vertices and 1503 edges via common edge filtering.
No threshold provided, using threshold of 0.19078570709222195.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 10968/10968 [00:00<00:00, 91621.26it/s]


Filtered from 872 vertices and 6108 edges to 107 vertices and 896 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 1503/1503 [00:00<00:00, 10048.42it/s]


Collecting edges...


100%|██████████████████████████████████████████████████████████████████████████████| 896/896 [00:00<00:00, 17626.52it/s]


No threshold provided, using threshold of 0.06673925242461716.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 52989/52989 [00:00<00:00, 84108.33it/s]


Filtered from 2282 vertices and 20744 edges to 282 vertices and 1765 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 1765/1765 [00:00<00:00, 7961.97it/s]


No threshold provided, using threshold of 0.06673925242461716.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 52989/52989 [00:00<00:00, 99799.49it/s]


Filtered from 2282 vertices and 20744 edges to 282 vertices and 1765 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 1765/1765 [00:00<00:00, 7285.79it/s]



c72x - att_D_no_prior_0 - nps_DecIntCurValue
No threshold provided, using threshold of 0.08023518943922868.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 40127/40127 [00:00<00:00, 80555.57it/s]


Filtered from 2015 vertices and 16803 edges to 225 vertices and 1458 edges via common edge filtering.
No threshold provided, using threshold of 0.17215301886965925.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 12862/12862 [00:00<00:00, 96744.30it/s]


Filtered from 982 vertices and 6839 edges to 122 vertices and 987 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 1458/1458 [00:00<00:00, 11914.31it/s]


Collecting edges...


100%|██████████████████████████████████████████████████████████████████████████████| 987/987 [00:00<00:00, 13554.56it/s]


No threshold provided, using threshold of 0.06673925242461716.
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 52989/52989 [00:00<00:00, 105508.15it/s]


Filtered from 2282 vertices and 20744 edges to 282 vertices and 1765 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 1765/1765 [00:00<00:00, 9741.54it/s]
  fig = plt.figure(figsize=(scale*len(mosaic[0]), scale*len(mosaic)), constrained_layout=True)


No threshold provided, using threshold of 0.06673925242461716.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 52989/52989 [00:00<00:00, 91269.53it/s]


Filtered from 2282 vertices and 20744 edges to 282 vertices and 1765 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 1765/1765 [00:00<00:00, 7586.58it/s]





### 4D Characteristic Curves

In [16]:
# TODO: Separate into two plots
# Get plots for each column
for contrast_name, _, column, _ in contrast_groupings:
    print(' - '.join((contrast_name, column)))

    # Get contrast
    contrast = get_contrast(contrast_name)

    # Compute
    df_subgroup = compute_contrast_summary(contrast, column=column)

    # Plot
    fig, axs = get_mosaic([list(range(2))], scale=9)
    
    # Plot heatmap, individually-sorted, and pop-sorted
    # plot_subgroup_heatmap(df_subgroup, ax=axs[0])
    # axs[1].get_shared_x_axes().join(axs[1], axs[2])
    # plot_contrast_curve(df_subgroup, ax=axs[1], legend=False)  # Individually sorted
    # plot_contrast_curve(df_subgroup, sorting_subgroup='Population', ax=axs[1])  # Population sorted
    # axs[2].set_ylabel(None)

    # Plot mean-sorted
    plot_subgroup_heatmap(df_subgroup, ax=axs[0])
    plot_contrast_curve(df_subgroup, sorting_subgroup='Mean', ax=axs[1])  # Mean sorted
    plt.tight_layout()
    fig.savefig(f'../plots/group_characteristic_curves_{contrast_name}_{column}.pdf', format='pdf', transparent=True, backend='cairo')
    print()

c06x - att_D_AD_0_1
No threshold provided, using threshold of 0.028856692729820266.
Removing duplicate edges...


100%|████████████████████████████████████████████████████████████████████████| 187003/187003 [00:01<00:00, 96744.01it/s]


Filtered from 3703 vertices and 50922 edges to 519 vertices and 3444 edges via common edge filtering.
No threshold provided, using threshold of 0.058875504282979364.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 63192/63192 [00:00<00:00, 90166.52it/s]


Filtered from 2498 vertices and 23940 edges to 255 vertices and 1640 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 3444/3444 [00:00<00:00, 3915.28it/s]


Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 1640/1640 [00:00<00:00, 7130.03it/s]


No threshold provided, using threshold of 0.023704617839825772.
Removing duplicate edges...


100%|████████████████████████████████████████████████████████████████████████| 250195/250195 [00:02<00:00, 97289.62it/s]


Filtered from 4002 vertices and 62619 edges to 556 vertices and 3837 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 3837/3837 [00:01<00:00, 3366.70it/s]



c71x - att_D_no_prior_0
No threshold provided, using threshold of 0.07789234502208577.
Removing duplicate edges...


100%|█████████████████████████████████████████████████████████████████████████| 42021/42021 [00:00<00:00, 104287.53it/s]


Filtered from 2059 vertices and 17327 edges to 233 vertices and 1503 edges via common edge filtering.
No threshold provided, using threshold of 0.19078570709222195.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 10968/10968 [00:00<00:00, 78228.23it/s]


Filtered from 872 vertices and 6108 edges to 107 vertices and 896 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 1503/1503 [00:00<00:00, 11350.75it/s]


Collecting edges...


100%|██████████████████████████████████████████████████████████████████████████████| 896/896 [00:00<00:00, 17737.16it/s]


No threshold provided, using threshold of 0.06673925242461716.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 52989/52989 [00:00<00:00, 91409.02it/s]


Filtered from 2282 vertices and 20744 edges to 282 vertices and 1765 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 1765/1765 [00:00<00:00, 8756.49it/s]



c72x - att_D_no_prior_0
No threshold provided, using threshold of 0.08023518943922868.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 40127/40127 [00:00<00:00, 97857.45it/s]


Filtered from 2015 vertices and 16803 edges to 225 vertices and 1458 edges via common edge filtering.
No threshold provided, using threshold of 0.17215301886965925.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 12862/12862 [00:00<00:00, 92103.19it/s]


Filtered from 982 vertices and 6839 edges to 122 vertices and 987 edges via common edge filtering.
Collecting edges...


100%|████████████████████████████████████████████████████████████████████████████| 1458/1458 [00:00<00:00, 11949.02it/s]

Collecting edges...



100%|██████████████████████████████████████████████████████████████████████████████| 987/987 [00:00<00:00, 18582.36it/s]


No threshold provided, using threshold of 0.06673925242461716.
Removing duplicate edges...


100%|██████████████████████████████████████████████████████████████████████████| 52989/52989 [00:00<00:00, 92997.87it/s]


Filtered from 2282 vertices and 20744 edges to 282 vertices and 1765 edges via common edge filtering.
Collecting edges...


100%|█████████████████████████████████████████████████████████████████████████████| 1765/1765 [00:00<00:00, 8344.65it/s]





## General

### Enrichment

In [15]:
# Generate fake enrichment data
import itertools
# Get columns
cell_type = synthetic_nodes_of_interest
disease = ['Alzheimer\'s', 'Schizophrenia', 'Bipolar', 'Depression', 'Weight Loss', 'Sleeplessness']
combined = [val for val in itertools.product(cell_type, disease)]
cell_type = [val[0] for val in combined]
disease = [val[1] for val in combined]
# Get significance
np.random.seed(42)
significance = np.exp(-8 * np.random.rand(len(combined)))
# Combine
df = pd.DataFrame({'cell_type': cell_type, 'disease': disease, 'significance': significance})
df = df.loc[df['significance'] < 5e-2]

# Rename
df = df.rename(columns={'cell_type': 'Cell Type', 'disease': 'Disease'})
# Add significance scale
df['-log10(p)'] = -np.log10(df['significance'])
# Plot
fig, axs = get_mosaic([[0, 0, 0], [0, 0, 0]], scale=3)
plot_enrichment(df, ax=axs[0])
plt.savefig(f'../plots/enrichment.pdf', format='pdf', transparent=True, backend='cairo')

### Legend

In [17]:
# Plot legend
plt.clf()
plot_legend()
plt.gca().axis('off')
plt.tight_layout()
plt.savefig(f'../plots/graph_legend.pdf', format='pdf', transparent=True, backend='cairo')

  plt.tight_layout()


# Archive

### Coex Individual Trio Comparison

In [None]:
# # Choose three graphs
# graphs = coex_g_individuals[:3]
# graphs_subject_ids = individual_subject_ids[:3]

# # Create figure
# fig, axs = get_mosaic([list(range(len(graphs)+1))], scale=9)

# # Compute edge summaries
# df, concatenated_graph = compute_edge_summary(graphs=graphs, subject_ids=graphs_subject_ids)

# # Show individual graph comparisons
# plot_graph_comparison(graphs, axs=axs, subject_ids=graphs_subject_ids)

# # Show edge summary
# plot_edge_summary(graphs, df=df, ax=axs[len(graphs)], subject_ids=graphs_subject_ids)

# # Save figure
# plt.tight_layout()
# fig.savefig(f'../plots/CoexIndividualTrioComparison.pdf', format='pdf', transparent=True, backend='cairo')

### Individual Trio Comparison

In [None]:
# # Choose three graphs
# graphs = data_g_individuals[:3]
# graphs_subject_ids = individual_subject_ids[:3]

# # Create figure
# plt.clf()
# fig, axs = get_mosaic([list(range(len(graphs)+1))], scale=9)

# # Compute edge summaries
# df, concatenated_graph = compute_edge_summary(graphs=graphs, subject_ids=graphs_subject_ids)

# # Show individual graph comparisons
# plot_graph_comparison(graphs, axs=axs, subject_ids=graphs_subject_ids)

# # Show edge summary
# plot_edge_summary(graphs, df=df, ax=axs[len(graphs)], subject_ids=graphs_subject_ids)

# # Save figure
# plt.tight_layout()
# fig.savefig(f'../plots/IndividualTrioComparison.pdf', format='pdf', transparent=True, backend='cairo')

### Aggregate Trio Comparison

In [None]:
# # Parameters
# contrast = 'c01x'
# column = column_ad

# # Create figure
# fig, axs = get_mosaic([list(range(len(get_contrast(contrast))+1))], scale=9)

# # Compute aggregate edge summaries
# contrast_group = compute_aggregate_edge_summary(get_contrast(contrast), column=column_ad)

# # Plot graph comparison
# plot_graph_comparison(
#     graphs=[v for k, v in contrast_group[0].items()],
#     subject_ids=[k for k, v in contrast_group[1].items()],
#     axs=[axs[i] for i in range(len(get_contrast(contrast)))])

# # Plot edge summary for subgroups
# plot_aggregate_edge_summary(ax=axs[len(get_contrast(contrast))], contrast=contrast_group)

# # Save figure
# plt.tight_layout()
# fig.savefig(f'../plots/AggregateTrioComparison.pdf', format='pdf', transparent=True, backend='cairo')

##### Linkage Analysis

In [None]:
# # Record edge instances
# # df = pd.DataFrame(columns=['Edge', 'Subgroup', 'Count'])
# df = {k: [] for k in ['Edge', 'Subgroup', 'Count']}
# for subgroup in contrast_group[0]:
#     g = contrast_group[0][subgroup]
#     for e in tqdm(g.edges(), total=g.num_edges()):
#         coefs = g.ep.coefs[e]
#         row = [get_edge_string(g, e), subgroup, sum([c!=0 for c in coefs])]
#         # df.loc[df.shape[0]] = row  # Slow
#         for k, v in zip(df, row):
#             df[k].append(v)
# df = pd.DataFrame(df)

# # Get edge counts
# count_table = df.pivot(index='Edge', columns='Subgroup', values='Count')
# count_table = count_table.fillna(0)
# # Max scale for fairness
# for subgroup in contrast_group[0]:
#     count_table[subgroup] /= count_table[subgroup].max()
# # Compute differences
# # TODO: REVISE DIFFERENCE METRIC
# count_table['Difference'] = count_table['AD'] - count_table['Control']
# count_table['Range'] = count_table.max(axis=1) - count_table.min(axis=1)

# # Get list of linkages by significance
# open(f'../plots/AggregateTrioComparisonList.txt', 'w').close()
# for i in np.unique(count_table['Difference'])[::-1]:
#     condition = (count_table['Difference'] == i)
#     significant_edges = list(count_table.loc[condition].index)
#     synthetic_genes = np.concatenate([detect_synthetic_vertices_graph(contrast_group[0][subgroup]) for subgroup in contrast_group[0]])
#     try: significant_genes = np.concatenate([e.split('--') for e in significant_edges])
#     except: significant_genes = []
#     significant_genes = np.unique([g for g in significant_genes if g not in synthetic_genes])

#     # Print significant genes
#     if len(significant_genes) > 0:
#         with open(f'../plots/AggregateTrioComparisonList.txt', 'a') as f:
#             print(f'--- {i} ---', file=f)
#             for g in significant_genes:
#                 print(g, file=f)
#             print(file=f)

### Differentially Expressed Edges

In [None]:
# # TODO: Fix nodes cutting off
# # Plot total and subplots for aggregate differences
# for prefix, individuals in zip(('diff', 'data'), (diff_g_individuals, data_g_individuals)):
#     plt.clf()
#     concat = concatenate_graphs(*individuals)
#     concat = get_intersection(concat)
#     concat = cull_isolated_leaves(concat)
#     concat = remove_text_by_centrality(concat)
#     concat = color_by_significance(concat)
#     visualize_graph(concat)
#     plt.gca().axis('off')
#     plt.tight_layout()
#     plt.savefig(f'../plots/{prefix}_concat.pdf', format='pdf', transparent=True, backend='cairo')

#     # Show all subsets of graph by cell type
#     for v_name in detect_synthetic_vertices_graph(concat):
#         plt.clf()
#         subset = subset_by_hub(concat, [v_name])
#         visualize_graph(subset)
#         plt.gca().axis('off')
#         plt.tight_layout()
#         plt.savefig(f'../plots/{prefix}_concat_{v_name}.pdf', format='pdf', transparent=True, backend='cairo')