In [None]:
from common_preamble import *

from meshparty import skeleton_io
from ais_pipeline_scripting_tools import point_orientation_from_zx_slice
from sklearn import cluster
import pycircstat as circstats
from scipy import stats
import matplotlib.lines as mlines
from collections import defaultdict

In [None]:
Rtrans = np.load(base_dir + '/data/in/pinky_rotation.npy')

In [None]:
full_soma_ids = dl.query_cell_ids('ais_analysis_soma')['pt_root_id'].values
arbor_ais_df = aggregated_ais_syn_df[np.isin(aggregated_ais_syn_df['post_pt_root_id'], full_soma_ids)]

In [None]:
er_table = 'er_points'
er_df = dl.query_cell_ids(er_table)

ids, cts = np.unique(er_df['pt_root_id'], return_counts=True)
ais_ids = ids[cts>10]   # A few points are outside the root object, but let's filter them out
er_df = er_df[np.isin(er_df['pt_root_id'], ais_ids)]

In [None]:
mm = trimesh_io.MeshMeta(disk_cache_path=mesh_dir,
                         cache_size=0, cv_path=mesh_cv_path, voxel_scaling=voxel_scaling)

Mesh contact angular distributions

In [None]:
chc_meshes = {}
for oid in chc_ids:
    chc_mesh = mm.mesh(seg_id=oid)
    chc_meshes[oid] = chc_mesh
    
ais_meshes = {}
ais_sks = {}
for oid in tqdm.tqdm(complete_ais_ids):
    ais_file = ais_mesh_dir + '/{}_ais.h5'.format(oid)
    if os.path.exists(ais_file):
        ais_mesh = mm.mesh(filename=ais_file)
        ais_meshes[oid] = ais_mesh
    else:
        ais_meshes[oid] = None
        print('{} AIS not found!'.format(oid))
    ais_sk_file = base_dir+'/data/skeletons/sk_ais_{}.h5'.format(oid)
    if os.path.exists(ais_sk_file):
        sk = skeleton_io.read_skeleton_h5(ais_sk_file)
        sk.voxel_scaling = voxel_scaling
        ais_sks[oid] = sk
    else:
        ais_sks[oid] = None
        print('{} AIS sk not found!'.format(oid))

Let's look at mesh-mesh contact

In [None]:
def mesh_mesh_proximity(target_mesh, other_mesh, max_dist=75):
    proximate_point_list = target_mesh.kdtree.query_ball_tree(other_mesh.kdtree, max_dist)
    is_proximate = np.array(list(map(len, proximate_point_list)))>0
    return is_proximate


def mesh_mesh_contact_points(target_mesh, other_mesh, max_dist=75, cluster_min_size=5, cluster_eps=500):
    '''
    :param target_mesh: Mesh for which we return contact indices and do vertex clustering
    :param other_mesh: Mesh for which we search for contacts.
    :param max_dist: Distance (in euclidean space) for which to consider contact between mesh vertices.
    :param cluster_min_size: Minimum size of a contact point cluster. Smaller outliers are ignored.
    :param cluster_eps: Distance (along mesh graph) for the max distance for two points to be located and
                        be considered within the same neighborhood for clustering. See sklearn.cluster.DBSCAN
    :returns: Array with row for each vertex in target mesh, with 0 for no contact and unique nonzero integers
              for each distinct contact.
              
    '''
    is_proximate = mesh_mesh_proximity(target_mesh, other_mesh, max_dist)
    proximate_inds = np.flatnonzero(is_proximate)
    ds_long = sparse.csgraph.dijkstra(target_mesh.csgraph, indices=proximate_inds)
    ds = ds_long[:, proximate_inds]
    
    dbscan = cluster.DBSCAN(eps=cluster_eps, metric='precomputed', min_samples=cluster_min_size)
    clust_res = dbscan.fit(ds)
    vert_labels = is_proximate.astype(int)
    vert_labels[proximate_inds] = clust_res.labels_ + 1
    return vert_labels

def er_location_df(ais_sk, ais_mesh, er_df, voxel_resolution=voxel_resolution):
    er_pts = np.vstack(er_df['pt_position'].values) * voxel_resolution
    _, er_inds = ais_mesh.kdtree.query(er_pts)
    er_orientation = point_orientation_from_zx_slice(er_inds, ais_mesh, 400)
    
    er_skinds = ais_sk.mesh_to_skel_map[ais_mesh.map_indices_to_unmasked(er_inds)]
    er_zdist = ais_sk.distance_to_root[er_skinds]
    
    er_df['orientation'] = er_orientation
    er_df['d_top'] = er_zdist
    er_df['mesh_ind'] = ais_mesh.map_indices_to_unmasked(er_inds)
    er_df['sk_ind'] = er_skinds
    return er_df

def annotation_location_indices(mesh, pos_column, anno_df, mesh_to_sk_map=None, max_dist=np.inf,
                                voxel_resolution=voxel_resolution):
    '''
    For a dataframe associated with a given neuron, find the mesh indices associated with each synapse.

    :param mesh: trimesh Mesh
    :param synapse_df: DataFrame with at least one position column
    :param pos_column: string, column of dataframe to use for annotation positions
    :param mesh_to_sk_map: Optional, Numpy array with skeleton vertex index for every mesh vertex index.
    :param max_dist: Optional, Maximum distance to the mesh allowed for assignment, else return -1.
    :param voxel_resolution: Optional, default is [4,4,40] nm/voxel.
    :returns: Mesh indices and, if desired, skeleton indices.
    '''
    if len(anno_df) == 0:
        if mesh_to_sk_map is None:
            return np.array([])
        else:
            return np.array([]), np.array([])

    anno_positions = np.vstack(anno_df[pos_column].values) * voxel_resolution
    ds, mesh_inds = mesh.kdtree.query(anno_positions)
    mesh_inds = mesh.map_indices_to_unmasked(mesh_inds)
    mesh_inds[ds>max_dist] = -1
    
    if mesh_to_sk_map is None:
        return mesh_inds
    else:
        mesh_to_sk_map = mesh_to_sk_map.astype(int)
        found_inds = mesh_inds>=0
        skinds = np.zeros(mesh_inds.shape)
        skinds[found_inds] = mesh_to_sk_map[mesh_inds[found_inds]]
        skinds[~found_inds] = -1
        return mesh_inds, skinds

