# plots of individual units

In [1]:
%matplotlib inline
import numpy as np
from scipy.stats import ttest_ind, ttest_1samp
import pandas as pd
from importlib import reload
from joblib import delayed, Parallel

import matplotlib.pyplot as plt
import seaborn as sns


import TreeMazeAnalyses2,Analyses.tree_maze_functions as tmf
import TreeMazeAnalyses2.Analyses.experiment_info as ei
import TreeMazeAnalyses2.Analyses.plot_functions as pf

import ipywidgets as widgets
from ipywidgets import interact, fixed, interact_manual

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=UserWarning)

In [2]:
ei = reload(ei)
info = ei.SummaryInfo()

In [3]:
seg_rates = info.get_bal_conds_seg_rates()
zrc = info.get_zone_rates_remap()
remap_measures = zrc.columns[10:]
id_vars = ['unit_id', 'subject', 'session', 'session_unit_id', 'unit_type']
test_vars = [metric for metric in remap_measures if (('corr' in metric) & (len(metric.split('-'))<=3))]
null_vars = [metric for metric in remap_measures if (('Even' in metric) & ('corr' in metric) & (len(metric.split('-'))<=3)) ]
z_vars = [metric for metric in remap_measures if (('z' in metric) & (len(metric.split('-'))>3))]

In [4]:
sw = info.select_session()

