In [1]:
#from goto import with_goto
import os,glob
import json
import pickle
import copy
import math
import numpy as np
import pandas as pd
import seaborn as sns
#import torch

from sklearn.cluster import MeanShift
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances
from scipy.spatial import distance_matrix
from scipy.spatial.distance import cdist
import SimpleITK as sitk

import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from matplotlib import cm

#from matplotlib.patches import Patch
# import hdbscan
# import umap

from swc_handler import parse_swc
from file_io import load_image
#from anatomy.anatomy_core import parse_id_map, parse_ana_tree, get_regional_neighbors

import graphviz
from graphviz import Digraph

# 1 brain acronym signal concatenate and plots

In [24]:
def get_tree():
    with open('../../assets/tree.json','r') as f:
        tree = json.loads(f.read())     
    return tree
    
def get_tree_from(info):
    with open('../../assets/tree.json','r') as f:
        tree = json.loads(f.read())     
    if type(info) == str:  
        if info in tree[0].keys():
            new_tree = {}
            for t in tree:
                new_tree[t[info]] = t
            return new_tree  
        else:
            print(f"input should be a string in {tree[0].keys()} instead of {info}")
            return {}
    else:
        print(f'input should be str instead of {type(info)}')
        return {}

    
def get_axis(level,hemi,axis):
    filt = f'in_n{level}'
    acronym = f'n{level}_acronym_lr' if hemi else f'n{level}_acronym'
    center = 'center_u25_lr' if hemi else 'center_u25'
    tree = [ [v[acronym], v[center][0], v[center][1], v[center][2]] for v in get_tree() if v[filt] ]
    center_df = pd.DataFrame(tree,columns=['acronym','centerx','centery','centerz'])
    
    a = ['centerx','centery','centerz'][['AP','DV','RL'].index(axis)]
    center_reshape_df = center_df.sort_values(by=a,axis=0,ascending=True)    
    region_list = center_reshape_df['acronym']
    return center_reshape_df,region_list


def calc_dist(xyz1,xyz2):
    qsum = 0
    for i in range(len(xyz1)):
        qsum += (xyz1[i]-xyz2[i])**2
    return math.sqrt(qsum)


def produce_swc_acronym(maskfile,info1,info2,somafile,csvfile):
    plot_soma = []
    plot_soma_ac = []
    plot_soma_color = []
    somata_num = 0
    plot_num = 0
    filt_num = 0

    mapping = get_tree_from(info1)
    annotation = sitk.GetArrayFromImage(sitk.ReadImage(maskfile))
    for marker_25 in glob.glob(f"{somafile}/*/[1-9]*.swc"):
        brain = marker_25.split('/')[-2]
        tree = parse_swc(marker_25)

        for leaf in tree:
            i,t,x,y,z,r,p = leaf
            somata_num += 1

            try:
                ac = annotation[round(z),round(y),round(x),]  
                if ac==0: raise
            except:
                filt_num += 1
                continue

            plot_num += 1

            soma = f'200k_{brain}_{round(x)}_{round(y)}_{round(z)}'
            plot_soma.append(soma)
            acronym = mapping[ac][info2]
            plot_soma_ac.append(acronym)

        print(marker_25,brain,somata_num,plot_num,filt_num)

    #import pdb; pdb.set_trace() 
    df_200ksoma = pd.DataFrame(np.array([plot_soma,plot_soma_ac]).T,columns=['soma',info2,])
    df_200ksoma.to_csv(csvfile)
    
    return df_200ksoma


def mykmeans(X,n_clusters):    
    X_internal = X[1:]-X[:-1]
    #print('\n','internal',X_internal,len(X_internal),'\n')
    
    cluster = {}
    ini_id = [np.where(X_internal==-i)[0][0] for i in np.sort(-X_internal)[0:n_clusters]]    
    for i in range(n_clusters):
        c = ini_id[i]
        c_ids = [ini_id[i]]
        c_datas = X[c_ids].tolist()
        cluster[c] = [c_ids,c_datas]    
    
    rm_ids = []
    [rm_ids.extend(v[0]) for k,v in cluster.items()]
    rm_datas = []
    [rm_datas.extend(v[1]) for k,v in cluster.items()]    
    data = X.tolist()
    [data.remove(d) for d in rm_datas]   
    
    #r = 0
    while len(data):
        data_dist_tomax = []
        data_dist_tomin = []
        for c in cluster.keys():
            [data_dist_tomax.append(abs(d-max(cluster[c][1]))) for d in data]
            [data_dist_tomin.append(abs(d-min(cluster[c][1]))) for d in data]
        min_maxdist = min(data_dist_tomax)
        min_mindist = min(data_dist_tomin)
        rm_id_dist = data_dist_tomax.index(min_maxdist) if min_maxdist<min_mindist else data_dist_tomin.index(min_mindist)
    #    print(len(data_dist_tomax),len(data_dist_tomin),min_maxdist,min_mindist,rm_id_dist,data_dist_tomax[rm_id_dist],data_dist_tomin[rm_id_dist])

        rm_id = rm_id_dist%len(data)
        rm_data = data[rm_id]
        rm_c = int(rm_id_dist/len(data))
        dc = list(cluster.keys())[rm_c]
    #    print(rm_id,rm_data)
    #    print(rm_c,dc)

        cluster[dc][1].extend([rm_data]) 
        x_id = np.where(X==rm_data)[0][0]
        cluster[dc][0].extend([x_id])        
    #    print(cluster,'\n')

        data.remove(rm_data) 
    
        #if r%10==0:
    yhat = []
    for x in X:
        xc = 0
        for c,d in cluster.items():
            if x in d[1]:
                xc = 1
                yhat.append(c)
                break
        if not xc: yhat.append(999)
            #print('round ',r,'\ncluster ',cluster,'\nyhat',yhat)
        #r += 1
    return yhat