def distance_to_co_pts(co_er_df_over, nrn_syn_df, median_r):
    co_xy = co_er_df_over[['d_orientation', 'd_top']].values

    nrn_syn_df['d_orientation'] = median_r * nrn_syn_df['orientation']
    syn_xy = nrn_syn_df[['d_orientation', 'd_top_skel']].values

    syn_co_dist = spatial.distance.cdist(syn_xy, co_xy)
    closest_ind = np.argmin(syn_co_dist, axis=1)
    
    syn_xy_rad = nrn_syn_df[['orientation', 'd_top_skel']].values
    co_xy_rad = co_er_df_over[['orientation', 'd_top']].values
    dist_rad = co_xy_rad[closest_ind] - syn_xy_rad
    delta_orientation = np.mod(dist_rad[:,0] + np.pi, 2*np.pi) - np.pi
    delta_z = dist_rad[:,1]
    return delta_orientation, delta_z

In [None]:
from tqdm import tqdm

In [None]:
is_chc = True
all_contact_orientations = {}
all_vert_labels = {}
for nrn_oid in tqdm(ais_ids):
    ais_mesh = ais_meshes[nrn_oid]

    nrn_chc_syn_df = ais_synapse_data[(ais_synapse_data['post_pt_root_id']==nrn_oid) & (ais_synapse_data['is_chandelier'] == is_chc)]
    nrn_chc_ids = np.unique(nrn_chc_syn_df['pre_pt_root_id'])

    if len(nrn_chc_syn_df) == 0:
        continue
    
    try:
        vert_labels = np.zeros(ais_mesh.n_vertices, dtype=int)
        for nrn_chc_id in nrn_chc_ids:
            if is_chc is True:
                chc_mesh = chc_meshes[nrn_chc_id]
            else:
                chc_mesh = mm.mesh(seg_id=nrn_chc_id)
            contact_pts = mesh_mesh_contact_points(ais_mesh, chc_mesh, max_dist=150, cluster_eps=400)
            contact_pts[contact_pts>0] =  contact_pts[contact_pts>0] + np.max(vert_labels)
            vert_labels = np.max(np.vstack((vert_labels, contact_pts)), axis=0)

        presyn_locs = np.vstack(nrn_chc_syn_df['ctr_pt_position'].values) * voxel_resolution
        _, syn_mesh_inds = ais_mesh.kdtree.query(presyn_locs)
        seen_contacts = np.unique(vert_labels[syn_mesh_inds])
        if 0 in seen_contacts:
            print(f'A real synapse is not in a contact cluster! {nrn_oid}')
        vert_labels[~np.isin(vert_labels, seen_contacts)] = 0
        
        all_vert_labels[nrn_oid] = vert_labels
        is_contact = vert_labels>0
        all_contact_orientations[nrn_oid] = point_orientation_from_zx_slice(np.flatnonzero(is_contact), ais_mesh, 300, rotation_matrix=Rtrans)
    except:
        print('Failed on {}!'.format(nrn_oid))
        continue

In [None]:
from matplotlib import cm 

In [None]:
def paintball_color_mesh(categories, base_color=(0.5, 0.5, 0.5), h=0.01, l=0.6, s=0.65, permute_colors=True):
    unique_categories = np.unique(categories)
    n_colors = len(unique_categories)-1
    
    color_lookup = np.full(np.max(unique_categories)+1, 0)
    color_lookup[unique_categories] = np.argsort(unique_categories)
    
    color_colors = sns.hls_palette(n_colors, h, l, s)
    if permute_colors:
        color_colors = np.random.permutation(color_colors)
    colors = np.vstack((base_color, color_colors))
    return colors[color_lookup[categories]]


In [None]:
ais_oid = ais_ids[6] 

ma = trimesh_vtk.mesh_actor(ais_meshes[ais_oid], opacity=1, vertex_colors=paintball_color_mesh(all_vert_labels[ais_oid] ))

nrn_chc_syn_df = ais_synapse_data[(ais_synapse_data['post_pt_root_id']==ais_oid) & (ais_synapse_data['is_chandelier'] == is_chc)]
chc_syn_pts = np.vstack( nrn_chc_syn_df['ctr_pt_position'].values ) * voxel_resolution
sa = trimesh_vtk.point_cloud_actor(chc_syn_pts, size=100, color=(0,0,0), opacity=1)

camera = trimesh_vtk.oriented_camera(ais_meshes[ais_oid].centroid, backoff=150)

trimesh_vtk.render_actors([ma, sa], camera=camera)

### Generate a dataframe where each row is a contact point that has columns:
* ~~AIS root_id~~
* ~~contact_id (unique to neuron)~~
* ~~mesh index~~
* ~~mesh vertex position~~
* ~~synapse_id if there is a synapse onto that contact~~
* ~~contact point "centroid"~~

