In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
from os.path import join as pjoin
from tqdm.notebook import tqdm
import plotly.graph_objects as go
# import pingouin as pg
from scipy.stats import pearsonr, spearmanr, zscore, skew
from natsort import natsorted

sys.path.append('/home/austinbaggetta/csstorage3/CircleTrack/CircleTrackAnalysis')
import circletrack_behavior as ctb
import circletrack_neural as ctn
import place_cells as pc
import plotting_functions as pf

In [None]:
## Settings
project_folder = ['MultiCon_Imaging']
experiment_folders = ['MultiCon_Imaging5', 'MultiCon_Imaging6']
experiment_folders = ['MultiCon_Imaging5']
dpath = f'../../{project_folder[0]}'
fig_path = f'../../../Manuscripts/MultiCon/intermediate_plots/spatial_info'
chance_color = '#7d7d7d'
avg_color = '#287347'
subject_color = '#7d7d7d'
ce_colors = ['#7d7d7d', '#287347']
ce_colors_dict = {'Control': '#B8293D', 'Experimental': '#287347', 'Two-Context': 'midnightblue', 'Multi-Context': '#287347'}
symbol_dict = {'Control': 'x', 'Experimental': 'circle'}
symbols_list = ['x', 'circle']
context_colors = {'A': '#00802d', 'B': '#006c79', 'C': '#004da4', 'D': '#430073'}
mouse_colors = ['midnightblue', 'darkred', 'darkorchid', 'darkturquoise']
male_mice = ['mc44', 'mc46', 'mc54', 'mc55']
control_mice = ['mc46', 'mc49', 'mc52', 'mc54', 'mc59', 'mc60']
imaging5 = ['mc44', 'mc46', 'mc48', 'mc49', 'mc51', 'mc52']
session_list = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]
control_list = [f'A{x}' for x in np.arange(1, 16)] + [f'B{x}' for x in np.arange(1, 6)]
day_list = [f'Day {x}' for x in np.arange(1, 21)]
bin_size = 0.16 ## size of linear position bins equivalent to 2cm-wide bins
velocity_thresh = 10
centroid_distance = 10
data_of_interest = 'aligned_place_cells' ## one of behav, aligned_minian, aligned_place_cells, lin_behav
z_thresh = 2.325

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

xr.set_options(keep_attrs=True)

### Spatial information shuffling testing.

In [None]:
data_type = 'S'
mouse = 'mc51'
session = '20'
correct_dir = True
only_running = True
minimum_firing_amount = 0.2
nshuffles = 10
data_path = f'../../{project_folder[0]}/{experiment_folders[0]}/output/aligned_place_cells/{mouse}/{data_type}/{mouse}_{data_type}_{session}.nc'
yra_path = f'../../{project_folder[0]}/{experiment_folders[0]}/output/aligned_minian/{mouse}/YrA/{mouse}_YrA_{session}.nc'

sdata = xr.open_dataset(data_path)[data_type]
yra = xr.open_dataset(yra_path)['YrA']
num_neurons = sdata.shape[0]

x_cm, y_cm = ctb.convert_to_cm(x=sdata['x'], y=sdata['y'])
velocity, running = pc.define_running_epochs(x_cm, y_cm, sdata['behav_t'], velocity_thresh=velocity_thresh)
neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, only_running=only_running, 
                                                                velocity_thresh=velocity_thresh)

# ar = neural_data.copy()
# nshuffles = 10
# shuffled_info_cell = np.zeros((nshuffles, 40))
# for shuffle in np.arange(0, nshuffles):
#     ## Shift position data forward by a random number of frames
#     shuffled_data = np.array(())
#     for trial in np.unique(neural_data['trials']):
#         position = ar['lin_position'][ar['trials'] == trial].values
#         random_shift = np.random.randint(60, position.shape[0]-1)  
#         rolled_position = np.roll(position, random_shift)
#         shuffled_data = np.concatenate((shuffled_data, rolled_position))
#     ## Calculate spatial coherence and spatial information with the reordered data
#     population_activity, occupancy, _ = pc.spatial_activity(ar, shuffled_data, bin_size=bin_size, binarized=binarize)
#     shuffled_tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
#     shuffled_info_cell[shuffle] = shuffled_tuning_curves[:, 8]
#     shuffled_info = pc.skaggs_information_content(shuffled_tuning_curves, occupancy)

## Calculate observed Skagg's information content and observed spatial coherence
population_activity, occupancy, _ = pc.spatial_activity(neural_data, position_data, bin_size=bin_size, binarized=binarize)
tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
avg_of_avg, spatial_coherence_values = pc.calculate_spatial_coherence(tuning_curves, ksize=8)
bits_per_event = pc.skaggs_information_content(tuning_curves, occupancy)
first_second = pc.first_second_half_stability(neural_data, bin_size=bin_size)
odd_even = pc.odd_even_stability(neural_data, bin_size=bin_size)

shuffled_si, shuffled_sc = pc.shuffle_spatial_metrics(neural_data, lin_pos_col='lin_position', bin_size=bin_size, nshuffles=nshuffles)
shuffled_first_second, shuffled_odd_even = pc.shuffle_stability_metrics(neural_data, bin_size=bin_size, nshuffles=nshuffles)

In [None]:
import plotly.express as px
color = px.colors.sequential.Blues
10 % len(color)

In [None]:
neuron = 0
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Calcium Event Rate', titles=[f'Example Neuron'], width=500, height=500)
color = px.colors.sequential.Blues

forward, _ = ctb.get_forward_reverse_trials(sdata)
for trial in forward[:-1]:
    idx = int(trial % len(color))
    neural_data = sdata[:, sdata['trials'] == trial].values
    position_data = sdata['lin_position'][sdata['trials'] == trial].values
    population_activity, occupancy, bins = pc.spatial_activity(neural_data, position_data, bin_size=bin_size, binarized=binarize)
    tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
    fig.add_trace(go.Scatter(x=np.arange(1, tuning_curves.shape[0]+1), y=tuning_curves[:, neuron], mode='lines', 
                             line_color=color[idx], opacity=0.8, showlegend=False))
