## updated of figures notebook coded from plot functions in the library

In [1]:
%matplotlib inline

from pathlib import Path
from types import SimpleNamespace
from joblib import delayed, Parallel
import time
import traceback

import numpy as np
import pandas as pd
import scipy.stats as stats
pd.set_option('display.max_rows', 50)

import matplotlib as mpl

import matplotlib.pyplot as plt
import matplotlib.collections as mcoll
import matplotlib.path as mpath
import seaborn as sns
sns.set(style='whitegrid', palette='muted')
from matplotlib_venn import venn2, venn3

import TreeMazeAnalyses2.Utils.robust_stats as rs
import TreeMazeAnalyses2.Analyses.experiment_info as ei
import TreeMazeAnalyses2.Analyses.spatial_functions as sf
import TreeMazeAnalyses2.Analyses.open_field_functions as of
import TreeMazeAnalyses2.Analyses.plot_functions as pf
#import TreeMazeAnalyses2.Analyses.cluster_match_functions as cmf

from importlib import reload

import ipywidgets as widgets
from ipywidgets import interact, fixed, interact_manual
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

In [2]:
ei = reload(ei)
info = ei.SummaryInfo()

In [3]:
metric_scores, model_scores = info.get_of_results()
unit_table = info.get_unit_table()

In [23]:
unit_table

Unnamed: 0,subject_cl_id,subject,session,task,date,subsession,tt,depth,unique_cl_name,session_cl_id,unit_type,n_matches_con,subject_cl_match_con_id,n_matches_lib,subject_cl_match_lib_id,snr,fr,isi_viol_rate,cl_id,cl_match_con_id,cl_match_lib_id,task2,match_lib_multi_task_id,match_con_multi_task_id
0,0,Li,Li_T3g_052818,T3g,52818,0,2,16.500,Li_T3g_052818-tt2_d16.5_cl11,11,mua,,,,,,2.84,0.20,0,,,T3,-1,-1
1,1,Li,Li_T3g_052818,T3g,52818,0,2,16.500,Li_T3g_052818-tt2_d16.5_cl14,14,mua,,,,,,2.69,0.05,1,,,T3,-1,-1
2,2,Li,Li_OF_052818,OF,52818,0,7,17.250,Li_OF_052818-tt7_d17.25_cl8,8,cell,0.0,43.0,0.0,39.0,5.09,13.56,1.13,2,43.0,39.0,OF,-1,-1
3,3,Li,Li_OF_052818,OF,52818,0,8,16.250,Li_OF_052818-tt8_d16.25_cl4,4,cell,1.0,48.0,1.0,44.0,10.08,2.95,0.44,3,48.0,44.0,OF,-1,-1
4,4,Li,Li_OF_052818,OF,52818,0,8,16.250,Li_OF_052818-tt8_d16.25_cl18,18,cell,1.0,50.0,2.0,52.0,6.60,33.03,0.54,4,50.0,52.0,OF,-1,-1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
4459,73,Mi,Mi_OF_021220,OF,21220,0,12,16.125,Mi_OF_021220-tt12_d16.125_cl20,20,cell,0.0,8.0,0.0,8.0,8.11,25.24,0.11,4459,964.0,858.0,OF,-1,-1
4460,74,Mi,Mi_OF_021220,OF,21220,0,12,16.125,Mi_OF_021220-tt12_d16.125_cl10,10,mua,,,,,6.75,8.17,0.51,4460,,,OF,-1,-1
4461,75,Mi,Mi_OF_021720,OF,21720,0,14,16.375,Mi_OF_021720-tt14_d16.375_cl1,1,cell,0.0,11.0,0.0,11.0,12.38,4.92,0.00,4461,967.0,861.0,OF,-1,-1
4462,76,Mi,Mi_OF_021720,OF,21720,0,14,16.375,Mi_OF_021720-tt14_d16.375_cl0,0,mua,,,,,20.18,0.42,0.00,4462,,,OF,-1,-1


In [12]:
analyses=['speed','hd','border','grid', 'stability']
@interact(analysis_type=analyses, unit_type=['cell', 'mua', None], thr=widgets.FloatSlider(min=-1, max=1, step=0.02))
def metric_filter_units(analysis_type, thr, unit_type=None):
    if unit_type is None:
        sub_table = metric_scores.loc[ (metric_scores.analysis_type==analysis_type) 
                                      & (metric_scores.score>=thr) 
                                      & (metric_scores.session_valid)]
    else:
        sub_table = metric_scores.loc[ (metric_scores.unit_type==unit_type) 
                                      & (metric_scores.analysis_type==analysis_type) 
                                      & (metric_scores.score>=thr)
                                      & (metric_scores.session_valid)]
    return sub_table.sort_values(by=['score'], ascending=False)