In [None]:
contact_dfs = []
for oid in tqdm(ais_ids):
    try:
        mesh = ais_meshes[oid]
        sk = ais_sks[oid]

        contact_pt_mesh_ind = np.flatnonzero([all_vert_labels[oid]>0])
        contact_label = all_vert_labels[oid][all_vert_labels[oid]>0]
        contact_pt_orientation = all_contact_orientations[oid]
        root_id = np.full(len(contact_pt_mesh_ind), oid)
        contact_pt_position = mesh.vertices[contact_pt_mesh_ind]
        contact_pt_z_loc = sk.distance_to_root[sk.mesh_to_skel_map[mesh.map_indices_to_unmasked(contact_pt_mesh_ind)]]

        ds_contact = sparse.csgraph.dijkstra(mesh.csgraph, indices=contact_pt_mesh_ind)
        ds_contact = ds_contact[:, contact_pt_mesh_ind]

        contact_label_values = np.unique(contact_label)
        contact_center =  []
        for label_ind in contact_label_values:
            is_label = contact_label==label_ind
            lowest_mean_dist_local = np.argmin(np.mean(ds_contact[:,is_label][is_label], axis=0))
            contact_center.append(contact_pt_mesh_ind[np.flatnonzero(is_label)[lowest_mean_dist_local]])
        contact_center=np.array(contact_center)

        contact_df = pd.DataFrame(data={'root_id': root_id,
                                       'contact_label': contact_label,
                                       'contact_pt_mesh_ind': contact_pt_mesh_ind,
                                       'contact_pt_orientation': contact_pt_orientation,
                                       'contact_pt_z_loc': contact_pt_z_loc,
                                       'contact_pt_position': [x for x in contact_pt_position]})

        center_df = pd.DataFrame(data={'contact_label': contact_label_values,
                                       'contact_center_mesh_ind': contact_center,
                                       })

        contact_df = contact_df.merge(center_df, on='contact_label', how='left', validate='many_to_many')

        nrn_chc_syn_df = ais_synapse_data[(ais_synapse_data['post_pt_root_id']==oid) & (ais_synapse_data['is_chandelier'] == True)]
        chc_syn_inds = mesh.filter_unmasked_indices(annotation_location_indices(mesh, 'ctr_pt_position', nrn_chc_syn_df))
        ds = sparse.csgraph.dijkstra(mesh.csgraph, indices=chc_syn_inds)
        nrn_chc_syn_df['contact_label'] = contact_label[np.argmin( ds[:, contact_pt_mesh_ind], axis=1)]

        contact_df = contact_df.merge(nrn_chc_syn_df[['id', 'pre_pt_root_id', 'contact_label']], on='contact_label', how='right', validate='many_to_many')
        contact_df = contact_df.rename(columns={'id': 'synapse_id'})
        contact_dfs.append(contact_df)
    except:
        continue


In [None]:
all_contact_df = pd.concat(contact_dfs)
all_contact_df['is_center'] = all_contact_df['contact_pt_mesh_ind']==all_contact_df['contact_center_mesh_ind']

all_contact_center_df = all_contact_df[all_contact_df['is_center']==True]

### 2. Build the ER cluster dataframe

In [None]:
nrn_syn_dfs = {}
co_er_dfs = {}
dist_to_cos = []
dist_to_cos_non = []
median_rs = {}
for ais_id in tqdm(ais_ids):
    dbscan = cluster.DBSCAN(eps=500, metric='precomputed', min_samples=5)

    ais_mesh = ais_meshes[ais_id]

    nrn_er_df = er_location_df(ais_sks[ais_id], ais_meshes[ais_id], er_df[er_df['pt_root_id']==ais_id])
    co_er_df = nrn_er_df[nrn_er_df['func_id']==1]

    co_mesh_inds = ais_mesh.filter_unmasked_indices(co_er_df['mesh_ind'])
    ds = sparse.csgraph.dijkstra(ais_mesh.csgraph, indices=co_mesh_inds)
    dmat = ds[:,co_mesh_inds]

    clust_res = dbscan.fit(dmat)
    co_er_df['clust_label'] = clust_res.labels_

#     co_er_df['clust_label'] = 1
    
    median_r = np.median(ais_sks[ais_id].vertex_properties['rs']) / 2
    median_rs[ais_id] = median_r
    co_er_df['d_orientation'] = median_r * co_er_df['orientation']

    left_co_er_df = co_er_df.copy()
    right_co_er_df = co_er_df.copy()

    left_co_er_df['orientation'] = co_er_df['orientation'] - 2*np.pi
    left_co_er_df['d_orientation'] = median_r * left_co_er_df['orientation']

    right_co_er_df['orientation'] = co_er_df['orientation'] + 2*np.pi
    right_co_er_df['d_orientation'] = median_r * right_co_er_df['orientation']

    nrn_syn_df = ais_synapse_data[ais_synapse_data['post_pt_root_id']==ais_id]
    nrn_syn_dfs[ais_id] = nrn_syn_df
    
    co_er_df_over = pd.concat((co_er_df, left_co_er_df, right_co_er_df))
    co_er_dfs[ais_id] = co_er_df_over
    dist_to_cos.append(distance_to_co_pts(co_er_df_over, nrn_syn_df[nrn_syn_df['is_chandelier']==True], median_r))
    dist_to_cos_non.append(distance_to_co_pts(co_er_df_over, nrn_syn_df[nrn_syn_df['is_chandelier']==False], median_r))

co_er_all_df = pd.concat(co_er_dfs.values())

### Integrate COs and Synapses

1st) Plot contacts and CO Clusters

2nd) Angle from Syn Contact->CO Cluster

In [None]:
ais_syn_ct_df = aggregated_ais_syn_df[np.isin(aggregated_ais_syn_df['post_pt_root_id'], ais_ids)].sort_values(by='syn_net_chc')[['post_pt_root_id', 'syn_net_chc']]
ais_ids_sorted = ais_syn_ct_df['post_pt_root_id'].values

In [None]:
co_color = (0.012, 0.843, 1)