fig.show()
fig.write_image(pjoin(fig_path, f'stability_example_neuron_{neuron}.png'))

In [None]:
neuron = 0
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Calcium Event Rate', titles=[f'Neuron {neuron+1}'], width=800, height=800)

forward, _ = ctb.get_forward_reverse_trials(sdata)
trial_ar = np.zeros((forward.shape[0], 40))
for trial in forward[:-1]:
    neural_data = sdata[:, sdata['trials'] == trial].values
    position_data = sdata['lin_position'][sdata['trials'] == trial].values
    population_activity, occupancy, bins = pc.spatial_activity(neural_data, position_data, bin_size=bin_size, binarized=binarize)
    tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
    trial_ar[int(trial)] = tuning_curves[:, neuron]
fig.add_trace(go.Heatmap(x=np.arange(1, tuning_curves.shape[0]+1), y=np.arange(0, forward.shape[0]), z=trial_ar, colorscale='viridis'))
fig.show()

In [None]:
trial

In [None]:
neuron = 7 # 7 is a non-place cell
xaxis = np.arange(1, population_activity.shape[0]+1)
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Event Rate', titles=['Non-Place Cell'])
fig.add_trace(go.Scatter(x=xaxis, y=tuning_curves[:, neuron]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_tuning_curves[:, neuron]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_info_cell[0]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_info_cell[1]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_info_cell[2]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_info_cell[3]))
fig.update_yaxes(range=[0, 0.035])
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_{session}_{neuron}_tc.png'))

In [None]:
num_events = np.sum(sdata > 0, axis=1)
np.where(num_events > 850)[0]

In [None]:
sdata['unit_id'].values[131]

In [None]:
scaling_factor = 10
neuron = 131
fig = pf.custom_graph_template(x_title='Time (s)', y_title='Activity (a.u.)', width=1000)
fig.add_trace(go.Scatter(x=yra['behav_t'], y=yra[neuron, :], mode='lines', line_color='darkgrey', opacity=0.8))
fig.add_trace(go.Scatter(x=sdata['behav_t'], y=sdata[neuron, :] * scaling_factor, mode='lines', line_color='red', opacity=0.5))
fig.show()

In [None]:
scaling_factor = 10
neuron = 6
fig = pf.custom_graph_template(x_title='Time (s)', y_title='Activity (a.u.)', width=1000)
fig.add_trace(go.Scatter(x=yra['behav_t'], y=yra[neuron, :][running], mode='lines', line_color='darkgrey', opacity=0.8))
fig.add_trace(go.Scatter(x=sdata['behav_t'], y=sdata[neuron, :][running] * scaling_factor, mode='lines', line_color='red', opacity=0.5))
fig.show()

In [None]:
neuron = 64
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Calcium Event Rate', width=600)
fig.add_trace(go.Scatter(x=np.arange(1, tuning_curves.shape[0]+1), y=tuning_curves[:, neuron], line_color='midnightblue', name='Tuning Curve'))
fig.add_trace(go.Scatter(x=np.arange(1, tuning_curves.shape[0]+1), y=avg_of_avg[:, neuron], line_color='darkgrey', name='Coherence Curve'))
fig.show()
print(pearsonr(tuning_curves[:, neuron], avg_of_avg[:, neuron]))
fig.write_image(pjoin(fig_path, 'tuning_coherence_curves.png'))

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information (bits/event)', y_title='Spatial Coherence')
fig.add_trace(go.Scatter(x=sdata['skaggs_info'], y=sdata['coherence'], mode='markers', marker_color='darkgrey'))
fig.show()
# fig.write_image(pjoin(fig_path, 'coherence_info_scatter.png'))

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information (bits/event)', y_title='Probability', titles=['All Cells'])
fig.add_trace(go.Histogram(x=sdata['skaggs_info'], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False))
fig.show()
fig.write_image(pjoin(fig_path, 'all_cells_si.png'))

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information', y_title='', rows=1, columns=3, width=1000,
                               titles=['All Cells', 'Place Cells', 'Non-Place Cells'], shared_x=True, shared_y=True)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=1)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'][skaggs_stability_z], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=2)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'][~skaggs_stability_z], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=3)
fig.update_yaxes(title='Probability', col=1)
fig.show()

In [None]:
np.where(sdata['skaggs_info'][~skaggs_z] > 5)

In [None]:
sdata['skaggs_info'][~skaggs_z][2]

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information', y_title='', rows=1, columns=3, width=1000,
                               titles=['All Cells', 'Place Cells', 'Non-Place Cells'], shared_x=True, shared_y=True)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=1)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'][skaggs_z], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=2)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'][~skaggs_z], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=3)
fig.update_yaxes(title='Probability', col=1)
fig.show()
# fig.write_image(pjoin(fig_path, 'all_place_nonplace_dist_only_si.png'))

In [None]:
neuron = 14
fig = pf.custom_graph_template(x_title='Spatial Information (bits/event)', y_title='Probability')
fig.add_trace(go.Histogram(x=shuffled_si[:, neuron], marker_color='darkgrey', histnorm='probability'))
fig.add_vline(x=bits_per_event[neuron], line_width=1, line_color='red', opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'shuffled_si_{neuron}.png'))

In [None]:
sdata

In [None]:
fig = pf.custom_graph_template(x_title='Odd vs Even Trials Stability', y_title='First Half vs Second Half Stability')
fig.add_trace(go.Scatter(x=sdata['odd_even'], y=sdata['first_second'], mode='markers', marker_color='darkgrey'))
fig.show()
fig.write_image(pjoin(fig_path, 'first_odd_stability_scatter.png'))