def get_axis_cluster(level,hemi,axis,n_clusters,outfile_dir,outfig_dir):    
    filt = f'in_n{level}'
    acronym = f'n{level}_acronym_lr' if hemi else f'n{level}_acronym'
    center = 'center_u25_lr' if hemi else 'center_u25'
    tree = [ [v[acronym], v[center][0], v[center][1], v[center][2]] for v in get_tree() if v[filt] ]
    center_df = pd.DataFrame(tree,columns=['acronym','centerx','centery','centerz'])
    
    a = ['centerx','centery','centerz'][['AP','DV','RL'].index(axis)]
    center_reshape_df = center_df.sort_values(by=a,axis=0,ascending=True)    

    region_list = center_reshape_df['acronym']
    print(f'\nfor axis: {axis}, acronym: {region_list}, len: {len(region_list)}\n')    
    
    center_list = center_reshape_df[a]    
    print(f'\nfor axis: {axis}, center: {center_list}, len: {len(center_list)}\n')

    yhat = mykmeans(center_list.values,n_clusters=n_clusters)
    cs,ccs = np.unique(yhat,return_counts=True)
    count_max = ccs.max()
    print(f'axis {axis} do {n_clusters} clustering\nyhat: {yhat}\nmax: {count_max}\n')
    
    count_clusters = 20
    num_clusters = 36
    while count_max > count_clusters and n_clusters<num_clusters:
        n_clusters += 2 
        yhat = mykmeans(center_list.values,n_clusters=n_clusters)
        cs,ccs = np.unique(yhat,return_counts=True)
        count_max = ccs.max()
        print(f'axis {axis} do {n_clusters} clustering\nyhat: {yhat}\nmax: {count_max}\n')
  
    color = ['black','gray','lightgray','mistyrose',
             'darkred','firebrick','red','orangered',
             'sienna','chocolate','peru','tan',
             'darkorange','orange','gold','yellow','y',
             'darkolivegreen','olivedrab','greenyellow','limegreen','green',
             'aquamarine','paleturquoise','teal','c','cyan','deepskyblue','blue',
             'mediumpurple','blueviolet','darkviolet','purple','magenta','deeppink','crimson',
            ]
    cluster_dict = {}
    legends = []
    fig = plt.figure(figsize=(10,1)) 
    xlims = [(0,528),(0,320),(0,456)][['AP','DV','RL'].index(axis)]
    plt.xlim(xlims)
    for c in yhat:
        if c==999: continue
        if c in cluster_dict.keys(): continue
        legend = ''        
        row = np.where(np.array(yhat) == c)[0]
        row_region = [r for r in region_list.values[row]]
        cluster_dict[int(c)] = row_region
        for ls in row_region:
            for l in ls:
                legend += l
            legend += ','
        legends.append(legend)
        #print('plot',center_list.values[row], color[len(cluster_dict.keys())-1], len(row))  
        plt.scatter(x=center_list.values[row], y=0 * center_list.values[row], s=2, c=color[len(cluster_dict.keys())-1])  

    plt.legend(legends,ncol=5,fontsize=2)
    plt.savefig(outfig_dir+f'/cluster_{n_clusters}_{axis}.png', dpi=450)
    plt.close()
    #print(f'legend: {legends}\n')    
    
    cluster_object = json.dumps(cluster_dict)
    with open(f"{outfile_dir}/cluster_{axis}.json", "w") as f:
        f.write(cluster_object) 
#     with open(f"{outfile_dir}/cluster_{axis}.json", "r") as f:
#         cluster_dict = f.read()   
#         print(f'{axis}: {cluster_dict}\n')        

    print(f'axis {axis} cluster dict: {cluster_dict}\n')
    return cluster_dict



# def prepare_region(region_list,vtk_des):
#     vtk_list = glob.glob('../../vtk/*.vtk')
#     if not os.path.exists(vtk_des): os.makedirs(vtk_des)

#     acronym_tree = get_tree_from('acronym')  
#     for acronym in region_list:
#         vtk_file = '../../vtk/' + acronym + '.vtk'
#         if not vtk_file in vtk_list: 
#             print(f'{vtk_file} does not exists'); continue
          