In [None]:
with sns.axes_style('whitegrid') as s:
    fig, axes = plt.subplots(figsize=(13, 5), ncols=len(ais_ids))
    for ii, ais_id in enumerate(ais_ids_sorted):
        ax = axes[ii]
        co_df = co_er_all_df[(co_er_all_df['pt_root_id']==ais_id) & (co_er_all_df['clust_label']>=0)]
        contact_center_df = all_contact_df[np.isin(all_contact_df['root_id'], [ais_id]) & (all_contact_df['is_center'])]

        median_r = np.median(ais_sks[ais_id].vertex_properties['rs']) / 2
        contact_center_df['d_orientation'] = median_r * contact_center_df['contact_pt_orientation']
        co_df['d_orientation'] = median_r * co_df['orientation']

    #     fig, ax = plt.subplots(figsize=(3,7))
        ax.plot([np.pi, np.pi], [0, 100000], linestyle='--', color=non_color, zorder=0)
        ax.plot([-np.pi, -np.pi], [0, 100000], linestyle='--', color=non_color, zorder=0)

        sns.scatterplot(x='orientation', y='d_top', s=50, data=co_df[co_df['func_id']==1], color=co_color, alpha=0.3, edgecolor=None, ax=ax)
        sns.scatterplot(x='contact_pt_orientation', y='contact_pt_z_loc', s=80, alpha=0.7, data=contact_center_df, color=chc_color, marker='o', edgecolor='w', ax=ax)

        if ii>0:
            ax.set_ylabel('')
            ax.set_yticklabels([])
            ax.set_xticklabels([])
            ax.set_xticks([-1 * np.pi, 0, 1 * np.pi])
        else:
            ax.set_yticklabels([0, 10, 20, 30, 40, 50], fontdict={'size':12})
            ax.set_ylabel('Depth on AIS ($\mu$m)', fontdict={'size':14})
            ax.set_xticks([-1 * np.pi, 0, 1 * np.pi])
            ax.set_xticklabels(['-$\pi$', '0', '$\pi$'], fontdict={'size':14})
            ax.set_xlabel('Orientation', fontdict={'size':14})
        ax.set_xlim([-1.25 * np.pi, 1.25 * np.pi])
        ax.set_ylim((0,50000))
        ax.set_xlabel('')
        ax.invert_yaxis()

    co_dot = mlines.Line2D([], [], color=co_color, linestyle='none', marker='o', markersize=5)
    syn_dot = mlines.Line2D([], [], color=chc_color, linestyle='none', marker='o', markersize=5)
    axes[0].legend([co_dot, syn_dot], ['CO', 'ChC Syn.'], bbox_to_anchor=(.5,1.))
    fig.savefig(plot_dir + f'/CO_FIG_co_synapse_locations_{str(is_chc).lower()}.pdf')

In [None]:
bins = np.arange(0,51, 4)
fig, ax = plt.subplots(figsize=(2,2))
ax.hist(co_er_all_df['d_top']/1000, bins=bins, cumulative=False, density=True, color=co_color, linewidth=1)
ax.set_xlabel('Depth on AIS ($\mu$m)')
ax.set_ylabel('CO Points (norm.)')
ax.set_xticks(np.arange(0, 51, 10))
fig.savefig(plot_dir + f'/CO_FIG_co_depth_{str(is_chc).lower()}')

In [None]:
fig, ax = plt.subplots(figsize=(2,2))
bins = np.arange(-np.pi, np.pi+0.1, np.pi/6)
plt.hist(co_er_all_df['orientation'], bins=bins, histtype='step', density=True, linewidth=3, color=co_color)
plt.hist(all_contact_df['contact_pt_orientation'], bins=bins, alpha=0.7, histtype='step', density=True, linewidth=3, color=chc_color)

ax.set_ylabel('Points (norm.)')
ax.set_xlabel('Orientation')
ax.set_xticks([-1 * np.pi, 0, np.pi])
ax.set_xticklabels(['-$\pi$', 0, '$\pi$'])

In [None]:
def min_angle(a0, b):
    b = np.array(b)
    a0ip = np.exp(a0*1j)
    bip = np.exp(b*1j)
    return np.argmin(np.abs(bip-a0ip))

def map_syn_to_new_mesh(zs, orientations, mesh, sk, z_diff=300):
    '''
    Given a distance from top and orientation, find equivalent points in an AIS specified by its mesh and skeleton.
    '''
    mapped_inds = []
    new_point_values = []
    potentials = []
    for z, orientation in zip(zs, orientations):
        sk_inds = np.flatnonzero(np.abs(sk.distance_to_root - z) < z_diff)
        mesh_inds = np.flatnonzero(np.isin(sk.mesh_to_skel_map, sk_inds))
        ais_mesh_inds = mesh.filter_unmasked_indices(mesh_inds)

        pot_orientations = point_orientation_from_zx_slice(ais_mesh_inds, mesh, 300, rotation_matrix = Rtrans)
        best_orientation_ind = min_angle(orientation, pot_orientations)
        potentials.append(np.array(pot_orientations))
        mapped_inds.append(ais_mesh_inds[best_orientation_ind])
        
        new_z = sk.distance_to_root[ sk.mesh_to_skel_map[mesh_inds[best_orientation_ind]]]
        new_orientation = pot_orientations[best_orientation_ind]
        new_point_values.append((new_z, new_orientation))
    return np.array(mapped_inds), np.array(new_point_values)

In [None]:
max_z = 30000 
remapped_contact_inds = defaultdict(dict)

for mesh_oid in tqdm(ais_ids):
    mesh = ais_meshes[mesh_oid]
    sk = ais_sks[mesh_oid]
    for syn_oid in tqdm(ais_ids):
        syn_df = all_contact_center_df[all_contact_center_df['root_id']==syn_oid]
        syn_df = syn_df[syn_df['contact_pt_z_loc'] < max_z]
        if syn_oid == mesh_oid:            
            remapped_contact_inds[mesh_oid][syn_oid] = syn_df['contact_center_mesh_ind']
        else:
            zs = syn_df['contact_pt_z_loc'].values
            orientations = syn_df['contact_pt_orientation'].values
            remapped_contact_inds[mesh_oid][syn_oid], _ = map_syn_to_new_mesh(zs, orientations, mesh, sk) 