In [None]:
## Example tuning curve for a neuron
neuron = 14
bins = np.arange(1, tuning_curves.shape[0]+1)
print(sdata.sel(unit_id=neuron)['skaggs_info'].values)
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Calcium Event Rate')
fig.add_trace(go.Scatter(x=bins, y=tuning_curves[:, neuron], mode='lines', line_color=avg_color))
fig.show()
# fig.write_image(pjoin(fig_path, f'example_place_cell_{neuron}.png'))

In [None]:
fig = pf.custom_graph_template(x_title='Odd-Even Stability', y_title='First-Second Stability')
fig.add_trace(go.Scatter(x=sdata['odd_even_fisherz'], y=sdata['first_second_fisherz'], mode='markers', marker_color='darkgrey'))
fig.show()

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information', y_title='Spatial Coherence')
fig.add_trace(go.Scatter(x=sdata['skaggs_info'], y=sdata['coherence'], mode='markers', marker_color='darkgrey'))
fig.show()

In [None]:
fig = pf.custom_graph_template(x_title='First-Second Stability', y_title='Spatial Information')
fig.add_trace(go.Scatter(x=sdata['first_second'], y=sdata['skaggs_info'], mode='markers', marker_color='darkgrey'))
fig.show()

In [None]:
av = (sdata['first_second_fisherz'] + sdata['odd_even_fisherz']) / 2
fig = pf.custom_graph_template(x_title='', y_title='')
fig.add_trace(go.Histogram(x=av, marker_color='darkgrey'))
fig.show()

### Spatial information for each mouse for each session for place and non-place cells.

In [None]:
data_type = 'S'
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'spatial_info_all': [], 'spatial_info_place': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Control' if mouse in control_mice else 'Experimental'
            for index, session in enumerate(os.listdir(mpath)):
                if '21' in session:
                    pass 
                else:
                    if (mouse == 'mc42') & (index > 14):
                        index += 1
                    elif (mouse == 'mc43') & (index > 11):
                        index += 1
                    elif (mouse == 'mc44') & (index > 7):
                        index += 1
                    elif (mouse == 'mc46') & (index > 9):
                        index += 1
                    elif (mouse == 'mc52') & (index > 2):
                        index += 1
                    elif (mouse == 'mc55') & (index > 2):
                        index += 1
                    
                    sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                    cell_dict['mouse'].append(mouse)
                    cell_dict['group'].append(group)
                    cell_dict['sex'].append(sex)
                    cell_dict['session'].append(sdata.attrs['session_two'])
                    cell_dict['day'].append(index+1)
                    cell_dict['spatial_info_all'].append(sdata['skaggs_info'].mean().values)
                    cell_dict['spatial_info_place'].append(sdata['skaggs_info'][sdata['skaggs_place']].mean().values)
spatial_info = pd.DataFrame(cell_dict)

In [None]:
spatial_info[spatial_info['mouse'] == 'mc52']

In [None]:
## Plot of average spatial information for each mouse for each session.
avg = spatial_info.groupby(['group', 'day'], as_index=False).agg({'spatial_info_all': ['mean', 'sem'], 'spatial_info_place': ['mean', 'sem']})
avg['group'] = avg['group'].replace({'Control': 'Two-Context', 'Experimental': 'Multi-Context'})

fig = pf.custom_graph_template(x_title='Day', y_title='', rows=1, columns=2, shared_x=True, shared_y=True,
                               titles=['All Cells', 'Place Cells'], width=800)

for group in ['Two-Context', 'Multi-Context']:
    gdata = avg[avg['group'] == group]
    ## All cells
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['spatial_info_all']['mean'], mode='lines+markers', name=group, legendgroup=group,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_all']['sem'])), row=1, col=1)
    ## Place cells
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['spatial_info_place']['mean'], mode='lines+markers', name=group, legendgroup=group, showlegend=False,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_place']['sem'])), row=1, col=2)
fig.update_yaxes(title='Spatial Information (bits/event)', col=1)
fig.show()
# fig.write_image(pjoin(fig_path, 'all_cells_place_cells_si.png'))

### Stability for each session

In [None]:
data_type = 'S'
place_cells_only = 'non_place_cells'
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'unit_id': [], 'odd_even': [], 'first_second': [], 'stability': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Two-Context' if mouse in control_mice else 'Multi-Context'
            for index, session in enumerate(os.listdir(mpath)):
                if '21' in session:
                    pass 
                else:
                    if (mouse == 'mc42') & (index > 14):
                        index += 1
                    elif (mouse == 'mc43') & (index > 11):
                        index += 1
                    elif (mouse == 'mc44') & (index > 7):
                        index += 1
                    elif (mouse == 'mc46') & (index > 9):
                        index += 1
                    elif (mouse == 'mc52') & (index > 2):
                        index += 1
                    elif (mouse == 'mc55') & (index > 2):
                        index += 1
                    
                    sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                    if place_cells_only == 'place_cells':
                        sdata = sdata[sdata['skaggs_place'], :]
                    elif place_cells_only == 'non_place_cells':
                        sdata = sdata[~sdata['skaggs_place'], :]
                    elif place_cells_only == 'all_cells':
                        sdata = sdata
                    stability = (sdata['odd_even'].values + sdata['first_second'].values) / 2 
                    odd_even = sdata['odd_even'].values 
                    first_second = sdata['first_second'].values 
                    unit_ids = sdata['unit_id'].values

                    for cell in np.arange(0, sdata.shape[0]):
                        cell_dict['mouse'].append(mouse)
                        cell_dict['group'].append(group)
                        cell_dict['sex'].append(sex)
                        cell_dict['session'].append(sdata.attrs['session_two'])
                        cell_dict['day'].append(index+1)
                        cell_dict['unit_id'].append(unit_ids[cell])
                        cell_dict['odd_even'].append(odd_even[cell])
                        cell_dict['first_second'].append(first_second[cell])
                        cell_dict['stability'].append(stability[cell])
