## Import of relevant python modules

In [None]:
import pickle
import numpy as np
from numpy import sin, cos, pi
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import networkx as nx
import plotly.offline
import plotly.graph_objects as go
from plotly.graph_objs import Mesh3d
import scipy.stats as sci  # calculate standard error
from scipy.stats import ks_2samp
from sbemdb import SBEMDB
from cleandb import clean_db, clean_db_uct
import os
import scipy.io
from math import sqrt
import math
from mapping import Mapping
np.warnings.filterwarnings('ignore')

db = SBEMDB() # connect to DB
db = clean_db(db) 
x,y,z = db.segments(444)
mapping = Mapping()

In [None]:
def load_obj(name):
    with open(name + '.pkl', 'rb') as f:
        return pickle.load(f)
cluster_params = load_obj('saved_param_clusters_path_2-5')

## Extract coordinates and other information on pre- and postsynaptic sites

In [None]:
(xx, yy, zz, pretid, posttid, synid, prenid, postnid) = db.synapses('pre.tid!=444 and post.tid=444',extended = True)

In [None]:
sid_tid = {}
for i, sid in enumerate(synid):
    sid_tid[sid] = pretid[i]

## Retrieve coherence values and corresponding color for visualization

In [None]:
# Read in the color scheme
path_color = os.path.join('cohcolor.csv') 
color_map = pd.read_csv(path_color)
color_map.phi = color_map.phi.apply(lambda x: round(x, 3))

f = open(path_color)
lines = f.readlines()
f.close()

# Read in phase and mag values
path_color = os.path.join('roi_phase_63.mat') 
phase_mat = scipy.io.loadmat(path_color)
path_alpha = os.path.join('roi_mag_63.mat') 
mag_mat = scipy.io.loadmat(path_alpha)

col_names = ['roi','Phase','Alpha']
data_values = pd.DataFrame(columns = col_names)
mag_values = mag_mat['roi_mag'][0];

for el, val in enumerate(phase_mat['roi_phase'][0]):
    data_values.loc[el,'roi'] = el + 1    
    data_values.loc[el,'Alpha'] = (mag_values[el]);
    data_values.loc[el,'Phase'] = round(val,3) 
    if val < 0: #convert to range [0,2pi]
        data_values.loc[el,'Phase'] = round(val+2*pi,3)

In [None]:
def is_roi(roi):
    return data_values['roi']==roi;

def to_8bit_rgb(rgb):
    return int(round(rgb*255));


def alpha_grey_convert(r, g, b, alpha=0):
    if(alpha>=1) : return (r,g,b);
    if(alpha<0) : return (128,128,128);
    k = alpha;
    n = 128*(1-alpha);
    return(int(k*r+n),int(k*g+n),int(k*b+n))

def get_color(db, tid=None, phase=None, alpha=None):
    if tid is not None:
        roi = mapping.sbem2roi[tid];
        phase = data_values.iloc[roi-1].Phase
        alpha = data_values.iloc[roi-1].Alpha
    else:
        phase = round(phase, 3)
        alpha = round(alpha, 3)
    
    #match phase to phi in the color_map variable and select the respective r,g,b values 
    rgb1 = color_map.loc[color_map.phi == phase, 'r'] #then get the color code for that phase from the color_mat --> needs rewriting as well
    rgb2 = color_map.loc[color_map.phi == phase, 'g']
    rgb3 = color_map.loc[color_map.phi == phase, 'b']
    
    
    #use alpha_grey_convert
    try:
        rgb1,rgb2,rgb3 = alpha_grey_convert(to_8bit_rgb(rgb1), to_8bit_rgb(rgb2), to_8bit_rgb(rgb3), alpha)
    except:
        rgb1,rgb2,rgb3 = 255, 255, 255
    
    return f'rgb({rgb1},{rgb2},{rgb3})';

def get_color_bar(phase):
    # match phase to phi in the color_map variable and select the respective r,g,b values
    rgb1 = color_map['r'].loc[
        color_map.phi == phase]  # then get the color code for that phase from the color_mat --> needs rewriting as well
    rgb2 = color_map['g'].loc[color_map.phi == phase]
    rgb3 = color_map['b'].loc[color_map.phi == phase]
    return f'rgb({to_8bit_rgb(rgb1)},{to_8bit_rgb(rgb2)},{to_8bit_rgb(rgb3)})';

In [None]:
synapses = {}
for i in range(len(xx)):
    tid = pretid[i]
    sid = synid[i]
    if tid not in synapses: synapses[tid] = []
    synapses[tid].append([xx[i], yy[i], zz[i], sid])

## Define ellipses to visualize synaptic clusters