In [None]:
dist_matched = []
dist_unmatched = []
for syn_oid in tqdm(ais_ids):
    for mesh_oid in ais_ids:
        syn_minds = remapped_contact_inds[mesh_oid][syn_oid]
        co_df = co_er_all_df[(co_er_all_df['pt_root_id']==mesh_oid) & (co_er_all_df['clust_label']>0)]
        co_minds = ais_meshes[mesh_oid].filter_unmasked_indices(co_df['mesh_ind'].values)
        ds = sparse.csgraph.dijkstra(ais_meshes[mesh_oid].csgraph, indices=syn_minds)
        min_dist_to_co = np.min(ds[:,co_minds], axis=1)
        if mesh_oid==syn_oid:
            dist_matched.append(min_dist_to_co)
        else:
            dist_unmatched.append(min_dist_to_co)

In [None]:
dist_data = np.concatenate([np.concatenate(dist_matched)/1000, np.concatenate(dist_unmatched)/1000])
is_matched = np.full(len(dist_data), 'Unmatched')
is_matched[0:len(np.concatenate(dist_matched))] = 'Matched'
match_data = pd.DataFrame(data={'Min. Distance ($\mu m$)':dist_data, 'CO/Synapse Locations':is_matched})

In [None]:
ttest_result = stats.mannwhitneyu(match_data[match_data['CO/Synapse Locations']=='Matched']['Min. Distance ($\mu m$)'],
                match_data[match_data['CO/Synapse Locations']!='Matched']['Min. Distance ($\mu m$)'])
stat_df = pd.DataFrame(data={'Min_Distance_pvalue':[ttest_result.pvalue]})
stat_df.to_csv(plot_dir + f'/CO_FIG_min_distance_distribution_stat_{str(is_chc).lower()}.csv')

In [None]:
stat_df

In [None]:
fig, ax = plt.subplots(figsize=(3,4))
# sns.boxplot(data=match_data, x='CO/Synapse Locations', y='Min. Distance ($\mu m$)', width=0.4, palette=(co_color, (0.5, 0.5, 0.5)), )
sns.violinplot(data=match_data, x='CO/Synapse Locations', y='Min. Distance ($\mu m$)', bw=0.25, linewidth=1.5, width=0.95, ax=ax, palette=(co_color, (0.5, 0.5, 0.5)), inner='quartile', cut=0)
sns.despine(offset=5, ax=ax)
ax.set_ylim(-0.1, 8)
ax.xaxis.label.set_fontsize(12)
ax.yaxis.label.set_fontsize(12)
ax.set_ylabel('Mean Syn/CO Dist. ($\mu m$)', fontdict={'size':14})
ax.text(0.5, 7, int(assign_stars(np.array([ttest_result.pvalue]), stars)[0]) * '*', fontdict={'horizontalalignment':'center', 'size':14}) 
fig.savefig(plot_dir + f'/CO_FIG_min_distance_distribution_violin_{str(is_chc).lower()}.pdf', bbox_inches='tight')

In [None]:
bins = np.arange(0,8,0.25)

fig, ax = plt.subplots(figsize=(3,3))
ax.hist(np.concatenate(dist_unmatched)/1000, density=True, bins=bins, histtype='step', linewidth=3, color='k', alpha=1)
ax.hist(np.concatenate(dist_matched)/1000, density=True, bins=bins, histtype='step', linewidth=3, color=co_color, alpha=0.8)
ax.set_xlabel('Synapse to CO Distance ($\mu$m)', fontdict={'size':12})
ax.set_xlim((-0.25, 8))
ax.set_ylim((-0.05, 2.2))
ax.set_ylabel('Density')
ax.legend(('Shuffled CO/Synapse', 'Observed CO/Synapse'), bbox_to_anchor=(1.1,0.91))
sns.despine(ax=ax, offset=5)
fig.savefig(plot_dir + f'/CO_FIG_min_distance_distribution_hist_{str(is_chc).lower()}.pdf')

Same AIS, rotated orientations

In [None]:
delta_orientations = np.arange(-np.pi, np.pi+0.0001, np.pi/8)
z_thresh = 30000

dist_rotated = []
for ii, ais_id in tqdm(enumerate(ais_ids), total=len(ais_ids)):
    dist_rotated.append([])
    for delta_orientation in tqdm(delta_orientations):
        contact_center_df = all_contact_df[np.isin(all_contact_df['root_id'], [ais_id]) & (all_contact_df['is_center'])]
        zs = contact_center_df['contact_pt_z_loc']
        good_zs = zs<z_thresh
        zs = zs[good_zs]
        
        new_orientation = contact_center_df['contact_pt_orientation'] + delta_orientation
        new_orientation = new_orientation[good_zs]
        new_mesh_inds, new_values = map_syn_to_new_mesh(zs, new_orientation, ais_meshes[ais_id], ais_sks[ais_id])

        co_df = co_er_all_df[(co_er_all_df['pt_root_id']==ais_id) & (co_er_all_df['clust_label']>=0)]
        co_minds = ais_meshes[ais_id].filter_unmasked_indices(co_df['mesh_ind'].values)
        ds = sparse.csgraph.dijkstra(ais_meshes[ais_id].csgraph, indices=new_mesh_inds)
        min_dist_to_co = np.min(ds[:,co_minds], axis=1)
        dist_rotated[ii].append(min_dist_to_co)

