In [None]:
import numpy as np
import networkx as nx
from dotmotif import Motif, GrandIsoExecutor
import matplotlib.pyplot as plt
from scipy.stats import ranksums
import statsmodels.stats.multitest as smm
from lsmm_data import LSMMData
import json
from tqdm import tqdm

with open('pyr_chains.json') as f:
    loaded_json = json.load(f)
my_data = LSMMData.LSMMData(loaded_json)
tables = my_data.data
params = my_data.params
dirs = my_data.dirs
mappings = my_data.mappings
 

## Or uncomment here to load the chain results
two_chain_results_array = np.load('pyr_cell_two_chain_results_array.npy')
three_chain_results_array = np.load('pyr_cell_three_chain_results_array.npy')
four_chain_results_array = np.load('pyr_cell_four_chain_results_array.npy')
chain_results_arrays = [two_chain_results_array, three_chain_results_array, four_chain_results_array]
chain_count_string_array = ['pyr_cell_2chain', 'pyr_cell_3chain', 'pyr_cell_4chain']
individual_assembly_indexes = [mappings['connectome_indexes_by_assembly'][f'A {i}'] for i in range(1,14)]
scaling_factors_two = [len(individual_assembly_indexes[i])*(len(individual_assembly_indexes[i])-1) for i in range(13)]
scaling_factors_three = [len(individual_assembly_indexes[i])*(len(individual_assembly_indexes[i])-1)*(len(tables['structural']['pre_cell'])-2) for i in range(13)]
scaling_factors_four = [len(individual_assembly_indexes[i])*(len(individual_assembly_indexes[i])-1)*(len(tables['structural']['pre_cell'])-2)*(len(tables['structural']['pre_cell'])-3) for i in range(13)]
scaling_factors_lists = [scaling_factors_two, scaling_factors_three, scaling_factors_four]

coregistered_cell_indexes = mappings['assemblies_by_connectome_index'].keys()
no_a_cell_indexes = mappings['connectome_indexes_by_assembly']['No A']




In [None]:

# ## Uncomment here to generate the chain results
# # Make a graph of just excitatory cells
# cell_table = tables['structural']['pre_cell']
# cell_table['connectome_index'] = cell_table.index
# synapse_table = tables['structural']['synapse']
# adjacency_matrix = tables['structural']['binary_connectome']
# pyr_graph = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph)

# # Motif Analysis with DotMotif
# executor = GrandIsoExecutor(graph=pyr_graph)

# chain_defs = Motif("""
#                 A -> B
#               """)

# chain_results = executor.find(chain_defs)

# chain_results_array = np.array([list(c.values()) for c in tqdm(chain_results)])
# np.save('pyr_cell_two_chain_results_array.npy', chain_results_array)
# np.save('pyr_cell_two_chain_chain_results.npy', chain_results)

In [None]:
# Intra-Assembly to Intra-Assembly chains

import scipy.stats as stats

print(scaling_factors_lists)

multi_length_chain_participation_by_coregistered_cell_pair = {}

pooled_assembly_indexes = list(set(coregistered_cell_indexes) - set(no_a_cell_indexes))
    