In [None]:
def generate_color():
    r = int(np.random.random()*255)
    g = int(np.random.random()*255)
    b = int(np.random.random()*255)
    color = f'rgb({r}, {g}, {b})'
    if r < 25 and g < 25 and b < 25:
        color = generate_color()
    return color

def ellipsoid3d(x0, y0, z0, xd, yd, zd):
    phi = np.linspace(0, 2*pi)
    theta = np.linspace(-pi/2, pi/2)
    phi, theta=np.meshgrid(phi, theta)

    x = x0 + cos(theta) * sin(phi) * (xd + 2)
    y = y0 + cos(theta) * cos(phi) * (yd + 2)
    z = z0 + sin(theta) * (zd + 2)
    
    return (x.flatten(), y.flatten(), z.flatten())

def find_coords(xs, ys, zs, euclidian=False):
    x0 = sum(xs)/len(xs)
    y0 = sum(ys)/len(ys)
    z0 = sum(zs)/len(zs)
    
    if euclidian:
        d = max([np.sqrt(np.square(xs[i]-x0) + np.square(ys[i]-y0) + np.square(zs[i]-z0)) for i in range(len(xs))])
        xd, yd, zd = d, d, d
    else:
        xd = max([abs(x-x0) for x in xs])
        yd = max([abs(y-y0) for y in ys])
        zd = max([abs(z-z0) for z in zs])
    return x0, y0, z0, xd, yd, zd

## Set cluster parameters - nearest neighbor distance and cluster extent

In [None]:
clusters_vis = cluster_params[(5, 50)]['clusters']

In [None]:
colorscale = []
for i in range(1, len(lines)-1):
    try:
        t = lines[i].split(",")
        color = get_color_bar(float(t[0]))
        colorscale.append(color)
    except:
        continue

## Plot connectome of DE3-R, synapses from presynaptic partner neurons and synaptic clusters

In [None]:
color_names = ['0°', '-90°', '±180°', '90°', '0°']
color_vals = np.linspace(1, 5, 5)

def plot_graph(synapses, clustered, min_synapses=1, elips=True):
    each_tree_scatter = []
    each_cluster = []
    soma = []
    synapses_clustered = []
    synapses_in_ellipses = []
    
    for clust in clustered:
        xs = []
        ys = []
        zs = []
        rois_ = []
        sids_ = []
        
        for syn in clust:
            xs.append(syn.x)
            ys.append(syn.y)
            zs.append(syn.z)
            sids_.append(syn.sid)
            
            tid = sid_tid[syn.sid]
            roi_ = mapping.sbem2roi[tid]
            if roi_ not in rois_: 
                rois_.append(roi_)
                
        synapses_clustered += sids_
        if(len(rois_) <= 1):
            continue
        synapses_in_ellipses += sids_

        vals = [data_values.iloc[roi-1] for roi in rois_]
        cohs = [val['Alpha'] * math.cos(val['Phase']) +  1j*val['Alpha']*math.sin(val['Phase']) for val in vals]
        coh_mean = np.mean(cohs)
        alpha_value = np.absolute(coh_mean)       
        phase_value = np.angle(coh_mean)
        if phase_value < 0:
            phase_value = phase_value+2*pi
            
        color_ = get_color(db, phase=phase_value, alpha=alpha_value)
        
        if elips:
            x0, y0, z0, xd, yd, zd = find_coords(xs, ys, zs, True)
            xe, ye, ze = ellipsoid3d(x0, y0, z0, xd, yd, zd)
            each_cluster.append(Mesh3d({
                        'x': xe, 
                        'y': ye, 
                        'z': ze, 
                        'alphahull': 0,
                        'opacity': 0.3,
                        'color': color_,
                        'hoverinfo': 'text',
                        'hovertext': [],
                    }))
            
    
    added_colorbar = False        
    for tid in synapses:
        try:
            if len(synapses[tid]) < min_synapses: continue
            xs = []
            ys = []
            zs = []
            texts_ = []
            marker_sizes = []
            
            for coords in synapses[tid]:
                sid = coords[3]
                
                xs.append(coords[0])
                ys.append(coords[1])
                zs.append(coords[2])
                texts_.append(f'tid={tid} | sid={sid}')
                
                if sid in synapses_in_ellipses: # hetero clustered synapses
                    marker_sizes.append(10)
                elif sid in synapses_clustered: # homo clustered synapses
                    marker_sizes.append(6)
                else: # unclustered synapses
                    marker_sizes.append(6)
                    
                scatter_ = go.Scatter3d(x=xs, y=ys, z=zs,
                                      mode='markers',
                                      hoverinfo='text',
                                      hovertext=texts_,
                                      name=f'{tid:4} - {mapping.sbem2can[tid]}',
                                      marker=dict(
                                          color=get_color(db, tid), # set color to an array/list of desired value
                                          size=marker_sizes,
                                          opacity=1,
                                          cmin=1,
                                          cmax=5
                                      ))
                
                if added_colorbar == False:
                    added_colorbar = True

                each_tree_scatter.append(scatter_)
        except:
            pass
    
    nodes = db.nodexyz(f'tid==444 and typ==1')  # (x, y, z, nid)
    soma.append(go.Scatter3d(x=nodes[0], y=nodes[1], z=nodes[2],
                                         name='Soma',
                                         hoverinfo=[],
                                         hovertext=[],
                                         mode='markers',
                                         marker=dict(
                                             color="rgb(0,0,0)",
                                             size=7,  # set color to an array/list of desired values
                                             opacity=1
                                         )))
    fig = go.Figure(data=[go.Scatter3d(x=x, y=y, z=z,
                                       mode='lines',
                                       line = dict(color='black', width=2.3),
                                       hoverinfo='text',
                                       hovertext=[],
                                       opacity=0.7,
                                       name='DE3',
                                       marker=dict(
                                           color=1,
                                           size=25,  # set color to an array/list of desired values
                                           colorscale='Viridis',  # choose a colorscale
                                           opacity=1
                                       )),
                          ] + each_tree_scatter + soma + each_cluster)
    fig.update_layout(title="DE3-R motor neuron with synaptic clusters",
                      showlegend=False,
        scene=dict(
            xaxis=dict(nticks=5, range=[50, 300], showbackground=False, showticklabels=False, title=''),
            yaxis=dict(nticks=13, range=[50, 750], showbackground=False, showticklabels=False, title=''),
            zaxis=dict(nticks=6, range=[50, 350], showbackground=False, showticklabels=False, title=''),
            aspectmode='data',
            dragmode='orbit'
        ),
       scene_camera=dict(
            eye=dict(x=2, y=-0.1, z=-0.8),
            up=dict(x=0, y=0, z=-1),
            center=dict(x=0, y=0, z=0)
        )
    )

    plotly.offline.plot(fig, filename='Visualization_coherent_synaptic_clusters_swim.html', auto_open=True)
    fig.show()
    
