In [3]:
import joblib

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from caveclient import CAVEclient
from meshparty import meshwork

# from common_setup import project_info, project_paths
from apical_features import *
from plotting_code import plot_layers
from apical_model_utils import *

ModuleNotFoundError: No module named 'apical_features'

In [4]:
# skel_dir = f'{project_paths.skeletons}/skeleton_files'
# preview_dir = f'{project_paths.skeletons}/previews'
# synapse_dir = f'{project_paths.data}/synapse_files'
model_dir = f'models'

# layer_bnds = np.load(f'{project_paths.data}/layer_bounds_v1.npy')
# height_bounds = np.load(f'{project_paths.data}/height_bounds_v1.npy')
# width_bounds = np.load(f'{project_paths.data}/width_bounds_v1.npy')

In [6]:
import joblib
rfc = joblib.load(f'{model_dir}/point_model_current.pkl')
feature_cols = joblib.load(f'{model_dir}/feature_cols_current.pkl')
branch_params = joblib.load(f'{model_dir}/branch_params_current.pkl')

# BranchClassifier = BranchClassifierFactory(rfc, feature_cols)
# branch_classifier = BranchClassifier(**branch_params)

In [7]:
branch_params

{'min_length': 50,
 'logodds_clip': 200,
 'softmax_scaler': 50,
 'bin_num': 80,
 'logodds_thresh': 0,
 'softmax_thresh': 0.3}

In [13]:
client = CAVEclient(project_info.datastack)

In [14]:
import tqdm
from loguru import logger
_ = logger.add('apical_classification.log')

In [None]:
column

In [None]:
import os

In [None]:
oid = 864691135510640592
row = column_df.query('pt_root_id == @oid').iloc[0]

In [None]:
nrn = meshwork.load_meshwork(f'{skel_dir}/{oid}.h5')
sk = nrn.skeleton
if row['classification_system'] == 'aibs_coarse_excitatory': 
    feature_df = process_apical_features(nrn)

    dendrite_synapse_indices = nrn.anno.post_syn.df.index
    dendrite_skind_mask = sk.node_mask
    nrn.reset_mask()

    feature_array = feature_df[point_feature_columns].values
    feature_df['apical_prob'] = point_model.predict_proba(feature_array)[:,1]

    ap_df_grp = feature_df[['root_id', 'base_skind', 'apical_prob']].groupby(['root_id', 'base_skind'])
    ap_df_branches = ap_df_grp.agg(list).reset_index()
    ap_df_branches['logodds_ratio'] = ap_df_branches['apical_prob'].apply(log_odds_ratio)
    ap_df_branches['len_br'] = ap_df_branches['apical_prob'].apply(len)
    ap_df_branches = ap_df_branches.query('len_br > 50').reset_index(drop=True)
    ap_df_branches['logodds_clipped'] = ap_df_branches['logodds_ratio'].apply(lambda x: np.clip(x, -100, 100))

    ap_df_branches['softmax_denom'] = ap_df_branches.groupby('root_id')['logodds_clipped'].transform(softmax_denominator)
    ap_df_branches['median_logprob'] = ap_df_branches['apical_prob'].apply(np.median)
    ap_df_branches['softmax_num'] = ap_df_branches['logodds_clipped'].transform(softmax_numerator)
    ap_df_branches['apical_softmax'] = ap_df_branches['softmax_num'] / ap_df_branches['softmax_denom']
    ap_df_branches['logodds_clipped_norm'] = ap_df_branches['logodds_clipped'] / 100
    ap_df_branches['predict_apical'] = np.logical_and(ap_df_branches['logodds_clipped']>0, ap_df_branches['apical_softmax'] > 0.3)

    apical_base = ap_df_branches.query('predict_apical')['base_skind'].values
    input_df = nrn.anno.post_syn.df.reset_index(drop=True)
    sk = nrn.skeleton
    if len(apical_base) > 0:
        apical_inds = np.concatenate([x.to_mesh_index for x in nrn.skeleton.downstream_nodes(apical_base)])

        sk_downstream = nrn.skeleton.downstream_nodes(apical_base)
        apical_sk_mask = np.sum(np.vstack([x.to_skel_mask for x in sk_downstream]), axis=0)>0
        apical_mesh_mask = np.sum(np.vstack([x.to_mesh_mask for x in sk_downstream]), axis=0)>0

        apical_minds = np.concatenate([x.to_mesh_index for x in sk_downstream])

        apical_post_inds = nrn.anno.post_syn.filter_query(apical_mesh_mask).df.index
        input_df.loc[apical_post_inds, 'is_apical'] = True
        input_df['is_apical'] = input_df['is_apical'].fillna(False)
    else:
        input_df['is_apical'] = False
        apical_sk_mask = np.full(len(sk.vertices), False)