for i in range(3):
    chain_results_array = chain_results_arrays[i]
    title = chain_count_string_array[i]
    scaling_factors = scaling_factors_lists[i]

    # Get chain participation
    chain_participation_by_coregistered_cell_pair = {}
    for index1 in tqdm(coregistered_cell_indexes):
        if index1 not in chain_participation_by_coregistered_cell_pair.keys():
                chain_participation_by_coregistered_cell_pair[index1] = {}
        if index1 not in multi_length_chain_participation_by_coregistered_cell_pair.keys():
                multi_length_chain_participation_by_coregistered_cell_pair[index1] = {}
        for index2 in coregistered_cell_indexes:
            if index2 not in chain_participation_by_coregistered_cell_pair[index1].keys():
                chain_participation_by_coregistered_cell_pair[index1][index2] = 0
            if index2 not in multi_length_chain_participation_by_coregistered_cell_pair[index1].keys():
                multi_length_chain_participation_by_coregistered_cell_pair[index1][index2] = 0
            temp = np.where(np.logical_and(chain_results_array[:,0] == index1, chain_results_array[:,-1] == index2))[0].size
            chain_participation_by_coregistered_cell_pair[index1][index2] += temp
            multi_length_chain_participation_by_coregistered_cell_pair[index1][index2] += temp

    # Pool cells which are not in assemblies
    no_a_cell_participation = []
    for index1 in tqdm(no_a_cell_indexes):
        for index2 in no_a_cell_indexes:
            no_a_cell_participation.append(chain_participation_by_coregistered_cell_pair[index1][index2])

    # Examine individual intra-assembly participation for an individual chain length
    print(f'Wilcoxon Rank Sum Test: \n\t{title} from cells within each assembly to other cells within that assembly\n\tvs {title} from cells not in assemblies to other cells not in assemblies')
    per_assembly_cell_participation = []
    for a in range(len(individual_assembly_indexes)):
        per_assembly_cell_participation.append([])
        scaling_factor = scaling_factors[a]
        for index in individual_assembly_indexes[a]:
            for index2 in individual_assembly_indexes[a]:
                if scaling_factor <= 0:
                    continue
                else: 
                    per_assembly_cell_participation[a].append(chain_participation_by_coregistered_cell_pair[index1][index2] / scaling_factor)
        plt.figure()
        plt.box
        plt.boxplot([per_assembly_cell_participation[a], no_a_cell_participation])
        plt.savefig(f'chains/{title}_intra_assembly_{a+1}.png')
        # Perform Wilcoxon Rank Sum Test
        print(f'A {a+1} vs No A')
        print(stats.ranksums(per_assembly_cell_participation[a], no_a_cell_participation))


    # Pool cells which are in assemblies
    pooled_assembly_cell_participation = []
    for index1 in tqdm(pooled_assembly_indexes):
        for index2 in pooled_assembly_indexes:
            pooled_assembly_cell_participation.append(chain_participation_by_coregistered_cell_pair[index1][index2])

    # Perform Wilcoxon Rank Sum Test
    print(f'Wilcoxon Rank Sum Test: \n\t{title} from cells within each assembly to other cells within the same assembly\n\tvs {title} from cells not in assemblies to other cells not in assemblies')
    print('A vs No A')
    print(stats.ranksums(pooled_assembly_cell_participation, no_a_cell_participation))

    plt.figure()
    plt.box
    plt.boxplot([pooled_assembly_cell_participation, no_a_cell_participation])
    plt.savefig(f'chains/{title}_intra_pooled.png')
    plt.close('all')

pooled_assembly_indexes = list(set(coregistered_cell_indexes) - set(no_a_cell_indexes))
individual_assembly_indexes = [mappings['connectome_indexes_by_assembly'][f'A {i}'] for i in range(1,14)]

# Pool cells which are in assemblies
pooled_assembly_cell_participation = []
for index1 in tqdm(pooled_assembly_indexes):
    for index2 in pooled_assembly_indexes:
        pooled_assembly_cell_participation.append(chain_participation_by_coregistered_cell_pair[index1][index2])

# Pool cells which are not in assemblies
no_a_cell_participation = []
for index1 in tqdm(no_a_cell_indexes):
    for index2 in no_a_cell_indexes:
        no_a_cell_participation.append(chain_participation_by_coregistered_cell_pair[index1][index2])

# Pool individual assembly participation
print(f'Wilcoxon Rank Sum Test: \n\t{title} from cells within each assembly to other cells within that assembly\n\tvs {title} from cells not in assemblies to other cells not in assemblies')
per_assembly_cell_participation = []
for a in range(len(individual_assembly_indexes)):
    per_assembly_cell_participation.append([])
    for index in individual_assembly_indexes[a]:
        for index2 in individual_assembly_indexes[a]:
            per_assembly_cell_participation[a].append(chain_participation_by_coregistered_cell_pair[index1][index2])
    plt.figure()
    plt.box
    plt.boxplot([per_assembly_cell_participation[a], no_a_cell_participation])
    plt.savefig(f'chains/{title}_intra_assembly_{a+1}.png')
    # Perform Wilcoxon Rank Sum Test
    print(f'A {a+1} vs No A')
    print(stats.ranksums(per_assembly_cell_participation[a], no_a_cell_participation))
    plt.close('all')

# Pool all chain lengths