interactive(children=(Dropdown(description='analysis_type', options=('speed', 'hd', 'border', 'grid', 'stabili…

In [21]:
def plot_OF(cl_name, track_data, spikes, fr, fr_map, min_speed=3, figsize=None):
    
    label_fontsize = 14
    
    if figsize is None:
        f = plt.figure(constrained_layout=True, figsize=(10 , 5))
    else:
        f = plt.figure(constrained_layout=True, figsize=figsize)
        
    gs = f.add_gridspec(2,3)
    ax = [[]]*4
    ax[0] = f.add_subplot(gs[:, 0])
    ax[1] = f.add_subplot(gs[:, 1])
    ax[2] = f.add_subplot(gs[0, 2], projection='polar')
    ax[3] = f.add_subplot(gs[1, 2])

    x,y = track_data['x'],track_data['y']
    theta, sp = track_data['hd'], track_data['sp']
    
    ax[0] = pf.plot_xy_spks(x,y,spikes,ax=ax[0])
    ax[1], cax = pf.plot_firing_rate_map(fr_map, ax=ax[1])
     
    res = sf.get_binned_angle_fr(theta=theta, fr=fr, speed=sp)

    ax[2], cax = plot_ang_fr(ang_bin_centers, ang_fr, mean_ang, vec_len,  ax[2])
    
    res = get_bin_sp_fr(track_data, fr)
    sp_bin_centers, sp_fr_m, sp_fr_s = res['sp_bin_centers'], res['sp_fr_m'], res['sp_fr_s']
    ax[3] = plot_sp_fr(sp_bin_centers, sp_fr_m, sp_fr_s, ax[3])

    ap = ax[2].get_position()
    pos = [ap.x0+0.03, ap.y0, ap.width*0.8, ap.height*0.75]
    ax[2].set_position(pos)
    
    ap = ax[3].get_position()
    pos = [ap.x0+0.1, ap.y0+0.1, ap.width*0.55, ap.height*0.65]
    ax[3].set_position(pos)
    
    axt = f.add_axes([0,0,0.02,1])
    axt.text(-0.2,.25, cl_name, rotation=90)
    axt.set_axis_off()
    return f, ax

def plot_OF2(cl_name, track_data, spikes, fr, fr_map):
    
    label_fontsize = 14
    
    f = plt.figure(figsize=(8,6))
    gs = f.add_gridspec(2,2)
    ax = [[]]*4
    ax[0] = f.add_subplot(gs[:, 0])
    ax[1] = f.add_subplot(gs[0, 1], projection='polar', position=[0.55, 0.55, 0.2, 0.2])
    ax[2] = f.add_subplot(gs[1, 1], position=[0.58, 0.28, 0.15, 0.2])
    
    x,y = track_data['x'],track_data['y']
    ax[0] = plot_xy_spks(x,y,spikes,ax[0])
  
    res = get_binned_ang_fr(track_data, fr)
    vec_len, mean_ang, ang_fr, ang_bin_centers = res['vec_len'], res['mean_ang'], res['ang_fr'], res['ang_bin_centers']
    ax[1], cax =plot_ang_fr(ang_bin_centers, ang_fr, mean_ang, vec_len,  ax[1])

    res = get_binned_sp_fr(track_data, fr)
    sp_bin_centers, sp_fr_m, sp_fr_s = res['sp_bin_centers'], res['sp_fr_m'], res['sp_fr_s']
    ax[2]=plot_sp_fr(sp_bin_centers, sp_fr_m, sp_fr_s, ax[2])
    
    pos = ax[0].get_position()
    axt = f.add_axes([pos.x0-0.03,pos.y0,0.02,pos.height])
#     axt = f.add_axes([0.1,0,0.02,1])
    axt.text(0,.1, cl_name, rotation=90)
    axt.set_axis_off()
    
    return f, ax

def load_cell(idx, table):
    subject, session, session_unit_id = table.loc[idx]['subject'], table.loc[idx]['session'], table.loc[idx]['session_unit_id']
    session_info = SubjectSessionInfo(subject, session)
    
    track_data = session_info.get_track_data()
    spikes = session_info.get_binned_spikes()[session_unit_id]
    fr = session_info.get_fr()[session_unit_id]
    fr_map = session_info.get_fr_maps()[session_unit_id]

    return track_data, spikes, fr, fr_map

def load_cell2(uuid, table):
    idx = table[table.unit_id==uuid].index[0]
    subject, session, session_unit_id = table.loc[idx]['subject'], table.loc[idx]['session'], table.loc[idx]['session_unit_id']
    session_info = ei.SubjectSessionInfo(subject, session)
    
    track_data = session_info.get_track_data()
    spikes = session_info.get_binned_spikes()[session_unit_id]
    fr = session_info.get_fr()[session_unit_id]
    fr_map = session_info.get_fr_maps()[session_unit_id]

    return track_data, spikes, fr, fr_map
#f,ax = plot_OF(track_data, spikes, fr, fr_maps)

In [22]:
analysis_type = 'grid'
thr = 0.2
sub_metric_table = metric_filter_units(analysis_type=analysis_type, thr=thr)

@interact(idx=sub_metric_table.index)
def plot_OF_unit(idx):

    uuid = metric_scores.loc[idx, 'unit_id']
    cl_name = metric_scores.loc[idx, 'cl_name']
    print(f"UUID = {uuid}")
    print(f"cl_name = {cl_name}")
    f,ax = plot_OF2(cl_name, *load_cell2(uuid, metric_scores) )
    return metric_scores[metric_scores.cl_name==cl_name]

interactive(children=(Dropdown(description='idx', options=(2879, 4476, 2788, 10839, 9473, 3139, 3010, 9690, 37…