#         color = acronym_tree[acronym]['rgb']
#         new_vtk_file = vtk_des + '/' + f'r{color[0]}g{color[1]}b{color[2]}_' + acronym + '.vtk'
        
#         cp_str = f'cp {vtk_file} {new_vtk_file}'
#         os.system(cp_str)
#     return True

In [4]:
class BrainsSignalAnalyzer(object):
    def __init__(self,allow0=False,
                 modalities=['fMOST-Zeng'],
                 modalities_colors=['black']): 
        self.allow0 = allow0
        self.modalities = modalities
        self.modalities_brains_dict = {}
        for m in self.modalities:
            self.modalities_brains_dict[m] = []
        self.modalities_colors = modalities_colors
        self.colors_brains_dict = {}
        for c in self.modalities_colors:
            self.colors_brains_dict[c] = []
        self.u16_id_lr_tree = get_tree_from('u16_id')       

           
    def ModalitiesSignalAnalyzer(self,signal_dir,outfile_dir,level,hemi,analysis_agg,outfig_dir):         
        column = f'n{level}_acronym_lr' if hemi else f'n{level}_acronym'
        acronym_list_dict = {}
        for axis in ['AP','DV','RL']:
            center_df,acronym_list_dict[axis] = get_axis(level,hemi,axis)
            center_df.to_csv(outfile_dir+f'/centers_{column}_{axis}.csv', sep=',', float_format='%.4f', index=False)
        #prepare(acronym_list_dict['AP'],'../plots/vtk/')     
    
        df_modalities = pd.DataFrame([])
        for modality_index, modality in enumerate(self.modalities):
            modality_color = self.modalities_colors[modality_index]
            modality_dir = signal_dir + '/' + modality
            print(modality_dir,modality,modality_color)
            
            df_modality = self.ModalitySignalAnalyzer(modality_dir,modality,modality_color) 
            df_modality.to_csv(outfile_dir+'/'+modality+'_modalities.csv', sep=',', index=False)
            
            df_plot = pd.pivot_table(df_modality,values=['count'],index=['brain'],columns=[column],aggfunc=analysis_agg,)['count'].fillna(0)       
            for axis in ['AP','DV','RL']:
                ii = 0
                for i in acronym_list_dict[axis]:
                    if i in df_plot.columns:
                        data = np.concatenate((data,[df_plot[i]]),axis=0) if ii else [df_plot[i]]
                    else:
                        data = np.concatenate((data,[np.zeros(df_plot.shape[0])]),axis=0) if ii else [np.zeros(df_plot.shape[0])]
                    ii += 1
                data = data.T
                df_plot_reshape = pd.DataFrame(data,columns=acronym_list_dict[axis],index=df_plot.index)
                df_plot_reshape.to_csv(outfile_dir+f'/{modality}_brain_acronym_n{level}_hemi{hemi}_{axis}.csv', sep=',')
            
            df_modalities = pd.concat([df_modalities,df_modality])        
        df_modalities.to_csv(outfile_dir+'/all_modalities.csv', sep=',', index=False)            
        
        df_plot = pd.pivot_table(df_modalities,values=['count'],index=['modality','brain'],columns=[column],aggfunc=analysis_agg,)['count'].fillna(0)
        df_plot.index = [i[1] for i in df_plot.index]
        for axis in ['AP','DV','RL']:
            ii = 0
            for i in acronym_list_dict[axis]:
                if i in df_plot.columns:
                    data = np.concatenate((data,[df_plot[i]]),axis=0) if ii else [df_plot[i]]
                else:
                    data = np.concatenate((data,[np.zeros(df_plot.shape[0])]),axis=0) if ii else [np.zeros(df_plot.shape[0])]
                ii += 1
            data = data.T
            df_plot_reshape = pd.DataFrame(data,columns=acronym_list_dict[axis],index=df_plot.index)
            df_plot_reshape.to_csv(outfile_dir+f'/all_modalities_brain_acronym_n{level}_hemi{hemi}_{axis}.csv', sep=',')
   
            self.plot_clustermap(column,axis,df_plot_reshape,outfig_dir)
            self.plot_acronyms_line(axis,df_plot_reshape,outfig_dir)
            self.plot_acronyms_heatmap(axis,df_plot_reshape,outfig_dir)
        self.plot_brains_line(df_plot_reshape,outfig_dir)
        self.plot_brains_heatmap(df_plot_reshape,outfig_dir)   
    
    def plot_clustermap(self,column,axis,df_plot,outfig_dir):
        row_colors = []
        for b in df_plot.index:
            color = [c for c,bs in self.colors_brains_dict.items() if b in bs][0]
            row_colors.append(color)
            
        acronym_tree = get_tree_from(column)
        c = column.split('_')[0]+'_rgb'
        rgbs = []
        column_colors = []
        for a in df_plot.columns:
            rgb = acronym_tree[a][c]
            rgbs.append(rgb)
            color = '#'
            for i in rgb: color += hex(i)[2:].upper()
            column_colors.append(color)
        rgbs = np.array(rgbs).reshape(1,-1,3)  
        
        sns.clustermap(df_plot,cmap='gray', norm=LogNorm(), cbar_pos=(0.01,0.05,0.02,0.15),#left bottom width height
                       row_cluster=False,
                       row_colors=row_colors,
                       col_cluster=False,
        #               col_colors=column_colors,
        #               figsize=(12,13),
        #               dendrogram_ratio=(.1,.2)),
                       )

        plt.savefig(outfig_dir+f'/clustermap_{axis}.png', dpi=450)  
        plt.close()
        
        plt.imshow(rgbs)
        plt.savefig(outfig_dir+f'/clustermap_row_colors_{axis}.png', dpi=450)  
        plt.close()        
        
    
    def plot_acronyms_line(self,axis,df_plot,outfig_dir):            
        #print('all',df_plot.values.shape)
        sns.lineplot(x=df_plot.columns, y=df_plot.values.mean(axis=0),color='black',)
        
        row_modalities = []  
        for b in df_plot.index:
            modality = [m for m,bs in self.modalities_brains_dict.items() if b in bs][0]
            row_modalities.append(modality)
            
        for modality in self.modalities:
            #print(modality,df_plot.values[np.array(row_modalities)==modality].shape)
            sns.lineplot(x=df_plot.columns, y=df_plot.values[np.array(row_modalities)==modality].mean(axis=0), 
                         color=self.modalities_colors[self.modalities.index(modality)],)
            
        plt.legend(['all']+self.modalities)
        plt.savefig(outfig_dir+f'/acronym_lineplot_{axis}.png', dpi=450)
        plt.close()
        
        
    def plot_brains_line(self,df_plot,outfig_dir):
        sns.lineplot(x=df_plot.index, y=df_plot.values.mean(axis=1), color='black',)     
        plt.savefig(outfig_dir+f'/brain_lineplot.png', dpi=450)
        plt.close()
        

    def plot_acronyms_heatmap(self,axis,df_plot,outfig_dir):
        acronym_corr = df_plot.corr()
        sns.heatmap(acronym_corr
        #            ,mask=mask       #只显示为true的值
                    , cmap='Greens'
        #             , vmax=.3
        #             , vmin=.1
        #             , center=0.5
        #             , annot=True
        #             , xticklabels=True
        #             , yticklabels=True
        #             , square=True
                   )
        plt.savefig(outfig_dir+f'/acronym_corr_heatmap_{axis}.png', dpi=450)   
        plt.close()
        
    
    def plot_brains_heatmap(self,df_plot,outfig_dir):
        df_plot_ = pd.DataFrame(df_plot.values.T,columns = df_plot.index,index = df_plot.columns) 
        brain_corr = df_plot_.corr()
        sns.heatmap(brain_corr
        #            ,mask=mask       #只显示为true的值
                    , cmap='Greens'
        #             , vmax=.3
        #             , vmin=.1
        #             , center=0
        #             , annot=True
        #             , xticklabels=True
        #             , yticklabels=True
        #             , square=True
                   )
        plt.savefig(outfig_dir+f'/brain_corr_heatmap.png', dpi=450)           
        plt.close()
        
    
    def ModalitySignalAnalyzer(self,modality_dir,modality,modality_color):
        df_modality = pd.DataFrame([])
        brains_csvs = glob.glob(modality_dir + '/*.csv')
        for b,brain_csv in enumerate(brains_csvs):
            brain = os.path.basename(brain_csv).replace('.csv','') 
            print(b,brain)
            
            df_brain = self.BrainSignalAnalyzer(modality,modality_color,brain_csv,brain)   
            df_modality = pd.concat([df_modality,df_brain])
        return df_modality


    def BrainSignalAnalyzer(self,modality,modality_color,brain_csv,brain):
        df_brain = pd.read_csv(brain_csv,header=0,usecols=[0,1],names=['u16_id_lr_mask','count'])#'u16_id_lr'
        if not self.allow0:
            df_brain = df_brain[df_brain['u16_id_lr_mask']!=0]
            
        df_brain['modality'] = modality
        df_brain['modality_color'] = modality_color
        df_brain['brain'] = brain
        
        self.modalities_brains_dict[modality].append(brain)
        self.colors_brains_dict[modality_color].append(brain)
        
        n316_acronym_list = []
        n70_acronym_list = []
        n8_acronym_list = []
        n316_acronym_lr_list = []
        n70_acronym_lr_list = []
        n8_acronym_lr_list = []
        for u16_id_lr in df_brain['u16_id_lr_mask']:
            n316_acronym_list.append(self.u16_id_lr_tree[u16_id_lr]['n316_acronym'])
            n70_acronym_list.append(self.u16_id_lr_tree[u16_id_lr]['n70_acronym'])
            n8_acronym_list.append(self.u16_id_lr_tree[u16_id_lr]['n8_acronym'])
            n316_acronym_lr_list.append(self.u16_id_lr_tree[u16_id_lr]['n316_acronym_lr'])
            n70_acronym_lr_list.append(self.u16_id_lr_tree[u16_id_lr]['n70_acronym_lr'])
            n8_acronym_lr_list.append(self.u16_id_lr_tree[u16_id_lr]['n8_acronym_lr'])
        df_brain['n316_acronym'] = n316_acronym_list
        df_brain['n70_acronym'] = n70_acronym_list
        df_brain['n8_acronym'] = n8_acronym_list
        df_brain['n316_acronym_lr'] = n316_acronym_lr_list
        df_brain['n70_acronym_lr'] = n70_acronym_lr_list
        df_brain['n8_acronym_lr'] = n8_acronym_lr_list
        
        if not self.allow0:
            df_brain = df_brain[df_brain['n70_acronym']!='0']
            
        return df_brain
    
    
    