# Pool cells which are in assemblies
allchain_pooled_assembly_cell_participation = []
for index1 in tqdm(pooled_assembly_indexes):
    for index2 in pooled_assembly_indexes:
        allchain_pooled_assembly_cell_participation.append(multi_length_chain_participation_by_coregistered_cell_pair[index1][index2])

# Pool cells which are not in assemblies
allchain_no_a_cell_participation = []
for index1 in tqdm(no_a_cell_indexes):
    for index2 in no_a_cell_indexes:
        allchain_no_a_cell_participation.append(multi_length_chain_participation_by_coregistered_cell_pair[index1][index2])

# Perform Wilcoxon Rank Sum Test
print(f'Wilcoxon Rank Sum Test: \n\tAll pyr chain lengths, intra-assembly, normalized by possible chains\n\tvs All pyr chain lengths, intra-nonassembly, normalized by possible chains')
print('A vs No A')
print(stats.ranksums(allchain_pooled_assembly_cell_participation, allchain_no_a_cell_participation))

plt.figure()
plt.box
plt.boxplot([allchain_pooled_assembly_cell_participation, allchain_no_a_cell_participation])
plt.savefig(f'chains/pyr_all_chain_lengths_intra_pooled.png')
# print('A', pooled_assembly_cell_participation)
# print('No A', no_a_cell_participation)

In [None]:

# Get chain participation
chain_participation_by_coregistered_cell = {}
for index in tqdm(coregistered_cell_indexes):
    # this only works because the index cannot appear in the same chain more than once
    chain_participation_by_coregistered_cell[index] = np.where(chain_results_array == index)[0].size

pooled_assembly_indexes = list(set(coregistered_cell_indexes) - set(no_a_cell_indexes))
individual_assembly_indexes = [mappings['connectome_indexes_by_assembly'][f'A {i}'] for i in range(1,14)]

# Pool cells which are in assemblies
pooled_assembly_cell_participation = []
for index in tqdm(pooled_assembly_indexes):
    pooled_assembly_cell_participation.append(chain_participation_by_coregistered_cell[index])

# Pool cells which are not in assemblies
no_a_cell_participation = []
for index in tqdm(no_a_cell_indexes):
    no_a_cell_participation.append(chain_participation_by_coregistered_cell[index])

# Pool individual assembly participation
per_assembly_cell_participation = []
for a in range(len(individual_assembly_indexes)):
    per_assembly_cell_participation.append([])
    for index in tqdm(individual_assembly_indexes[a]):
        per_assembly_cell_participation[a].append(chain_participation_by_coregistered_cell[index])
    plt.figure()
    plt.box
    plt.boxplot([per_assembly_cell_participation[a], no_a_cell_participation])
    plt.savefig(f'pyr_per_assembly_chain_participation_assembly_{a+1}.png')

# estimate sample size via power analysis
from statsmodels.stats.power import tt_ind_solve_power

# parameters for power analysis
nobs1_array = [len(individual_assembly_indexes[i]) for i in range(len(individual_assembly_indexes))]
mean_no_a = np.mean(no_a_cell_participation)
mean_diff_by_difference = [np.mean(per_assembly_cell_participation[a]) - mean_no_a for a in range(len(per_assembly_cell_participation))]
alpha = 0.05
# perform power analysis
for i, n in enumerate(nobs1_array):
    effect = mean_diff_by_difference[i] / np.std(per_assembly_cell_participation[i])
    r =  np.sum(nobs1_array) / n
    result = tt_ind_solve_power(effect_size = effect, nobs1 = None, alpha = alpha, power = 0.70,  ratio = r)
    print(f'Sample Size for Assembly {i + 1}: %.3f' % result)

plt.figure()
plt.box
plt.boxplot([pooled_assembly_cell_participation, no_a_cell_participation])
plt.savefig('pyr_cell_4chain_participation.png')

print('A', pooled_assembly_cell_participation)
print('No A', no_a_cell_participation)

In [None]:
# # First Order
# # tables['structural']['synapse']

