In [1]:
import os
import random
from time import time
import pandas as pd
import numpy as np
import networkx as nx
import itertools

# only allow 100 rows to display pandas
pd.set_option('display.max_rows', 100)

In [2]:
syn_table = pd.read_csv('/Users/rweberla/Downloads/connections_no_threshold.csv')
syn_table

Unnamed: 0,pre_root_id,post_root_id,neuropil,syn_count,nt_type
0,720575940629970489,720575940631267655,AVLP_R,7,GABA
1,720575940623828999,720575940612348950,SLP_R,4,GLUT
2,720575940624078484,720575940616950161,SMP_R,2,ACH
3,720575940629583345,720575940620324735,SMP_L,2,GLUT
4,720575940605876866,720575940606514878,LAL_R,15,GABA
...,...,...,...,...,...
16847992,720575940615769750,720575940622822680,AVLP_L,1,ACH
16847993,720575940624016823,720575940622760993,ME_R,4,ACH
16847994,720575940637397309,720575940633255123,SLP_R,2,GABA
16847995,720575940636432014,720575940617470681,ME_R,1,GLUT


In [3]:
# Find all unique cell ids in both the pre and post columns
cellids =  np.unique(syn_table[["pre_root_id", "post_root_id"]])
print(len(cellids))

138639


In [4]:
# Create a dictionary that maps cell ids to index id values
nid2cid = {i: cid for i, cid in enumerate(cellids)} 

In [5]:
# Create a dictionary that maps index id values to cell ids, may not be needed
cid2nid = {cid: i for i, cid in enumerate(cellids)}

In [6]:
# Add the index id values to the syn_table for pre and post columns
syn_table["pre_nid"] = pd.Series([cid2nid[cid] for cid in syn_table["pre_root_id"]], 
		index=syn_table.index)
syn_table["post_nid"] = pd.Series([cid2nid[cid] for cid in syn_table["post_root_id"]], 
		index=syn_table.index)

In [7]:
syn_table

Unnamed: 0,pre_root_id,post_root_id,neuropil,syn_count,nt_type,pre_nid,post_nid
0,720575940629970489,720575940631267655,AVLP_R,7,GABA,96274,103137
1,720575940623828999,720575940612348950,SLP_R,4,GLUT,62790,13822
2,720575940624078484,720575940616950161,SMP_R,2,ACH,64248,29086
3,720575940629583345,720575940620324735,SMP_L,2,GLUT,94318,42941
4,720575940605876866,720575940606514878,LAL_R,15,GABA,2405,3347
...,...,...,...,...,...,...,...
16847992,720575940615769750,720575940622822680,AVLP_L,1,ACH,24961,57182
16847993,720575940624016823,720575940622760993,ME_R,4,ACH,63889,56848
16847994,720575940637397309,720575940633255123,SLP_R,2,GABA,122411,111854
16847995,720575940636432014,720575940617470681,ME_R,1,GLUT,120280,30981


In [33]:
def collect_triplets(V, E):
    pre_col = E.columns[0]
    post_col = E.columns[1]
    
    # Create a dictionary for subgraphs
    tri = {}
    for a, b in E.values:
        for c in V:
            # Create pairs to check
            pair1 = (c, a)
            pair2 = (a, c)
            pair3 = (c, b)
            pair4 = (b, c)
            # Check if each pair exists in E
            if c != a and c != b and (E[[pre_col, post_col]].isin(pair1).all(axis=1).any() or
                    E[[pre_col, post_col]].isin(pair2).all(axis=1).any() or
                    E[[pre_col, post_col]].isin(pair3).all(axis=1).any() or
                    E[[pre_col, post_col]].isin(pair4).all(axis=1).any()):
                
                # sort the triad and create set
                t = tuple(sorted([a, b, c]))
                print(t)
                if t not in tri:
                    print('not in tri')
                    tri[t] = set([(a,b)])
                    if (a, c) in E:
                        
                        tri[t] = tri[t].union(set([(a,c)]))
                    if (c, a) in E:
                        tri[t] = tri[t].union(set([(c,a)]))
                    if (b, c) in E:
                        tri[t] = tri[t].union(set([(b,c)]))
                    if (c, b) in E:
                        tri[t] = tri[t].union(set([(c,b)]))
                    if (b, a) in E:
                        tri[t] = tri[t].union(set([(b,a)]))
                else:
                    print('in tri')
                    if (a, c) in E:
                        tri[t] = tri[t].union(set([(a,c)]))
                    if (c, a) in E:
                        tri[t] = tri[t].union(set([(c,a)]))
                    if (b, c) in E:
                        tri[t] = tri[t].union(set([(b,c)]))
                    if (c, b) in E:
                        tri[t] = tri[t].union(set([(c,b)]))
                    if (b, a) in E:
                        tri[t] = tri[t].union(set([(b,a)]))
                
    return tri

#t = [0,1,2]
#motifs = { 1: Triplet(t, edges=set([(1,0),(2,0),(1,2),(2,1)]))}

def collect_three_neuron_motifs(V, E):
    tri_g = collect_triplets(V, E)
    print(tri_g)
   



In [34]:
# Isolate the pre_nid and post_nid columns with syn_count
syn_table_limit = syn_table[["pre_nid", "post_nid"]]
syn_table_test=syn_table_limit[:1000]

In [35]:
e_test= [(11,21), (21,10), (21,11), (11,10)]
df_test = pd.DataFrame(columns=["s", "t"], data = e_test)
v_test = np.unique(df_test[["s", "t"]])
df_test

Unnamed: 0,s,t
0,11,21
1,21,10
2,21,11
3,11,10


In [36]:
# Adjust the syn_table_test to be able to run the count_three_neuron_motifs function
cellids_test =  np.unique(syn_table_test[["pre_nid", "post_nid"]])
collect_three_neuron_motifs(v_test, df_test)


(10, 11, 21)
not in tri
(10, 11, 21)
in tri
(10, 11, 21)
in tri
(10, 11, 21)
in tri
{(10, 11, 21): {(11, 21)}}


In [140]:
E = df_test
V = v_test

In [173]:
tri[t] + (a, b)

(11, 10, 11, 10)

In [171]:
tri[t]

(11, 10, 11, 10)