In [7]:
allow0 = False
modalities = ['fMOST-Zeng','fMOST-Huang','LSFM-Wu','LSFM-Osten',]
modalities_colors = ['red','yellow','blue','purple',]
bssa = BrainsSignalAnalyzer(allow0 = allow0,
                            modalities=modalities, 
                            modalities_colors = modalities_colors,
                           )        

signal_dir = '../statis_out_adaThr_all'
outfile_dir = '../data'
level = 70
hemi = True
analysis_agg = 'sum'
outfig_dir = '../plots'
bssa.ModalitiesSignalAnalyzer(signal_dir, outfile_dir, level, hemi, analysis_agg, outfig_dir)

../statis_out_adaThr3/fMOST-Zeng fMOST-Zeng red
0 15257
1 17051
2 17052
3 17109
4 17298
5 17300
6 17301
7 17302
8 17304
9 17539
10 17541
11 17542
12 17543
13 17544
14 17545
15 17781
16 17782
17 17783
18 17785
19 17786
20 17788
21 18047
22 18049
23 18052
24 18053
25 182711
26 182712
27 182720
28 182721
29 182722
30 182724
31 182725
32 182726
33 182727
34 182737
35 18452
36 18453
37 18454
38 18455
39 18457
40 18458
41 18459
42 18461
43 18463
44 18464
45 18465
46 18466
47 18467
48 18468
49 18469
50 18470
51 18471
52 18472
53 18860
54 18861
55 18862
56 18864
57 18865
58 18866
59 18867
60 18868
61 18869
62 18871
63 191797
64 191798
65 191799
66 191801
67 191803
68 191804
69 191807
70 191808
71 191809
72 191810
73 191811
74 191812
75 191813
76 191815
77 191817
78 192333
79 192334
80 192335
81 192337
82 192338
83 192339
84 192340
85 192341
86 192342
87 192343
88 192344
89 192346
90 192348
91 192349
92 194060
93 194062
94 194063
95 194064
96 194065
97 194066
98 194067
99 194068
100 194069
101 