# tables['structural']['synapse']['pre_connectome_index'] = [mappings['pt_root_id_to_connectome_index'][tables['structural']['synapse'].iloc[i]['pre_pt_root_id']] for i in range(len(tables['structural']['synapse']))]
# tables['structural']['synapse']['post_connectome_index'] = [mappings['pt_root_id_to_connectome_index'][tables['structural']['synapse'].iloc[i]['post_pt_root_id']] for i in range(len(tables['structural']['synapse']))]
# # sizes_by_pair = tables['structural']['synapse'][['pre_connectome_index', 'post_connectome_index', 'size']].groupby(['pre_connectome_index', 'post_connectome_index'])



# intra_assembly_sizes = {}
# intra_assembly_connections = {}
# intra_assembly_syn_count = {}
# assembly_indexes = mappings['connectome_indexes_by_assembly']
# for a in range(1, 16):
#     intra_assembly_sizes[f'A {a}'] = []
#     intra_assembly_syn_count[f'A {a}'] = []
#     intra_assembly_connections[f'A {a}'] = []
# intra_assembly_sizes['No A'] = []
# intra_assembly_syn_count['No A'] = []
# intra_assembly_connections['No A'] = []

# # for a in range(len(individual_assembly_indexes)):

# for assembly in intra_assembly_sizes.keys():
#     coregistered_assembly_indexes = assembly_indexes[assembly]
#     no_a_indexes = assembly_indexes['No A']
#     print(assembly, coregistered_assembly_indexes)
#     for pre_index in coregistered_assembly_indexes:
#         for post_index in no_a_indexes:
#             syns = tables['structural']['synapse'].loc[tables['structural']['synapse'].pre_connectome_index == pre_index].loc[tables['structural']['synapse'].post_connectome_index == post_index]['size']
#             summed_psd_volumes = np.sum(syns)
#             if summed_psd_volumes > 0:
#             # print(summed_psd_volumes)
#                 for syn_item in syns:
#                     intra_assembly_sizes[assembly].append(syn_item)
#                 # intra_assembly_sizes[assembly] = intra_assembly_sizes[assembly] + syns
#                 intra_assembly_syn_count[assembly].append(len(syns))
            
#             # if summed_psd_volumes > 0:
#                 intra_assembly_connections[assembly].append(1)
#             else:
#                 intra_assembly_connections[assembly].append(0)
        
# pooled_data = [intra_assembly_sizes[assembly] for assembly in intra_assembly_sizes.keys() if assembly != 'No A']
# pooled_data = [item for sublist in pooled_data for item in sublist]
# intra_assembly_sizes['Pooled'] = pooled_data
# print(intra_assembly_sizes)

In [None]:
# #####
# ## BREADCRUMB: Is there a way to tell if that one 6k outlier in the No A category is a misclassified cell, or mis-co-registered?  Otherwise invalid?
# ## BREADCRUMB 2: Bayesian hypothesis testing
# ## BREADCRUMB 3: Multiple regression for assembly membership scores?  Weighting by assembly membership scores?  Or the decoding classifier Julian made applied to decoding assembly status from connectivity data?
# #######

# # print(assembly_sizes.keys())
# # print(len(chain_indexes))
# plt.figure()
# data_arrays = []
# ticks = []
# for assembly in intra_assembly_sizes.keys():
#     ticks.append(assembly)
#     # print(intra_assembly_syn_count[assembly])
#     # print(, len(assembly_sizes[f'A {a}']), np.max(assembly_sizes[f'A {a}']))
#     data_arrays.append(intra_assembly_sizes[assembly])
# # data_arrays.append(intra_assembly_sizes['No A'])

# # pooled_data = [intra_assembly_syn_count[assembly] for assembly in intra_assembly_syn_count.keys():]
# # print(data_arrays1.shape)
# # data_arrays1 = np.delete(data_arrays1, 16, axis=1)
# # data_arrays1 = np.sum(data_arrays1, axis=1)
# # data_arrays = np.hstack([data_arrays, data_arrays1])
# data_arrays.append(pooled_data)
# plt.boxplot(data_arrays)
# ticks.append('Pooled')
# plt.xticks(range(1, len(intra_assembly_sizes.keys())+2), ticks)
# # plt.savefig(f'first_order_intra_assembly_syn_count.png')