stability_df = pd.DataFrame(cell_dict)
avg_mouse = stability_df.groupby(['group', 'mouse', 'day'], as_index=False).agg({'odd_even': 'mean', 'first_second': 'mean', 'stability': 'mean'})
avg_stab = avg_mouse.groupby(['group', 'day'], as_index=False).agg({'odd_even': ['mean', 'sem'], 'first_second': ['mean', 'sem'], 'stability': ['mean', 'sem']})

In [None]:
## Odd vs even stability
fig = pf.custom_graph_template(x_title='Day', y_title='Odd-Even Trial Stability')
for group in ['Two-Context', 'Multi-Context']:
    gdata = avg_stab[avg_stab['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['odd_even']['mean'], mode='lines+markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['odd_even']['sem'])))
for pos in [5.5, 10.5, 15.5]:
    fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 1])
fig.show()
fig.write_image(pjoin(fig_path, 'odd_even_trial_stability_across_days_non_place_cells.png'))

In [None]:
## First half-second half trial stability
fig = pf.custom_graph_template(x_title='Day', y_title='First-Second Half Trial Stability')
for group in ['Two-Context', 'Multi-Context']:
    gdata = avg_stab[avg_stab['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['first_second']['mean'], mode='lines+markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['first_second']['sem'])))
for pos in [5.5, 10.5, 15.5]:
    fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 1])
fig.show()
fig.write_image(pjoin(fig_path, 'first_second_stability_across_days_non_place_cells.png'))

In [None]:
## Stability metric
fig = pf.custom_graph_template(x_title='Day', y_title='Stability')
for group in ['Two-Context', 'Multi-Context']:
    gdata = avg_stab[avg_stab['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['stability']['mean'], mode='lines+markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['stability']['sem'])))
for pos in [5.5, 10.5, 15.5]:
    fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 1])
fig.show()
fig.write_image(pjoin(fig_path, 'stability_metric_across_days_non_place_cells.png'))

### Plot percentage of place cells across days for both groups.

In [None]:
data_type = 'S'
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'place_skaggs_only': [], 'lick_acc': []} 
            #  'place_coherence_only': [], 'place_coherence_skagg': [], 'percent_correct': [], 'place_stability_only': [],
            #  'place_skaggs_stability': []}

for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Control' if mouse in control_mice else 'Experimental'
            for index, session in enumerate(os.listdir(mpath)):
                if '21' in session:
                    pass 
                else:
                    if (mouse == 'mc42') & (index > 14):
                        index += 1
                    elif (mouse == 'mc43') & (index > 11):
                        index += 1
                    elif (mouse == 'mc44') & (index > 7):
                        index += 1
                    elif (mouse == 'mc46') & (index > 9):
                        index += 1
                    elif (mouse == 'mc52') & (index > 2):
                        index += 1
                    elif (mouse == 'mc55') & (index > 2):
                        index += 1
                    
                    sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                    lick_acc = ctb.lick_accuracy(sdata, port_list=[sdata.attrs['reward_one'], sdata.attrs['reward_two']], lick_threshold=5, by_trials=False)
                    num_cells = sdata.shape[0]

                    cell_dict['mouse'].append(mouse)
                    cell_dict['group'].append(group)
                    cell_dict['sex'].append(sex)
                    cell_dict['session'].append(sdata.attrs['session_two'])
                    cell_dict['day'].append(index+1)
                    cell_dict['place_skaggs_only'].append((np.sum(sdata['skaggs_place'].values) / num_cells) * 100)
                    cell_dict['lick_acc'].append(lick_acc)
                    # cell_dict['place_coherence_only'].append((np.sum(coh_z) / num_cells) * 100)
                    # cell_dict['place_coherence_skagg'].append((np.sum(coh_skaggs_z) / num_cells) * 100)
                    # cell_dict['percent_correct'].append(lick_acc)
                    # cell_dict['place_stability_only'].append((np.sum(avg_stability_z) / num_cells) * 100)
                    # cell_dict['place_skaggs_stability'].append((np.sum(skaggs_stability_z) / num_cells) * 100)
place_df = pd.DataFrame(cell_dict)

In [None]:
place_df['normalized_acc'] = place_df['lick_acc'] / place_df['place_skaggs_only']
place_df

In [None]:
avg = place_df.groupby(['group', 'day'], as_index=False).agg({'normalized_acc': ['mean', 'sem']})
fig = pf.custom_graph_template(x_title='Day', y_title='Acc Metric')
for group in ['Control', 'Experimental']:
    gdata = avg[avg['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['normalized_acc']['mean'], mode='lines+markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['normalized_acc']['sem'])))
fig.show()

In [None]:
## Plot for percentage of place cells based on spatial information
avg_place = place_df.groupby(['group', 'day'], as_index=False).agg({'place_skaggs_only': ['mean', 'sem']})
avg_place = avg_place.replace({'Control': 'Two-Context', 'Experimental': 'Multi-Context'})
fig = pf.custom_graph_template(x_title='Day', y_title='Place Cells (%)', width=600)
for group in ['Two-Context', 'Multi-Context']:
    gdata = avg_place[avg_place['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['place_skaggs_only']['mean'], mode='lines+markers', name=group,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['place_skaggs_only']['sem'])))
for val in [5.5, 10.5, 15.5]:
    fig.add_vline(x=val, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'percentage_place_skaggsonly.png'))

In [None]:
last_days = avg_place[(avg_place['day'] == 5) | (avg_place['day'] == 10) | (avg_place['day'] == 15) | (avg_place['day'] == 20)]
fig = pf.custom_graph_template(x_title='Day', y_title='Place Cells (%)')
for group in ['Two-Context', 'Multi-Context']:
    gdata = last_days[last_days['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['place_skaggs_only']['mean'], mode='markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['place_skaggs_only']['sem'])))
fig.update_yaxes(range=[0, 100])
fig.show()

In [None]:
place_df.mixed_anova(dv='place_skaggs_only', within='day', between='group', subject='mouse')

