In [26]:
import math
import copy
from tqdm import tqdm

import numpy as np
import pandas as pd

import networkx as nx

import matplotlib.pyplot as plt
import plotly.offline
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sbemdb import SBEMDB
from cleandb import clean_db, clean_db_uct

from distance import distance
from findpath import find_path

In [27]:
db = SBEMDB() # connect to DB
db = clean_db(db)

trees Before 398
trees After 46
nodes Before 37481
nodes After 16320
nodecons Before 73994
nodecons After 32548
syncons Before 1199
syncons After 826
synapses Before 552
synapses After 535


In [28]:
db.nodetypes()

{1: 'Soma', 2: 'ExitPoint', 3: 'TreeNode', 5: 'PresynTerm', 6: 'PostsynTerm'}

In [29]:
node_ids_db = list(db.nodeDetails('tid==444 and typ=6').keys())

In [30]:
segments = pd.read_csv('segments_table.csv')
segments.head()

Unnamed: 0,segment_id,branch_id,type,depth,is_synapse,node_id,point_node_id
0,0,2137,main,0,False,2137,2137
1,0,2137,main,0,False,3477,2137
2,0,2137,main,0,False,3478,3477
3,0,2137,main,0,False,3479,3478
4,0,2137,main,0,False,3480,3479


In [31]:
x, y, z, nid = db.nodexyz()

coord_nid = {} # {(coords): nid}
nid_coord = {} # {nid: (coords)}
for i in range(len(x)):
    coord_nid[(x[i], y[i], z[i])] = nid[i]
    nid_coord[nid[i]] = (x[i], y[i], z[i])

In [32]:
x,y,z = db.segments(444)

G = nx.Graph() # graph of connected node ids given segments

for i in range(len(x) - 1):
    if not math.isnan(x[i]) and not math.isnan(x[i+1]):
        G_nid1 = coord_nid[(x[i], y[i], z[i])]
        G_nid2 = coord_nid[(x[i+1], y[i+1], z[i+1])]
        G.add_edge(G_nid1, G_nid2, distance=distance(x[i], y[i], z[i], x[i+1], y[i+1], z[i+1]))

### Length-specific

In [33]:
segment_ids = segments['segment_id'].unique()

## Example
segments[segments['segment_id'] == segment_ids[111]]

Unnamed: 0,segment_id,branch_id,type,depth,is_synapse,node_id,point_node_id
1211,698,5846,is,7,False,7022,7022
1212,698,5846,is,7,False,7024,7022
1213,698,5846,is,7,False,7025,7024
1214,698,5846,is,7,False,7026,7025
1215,698,5846,is,7,False,7027,7026
1216,698,5846,is,7,False,7028,7027
1217,698,5846,is,7,False,7029,7028
1218,698,5846,is,7,False,7030,7029
1219,698,5846,is,7,False,7031,7030
1220,698,5846,is,7,False,7032,7031


In [34]:
def get_length(segment_id):
    segment_nodes = segments[segments['segment_id'] == segment_id]
    start_node = segment_nodes.iloc[-1]['node_id']
    end_node = segment_nodes.iloc[0]['point_node_id']
    length = nx.dijkstra_path_length(G, start_node, end_node, 'distance')
    return length

In [35]:
segment_length = {segment_id: get_length(segment_id) for segment_id in segment_ids}
segment_type = {segment_id: segments[segments['segment_id'] == segment_id]['type'].values[0] for segment_id in segment_ids}

In [36]:
# remove terminal segments, which length < 0.5
#for segment_id, length in segment_length.items():
#    if segment_type[segment_id] == 'ts' and length < 0.5:
#        segments = segments[segments['segment_id'] != segment_id]

In [37]:
segment_length_all = list(segment_length.values())
segment_length_main = [segment_length[seg_id] for seg_id in segments[segments['type'] == 'main']['segment_id'].unique()]
segment_length_intermediate = [segment_length[seg_id] for seg_id in segments[segments['type'] == 'is']['segment_id'].unique()]
segment_length_terminal = [segment_length[seg_id] for seg_id in segments[segments['type'] == 'ts']['segment_id'].unique()]
segment_length_root = [segment_length[seg_id] for seg_id in segments[segments['type'] == 'root']['segment_id'].unique()]

In [51]:
fig = make_subplots(rows=3, cols=2, subplot_titles=("All", "Main", "Intermediate", "Terminal", "Root", ""))

fig.add_trace(
    go.Histogram(x=segment_length_all, name='Length'),
    row=1, col=1
)

fig.add_trace(
    go.Histogram(x=segment_length_main, name='Length'),
    row=1, col=2
)

fig.add_trace(
    go.Histogram(x=segment_length_intermediate, name='Length'),
    row=2, col=1
)

fig.add_trace(
    go.Histogram(x=segment_length_terminal, name='Length'),
    row=2, col=2
)

fig.add_trace(
    go.Histogram(x=segment_length_root, name='Length'),
    row=3, col=1
)

fig.update_layout(height=600, width=800, title_text="Distribution of lengths")
fig.show()

### Synapse-specific

In [40]:
seg_n_synapses = {id: segments[segments['segment_id'] == id]['is_synapse'].sum() for id in segment_ids}