# 2 acronym soma statistic

In [8]:
acronym = 'acronym_lr' if hemi else 'acronym'
somafile = '../../fig1bstype/marker_regi_25'
maskfile = f'../../assets/n{level}_u16.nrrd' if hemi else f'../../assets/n{level}_u32.nrrd'
info = 'u16_id' if hemi else 'u32_id'
csvfile = f'../../assets/200ksomata_acronym_n{level}_hemi{hemi}.csv'
soma_acronym_df = produce_swc_acronym(maskfile,info,acronym,somafile,csvfile)

../../fig1bstype/marker_regi_25/15257/15257_refined_stps.swc 15257 6308 4910 1398
../../fig1bstype/marker_regi_25/17051/17051_total_stps.swc 17051 6613 5188 1425
../../fig1bstype/marker_regi_25/17052/17052_total_stps.swc 17052 6889 5396 1493
../../fig1bstype/marker_regi_25/17109/17109_total_stps.swc 17109 7773 6200 1573
../../fig1bstype/marker_regi_25/17298/17298_refined_stps.swc 17298 14598 12658 1940
../../fig1bstype/marker_regi_25/17300/17300_total_stps.swc 17300 16900 14604 2296
../../fig1bstype/marker_regi_25/17301/17301_total_stps.swc 17301 17903 15394 2509
../../fig1bstype/marker_regi_25/17302/17302_total_stps.swc 17302 18555 15884 2671
../../fig1bstype/marker_regi_25/17304/17304_total_stps.swc 17304 18756 16080 2676
../../fig1bstype/marker_regi_25/17539/17539_total_stps.swc 17539 19406 16680 2726
../../fig1bstype/marker_regi_25/17541/17541_total_stps.swc 17541 19527 16769 2758
../../fig1bstype/marker_regi_25/17542/17542_total_stps.swc 17542 23002 19818 3184
../../fig1bstype/mar