else:
    nrn = apply_dendrite_mask(nrn)
    dendrite_synapse_indices = nrn.anno.post_syn.df.index
    dendrite_skind_mask = sk.node_mask

    nrn.reset_mask()

    input_df = nrn.anno.post_syn.df.reset_index(drop=True)
    input_df['is_apical'] = False
    apical_sk_mask = np.full(len(sk.vertices), False)
    
soma_post_inds = nrn.anno.post_syn.filter_query(nrn.root_region.to_mesh_mask).df.index
input_df.loc[np.intersect1d(soma_post_inds, input_df.index), 'is_soma'] = True
input_df['is_soma'] = input_df['is_soma'].fillna(False)

input_df.loc[np.intersect1d(dendrite_synapse_indices, input_df.index), 'is_dendrite'] = True
input_df['is_dendrite'] = input_df['is_dendrite'].fillna(False)

input_df['dist_to_root'] = nrn.distance_to_root(input_df['post_pt_mesh_ind'])/1000

input_df.to_feather(f'{synapse_dir}/{oid}_inputs.feather')
logger.info(f'Saved synapses for {oid}')

fig, ax = plt.subplots(figsize=(5,5), facecolor='w', dpi=150)

axon_mask = ~dendrite_skind_mask
if np.any(axon_mask):
    ax.scatter(x=sk.vertices[axon_mask,0]/1000,
            y=sk.vertices[axon_mask,1]/1000,
            s=0.2,
            alpha=0.5,
            color=(0.059, 0.780, 1.000))

basal_mask = np.logical_and(dendrite_skind_mask, ~apical_sk_mask)
if np.any(basal_mask):
    ax.scatter(x=sk.vertices[basal_mask,0]/1000,
            y=sk.vertices[basal_mask,1]/1000,
            s=0.2,
            alpha=0.5,
            color='k')
    
if np.any(apical_sk_mask):
    ax.scatter(x=sk.vertices[apical_sk_mask,0]/1000,
            y=sk.vertices[apical_sk_mask,1]/1000,
            s=0.2,
            alpha=0.5,
            color='r')

ax.plot(sk.vertices[sk.root, 0]/1000,
        sk.vertices[sk.root, 1]/1000,
        marker='o', color='w', markersize=5, markeredgecolor='k')

ax.set_aspect('equal')
plot_layers(layer_bnds, height_bounds, width_bounds, ax=ax, linestyle=':', linewidth=1, color='k')
ax.set_title(f'{oid} | {row["cell_type"]} | {len(input_df.query("is_apical"))}/{len(input_df.query("is_soma"))}/{len(input_df)}')

In [None]:
synapse_dir