# for assembly in intra_assembly_sizes.keys():
#     plt.figure()
#     plt.title(assembly)
#     vals, bins, _ = plt.hist(intra_assembly_sizes[assembly], density=False)
#     print(assembly, vals)
#     plt.xlabel('Synapse Count')
#     plt.ylabel('Frequency')
#     # plt.savefig(f'first_order_intra_assembly_{assembly}_syn_count_histogram.png')


In [None]:
# # Sizes between Assembly Cells

# # print(assembly_sizes.keys())
# # print(len(chain_indexes))
# plt.figure()
# data_arrays = []
# ticks = []
# for assembly in intra_assembly_sizes.keys():
#     ticks.append(assembly)
#     # print(intra_assembly_syn_count[assembly])
#     # print(, len(assembly_sizes[f'A {a}']), np.max(assembly_sizes[f'A {a}']))
#     data_arrays.append(intra_assembly_sizes[assembly])
# data_arrays.append(intra_assembly_sizes['No A'])
# # print(data_arrays)
# plt.boxplot(data_arrays)
# # ticks.append('No A')
# plt.xticks(range(len(intra_assembly_sizes.keys())), intra_assembly_sizes.keys())
# plt.savefig(f'first_order_intra_assembly_syn_count.png')

# for assembly in intra_assembly_sizes.keys():
#     plt.figure()
#     plt.title(assembly)
#     vals, bins, _ = plt.hist(intra_assembly_sizes[assembly], density=False)
#     # print(assembly, vals)
#     plt.xlabel('Summed PSD Size')
#     plt.ylabel('Frequency')
#     plt.savefig(f'first_order_intra_assembly_{assembly}_syn_count_histogram.png')

# import scipy.stats
# w_test = scipy.stats.ranksums(intra_assembly_sizes['No A'], intra_assembly_sizes['Pooled'], alternative='greater')
# print(w_test)


In [None]:
# # print(assembly_sizes.keys())
# # print(len(chain_indexes))
# plt.figure()
# data_arrays = []
# ticks = []
# for assembly in intra_assembly_syn_count.keys():
#     ticks.append(assembly)
#     # print(intra_assembly_syn_count[assembly])
#     # print(, len(assembly_sizes[f'A {a}']), np.max(assembly_sizes[f'A {a}']))
#     data_arrays.append(intra_assembly_syn_count[assembly])
# data_arrays.append(intra_assembly_syn_count['No A'])
# # print(data_arrays)
# plt.boxplot(data_arrays)
# # ticks.append('No A')
# plt.xticks(range(len(intra_assembly_syn_count.keys())), intra_assembly_syn_count.keys())
# plt.savefig(f'first_order_intra_assembly_syn_count.png')

# for assembly in intra_assembly_syn_count.keys():
#     plt.figure()
#     plt.title(assembly)
#     vals, bins, _ = plt.hist(intra_assembly_syn_count[assembly], density=False)
#     print(assembly, vals)
#     plt.xlabel('Synapse Count')
#     plt.ylabel('Frequency')
#     plt.savefig(f'first_order_intra_assembly_{assembly}_syn_count_histogram.png')


In [None]:
print(data_arrays[0])
print(data_arrays[-1])

In [None]:
# 3-Chain
# tables['structural']['synapse']

tables['structural']['synapse']['pre_connectome_index'] = [mappings['pt_root_id_to_connectome_index'][tables['structural']['synapse'].iloc[i]['pre_pt_root_id']] for i in range(len(tables['structural']['synapse']))]
tables['structural']['synapse']['post_connectome_index'] = [mappings['pt_root_id_to_connectome_index'][tables['structural']['synapse'].iloc[i]['post_pt_root_id']] for i in range(len(tables['structural']['synapse']))]
# sizes_by_pair = tables['structural']['synapse'][['pre_connectome_index', 'post_connectome_index', 'size']].groupby(['pre_connectome_index', 'post_connectome_index'])



assembly_sizes = {}
for a in range(1, 16):
    assembly_sizes[f'A {a}'] = []
assembly_sizes['No A'] = []

# for a in range(len(individual_assembly_indexes)):