../../fig1bstype/marker_regi_25/201585/201585_refined_stps.swc 201585 205417 175008 30409
../../fig1bstype/marker_regi_25/201586/201586_refined_stps.swc 201586 210183 178257 31926
../../fig1bstype/marker_regi_25/201588/201588_refined_stps.swc 201588 210296 178340 31956
../../fig1bstype/marker_regi_25/201589/201589_refined_stps.swc 201589 220871 187203 33668
../../fig1bstype/marker_regi_25/201590/201590_refined_stps.swc 201590 223793 189892 33901
../../fig1bstype/marker_regi_25/201595/201595_refined_stps.swc 201595 224248 190218 34030
../../fig1bstype/marker_regi_25/201598/201598_refined_stps.swc 201598 224343 190302 34041
../../fig1bstype/marker_regi_25/201604/201604_refined_stps.swc 201604 224346 190304 34042
../../fig1bstype/marker_regi_25/201605/201605_refined_stps.swc 201605 224353 190311 34042
../../fig1bstype/marker_regi_25/201606/201606_refined_stps.swc 201606 224412 190367 34045
../../fig1bstype/marker_regi_25/211541/211541_refined_stps.swc 211541 224893 190778 34115
../../fig1

In [9]:
soma_acronym_df

Unnamed: 0,soma,acronym_lr
0,200k_15257_275_176_91,STRd_l
1,200k_15257_209_248_306,PALv_r
2,200k_15257_191_209_260,PALc_r
3,200k_15257_199_205_259,PALc_r
4,200k_15257_193_206_261,PALc_r
...,...,...
191360,200k_194075_225_131_332,STRd_r
191361,200k_194075_223_132_332,STRd_r
191362,200k_194075_212_177_51,Isocortex_l
191363,200k_194075_179_163_125,STRd_l


# 3 acronym kmeans cluster

In [7]:
#     axis='DV'
    
#     n_clusters = 10
#     color = ['red','magenta','blue','purple','cyan','green','greenyellow','limegreen','black','orange']
    
#     filt = f'in_n{level}'
#     acronym = f'n{level}_acronym_lr' if hemi else f'n{level}_acronym'
#     center = 'center_u25_lr' if hemi else 'center_u25'
#     tree = [ [v[acronym], v[center][0], v[center][1], v[center][2]] for k,v in get_tree_from('u16_id').items() if v[filt] ]
#     center_df = pd.DataFrame(tree,columns=['acronym','centerx','centery','centerz'])
    
#     a = ['centerx','centery','centerz'][['AP','DV','RL'].index(axis)]
#     center_reshape_df = center_df.sort_values(by=a,axis=0,ascending=True)    

#     region_list = center_reshape_df['acronym']
#     print(f'\nfor axis: {axis}, acronym: {region_list}, len: {len(region_list)}\n')    
    
#     center_list = center_reshape_df[a]    
#     print(f'\nfor axis: {axis}, center: {center_list}, len: {len(center_list)}\n')
    
    

    
#     X = center_list.values
    
#     X_internal = X[1:]-X[:-1]
#     print('\n','internal',X_internal,len(X_internal),'\n')
    
#     cluster = {}
#     ini_id = [np.where(X_internal==-i)[0][0] for i in np.sort(-X_internal)[0:n_clusters]]    
#     for i in range(n_clusters):
#         c = ini_id[i]
#         c_ids = [ini_id[i]]
#         c_datas = X[c_ids].tolist()
#         cluster[c] = [c_ids,c_datas]    
#     yhat = []
#     for x in X:
#         xc = 0
#         for c,d in cluster.items():
#             if x in d[1]:
#                 xc = 1
#                 yhat.append(c)
#                 break
#         if not xc: yhat.append(999)
    
#     rm_ids = []
#     [rm_ids.extend(v[0]) for k,v in cluster.items()]
#     rm_datas = []
#     [rm_datas.extend(v[1]) for k,v in cluster.items()]    
#     data = X.tolist()
#     [data.remove(d) for d in rm_datas]   
#     r = 0
#     while len(data):
    
#         if r%10==0:
#             cluster_dict = {}
#             legends = []
#             fig = plt.figure(figsize=(20,7)) 
#             plt.xlim(0,320)
#             for c in yhat:
#                 if c==999: continue
#                 if c in cluster_dict.keys(): continue
#                 legend = ''        
#                 row = np.where(np.array(yhat) == c)[0]
#                 row_region = region_list[row]
#                 cluster_dict[c] = row_region
#                 for ls in row_region:
#                     for l in ls:
#                         legend += l
#                     legend += ','
#                 legends.append(legend)
#                 print('plot',X[row], color[len(cluster_dict.keys())-1])  
#                 plt.scatter(x=X[row], y=0 * X[row], s=2, c=color[len(cluster_dict.keys())-1])  
            