In [None]:
def list_of_lists_to_categorical(data, categories=None):
    if categories is None:
        categories = np.arange(len(data))
    
    data_long = np.concatenate(data)
    categories_long = np.concatenate([np.full(len(d), i) for d, i in zip(data, categories)])
    return data_long, categories_long

In [None]:
all_dists = []
for ii in range(len(delta_orientations)):
    dist, _ = list_of_lists_to_categorical([x[ii] for x in dist_rotated])
    all_dists.append(dist)

In [None]:
d_dist, do_list = list_of_lists_to_categorical(all_dists, delta_orientations)

In [None]:
do_df = pd.DataFrame(data={'min_dist':d_dist/1000, 'delta_orientation':do_list})

In [None]:
sns.set_style("ticks", {"xtick.major.size": 8, "ytick.major.size": 8})
fig, ax = plt.subplots(figsize=(3,3))
dist_matched_long = np.concatenate(dist_matched)/1000
dist_unmatched_long = np.concatenate(dist_unmatched)/1000

base_data = do_df[do_df['delta_orientation']==delta_orientations[8]]

all_test_res = []
for del_or in delta_orientations:
    data = do_df[do_df['delta_orientation']==del_or]
    all_test_res.append(stats.mannwhitneyu(data['min_dist'], base_data['min_dist']).pvalue)

data_stars = assign_stars(sm.stats.multipletests(all_test_res)[1], stars)

#sns.lineplot(x='delta_orientation', y='min_dist', data=do_df.query('delta_orientation<3.15'), ci=90, color=co_color, ax=ax, marker='s')
sns.lineplot(x='delta_orientation', y='min_dist', data=do_df, ci=90, color=co_color, ax=ax, marker='s')
# ax.set_ylim((0.35,1.3))

plot_stars(delta_orientations, do_df.groupby('delta_orientation').mean()['min_dist'].values,
           data_stars, ax, horizontalalignment='center', xytext=(0,45))
ax.set_xticks([-np.pi, -np.pi/2, 0, np.pi/2, np.pi])
ax.set_xticklabels(['-$\pi$', '$\pi$/2', '0', '$\pi$/2', '$\pi$'], fontdict={'size':12})
ax.set_xlabel('$\Delta$ Orientation', fontdict={'size':14})
ax.set_ylabel('Mean Syn/CO Dist. ($\mu m$)', fontdict={'size':14})
ax.grid(False)
sns.despine(offset=2, ax=ax, trim=True)
fig.savefig(plot_dir + f'/CO_synapse_rotation_{str(is_chc).lower()}.pdf', bbox_inches='tight')

In [None]:
rotate_stats_df = pd.DataFrame(data={'delta_orientation': delta_orientations,
                                     'mann-whitney_p': all_test_res,
                                     'mann-whitney_p_multitest': sm.stats.multipletests(all_test_res)[1]})
rotate_stats_df.to_csv(plot_dir + f'/CO_rotation_tests_{str(is_chc).lower()}.csv')
del rotate_stats_df

---

In [None]:
fig, ax = plt.subplots(figsize=(4,4))
dist_matched_long = np.concatenate(dist_matched)/1000
dist_unmatched_long = np.concatenate(dist_unmatched)/1000

sns.lineplot(x='delta_orientation', y='min_dist', data=do_df.query('delta_orientation<3.15'), ci=90, color=co_color, ax=ax, marker='s')
ax.errorbar([-0.5], np.mean(dist_matched_long), np.abs(np.mean(dist_matched_long)-np.percentile(sns.algorithms.bootstrap(dist_matched_long), [5,95])).reshape(2,1), marker='s')
ax.errorbar([np.pi + 0.5], np.mean(dist_unmatched_long), np.abs(np.mean(dist_unmatched_long)-np.percentile(sns.algorithms.bootstrap(dist_unmatched_long), [5,95])).reshape(2,1), marker='s')
ax.set_ylim((0.2,1.4))
sns.despine(offset=5, ax=ax)

ax.set_xticks([-0.5, 0, np.pi/2, np.pi, np.pi+0.5])
ax.set_xticklabels(['Observed', '0', '$\pi$/2', '$\pi$', 'Shuffled'], rotation=45, fontdict={'size':12})
ax.set_xlabel('Synapse/CO Alignment', fontdict={'size':14})
ax.
ax.grid(axis='x')

In [None]:
test_res = []
for delta_orientation in delta_orientations:
    test_dist = do_df[do_df['delta_orientation']==delta_orientation]['min_dist']
    test_res.append(stats.ttest_ind(dist_matched_long, test_dist))

test_res.append(stats.ttest_ind(dist_matched_long, dist_unmatched_long))

In [None]:
import statsmodels.api as sm

In [None]:
sm.stats.multipletests([t.pvalue for t in test_res])

In [None]:
all_stats = []
for ii, row in do_df.groupby('delta_orientation').agg(list).iterrows():
    all_stats.append(stats.mannwhitneyu(np.concatenate(dist_matched), row['min_dist']))

In [None]:
all_stats

In [None]:
stats.mannwhitneyu(np.concatenate(dist_matched), d_dist[])

In [None]:
plt.plot(do_df.groupby('delta_orientation').mean(), 's--')

In [None]:
fig, ax = plt.subplots(figsize=(6,3))
sns.boxplot(x='delta_orientation', y='min_dist', data=do_df, ax=ax)
ax.set_ylim(0,5000)

In [None]:
dist_rotated[0][8]

In [None]:
fig, ax = plt.plot()
ax.plot(contact_center_df['contact_pt_orientation'], new_values[:,1], 's')

#### 3d Plot of ER points