plot_graph(synapses, clusters_vis);

In [None]:
syn_x, syn_y, syn_z, _tid, _, _sid, _pre_nid, _post_nid = db.synapses(f'post.tid=444', extended=True) 

In [None]:
syn_types = pd.read_csv(r"C:\Users\Amanda P\Desktop\SURF\leechem-public\new\leechem-public\data\tree_sids.csv", header=None)
tree_syn_types = {}
for i, item in syn_types.itertuples():
    tree_syn_types[int(item.split(',')[0])] = int(item.split(',')[1])

syn_ratio = [[],[],[]]
syn_total = [[],[],[]]
for clust in clusters_vis:
    length = len(clust)
    inhib = sum([tree_syn_types[_tid[np.where(_sid == id.sid)[0][0]]] == -1 for id in clust])
    exci = sum([tree_syn_types[_tid[np.where(_sid == id.sid)[0][0]]] == 1 for id in clust])
    unknown = sum([tree_syn_types[_tid[np.where(_sid == id.sid)[0][0]]] == 0 or 
                   tree_syn_types[_tid[np.where(_sid == id.sid)[0][0]]] == '' for id in clust])
    syn_total[0].append(inhib)
    syn_total[1].append(exci)
    syn_total[2].append(unknown)
    syn_ratio[0].append(inhib/length)
    syn_ratio[1].append(exci/length)
    syn_ratio[2].append(unknown/length)

In [None]:
inhib_total, exci_total, na_total = sum(syn_total[0]), sum(syn_total[1]), sum(syn_total[2]) # 35, 45, 70

In [None]:
mean_inhib_ratio, mean_exci_ratio, mean_na_ratio = np.mean(syn_ratio[0]), np.mean(syn_ratio[1]), np.mean(syn_ratio[2])
# avg inhibitory synapse ratio in cluster: 0.249, avg excitatory synapse ratio in cluster:  0.376, avg unknown synapse ratio in cluster: 0.375
plt.rcParams.update({'errorbar.capsize': 2})

#general y/x axis label (type of synapse)
plt.bar(x=range(3), height=[mean_inhib_ratio, mean_exci_ratio, mean_na_ratio], 
        tick_label= ['Inhibitory', 'Excitatory', 'Unknown'], 
        yerr = [sci.sem(syn_ratio[0]), sci.sem(syn_ratio[1]), sci.sem(syn_ratio[2])])
plt.title("Swimming Trial (Trial 63)")

In [None]:
num_clust = len(clusters_vis) # total num clusters: 30
num_syn = [len(clust) for clust in clusters_vis]
mean_syn, sem_syn = np.mean(num_syn), sci.sem(num_syn)  # mean: 5.00, standard error: 0.743