#             plt.legend(legends)
#             plt.savefig(outfig_dir+f'/round{r}_cluster_{n_clusters}_{axis}.png', dpi=450)
#             plt.close()
#             print(cluster,'\n',yhat)
#             print(f'legend: {legends}\n')            
# #         try:
#     #    print(data,len(data))
#         data_dist_tomax = []
#         data_dist_tomin = []
#         for c in cluster.keys():
#             [data_dist_tomax.append(abs(d-max(cluster[c][1]))) for d in data]
#             [data_dist_tomin.append(abs(d-min(cluster[c][1]))) for d in data]
#         min_maxdist = min(data_dist_tomax)
#         min_mindist = min(data_dist_tomin)
#         rm_id_dist = data_dist_tomax.index(min_maxdist) if min_maxdist<min_mindist else data_dist_tomin.index(min_mindist)
#     #    print(len(data_dist_tomax),len(data_dist_tomin),min_maxdist,min_mindist,rm_id_dist,data_dist_tomax[rm_id_dist],data_dist_tomin[rm_id_dist])

#         rm_id = rm_id_dist%len(data)
#         rm_data = data[rm_id]
#         rm_c = int(rm_id_dist/len(data))
#         dc = list(cluster.keys())[rm_c]
#     #    print(rm_id,rm_data)
#     #    print(rm_c,dc)

#         cluster[dc][1].extend([rm_data]) 
#         x_id = np.where(X==rm_data)[0][0]
#         cluster[dc][0].extend([x_id])        
#     #    print(cluster,'\n')

#         data.remove(rm_data) 
#         r += 1
    
#         yhat = []
#         for x in X:
#             xc = 0
#             for c,d in cluster.items():
#                 if x in d[1]:
#                     xc = 1
#                     yhat.append(c)
#                     break
#             if not xc: yhat.append(999)
                

In [8]:
# x = []
# y = []
# z = []
# a = []
# for k,v in plot_tree.items():
#     x += [v['center_u25_lr'][0]]
#     y += [v['center_u25_lr'][1]]
#     z += [v['center_u25_lr'][2]]
#     a += [k]
# center_df = pd.DataFrame(np.array([x,y,z,a]).T,columns=['x','y','z','a'])

# fig = px.scatter(x=x,y=z,)
# fig.show()
# center_df

In [1]:
np.set_printoptions(suppress=True)

n_clusters = 20
for axis in ['AP','DV','RL']:
    cluster_dict = get_axis_cluster(level,hemi,axis,n_clusters,outfile_dir,outfig_dir)

NameError: name 'np' is not defined

# 4 graphviz

In [10]:
# for axis in ['AP','DV','RL']:
#     plot_tree = {}
#     filename = f'{outfile_dir}/all_modalities_brain_acronym_n{level}_hemi{hemi}_{axis}.csv'
#     brains_acronyms_count_df = pd.read_csv(filename,index_col=[0])
#     for a in brains_acronyms_count_df.columns:
#         neighbor_id = acronym_tree[a][neighbor]
#         neighbor_acronym = [id_tree[i][acronym] for i in neighbor_id]
#         plot_tree[a] = {'neighbor':neighbor_acronym,
#                         center:acronym_tree[a][center], 
#                         'count':int(brains_acronyms_count_df[a].mean()),
#                         voxel:acronym_tree[a][voxel], 
#                         'rgb':acronym_tree[a]['rgb'],
#                         'soma':len(np.nonzero((soma_acronym_df[acronym]==a).values)[0]),
#                        } 
#     print(len(plot_tree), plot_tree)

#     tree_object = json.dumps(plot_tree)
#     with open(f"../data/graph_tree_{axis}.json", "w") as f:
#         f.write(tree_object)

In [26]:
#lr=''

tree = get_tree()
acronym_tree = get_tree_from(acronym)
center = 'center_u25_lr' if hemi else 'center_u25'
voxel = 'voxel_u25_lr' if hemi else 'voxel_u25'
neighbor = f'in_n{level}_neighbor_u16_id' if hemi else f'in_n{level}_neighbor_u32_id'
id_tree = get_tree_from('u16_id') if hemi else get_tree_from('u32_id')