In [None]:
for mesh_oid in ais_ids:
    co_nrn_df = co_er_all_df[co_er_all_df['pt_root_id']==mesh_oid]
    ma = trimesh_vtk.mesh_actor(ais_meshes[mesh_oid], color=(0.6, 0.6, 0.6), opacity=0.4)
    pa = trimesh_vtk.point_cloud_actor(np.vstack(co_nrn_df['pt_position'])*voxel_resolution, size=250, color=co_color, opacity=1)

    nrn_contact_center_df = all_contact_center_df[all_contact_center_df['root_id']==mesh_oid]
    chc_contact_pts = np.vstack(nrn_contact_center_df['contact_pt_position'])
    ca = trimesh_vtk.point_cloud_actor(chc_contact_pts, size=300, color=chc_color, opacity=0.8)
    
    camera = trimesh_vtk.oriented_camera(ais_meshes[mesh_oid].centroid, backoff=150)
    trimesh_vtk.render_actors([ma, pa, ca], camera=camera, do_save=True, filename=plot_dir+'/CO_FIG_co_ais_{}.png'.format(mesh_oid))
    
    top_pt_ind = np.argmin(ais_meshes[mesh_oid].vertices[:,1])
    top_pt = ais_meshes[mesh_oid].vertices[top_pt_ind]
    delta_y = np.array([0, -17000, 0])

    camera = trimesh_vtk.oriented_camera(top_pt - delta_y, backoff=65)
    trimesh_vtk.render_actors([ma, pa, ca], camera=camera, do_save=True, filename=plot_dir+'/CO_FIG_co_ais_zoom_{}.png'.format(mesh_oid))

#### Mapping Synapses across AISes

In [None]:
pa = trimesh_vtk.point_cloud_actor(ais_meshes[mesh_id].vertices[mapped_inds], size=200, color=(1, 0, 0), opacity=1)

ma = trimesh_vtk.mesh_actor(ais_meshes[mesh_id], color=(0.4, 0.4, 0.4), opacity=0.7)
pa2 = trimesh_vtk.point_cloud_actor(ais_meshes[oid].vertices[orig_inds], size=200, color=(0, 0, 0.7), opacity=1)
ma2 = trimesh_vtk.mesh_actor(ais_meshes[oid] ,color=(0.6, 0.6, 0.6), opacity=0.7)
lpa = trimesh_vtk.linked_point_actor(ais_meshes[oid].vertices[orig_inds], ais_meshes[mesh_id].vertices[mapped_inds])
trimesh_vtk.render_actors([ma, ma2, pa, pa2, lpa])

In [None]:
min_dist_to_cos = []

for oid in ais_ids:
    ais_mesh = ais_meshes[oid]
    co_df = co_er_all_df[(co_er_all_df['pt_root_id']==oid) & (co_er_all_df['clust_label']>0)]
    contact_center_df = all_contact_df[np.isin(all_contact_df['root_id'], [oid]) & (all_contact_df['is_center'])]

    syn_minds = contact_center_df['contact_center_mesh_ind']
    co_minds = ais_mesh.filter_unmasked_indices(co_df['mesh_ind'])

    ds = sparse.csgraph.dijkstra(ais_mesh.csgraph, indices=syn_minds)
    min_dist_to_co = np.min(ds[:, co_minds], axis=1)
    min_dist_to_cos.append(min_dist_to_co)

In [None]:
bins = np.arange(-np.pi, np.pi+0.1, np.pi/4)
plt.hist(co_er_all_df['orientation'], bins=bins, histtype='step', density=True, linewidth=2)
plt.hist(all_contact_df['contact_pt_orientation'], bins=bins, histtype='step', density=True, linewidth)

In [None]:
del_orient = np.concatenate([v[0] for v in dist_to_cos])
del_z = np.concatenate([v[1] for v in dist_to_cos])

del_orient_non = np.concatenate([v[0] for v in dist_to_cos_non])

del_z_non = np.concatenate([v[1] for v in dist_to_cos_non])

In [None]:
from itertools import product

ds_mismatched = []
ds_mismatched_non = []

for ais_id_a, ais_id_b in product(ais_ids, ais_ids):
    if ais_id_a == ais_id_b:
        continue
    else:
        ns_df = nrn_syn_dfs[ais_id_b]
        ds_mismatched.append(distance_to_co_pts(co_er_dfs[ais_id_a], ns_df[ns_df['is_chandelier']==True], median_rs[ais_id_a]))
        ds_mismatched_non.append(distance_to_co_pts(co_er_dfs[ais_id_a], ns_df[ns_df['is_chandelier']==False], median_rs[ais_id_a]))

In [None]:
del_orient_mismatched = np.concatenate([v[0] for v in ds_mismatched])
del_z_mismatched = np.concatenate([v[1] for v in ds_mismatched])

In [None]:
fig, ax = plt.subplots(figsize=(3,3))
bins = np.arange(-np.pi, np.pi+0.001, np.pi/10)
ax.hist(del_orient, edgecolor=chc_color, color=chc_color, density=True, bins=bins, )
ax.hist(del_orient_mismatched, color='k', linewidth=3, histtype='step', density=True, bins=bins)

# plt.hist(del_orient_non, histtype='step', linewidth=4, color=non_color, density=True, bins=bins)

In [None]:
fig, ax = plt.subplots(figsize=(3,3))
bins = np.arange(-5250, 5251, 500)
ax.hist(del_z, edgecolor=chc_color, color=chc_color, density=True, bins=bins, )
ax.hist(del_z_mismatched, color='k', linewidth=3, histtype='step', density=True, bins=bins)

# plt.hist(del_orient_non, histtype='step', linewidth=4, color=non_color, density=True, bins=bins)

In [None]:
jitterbar.jitterbar((np.abs(del_orient), np.abs(del_orient_mismatched)), (chc_color, non_color), scatter_kwargs={'alpha':0.2}, width=0.15, mode='median')

In [None]:
import pycircstat