In [None]:
yvars = ['place_skaggs_only', 'place_coherence_only', 'place_coherence_skagg', 'place_stability_only', 'place_skaggs_stability']
avg_place = place_df.groupby(['group', 'day'], as_index=False).agg({'place_skaggs_only': ['mean', 'sem'],
                                                                    'place_coherence_only': ['mean', 'sem'],
                                                                    'place_coherence_skagg': ['mean', 'sem'],
                                                                    'place_stability_only': ['mean', 'sem'],
                                                                    'place_skaggs_stability': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', 
                               titles=["SI", "Coherence", "Coherence and SI", 'Stability', 'Stability and SI'],
                               rows=1, columns=5, shared_y=True, shared_x=True, width=1200)
for i in np.arange(0, 5):
    y_var = yvars[i]
    for group in avg_place['group'].unique():
        gdata = avg_place[avg_place['group'] == group]
        fig.add_trace(go.Scatter(x=gdata['day'], y=gdata[y_var]['mean'], mode='lines+markers',
                                 marker_color=ce_color_dict[group], showlegend=False, legendgroup=group, name=group,
                                 error_y=dict(type='data', array=gdata[y_var]['sem'])), row=1, col=i+1)
for value in [5.5, 10.5, 15.5]:
    fig.add_vline(x=value, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Place Cells (%)', col=1)
fig.show()
# fig.write_image(pjoin(fig_path, 'all_place_criteria.png'))

In [None]:
yvars = ['place_skaggs_only', 'place_stability_only', 'place_skaggs_stability']
avg_place = place_df.groupby(['group', 'day'], as_index=False).agg({'place_skaggs_only': ['mean', 'sem'],
                                                                    'place_coherence_only': ['mean', 'sem'],
                                                                    'place_coherence_skagg': ['mean', 'sem'],
                                                                    'place_stability_only': ['mean', 'sem'],
                                                                    'place_skaggs_stability': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', 
                               titles=["SI", 'Stability', 'Stability and SI'],
                               rows=1, columns=3, shared_y=True, shared_x=True, width=1200)
for i in np.arange(0, 3):
    y_var = yvars[i]
    for group in avg_place['group'].unique():
        gdata = avg_place[avg_place['group'] == group]
        fig.add_trace(go.Scatter(x=gdata['day'], y=gdata[y_var]['mean'], mode='lines+markers',
                                 marker_color=ce_color_dict[group], showlegend=False, legendgroup=group, name=group,
                                 error_y=dict(type='data', array=gdata[y_var]['sem'])), row=1, col=i+1)
for value in [5.5, 10.5, 15.5]:
    fig.add_vline(x=value, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Place Cells (%)', col=1, range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'percent_place_stability_si.png'))

### Percentage of cells vs spatial information.

In [None]:
data_type = 'S'
spatial_info_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'decile': [], 'spatial_info': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/aligned_minian_place/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Control' if mouse in control_mice else 'Experimental'
            for index, session in enumerate(os.listdir(mpath)):
                if '21' in session:
                    pass 
                else:
                    if (mouse == 'mc42') & (index > 14):
                        index += 1
                    elif (mouse == 'mc43') & (index > 11):
                        index += 1
                    elif (mouse == 'mc44') & (index > 7):
                        index += 1
                    elif (mouse == 'mc46') & (index > 9):
                        index += 1
                    elif (mouse == 'mc52') & (index > 2):
                        index += 1
                    elif (mouse == 'mc55') & (index > 2):
                        index += 1
                    
                    sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                    decile_array = np.arange(0, 101, 10)
                    deciles = np.percentile(sdata['skaggs_info'], decile_array)
                    for idx, decile_info_val in enumerate(deciles):
                        spatial_info_dict['mouse'].append(mouse)
                        spatial_info_dict['group'].append(group)
                        spatial_info_dict['sex'].append(sex)
                        spatial_info_dict['session'].append(sdata.attrs['session_two'])
                        spatial_info_dict['day'].append(index+1)
                        spatial_info_dict['decile'].append(decile_array[idx])
                        spatial_info_dict['spatial_info'].append(decile_info_val)
spatial_df = pd.DataFrame(spatial_info_dict)
avg_spatial = spatial_df.groupby(['group', 'day', 'decile'], as_index=False).agg({'spatial_info': ['mean', 'sem']})

In [None]:
fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, titles=day_list, height=1000, width=1000,
                               shared_y=True, shared_x=True)

for group in avg_spatial['group'].unique():
    gdata = avg_spatial[avg_spatial['group'] == group]
    for day in gdata['day'].unique():
        day = day - 1 ## adjust day to use as an index
        day_data = gdata[gdata['day'] == day + 1]
        if day < 5:
            row, col = 1, day + 1
        elif (day >= 5) & (day < 10):
            row, col = 2, day - 4
        elif (day >= 10) & (day < 15):
            row, col = 3, day - 9
        else:
            row, col = 4, day - 14
        
        fig.add_trace(go.Scatter(x=day_data['spatial_info']['mean'], y=day_data['decile'], mode='markers', showlegend=False,
                                marker_color=ce_color_dict[day_data['group'].unique()[0]], opacity=0.6,
                                error_x=dict(type='data', array=day_data['spatial_info']['sem'])), row=row, col=col)
fig.update_yaxes(title='Percent of Cells', col=1)
fig.update_xaxes(title='Spatial Info', row=4)
fig.show()
fig.write_image(pjoin(fig_path, 'percent_of_cells_with_spatial_info.png'))

### Snake plots for a single mouse for each session.

In [None]:
data_type = 'S'
mouse = 'mc60'
experiment = 'MultiCon_Imaging6'
correct_dir = True 
only_running = True
normalize_plot = True
colorscale = 'cividis'

if mouse in control_mice:
    title_list = control_list 
else:
    title_list = session_list

fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, width=1400, height=1400,
                               titles=title_list)

exp_path = pjoin(dpath, f'{experiment}/output/aligned_minian/')
mpath = pjoin(exp_path, f'{mouse}/{data_type}')
for idx, session in enumerate(os.listdir(mpath)):
    if '21' in session:
        pass 
    else:
        if (mouse == 'mc42') & (idx > 14):
            idx += 1
        elif (mouse == 'mc43') & (idx > 11):
            idx += 1
        elif (mouse == 'mc44') & (idx > 7):
            idx += 1
        elif (mouse == 'mc46') & (idx > 9):
            idx += 1
        elif (mouse == 'mc52') & (idx > 2):
            idx += 1
        elif (mouse == 'mc55') & (idx > 2):
            idx += 1
        
        if idx < 5:
            row, col = 1, idx + 1
        elif (idx >= 5) & (idx < 10):
            row, col = 2, idx - 4
        elif (idx >= 10) & (idx < 15):
            row, col = 3, idx - 9
        else:
            row, col = 4, idx - 14
        
        sdata = xr.open_dataset(pjoin(mpath, f'{session}'))[data_type]
        reward_one_pos = np.mean(sdata['lin_position'][(sdata['lick_port'] == sdata.attrs['reward_one']) & (sdata['water'])].values)
        reward_two_pos = np.mean(sdata['lin_position'][(sdata['lick_port'] == sdata.attrs['reward_two']) & (sdata['water'])].values)

        neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, only_running=only_running, velocity_thresh=velocity_thresh)
        population_activity, occupancy, bins = pc.spatial_activity(neural_data, position_data, bin_size=bin_size)
        tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
        resh_pop = tuning_curves.reshape((tuning_curves.shape[1], tuning_curves.shape[0]))
        centers_resh = np.argmax(resh_pop, axis=1)
        order_resh = np.argsort(centers_resh)

        if normalize_plot:
            max_vals = np.max(resh_pop, axis=1)
            for uid in np.arange(0, resh_pop.shape[0]):
                resh_pop[uid, :] = resh_pop[uid, :] / max_vals[uid]

        fig.add_trace(go.Heatmap(x=bins, y=np.arange(0, sdata.shape[0]), z=resh_pop[order_resh], coloraxis='coloraxis1'), row=row, col=col)
        for pos in [reward_one_pos, reward_two_pos]:
            fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color='red', opacity=1, row=row, col=col)

