In [1]:
import numpy as np
import os
from numpy import sin, cos, pi
import pandas as pd
import plotly.offline
import plotly.graph_objects as go
from plotly.graph_objs import Mesh3d

import nn_clustering

from sbemdb import SBEMDB
from cleandb import clean_db_uct
from get_tree_ids import get_tree_ids

import pickle

In [2]:
def save_obj(obj, name):
    with open(name + '.pkl', 'wb') as f:
        pickle.dump(obj, f, pickle.HIGHEST_PROTOCOL)

Getting and preparing database.

In [3]:
db = SBEMDB()
db = clean_db_uct(db)

trees Before 398
trees After 46
nodes Before 37481
nodes After 16320
nodecons Before 73994
nodecons After 32548
syncons Before 1199
syncons After 826
synapses Before 552
synapses After 535


From database get IDs of all trees except tree 444.

In [4]:
all_tree_ids = get_tree_ids(db,'tid != 444')

Let's get the Synapses from each tree, that have at least 2 connection to the tree 444.

In [5]:
all_trees = {}
all_synapses = []
syn_tid = {}

for tree_id in all_tree_ids:
    _x, _y, _z, _, _, _sid, _pre_nid, _post_nid = db.synapses(
        f'pre.tid={tree_id} and post.tid=444', extended=True)

    all_synapses_from_free = {}
    for i in range(len(_sid)):
        x = _x[i]
        y = _y[i]
        z = _z[i]
        sid = _sid[i]
        pre_nid = _pre_nid[i]
        post_nid = _post_nid[i]
        
        syn_tid[sid] = tree_id

        # Overwrites sid duplicates. Only the last one returned from
        # that method remains (and exact coords for synapses with different pre- and postnodes should not matter).
        all_synapses_from_free[sid] = nn_clustering.Synapse(
            sid, x, y, z, pre_nid, post_nid
        )

    # Condition based on synapse count on tree.
    if len(all_synapses_from_free) >= 2:
        all_trees[tree_id] = list(all_synapses_from_free.values())
        all_synapses += list(all_synapses_from_free.values())

In [6]:
for tree_id in all_trees:

    tree_synapses = all_trees[tree_id]
    #print(f"From tree: {tree_id}:\nSynapses:\n{tree_synapses}\n")

Now with the data available analysis can be done. For it we need some set up.
Define parameters, select functions and instantiate objectets needed for the hierarchical clustering.

In [7]:
def find_clusters(param_1, param_2):

    # Create the distance calculator object.
    path_matrix = os.path.join('..','data', 'distance_matrix_path.npy') 
    matrix = np.load(path_matrix)
    #matrix = np.load('distance_matrix_path.npy')
    distance_calculator = nn_clustering.MatrixSynapseDistance(matrix)

    # Constraint classes define how the clusters (in this case their geometry) is limited in
    # the hierarchical clustering algorithm.
    constraint_chaining = nn_clustering.ConstraintChaining(distance_calculator, param_1)
    constraint_diameter = nn_clustering.ConstraintDiameter(distance_calculator, param_2)

    #   To add, or remove the "cluster" limits for termination condition
    #   just add or remove values from list passed to Validator instantiation.
    constraint_validator = nn_clustering.Validator(
        [
            constraint_chaining,
            constraint_diameter,
        ]
    )
    # Closest linkage --> can take any distance function with appropriate signature.
    linkage = nn_clustering.ClosestLinkage(distance_calculator)
    
    clusters = nn_clustering.hierarchical_clustering(
        linkage, constraint_validator, all_synapses)
    
    clustered = []
    unclustered = []
    for clus in clusters:
        if len(clus) == 1:
            unclustered += clus
        else:
            clustered.append(clus)
            
    clusters_tree_ids = []
    for clust in clustered:
        tree_ids_ = []
        for syn in clust:
            tree_ids_.append(syn_tid[syn.sid])
        clusters_tree_ids.append(tree_ids_)
        
    is_homo = []
    for clust in clusters_tree_ids:
        is_clust_homo = np.all(clust[0] == np.array(clust))
        is_homo.append(is_clust_homo)
    is_homo = np.array(is_homo)
    hetero_frac = (is_homo == False).sum()/len(is_homo)
                
    return clustered, unclustered, hetero_frac

In [8]:
sid_tid = {}
xx, yy, zz, pretid, posttid, synid, prenid, postnid = db.synapses(extended=True)
for i in range(len(xx)):
    sid_tid[synid[i]] = pretid[i]

In [None]:
params_1 = np.arange(2.5, 30, 2.5)
params_2 = np.arange(10, 125, 5)


num_clusters = []
fractions = []
syn_stats = pd.DataFrame(columns=['id_param_pair', 'id_cluster', 'param1', 'param2', 
                                  'total_clusters', 'num_synapses', 'hetero_fracion', 'synapse ids', 'tids'])
id_param_pair = 0
id_cluster = 0

data_to_save = {}

for p1 in params_1:
    n_cl_ = []
    frac_ = []
    for p2 in params_2:
        clusters_, unclustered_, hetero_fraction_ = find_clusters(p1, p2)
        data_to_save[(p1, p2)] = {'clusters': clusters_, 'unclustered': unclustered_, 'hetero_fraction': hetero_fraction_}
        n_cl_.append(len(clusters_))
        frac_.append(hetero_fraction_)
        for clus_ in clusters_:
            syn_ids = [syn.sid for syn in clus_]
            tids = [sid_tid[syn.sid] for syn in clus_]
            syn_stats.loc[len(syn_stats)] = [id_param_pair, id_cluster, p1, p2, len(clusters_), 
                                             len(clus_), hetero_fraction_, syn_ids, tids]
            id_cluster += 1
        syn_ids = [syn.sid for syn in unclustered_]
        tids = [sid_tid[syn.sid] for syn in unclustered_]
        syn_stats.loc[len(syn_stats)] = [id_param_pair, 'unclustered', p1, p2, -1, 
                                         len(unclustered_), hetero_fraction_, syn_ids, tids]
        id_param_pair += 1
    num_clusters.append(n_cl_)
    fractions.append(frac_)
    
syn_stats.to_csv('clust_stats_path.csv', index=False)
save_obj(data_to_save, 'saved_param_clusters_path')

## Num clusters

In [None]:
import plotly.graph_objects as go

fig = go.Figure(
    data=go.Heatmap(z=num_clusters, y=[str(x) for x in params_1], x=[str(x) for x in params_2]),
    layout=go.Layout(
        xaxis=dict(title='param2', tickvals=params_2),
        yaxis=dict(title='param1', tickvals=params_1)
    )
)
fig.show()

## Hetero Fractions

In [None]:
import plotly.graph_objects as go

fig = go.Figure(
    data=go.Heatmap(z=fractions, y=[str(x) for x in params_1], x=[str(x) for x in params_2]),
    layout=go.Layout(
        xaxis=dict(title='param2', tickvals=params_2),
        yaxis=dict(title='param1', tickvals=params_1)
    )
)
fig.show()