interactive(children=(Dropdown(description='subject', options=('Li', 'Ne', 'Cl', 'Al', 'Ca'), value='Li'), Dro…

In [5]:
session_info = sw.result
print(session_info)


Session Information for subject Li, session Li_T3g_052818
Number of curated units: 2
Methods listed below can be executed with get_{method}(), eg. get_spikes():
  -> track_data. Executed = True
  -> spikes. Executed = True
  -> binned_spikes. Executed = True
  -> fr. Executed = True
  -> pos_zones. Executed = True
  -> event_table. Executed = True
  -> trial_zone_rates. Executed = False
  -> bigseg_comps. Executed = True
  -> zone_rates_comps. Executed = True
  -> zone_rates_remap. Executed = True
  -> pop_zone_rates_remap. Executed = True
  -> bal_conds_seg_rates. Executed = True
  -> bal_conds_seg_boot_rates. Executed = nan
  -> zone_encoder. Executed = True
  -> zone_decoder. Executed = True

To run all analyses use run_analyses().



In [6]:
%%time
tmf = reload(tmf)
ta = tmf.TrialAnalyses(session_info)
tree_maze = tmf.TreeMazeZones()

CPU times: user 5.7 s, sys: 137 ms, total: 5.84 s
Wall time: 1.88 s


In [7]:
ta.trial_table

Unnamed: 0,t0,tD,tE,tE_1,tE_2,tR,dur,cue,dec,correct,long,goal,grw,sw,vsw
0,439,1157,1360,1360,1490,1929,921,R,L,0.0,,,0,0,0
1,1935,2638,3014,3014,3610,3807,1079,R,R,1.0,0.0,1.0,1,0,0
2,3813,4089,4438,4438,4554,5288,624,R,L,0.0,,,0,0,0
3,5294,5818,6077,6077,6609,7335,783,R,L,0.0,,,0,0,0
4,7342,7551,8050,7897,8050,8426,708,L,L,1.0,1.0,4.0,1,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
56,57296,57510,58091,57754,58091,58521,795,L,L,1.0,1.0,3.0,1,0,0
57,58528,58772,59139,58960,59139,59497,611,R,R,1.0,1.0,1.0,1,1,1
58,59504,64470,64758,64758,64758,65571,5253,L,R,0.0,,,0,1,1
59,65578,66115,66381,66381,67403,67618,803,L,L,1.0,0.0,3.0,1,0,0


## Load session specific quantities

In [8]:
zrc[zrc.session==session].sort_values(z_vars[0]) [ id_vars + [z_vars[0]]] 

NameError: name 'session' is not defined

In [9]:
zone_rates = ta.get_avg_trial_zone_rates()
fr_maps_trials = ta.get_trial_rate_maps(occ_rate_mask=True)
bal_cond_bigseg_rates = session_info.get_bal_conds_seg_rates()

In [10]:
bal_cond_bigseg_rates.head()

Unnamed: 0,CR_bo-left-m,CR_bo-left-n,CL_bo-left-m,CL_bo-left-n,Co_bo-left-m,Co_bo-left-n,Inco_bo-left-m,Inco_bo-left-n,Co_bi-left-m,Co_bi-left-n,...,CR_bo-CL_bo-right-t,CR_bo-CL_bo-right-p,Co_bo-Inco_bo-right-t,Co_bo-Inco_bo-right-p,Co_bi-Inco_bi-right-t,Co_bi-Inco_bi-right-p,Even_bo-Odd_bo-right-t,Even_bo-Odd_bo-right-p,Even_bi-Odd_bi-right-t,Even_bi-Odd_bi-right-p
0,1.557546,15.2,1.737068,13.0,1.731303,21.0,1.571547,7.99,3.178401,21.0,...,-0.074674,0.484361,-0.055014,0.989928,1.994807,0.0,-1.974652,0.0,0.589448,0.216129
1,1.544121,15.2,1.723232,13.0,1.726945,21.0,1.54845,7.99,3.044099,21.0,...,-0.627631,0.022049,-0.576331,0.585474,1.971708,0.0,-0.953063,7e-06,-1.194884,1.938127e-10


### bootstrap quantitites

In [10]:
%%time
boot_subseg_rates = ta.get_avg_seg_rates_boot(segment_type='subseg', n_boot=50)
boot_bigseg_rates = ta.get_avg_seg_rates_boot(segment_type='bigseg', n_boot=50)

CPU times: user 35.2 s, sys: 0 ns, total: 35.2 s
Wall time: 46.2 s


In [11]:
boot_bigseg_rates.head()

Unnamed: 0,boot,cond,unit,seg,m
0,0,CR_bo,0,left,3.637282
1,0,CR_bo,1,left,6.873106
2,0,CR_bo,2,left,2.197301
3,0,CR_bo,3,left,5.273711
4,0,CR_bo,4,left,5.993457


In [12]:
%%time
cond_pairs = ta.bal_cond_pairs
with Parallel(n_jobs=5) as parallel:
    boot_corrs = {}
    for cond_pair in cond_pairs:
        boot_corrs[cond_pair] = ta.zone_rate_maps_bal_conds_boot_corr(bal_cond_pair=cond_pair, parallel=parallel)

CPU times: user 16.9 s, sys: 0 ns, total: 16.9 s
Wall time: 36.5 s


## plot all trials

In [22]:
def plot_trial_track_spikes(trial_analyses, unit=0, ax=None):
    
    lw = 0.1 # line width
    la = 0.3 # line alpha
    lc = '0.5' # line color
    
    ss = 1 # scatter scale
    sc = 'r' # scatter color 
    sa = 0.1 # scatter alpha
    
    if ax is None:
        f,ax = plt.subplots()
    else:
        f = ax.figure
        
    x,y, _ = trial_analyses.get_trial_track_pos()
    invalid_samps = ta.pz_invalid_samps
    spk = trial_analyses.get_trial_neural_data(data_type='spikes')
    
    for tr in range(trial_analyses.n_trials):
        ax.plot(x[tr], y[tr], linewidth=lw, alpha=la, color=lc, zorder=-1)
        ax.scatter(x[tr], y[tr], s=spk[unit,tr]*ss, color=sc, alpha=sa, linewidth=0)

    ax.axis("square")
    ax.axis("off")
    ax.set_ylim(trial_analyses.y_edges[0], trial_analyses.y_edges[-1])
    ax.set_xlim(trial_analyses.x_edges[0], trial_analyses.x_edges[-1])
    
    return ax

def plot_trial_rate_map(trial_analyses, unit=0, ax=None):
    cmap = 'viridis'
    
    if ax is None:
        f,ax = plt.subplots()
    else:
        f = ax.figure    
    
    ax = sns.heatmap(fr_maps_trials[unit], cbar=False, square=True, cmap=cmap, ax=ax)
    ax.invert_yaxis()
    ax.axis("off")
    
    data = fr_maps_trials[unit].flatten()
    data_colors, color_array = pf.get_colors_from_data(data, cmap=cmap)
    
    ax_p = ax.get_position()
    w, h = ax_p.width, ax_p.height
    x0,y0 = ax_p.x0, ax_p.y0

    cax_p = [x0+w*0.85, y0+h*0.05, w*0.05, h*0.15]
    cax = f.add_axes(cax_p)

    pf.get_color_bar_axis(cax, color_array, color_map=cmap, label='FR')

    return ax

def plot_zone_rates(zone_rates, ax=None, min_value=0, max_value=None, label='FR', color_map='viridis', div=False):
    if ax is None:
        f,ax = plt.subplots()
    else:
        f = ax.figure
    
    tree_maze.plot_zone_activity(zone_rates,  ax=ax, min_value=min_value, max_value=max_value, color_map=color_map, label=label)

In [28]:
unit_widget = widgets.IntSlider(value=0, max=ta.n_units-1)
save_button = widgets.Button(description='Save Figure')

def plot_maps(unit=np.arange(ta.n_units)):
    f,ax = plt.subplots(1,3,figsize=(6,2),dpi=400)
    plot_trial_track_spikes(ta, unit=unit, ax=ax[0])
    plot_trial_rate_map(ta,unit=unit, ax=ax[1])
    plot_zone_rates(zone_rates.loc[unit], ax=ax[2], color_map = 'YlOrRd')
    return f

def savefig(*args):
    f = f1.result
    p = info.paths['figures']
    figname = f"all-trials-{session_info.session}-{f1.children[0].value}.png"   
    f.savefig(p/figname, dpi=500, bbox_inches='tight')
    
f1 = widgets.interactive(plot_maps, {'unit': unit_widget})
save_button.on_click(savefig)
display(f1, save_button)

interactive(children=(Dropdown(description='unit', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10), value=0), Outpu…

Button(description='Save Figure', style=ButtonStyle())

## plot segment rates bootstraps

In [163]:
dfm = boot_subseg_rates.groupby(['cond', 'unit', 'seg']).median()
dfm = dfm.reset_index()

In [164]:
def savefig(*args):
    f = f2.result
    p = info.paths['figures']
    figname = f"zone_rates-{session_info.session}-{f2.children[0].value}-{f2.children[1].value}.png"   
    f.savefig(p/figname, dpi=500, bbox_inches='tight')
    
def plot_zr_conds(unit=np.arange(ta.n_units), cond_pair=cond_pairs):
    
    conds = cond_pair.split('-')
    if cond_pair == 'CR_bo-CL_bo':
        conds = conds[::-1]
        
    f,ax = plt.subplots(1,2,figsize=(4,2),dpi=400)
   
    max_val = dfm[(dfm.cond.isin(conds)) & (dfm.unit==unit)].m.max()
    
    for ii, cond in enumerate(conds):
        zr = dfm[(dfm.cond==cond) & (dfm.unit==unit)][['m','seg']]
        zr = zr.pivot_table(columns='seg', aggfunc= lambda x: x)
        zr = zr.reset_index().drop('index', axis=1)
        
        plot_zone_rates(zr.loc[0], ax=ax[ii], max_value=max_val)
        
        ax[ii].set_title(cond.split('_')[0])
    return f

cond_pairs_widget = widgets.Dropdown(options=cond_pairs)

f2 = widgets.interactive(plot_zr_conds, {'unit': unit_widget,'cond_pair':cond_pairs_widget })
save_button.on_click(savefig)
display(f2, save_button)


interactive(children=(Dropdown(description='unit', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, …

Button(description='Save Figure', style=ButtonStyle())

In [89]:
cond_pairs[0].split('-')[::-1]

['CL_bo', 'CR_bo']

In [145]:
def savefig(*args):
    f = f3.result
    p = info.paths['figures']
    figname = f"boot_corrs-{session_info.session}-{f3.children[0].value}-{f3.children[1].value}.png"   
    f.savefig(p/figname, dpi=500, bbox_inches='tight')
    
def plot_boot_corrs(unit=np.arange(ta.n_units), cond_pair=cond_pairs,):
    
    null_pair = ta.test_null_bal_cond_pairs[cond_pair]

    test_boots = boot_corrs[cond_pair].loc[unit]
    null_boots = boot_corrs[null_pair].loc[unit]
    
    f, ax = plt.subplots(figsize=(1,1),dpi=400)
    sns.kdeplot(data=test_boots, fill=False, color='r', linewidth=2, ax=ax, label=cond_pair)
    sns.kdeplot(data=null_boots, fill=False, color='0.5', linewidth=2, ax=ax, label=null_pair)
    ax.set_xlim([-0.2,1])
    ax.set_xlabel(r"$\tau$")
   # ax.legend(handlelength = 0.25, handletextpad=0.5, bbox_to_anchor=[-0,0], loc='lower left', frameon=False, fontsize=7)
    
    return f

cond_pairs_widget = widgets.Dropdown(options=cond_pairs)

f3 = widgets.interactive(plot_boot_corrs, {'unit': unit_widget,'cond_pair':cond_pairs_widget })
save_button.on_click(savefig)
display(f3, save_button)


interactive(children=(Dropdown(description='unit', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, …

Button(description='Save Figure', style=ButtonStyle())

In [281]:
zrc[ (zrc.session_unit_id.isin([3,5,19,27])) & (zrc.session==session_info.session) ][z_vars[0]].values

array([-2.33932016, -3.90215603, -4.68906125, -4.44077086])

In [247]:

cond_pair = cond_pairs[0]
null_pair = ta.test_null_bal_cond_pairs[cond_pair]

test_boots = boot_corrs[cond_pair].loc[unit]
null_boots = boot_corrs[null_pair].loc[unit]

boot_df = pd.DataFrame( np.array ((test_boots, null_boots)).T, columns=[cond_pair, null_pair] )
boot_df = boot_df.melt(var_name='cond_pair', value_name='corr')


In [234]:
test_boots

0     0.627530
1     0.619433
2     0.676113
3     0.608637
4     0.543860
        ...   
95    0.627530
96    0.646424
97    0.570850
98    0.665317
99    0.662618
Name: 0, Length: 100, dtype: float64

In [282]:
boot_bigseg_rates

Unnamed: 0,boot,cond,unit,seg,m
0,0,CR_bo,0,left,14.873397
1,0,CR_bo,1,left,9.128395
2,0,CR_bo,2,left,18.248957
3,0,CR_bo,3,left,19.502219
4,0,CR_bo,4,left,15.039575
...,...,...,...,...,...
44995,49,Odd_bi,25,right,5.987235
44996,49,Odd_bi,26,right,5.694151
44997,49,Odd_bi,27,right,5.445853
44998,49,Odd_bi,28,right,0.852974


In [None]:
cond_pair = ta.bal_cond_pairs[0].split('-')

ax=sns.violinplot(data=df[ (df.unit==0) & (df.cond.isin(cond_pair))], x='seg', hue='cond', y='m', split=True, inner='quartile', hue_order=[cond_pair[1], cond_pair[0]], palette=['green','purple'], alpha=0.7, saturation=1)
plt.setp(ax.collections, alpha=.7)

In [154]:
a = dfm[(dfm.cond.isin( ['CL_bo', 'CR_bo'])) & (dfm.unit==0)]
a.head().m.max(

SyntaxError: unmatched ']' (3725056251.py, line 1)

In [155]:
df[(df.cond1) & (df.unit==0)]

Unnamed: 0,boot,cond,unit,seg,m
0,0,CR_bo,0,H,32.735157
1,0,CR_bo,1,H,36.220822
2,0,CR_bo,2,H,46.683986
3,0,CR_bo,3,H,0.898082
4,0,CR_bo,4,H,5.990283
...,...,...,...,...,...
389995,49,Odd_bi,15,G4,1.746129
389996,49,Odd_bi,16,G4,0.200968
389997,49,Odd_bi,17,G4,5.620847
389998,49,Odd_bi,18,G4,5.094654


In [287]:
def savefig(*args):
    f = f4.result
    p = info.paths['figures']
    figname = f"boot_bigsegs-{session_info.session}-{f4.children[0].value}-{f3.children[1].value}.png"   
    f.savefig(p/figname, dpi=500, bbox_inches='tight')
    
def plot_boot_segs(unit=np.arange(ta.n_units), cond_pair=cond_pairs):
    conds = cond_pair.split('-')
    
    f, ax = plt.subplots(figsize=(3,3),dpi=400)
    ax=sns.violinplot(data=boot_bigseg_rates[ (boot_bigseg_rates.unit==unit) & (boot_bigseg_rates.cond.isin(conds))], 
                      x='seg', hue='cond', y='m', split=True, inner='quartile', 
                      hue_order=[conds[1], conds[0]], palette=['green','purple'], 
                      alpha=0.7, saturation=1, legend=False)
    plt.setp(ax.collections, alpha=.7)
    ax.get_legend().remove()
    ax.set_ylabel('FR [spikes/s]')
    ax.set_xlabel('Segment')
    
    return f

cond_pairs_widget = widgets.Dropdown(options=cond_pairs)

f4 = widgets.interactive(plot_boot_segs, {'unit': unit_widget,'cond_pair':cond_pairs_widget })
save_button.on_click(savefig)
display(f4, save_button)

interactive(children=(Dropdown(description='unit', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, …

Button(description='Save Figure', style=ButtonStyle())

In [283]:
boot_bigseg_rates

Unnamed: 0,boot,cond,unit,seg,m
0,0,CR_bo,0,left,14.873397
1,0,CR_bo,1,left,9.128395
2,0,CR_bo,2,left,18.248957
3,0,CR_bo,3,left,19.502219
4,0,CR_bo,4,left,15.039575
...,...,...,...,...,...
44995,49,Odd_bi,25,right,5.987235
44996,49,Odd_bi,26,right,5.694151
44997,49,Odd_bi,27,right,5.445853
44998,49,Odd_bi,28,right,0.852974