fig.update_yaxes(autorange='reversed')
fig.update_yaxes(title='Neuron ID', col=1)
fig.update_layout(coloraxis1={'colorscale': colorscale})
fig.show()
fig.write_image(pjoin(fig_path, f'snake_plots_all_sessions_{mouse}.png'))

### Snake plots for a single mouse for each context.

In [None]:
data_type = 'S'
session_list = [f'{x}' for x in np.arange(16, 21)]
ordering_session = '20'
context_str = 'fourth_five'
mouse = 'mc46'
experiment = 'MultiCon_Imaging5'
correct_dir = True 
only_running = True
smooth = False
normalize_plot = True
colorscale = 'viridis'

fig = pf.custom_graph_template(x_title='Position (rad)', y_title='', rows=1, columns=5, height=600, width=1200,
                               shared_y=True, shared_x=True)
exp_path = pjoin(dpath, f'{experiment}/output/aligned_minian/')
mpath = pjoin(exp_path, f'{mouse}/{data_type}')
crossreg_path = pjoin(dpath, f'{experiment}/output/cross_registration_results')
mappings = pd.read_pickle(pjoin(crossreg_path, f'circletrack_data/{mouse}/mappings_meta_{centroid_distance}_{context_str}.pkl'))
mappings.columns = mappings.columns.droplevel(0)

date_list = []
for session in session_list:
    try:
        sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
        date_list.append(sdata.attrs['date'])
    except:
        pass

shared_cells = mappings[date_list].dropna().reset_index(drop=True)

## Order cells based on specified session
sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{ordering_session}.nc'))[data_type]
sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
reward_one_pos = np.mean(sdata['lin_position'][sdata['lick_port'] == sdata.attrs['reward_one']].values)
reward_two_pos = np.mean(sdata['lin_position'][sdata['lick_port'] == sdata.attrs['reward_two']].values)

neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, lin_pos_col='lin_position',
                                                                only_running=only_running, velocity_thresh=velocity_thresh)
if smooth:
    smoothed_data = ctn.moving_average(neural_data, ksize=8) ## 264 ms
else:
    smoothed_data = neural_data.copy()
population_activity, occupancy, bins = pc.spatial_activity(smoothed_data, position_data, bin_size=bin_size)
tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
resh_pop = tuning_curves.reshape((tuning_curves.shape[1], tuning_curves.shape[0]))
centers_resh = np.argmax(resh_pop, axis=1)
order_resh = np.argsort(centers_resh)

cross_session_act = np.empty((len(session_list), shared_cells.shape[0], bins.shape[0]-1))
for idx, session in enumerate(session_list):
    if (mouse == 'mc56') & (context_str == 'first_five') & (idx == 0):
        pass 
    else:
        try:
            sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
        except:
            sdata = None
        
        if sdata is not None:
            sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
            neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, lin_pos_col='lin_position',
                                                                            only_running=only_running, velocity_thresh=velocity_thresh)
            if smooth:
                smoothed_data = ctn.moving_average(neural_data, ksize=8) ## 264 ms
            else:
                smoothed_data = neural_data.copy()
            population_activity, occupancy, bins = pc.spatial_activity(smoothed_data, position_data, bin_size=bin_size)
            tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
            resh_pop = tuning_curves.reshape((tuning_curves.shape[1], tuning_curves.shape[0]))
            cross_session_act[idx, :, :] = resh_pop
            neur_array = np.arange(0, sdata.shape[0])

            if normalize_plot:
                max_vals = np.max(resh_pop, axis=1)
                for uid in np.arange(0, resh_pop.shape[0]):
                    resh_pop[uid, :] = resh_pop[uid, :] / max_vals[uid]

            fig.add_trace(go.Heatmap(x=bins, y=neur_array, z=resh_pop[order_resh], coloraxis='coloraxis1'), row=1, col=idx+1)
        else:
            pass