In [None]:
pycircstat.cmtest(del_orient, del_orient_mismatched)

In [None]:
mismatch_synapses = np.histogram(np.concatenate(ds_mismatched), bins=bins)
match_synapses = np.histogram(np.concatenate(dist_to_cos), bins=bins)
mismatch_non = np.histogram(np.concatenate(ds_mismatched_non), bins=bins)
match_non = np.histogram(np.concatenate(dist_to_cos_non), bins=bins)

In [None]:
fig, ax = plt.subplots(figsize=(4,4))
ax.plot(bins[1:]/1000, np.cumsum(match_synapses[0])/sum(match_synapses[0]), color=chc_color, linewidth=2)
ax.plot(bins[1:]/1000, np.cumsum(mismatch_synapses[0])/sum(mismatch_synapses[0]), color=chc_color, linewidth=2, linestyle=':')

ax.plot(bins[1:]/1000, np.cumsum(match_non[0])/sum(match_non[0]), color=non_color, linewidth=2)
ax.plot(bins[1:]/1000, np.cumsum(mismatch_non[0])/sum(mismatch_non[0]), color=non_color, linewidth=2, linestyle=':')

ax.set_xlim(np.array([0, 5000])/1000)

sns.despine(offset=5, ax=ax)
ax.set_xlabel('Distance from CO point')
ax.set_ylabel('CDF')

In [None]:
stats.mannwhitneyu(np.concatenate(dist_to_cos), np.concatenate(ds_mismatched), alternative='two-sided')

In [None]:
stats.mannwhitneyu(np.concatenate(dist_to_cos_non), np.concatenate(dist_to_cos), alternative='two-sided')

In [None]:
fig, axes = plt.subplots(figsize=(30,7), ncols=len(ais_ids))
median_rs = []
for ii, ais_id in enumerate(ais_ids):
    ax = axes[ii]
    nrn_er_df = er_location_df(ais_sks[ais_id], ais_meshes[ais_id], er_df[er_df['pt_root_id']==ais_id])
    nrn_syn_df = ais_synapse_data[ais_synapse_data['post_pt_root_id']==ais_id]
    median_r = np.median(ais_sks[ais_id].vertex_properties['rs']) / 2
    median_rs.append(median_r)
    nrn_er_df['d_orientation'] = np.pi * median_r * nrn_er_df['orientation']
    nrn_syn_df['d_orientation'] = np.pi * median_r * nrn_syn_df['orientation']

#     fig, ax = plt.subplots(figsize=(3,7))
    sns.scatterplot(x='d_orientation', y='d_top', color=(0.3, 0.2, 0.3), s=15, data=nrn_er_df[nrn_er_df['func_id']==0], ax=ax, legend=False, edgecolor=None, alpha=0.3)
    sns.scatterplot(x='d_orientation', y='d_top', color=(0.7, 0.2, 0), s=15, data=nrn_er_df[nrn_er_df['func_id']==1], ax=ax, legend=False, edgecolor=None, alpha=0.3)
    # sns.scatterplot(x='d_orientation', y='d_top_skel', hue='is_chandelier', data=nrn_syn_df, s=100, alpha=0.8, palette='dark', ax=ax, legend=False)
    ax.set_ylim((0,50000))
    ax.invert_yaxis()
    ax.set_aspect('equal')

In [None]:
ais_id = ais_ids[5]

nrn_er_df = er_location_df(ais_sks[ais_id], ais_meshes[ais_id], er_df[er_df['pt_root_id']==ais_id])
nrn_syn_df = ais_synapse_data[ais_synapse_data['post_pt_root_id']==ais_id]
median_r = np.median(ais_sks[ais_id].vertex_properties['rs']) / 2

nrn_er_df['d_orientation'] = np.pi * median_r * nrn_er_df['orientation']
nrn_syn_df['d_orientation'] = np.pi * median_r * nrn_syn_df['orientation']

fig, ax = plt.subplots(figsize=(3,7))
sns.scatterplot(x='d_orientation', y='d_top', hue='func_id', s=15, data=nrn_er_df, ax=ax, legend=False)
sns.scatterplot(x='d_orientation', y='d_top_skel', hue='is_chandelier', data=nrn_syn_df, s=100, alpha=0.8, palette='dark', ax=ax, legend=False)
ax.set_ylim((0,40000))
ax.invert_yaxis()
ax.set_aspect('equal')

In [None]:
cell_df = dl.query_cell_types(soma_table)

In [None]:
from decimal import Decimal

In [None]:
a = cell_df['pt_root_id'].apply(Decimal)

In [None]:
ais_mesh = ais_meshes[ais_ids[0]]
syn_locs = all_contact_center_df[all_contact_center_df['root_id']==ais_ids[0]]

In [None]:
ais_sk = ais_sks[ais_ids[0]]

In [None]:
trimesh.base.Geometry?

In [None]:
img_locs = [[94413, 49229, 1045], [100306, 59517, 970]]

In [None]:
from annotationframeworkclient import imagery

In [None]:
img = imagery.ImageryClient(dataset_name='pinky100', base_resolution=[4,4,40])

In [None]:
def make_bounds(center, width, height, depth=0):
    half_depth = 0.5 * (depth-1)
    half_width = 0.5 * (width-1)
    half_height = 0.5 * (height-1)
    offset = [half_width, half_height, half_depth]
    lbounds = [int(x-h) for x,h in zip(center, offset)]
    ubounds = [int(x+h) for x,h in zip(center, offset)]
    return [lbounds, ubounds]

In [None]:
img2 = img.image_cutout(bounds=make_bounds(img_locs[1], 400, 600, 1))

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
img1.shape

In [None]:
plt.imshow(np.squeeze(img2).T)

In [None]:
img.save_imagery(f'{plot_dir}/CO_example_2', precomputed_image=img2)