for index in tqdm(coregistered_cell_indexes):
    chain_indexes = np.where(chain_results_array == index)[0]
    for row in chain_results_array[chain_indexes, :]:
        # try:
        size1 = np.sum(tables['structural']['synapse'].loc[tables['structural']['synapse'].pre_connectome_index == row[0]].loc[tables['structural']['synapse'].post_connectome_index == row[1]]['size'])
        size2 = np.sum(tables['structural']['synapse'].loc[tables['structural']['synapse'].pre_connectome_index == row[1]].loc[tables['structural']['synapse'].post_connectome_index == row[2]]['size'])
            # pt_root_id1 = mappings['connectome_index_to_root_id'][row[0]]
            # pt_root_id2 = mappings['connectome_index_to_root_id'][row[1]]
            # size1 = np.sum(tables['structural']['synapse'].query('pre_pt_root_id == @pt_root_id1 and post_pt_root_id == @pt_root_id2')['size'].values)
        # except(KeyError):
            # mappings['connectome_index_to_root_id'][index]
            # size1 = 0.0
            
        # try:
        #     pt_root_id2 = mappings['connectome_index_to_root_id'][row[1]]
        #     pt_root_id3 = mappings['connectome_index_to_root_id'][row[2]]
        #     size2 = np.sum(tables['structural']['synapse'].query('pre_pt_root_id == @pt_root_id2 and post_pt_root_id == @pt_root_id3')['size'].values)
        # except(KeyError):
        #     size2 = 0.0
            
            # size2 = np.sum(tables['structural']['synapse'].query('pre_pt_root_id == @pt_root_id2 and post_pt_root_id == @pt_root_id3')['size'].values)
        summed_chain_size = size1 + size2
        assemblies = mappings['assemblies_by_connectome_index'][index]
        for assembly in assemblies:
            # print(assembly, size1 + size2)
            assembly_sizes[assembly].append(summed_chain_size)
        
        
        # Map connectome indexes to pt_root_ids
        
        # size2 = np.sum(tables['structural']['synapse'].query('pre_pt_root_id == @pt_root_id2 and post_pt_root_id == @pt_root_id3').iloc[0].values)
        



    # pt_root_ids = [mappings['pt_root_id_to_connectome_index'].get(index) for index in row['connectome_indexes']]
    
    # Retrieve the size from tables['structural']['synapse'] using the mapped pt_root_ids
    # sizes = [tables['structural']['synapse'].get(pt_root_id, {}).get('size') for pt_root_id in pt_root_ids]
    
    # # Iterate through each assembly in mappings['assemblies by connectome_index']
    # for connectome_index, assemblies in mappings['assemblies by connectome_index'].items():
    #     # Store the resulting size under those assemblies
    #     for assembly in assemblies:
    #         if 'sizes' not in assembly:
    #             assembly['sizes'] = []
    #         assembly['sizes'].extend(sizes)



# for n, p in enumerate(sizes_by_pair):
#     print(p)



# for c in range(chain_results_array.shape[0]):
#     chain(cells, synapses) = chain_results_array[c]
    

In [None]:
## BREADRCUMB:  Check chains between assembly activity frames.  Use code from old check with current framework.

In [None]:
# for a in range(1, 16):
#     plt.figure()
#     plt.title(f'A {a}')
#     plt.hist(assembly_sizes[f'A {a}'], bins=100)
#     plt.xlabel('Summed Synapse Sizes')
#     plt.ylabel('Frequency')
#     plt.savefig(f'histogram_of_assembly_{a}_summed_chain_size.png')

# plt.figure()
# plt.title('No A')
# plt.hist(assembly_sizes['No A'], bins=100)
# plt.xlabel('Summed Synapse Sizes')
# plt.ylabel('Frequency')
# plt.savefig('histogram_of_no_a_summed_chain_size.png')

In [None]:
# # print(assembly_sizes.keys())
# print(len(chain_indexes))
# plt.figure()
# data_arrays = []
# ticks = []
# for a in range(1, 16):
#     ticks.append(f'A {a}')
#     print(a, len(assembly_sizes[f'A {a}']), np.max(assembly_sizes[f'A {a}']))
#     data_arrays.append(assembly_sizes[f'A {a}'])
# data_arrays.append(assembly_sizes['No A'])
# # print(data_arrays)
# plt.boxplot(data_arrays)
# ticks.append('No A')
# plt.xticks(range(1, 17), ticks)
# plt.savefig(f'three_chain_psd_sizes2.png')

In [None]:
# print(np.array(data_arrays))