In [None]:
redo_rows = []
for _, row in tqdm.tqdm(column_df.iterrows()):
    oid = row['pt_root_id']
    if not os.path.exists(f'{skel_dir}/{oid}.h5'):
        redo_rows.append(row)
        continue
    try:
        nrn = meshwork.load_meshwork(f'{skel_dir}/{oid}.h5')
        sk = nrn.skeleton
        if row['classification_system'] == 'aibs_coarse_excitatory': 
            feature_df = process_apical_features(nrn)

            dendrite_synapse_indices = nrn.anno.post_syn.df.index
            dendrite_skind_mask = nrn.skeleton.node_mask
            nrn.reset_mask()

            feature_array = feature_df[point_feature_columns].values
            feature_df['apical_prob'] = point_model.predict_proba(feature_array)[:,1]

            ap_df_grp = feature_df[['root_id', 'base_skind', 'apical_prob']].groupby(['root_id', 'base_skind'])
            ap_df_branches = ap_df_grp.agg(list).reset_index()
            ap_df_branches['logodds_ratio'] = ap_df_branches['apical_prob'].apply(log_odds_ratio)
            ap_df_branches['len_br'] = ap_df_branches['apical_prob'].apply(len)
            ap_df_branches = ap_df_branches.query('len_br > 50').reset_index(drop=True)
            ap_df_branches['logodds_clipped'] = ap_df_branches['logodds_ratio'].apply(lambda x: np.clip(x, -100, 100))

            ap_df_branches['softmax_denom'] = ap_df_branches.groupby('root_id')['logodds_clipped'].transform(softmax_denominator)
            ap_df_branches['median_logprob'] = ap_df_branches['apical_prob'].apply(np.median)
            ap_df_branches['softmax_num'] = ap_df_branches['logodds_clipped'].transform(softmax_numerator)
            ap_df_branches['apical_softmax'] = ap_df_branches['softmax_num'] / ap_df_branches['softmax_denom']
            ap_df_branches['logodds_clipped_norm'] = ap_df_branches['logodds_clipped'] / 100
            ap_df_branches['predict_apical'] = np.logical_and(ap_df_branches['logodds_clipped']>0, ap_df_branches['apical_softmax'] > 0.3)

            apical_base = ap_df_branches.query('predict_apical')['base_skind'].values
            input_df = nrn.anno.post_syn.df.reset_index(drop=True)
            if len(apical_base) > 0:
                apical_inds = np.concatenate([x.to_mesh_index for x in nrn.skeleton.downstream_nodes(apical_base)])

                sk_downstream = nrn.skeleton.downstream_nodes(apical_base)
                apical_sk_mask = np.sum(np.vstack([x.to_skel_mask for x in sk_downstream]), axis=0)>0
                apical_mesh_mask = np.sum(np.vstack([x.to_mesh_mask for x in sk_downstream]), axis=0)>0

                apical_minds = np.concatenate([x.to_mesh_index for x in sk_downstream])

                apical_post_inds = nrn.anno.post_syn.filter_query(apical_mesh_mask).df.index
                input_df.loc[apical_post_inds, 'is_apical'] = True
                input_df['is_apical'] = input_df['is_apical'].fillna(False)
            else:
                input_df['is_apical'] = False
                apical_sk_mask = np.full(len(sk.vertices), False)
        else:
            nrn = apply_dendrite_mask(nrn)
            dendrite_skind_mask = nrn.skeleton.node_mask
            dendrite_synapse_indices = nrn.anno.post_syn.df.index
            nrn.reset_mask()
            
            input_df = nrn.anno.post_syn.df.reset_index(drop=True)
            input_df['is_apical'] = False
            apical_sk_mask = np.full(len(sk.vertices), False)
            
        soma_post_inds = nrn.anno.post_syn.filter_query(nrn.root_region.to_mesh_mask).df.index
        input_df.loc[np.intersect1d(soma_post_inds, input_df.index), 'is_soma'] = True
        input_df['is_soma'] = input_df['is_soma'].fillna(False)

        input_df.loc[np.intersect1d(dendrite_synapse_indices, input_df.index), 'is_dendrite'] = True
        input_df['is_dendrite'] = input_df['is_dendrite'].fillna(False)

        input_df['dist_to_root'] = nrn.distance_to_root(input_df['post_pt_mesh_ind'])/1000
        
        input_df.to_feather(f'{synapse_dir}/{oid}_inputs.feather')
        logger.info(f'Saved synapses for {oid}')

        fig, ax = plt.subplots(figsize=(5,5), facecolor='w', dpi=150)

        axon_mask = ~dendrite_skind_mask
        ax.scatter(x=sk.vertices[axon_mask,0]/1000,
                y=sk.vertices[axon_mask,1]/1000,
                s=0.2,
                alpha=0.5,
                color=(0.059, 0.780, 1.000))

        basal_mask = np.logical_and(dendrite_skind_mask, ~apical_sk_mask)
        ax.scatter(x=sk.vertices[basal_mask,0]/1000,
                y=sk.vertices[basal_mask,1]/1000,
                s=0.2,
                alpha=0.5,
                color='k')

        ax.scatter(x=sk.vertices[apical_sk_mask,0]/1000,
                y=sk.vertices[apical_sk_mask,1]/1000,
                s=0.2,
                alpha=0.5,
                color='r')

        ax.plot(sk.vertices[sk.root, 0]/1000,
                sk.vertices[sk.root, 1]/1000,
                marker='o', color='w', markersize=5, markeredgecolor='k')

        ax.set_aspect('equal')
        plot_layers(layer_bnds, height_bounds, width_bounds, ax=ax, linestyle=':', linewidth=1, color='k')
        ax.set_title(f'{oid} | {row["cell_type"]} | {len(input_df.query("is_apical"))}/{len(input_df.query("is_soma"))}/{len(input_df)}')

        fig.savefig(f'{preview_dir}/apical_preview_{oid}.png', bbox_inches='tight')
        plt.close(fig)
        logger.info(f'Saved figure for {oid}')
    except Exception as e:
        logger.exception(e)

In [None]:
fig.savefig(f'{}')

In [None]:
fig, ax = plt.subplots(figsize=(4,4))
sns.scatterplot(y='apical_softmax',
                x='logodds_clipped', data=ap_df_branches, alpha=0.2, size='median_logprob', legend=False) 
ax.plot((-100, 100), (0.3, 0.3), 'k:')

In [None]:
isinstance(column_df.columns, (list, np.ndarray, pd.Index))

In [None]:
client.annotation.get_table_metadata('functional_coreg')