seg_n_syn_all = list(seg_n_synapses.values())
seg_n_syn_main = [seg_n_synapses[seg_id] for seg_id in segments[segments['type'] == 'main']['segment_id'].unique()]
seg_n_syn_intermediate = [seg_n_synapses[seg_id] for seg_id in segments[segments['type'] == 'is']['segment_id'].unique()]
seg_n_syn_terminal = [seg_n_synapses[seg_id] for seg_id in segments[segments['type'] == 'ts']['segment_id'].unique()]
seg_n_syn_root = [seg_n_synapses[seg_id] for seg_id in segments[segments['type'] == 'root']['segment_id'].unique()]

In [41]:
fig = make_subplots(rows=3, cols=2, subplot_titles=("All", "Main", "Intermediate", "Terminal", "Root", ""))

fig.add_trace(
    go.Histogram(x=seg_n_syn_all, name='Num. synapses'),
    row=1, col=1
)

fig.add_trace(
    go.Histogram(x=seg_n_syn_main, name='Num. synapses'),
    row=1, col=2
)

fig.add_trace(
    go.Histogram(x=seg_n_syn_intermediate, name='Num. synapses'),
    row=2, col=1
)

fig.add_trace(
    go.Histogram(x=seg_n_syn_terminal, name='Num. synapses'),
    row=2, col=2
)

fig.add_trace(
    go.Histogram(x=seg_n_syn_root, name='Num. synapses'),
    row=3, col=1
)

fig.update_layout(height=600, width=800, 
                  title_text=f"Distribution of number of synapses per segment (average: {np.mean(seg_n_syn_all):.3})")
fig.show()

### At least 1

In [42]:
seg_n_synapses = {id: segments[segments['segment_id'] == id]['is_synapse'].sum() for id in segment_ids 
                  if segments[segments['segment_id'] == id]['is_synapse'].sum() > 0}
all_seg_ids = list(seg_n_synapses.keys())

seg_n_syn_all = list(seg_n_synapses.values())
seg_n_syn_main = [seg_n_synapses[seg_id] for seg_id in 
                  [id for id in all_seg_ids if id in segments[segments['type'] == 'main']['segment_id'].unique()]]
seg_n_syn_intermediate = [seg_n_synapses[seg_id] for seg_id in 
                          [id for id in all_seg_ids if id in segments[segments['type'] == 'is']['segment_id'].unique()]]
seg_n_syn_terminal = [seg_n_synapses[seg_id] for seg_id in 
                      [id for id in all_seg_ids if id in segments[segments['type'] == 'ts']['segment_id'].unique()]]
seg_n_syn_root = [seg_n_synapses[seg_id] for seg_id in 
                      [id for id in all_seg_ids if id in segments[segments['type'] == 'root']['segment_id'].unique()]]

In [43]:
fig = make_subplots(rows=3, cols=2, subplot_titles=("All", "Main", "Intermediate", "Terminal", "Root", ""))

fig.add_trace(
    go.Histogram(x=seg_n_syn_all, name='Num. synapses'),
    row=1, col=1
)

fig.add_trace(
    go.Histogram(x=seg_n_syn_main, name='Num. synapses'),
    row=1, col=2
)

fig.add_trace(
    go.Histogram(x=seg_n_syn_intermediate, name='Num. synapses'),
    row=2, col=1
)

fig.add_trace(
    go.Histogram(x=seg_n_syn_terminal, name='Num. synapses'),
    row=2, col=2
)

fig.add_trace(
    go.Histogram(x=seg_n_syn_root, name='Num. synapses'),
    row=3, col=1
)

fig.update_layout(height=600, width=800, 
                  title_text=f"Distribution of number of synapses per segment (average: {np.mean(seg_n_syn_all):.3})")
fig.show()

### For branch structures

In [44]:
branch_ids = segments['branch_id'].unique()

In [45]:
seg_n_synapses = {id: segments[segments['branch_id'] == id].drop_duplicates('node_id')['is_synapse'].sum() 
                  for id in branch_ids}

seg_n_syn_all = list(seg_n_synapses.values())

In [46]:
fig = make_subplots(rows=1, cols=1, subplot_titles=("All"))

fig.add_trace(
    go.Histogram(x=seg_n_syn_all, name='Num. synapses'),
    row=1, col=1
)

fig.update_layout(height=600, width=800, 
                  title_text=f"Distribution of number of synapses per branch (average: {np.mean(seg_n_syn_all):.3})")
fig.show()

### For signal path branches

In [47]:
terminal_synapse_ids = segments[(segments['type'] == 'ts') & (segments['is_synapse'])]['node_id'].values

In [48]:
def get_signal_path_nodes(synapse_node_id):
    nodes = [synapse_node_id]
    row = segments[segments['node_id'] == synapse_node_id]
    while row['type'].values[0] != 'main':
        row = segments[segments['node_id'] == row['point_node_id'].values[0]]
        nodes.append(row['node_id'].values[0])
    return nodes

In [49]:
syn_n_signal_branches = []
for syn_id in terminal_synapse_ids:
    nodes = get_signal_path_nodes(syn_id)
    n_synapses = segments[segments['node_id'].isin(nodes)].drop_duplicates('node_id')['is_synapse'].sum()
    syn_n_signal_branches.append(n_synapses)

In [50]:
fig = make_subplots(rows=1, cols=1, subplot_titles=("All"))

fig.add_trace(
    go.Histogram(x=syn_n_signal_branches, name='Num. synapses'),
    row=1, col=1
)

fig.update_layout(height=600, width=800, 
                  title_text=f"Distribution of number of synapses per signal path (average: {np.mean(seg_n_syn_all):.3})")
fig.show()