fig.update_yaxes(autorange='reversed')
fig.update_yaxes(title='Neuron ID', col=1)
fig.update_layout(coloraxis1={'colorscale': colorscale})
for pos in [reward_one_pos, reward_two_pos]:
    fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color='red', opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_snake_plots_{context_str}_{ordering_session}.png'))

In [None]:
position_data

In [None]:
data_type = 'S'
session_dict = {'first_five': [f'{x}' for x in np.arange(1, 6)], 'second_five': [f'{x}' for x in np.arange(6, 11)],
                'third_five': [f'{x}' for x in np.arange(11, 16)], 'fourth_five': [f'{x}' for x in np.arange(16, 21)]}
context_str_list = ['first_five', 'second_five', 'third_five', 'fourth_five']
experiment_list = ['MultiCon_Imaging5', 'MultiCon_Imaging6']
order_by = 'first' ## last or first
if order_by == 'last':
    mouse_list = ['mc44', 'mc46', 'mc48', 'mc49', 'mc51', 'mc52', 'mc54', 'mc56', 'mc58']
else:
    mouse_list = ['mc44', 'mc48', 'mc49', 'mc51', 'mc52', 'mc54', 'mc58']
correct_dir = True 
only_running = True
res_dict = {'mouse': [], 'group': [], 'sex': [], 'day': [], 'context_str': [], 'cell_idx': [], 'r': [], 'pval': []}

for experiment in experiment_list:
    exp_path = pjoin(dpath, f'{experiment}/output/aligned_minian/')
    for mouse in os.listdir(exp_path):
        if mouse in mouse_list:
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Control' if mouse in control_mice else 'Experimental'
            for context_str in context_str_list:
                crossreg_path = pjoin(dpath, f'{experiment}/output/cross_registration_results')
                mappings = pd.read_pickle(pjoin(crossreg_path, f'circletrack_data/{mouse}/mappings_meta_{centroid_distance}_{context_str}.pkl'))
                mappings.columns = mappings.columns.droplevel(0)

                date_list = []
                for session in session_dict[context_str]:
                    try:
                        sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
                        date_list.append(sdata.attrs['date'])
                    except:
                        pass

                shared_cells = mappings[date_list].dropna().reset_index(drop=True)

                ## Order cells based on specified session
                if order_by == 'last':
                    ordering_session = session_dict[context_str][-1]
                elif order_by == 'first':
                    ordering_session = session_dict[context_str][0]
                    
                sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{ordering_session}.nc'))[data_type]
                sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
                reward_one_pos = np.mean(sdata['lin_position'][sdata['lick_port'] == sdata.attrs['reward_one']].values)
                reward_two_pos = np.mean(sdata['lin_position'][sdata['lick_port'] == sdata.attrs['reward_two']].values)

                neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, only_running=only_running, velocity_thresh=velocity_thresh)
                population_activity, occupancy, bins = pc.average_spatial_activity(neural_data, position_data, bin_size=bin_size)
                resh_pop = population_activity.reshape((population_activity.shape[1], population_activity.shape[0]))
                centers_resh = np.argmax(resh_pop, axis=1)
                order_resh = np.argsort(centers_resh)

                cross_session_act = np.empty((len(session_dict[context_str]), shared_cells.shape[0], bins.shape[0]-1))
                cross_session_act[:] = np.nan
                for idx, session in enumerate(session_dict[context_str]):
                    if (mouse == 'mc56') & (context_str == 'first_five') & (idx == 0):
                        pass
                    else:
                        try:
                            sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
                        except:
                            sdata = None
                        
                        if sdata is not None:
                            sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
                            neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, 
                                                                                            only_running=only_running, velocity_thresh=velocity_thresh)
                            population_activity, occupancy, bins = pc.average_spatial_activity(neural_data, position_data, bin_size=bin_size)
                            resh_pop = population_activity.reshape((population_activity.shape[1], population_activity.shape[0]))
                            cross_session_act[idx, :, :] = resh_pop
                        else:
                            pass
                
                ref_idx = session_dict[context_str].index(ordering_session)
                ref_sess = cross_session_act[ref_idx]

                for day in np.arange(0, cross_session_act.shape[0]):
                    mat = cross_session_act[day]
                    if np.isnan(mat).any():
                        print(f'Nans found for {mouse} in the {context_str} on {day}!')
                        pass 
                    else:
                        for cell in np.arange(0, mat.shape[0]):
                            res = pearsonr(ref_sess[cell], mat[cell])

                            res_dict['mouse'].append(mouse)
                            res_dict['group'].append(group)
                            res_dict['sex'].append(sex)
                            res_dict['day'].append(day+1)
                            res_dict['context_str'].append(context_str)
                            res_dict['cell_idx'].append(cell)
                            res_dict['r'].append(res[0])
                            res_dict['pval'].append(res[1])
res_df = pd.DataFrame(res_dict)