for axis in ['AP','DV','RL']:
    
    brains_acronyms_count_filename = f'{outfile_dir}/all_modalities_brain_acronym_n{level}_hemi{hemi}_{axis}.csv'
    brains_acronyms_count_df = pd.read_csv(brains_acronyms_count_filename,index_col=[0])
    
    cluster_filename = f'{outfile_dir}/cluster_{axis}.json'
    with open(cluster_filename, "r") as f:
        cluster_dict = json.load(f)        
     
    nodes_acronym = {}
    nodes_size = {}
    nodes_color = {}
    for c,ca in cluster_dict.items():
        #if lr=='' and c[-2:]=='_l': continue
            
        nodes_acronym[c] = []
        nodes_size[c] = []
        nodes_color[c] = []     
        for i,a in enumerate(ca):
            nodes_acronym[c].append(a)
            v = 0.04*0.04*0.04*acronym_tree[a][voxel]#um*3
            nodes_size[c].append(v)
            n = int(brains_acronyms_count_df[a].mean())
            d = n/v#num/um*3       
            nodes_color[c].append(d) 
            # 'rgb':acronym_tree[a]['rgb'],
            #'soma':len(np.nonzero((soma_acronym_df[acronym]==a).values)[0]),
    nodes_acronym_list = []
    [nodes_acronym_list.extend(v) for k,v in nodes_acronym.items()]
    nodes_size_list = []
    [nodes_size_list.extend(v) for k,v in nodes_size.items()]
    nodes_color_list = []
    [nodes_color_list.extend(v) for k,v in nodes_color.items()]     
    print(nodes_acronym)
    print(nodes_acronym_list)
    print('\n')
    print(nodes_size)
    print(nodes_size_list)
    print('\n')
    print(nodes_color)
    print(nodes_color_list)
    print('\n')  
            
    edges_acronym = []
    edges_pdist = []
    edges_color = []
    edges_cluster = []
    k = 0
    for c,ca in cluster_dict.items():
        #if lr=='' and c[-2:]=='_l': continue
            
        print(c,ca)
        if k:    
            ne = 0
            for i,a in enumerate(la): 
                neighbor_id_i = acronym_tree[a][neighbor]
                neighbor_acronym_i = [id_tree[ni][acronym] for ni in neighbor_id_i]
                for j,acronym_j in enumerate(ca):       
                    neighbor_id_j = acronym_tree[acronym_j][neighbor]
                    neighbor_acronym_j = [id_tree[ni][acronym] for ni in neighbor_id_j]
                    if acronym_j in neighbor_acronym_i or a in neighbor_acronym_j:
                            ne += 1
                            print((a,acronym_j))
                            edges_acronym.append((a, acronym_j))
                            #center:acronym_tree[a][center], 
                            d = calc_dist(acronym_tree[a][center],acronym_tree[acronym_j][center])/25#um  
                            edges_pdist.append(d)
                            rgb = [round((v+1)/256.,4) for v in acronym_tree[a]['rgb']]
                            edges_color.append(rgb)

            if ne==0: 
                print((lc,c))
                edges_cluster.append((la[-1],ca[-1],lc,c))
        
        lc = c
        la = ca
        k += 1
        print('\n')        
        
    plot = {'nodes_acronym_list':nodes_acronym_list,
            'nodes_size_list':nodes_size_list,
            'nodes_color_list':nodes_color_list,
            'cluster_dict':cluster_dict,
            'nodes_acronym':nodes_acronym,
            'nodes_size':nodes_size,
            'nodes_color':nodes_color,
            'edges_acronym':edges_acronym,
            'edges_pdist':edges_pdist,
            'edges_color':edges_color,
            'edges_cluster':edges_cluster  
           }
    plot_object = json.dumps(plot)
    with open(f"{outfile_dir}/plot_{axis}.json", "w") as f:
        f.write(plot_object)         

{'1': ['ORB_r', 'ORB_l'], '3': ['OLF_l', 'OLF_r'], '5': ['MO_r', 'MO_l'], '9': ['AI_l', 'AI_r', 'STRv_l', 'STRv_r'], '11': ['ACA_l', 'ACA_r'], '15': ['MSC_r', 'MSC_l', 'LS_l', 'LS_r'], '17': ['PALv_l', 'PALv_r'], '23': ['LSX_l', 'LSX_r', 'STRd_r', 'STRd_l', 'PALc_r', 'PALc_l'], '25': ['PALm_l', 'PALm_r'], '31': ['SSp_r', 'SSp_l', 'PVR_l', 'PVR_r', 'Isocortex_r', 'Isocortex_l'], '43': ['EP_r', 'EP_l', 'SS_l', 'SS_r', 'PALd_r', 'PALd_l', 'MTN_l', 'MTN_r', 'ATN_r', 'ATN_l', 'DORpm_r', 'DORpm_l'], '63': ['PVZ_r', 'PVZ_l', 'sAMY_l', 'sAMY_r', 'HY_r', 'HY_l', 'MEZ_r', 'MEZ_l', 'MED_l', 'MED_r', 'LZ_r', 'LZ_l', 'VENT_l', 'VENT_r', 'EPI_r', 'EPI_l', 'CTXsp_l', 'CTXsp_r', 'VP_r', 'VP_l'], '65': ['ILM_r', 'ILM_l'], '73': ['LAT_r', 'LAT_l', 'PTLp_r', 'PTLp_l', 'COA_r', 'COA_l', 'TM_l', 'TM_r'], '93': ['GENv_r', 'GENv_l', 'SPF_r', 'SPF_l', 'DORsm_r', 'DORsm_l', 'AUD_r', 'AUD_l', 'CA_l', 'CA_r', 'GENd_l', 'GENd_r', 'RSP_r', 'RSP_l', 'HIP_r', 'HIP_l', 'MBO_l', 'MBO_r', 'PRT_l', 'PRT_r'], '101': ['RA