In [None]:
avg_res = res_df.groupby(['group', 'context_str', 'day'], as_index=False).agg({'r': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', rows=1, columns=4, width=1000, shared_y=True, shared_x=True,
                               titles=['A', 'B', 'C', 'D'])
for group in ['Control', 'Experimental']:
    gdata = avg_res[avg_res['group'] == group]
    for context_str in avg_res['context_str'].unique():
        if context_str == 'first_five':
            col = 1
        elif context_str == 'second_five':
            col = 2
        elif context_str == 'third_five':
            col = 3
        else:
            col = 4
    
        cdata = gdata[gdata['context_str'] == context_str]
        fig.add_trace(go.Scatter(x=cdata['day'], y=cdata['r']['mean'], mode='markers', marker_color=ce_color_dict[group], opacity=0.6,
                                 name=group, error_y=dict(type='data', array=cdata['r']['sem']), showlegend=False), row=1, col=col)
fig.update_yaxes(title='PV Correlation', col=1)
fig.show()

### Reproducing Will's snake plots

In [None]:
data_type = 'S'
exp_path = '/media/caishuman/csstorage/phild/git/MazeProjects/output'
session_list = ['Training1', 'Training2', 'Training3', 'Training4', 'Reversal1']
ordering_session = 'Training4'
mouse = 'Atlas'
correct_dir = True 
only_running = True
normalize_plot = False

fig = pf.custom_graph_template(x_title='Position (rad)', y_title='', rows=1, columns=5, width=1000,
                               shared_y=True, shared_x=True)

mappings = pd.read_feather(pjoin(exp_path, f'registration/{mouse}.feat'))
shared_cells = mappings.dropna().reset_index(drop=True)

## Order cells based on specified session
sdata = xr.open_dataset(pjoin(exp_path, f'processed/{mouse}_{ordering_session}.nc'))[data_type]
sdata = sdata.sel(unit_id=shared_cells[ordering_session].values)
behav = pd.read_feather(pjoin(exp_path, f'behav/{mouse}_{ordering_session}.feat'))

neural_data = sdata.values
position_data = behav['lin_position'].values
sdata = sdata.assign_coords(trials=('frame', behav['trials']),
                            lin_position=('frame', behav['lin_position']),
                            x=('frame', behav['x']),
                            y=('frame', behav['y']))
neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, 
                                                                only_running=only_running, velocity_thresh=velocity_thresh)


population_activity, occupancy, bins = pc.average_spatial_activity(neural_data, position_data, bin_size=bin_size)
resh_pop = population_activity.reshape((population_activity.shape[1], population_activity.shape[0]))
centers_resh = np.argmax(resh_pop, axis=1)
order_resh = np.argsort(centers_resh)

# cross_session_act = np.empty((len(session_list), shared_cells.shape[0], bins.shape[0]-1))
# for idx, session in enumerate(session_list):
#     try:
#         sdata = xr.open_dataset(pjoin(exp_path, f'processed/{mouse}_{session}.nc'))[data_type]
#     except:
#         sdata = None

#     if sdata is not None:
#         sdata = sdata.sel(unit_id=shared_cells[session].values)
#         behav = pd.read_feather(pjoin(exp_path, f'behav/{mouse}_{session}.feat'))
#         neural_data = sdata.values
#         position_data = behav['lin_position'].values
#         # neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, 
#         #                                                                 only_running=only_running, velocity_thresh=velocity_thresh)
#         population_activity, occupancy, bins = pc.average_spatial_activity(neural_data, position_data, bin_size=bin_size)
#         resh_pop = population_activity.reshape((population_activity.shape[1], population_activity.shape[0]))
#         cross_session_act[idx, :, :] = resh_pop
#         neur_array = np.arange(0, sdata.shape[0])

#         if normalize_plot:
#             max_vals = np.max(resh_pop, axis=1)
#             for uid in np.arange(0, resh_pop.shape[0]):
#                 resh_pop[uid, :] = resh_pop[uid, :] / max_vals[uid]

#         fig.add_trace(go.Heatmap(x=bins, y=neur_array, z=resh_pop[order_resh], 
#                         coloraxis='coloraxis1'), row=1, col=idx+1)
#     else:
#         pass
# fig.update_yaxes(autorange='reversed')
# fig.update_yaxes(title='Neuron ID', col=1)
# fig.show()

In [None]:
import numpy as np
from scipy import ndimage as nd
def define_population_bursts(ar, min_len=3, zthresh=2):
    ## Z-score each cell, then get the average activity of all cells, then z-score again
    pop_act = zscore(np.mean(zscore(ar, axis=1), axis=0))
    ## Get frames where the population activity is above some z-threshold
    above_thresh = pop_act > zthresh
    ## Label every frame not above zthresh with zero, label frames above zthresh with what burst they belong to
    burst_array, burst_count = nd.label(above_thresh)
    burst_start = np.array([np.min(np.where(burst_array==b + 1)[0]) for b in np.arange(burst_count) if np.where(burst_array==b+1)[0].shape[0] >= min_len])
    burst_end = np.array([np.max(np.where(burst_array==b + 1)[0]) for b in np.arange(burst_count) if np.where(burst_array==b+1)[0].shape[0] >= min_len])
    return burst_start, burst_end

In [None]:
## Population burst testing
thresh = 1.96
min_len = 3
pop_act = zscore(np.mean(zscore(sdata, axis=1), axis=0))
bursts_og = pop_act > thresh
burst_start, burst_end = define_population_bursts(sdata, min_len=min_len, zthresh=thresh)
burst_mid = np.round((burst_start + burst_end) / 2).astype(int)

fig = pf.custom_graph_template(x_title='Time (s)', y_title='Z-Scored Activity', rows=2, columns=1, 
                               height=800, width=1000, shared_x=True, shared_y=True)
## OG bursts (every frame considered a separate burst)
fig.add_trace(go.Scattergl(x=sdata['behav_t'], y=pop_act, mode='lines', line_color='black', showlegend=False), row=1, col=1)
fig.add_trace(go.Scattergl(x=sdata['behav_t'][bursts_og], y=pop_act[bursts_og], mode='markers', line_color='red', showlegend=False), row=1, col=1)
## New bursts
fig.add_trace(go.Scattergl(x=sdata['behav_t'], y=pop_act, mode='lines', line_color='black', showlegend=False), row=2, col=1)
fig.add_trace(go.Scattergl(x=sdata['behav_t'][burst_mid], y=pop_act[burst_mid], mode='markers', marker_color='red', showlegend=False), row=2, col=1)
fig.add_hline(y=thresh, line_width=1, line_color='red', line_dash='dash', opacity=1)
fig.show()
fig.write_html(pjoin(fig_path, 'population_bursts_example.html'))