In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
from os.path import join as pjoin
from tqdm.notebook import tqdm
from sklearn.metrics import mutual_info_score
import plotly.graph_objects as go
from scipy.stats import pearsonr, spearmanr, zscore
from natsort import natsorted
import statsmodels.api as sm
from statsmodels.formula.api import ols
from numpy.random import RandomState, SeedSequence, MT19937

sys.path.append('../')
import circletrack_behavior as ctb
import circletrack_neural as ctn
import place_cells as pc
import plotting_functions as pf

def set_group(mouse, control_mice):
    return 'Two-context' if mouse in control_mice else 'Multi-context'

In [None]:
## Settings
project_folder = ['MultiCon_Imaging']
experiment_folders = ['MultiCon_Imaging5', 'MultiCon_Imaging6', 'MultiCon_Imaging7']
dpath = f'../../{project_folder[0]}'
fig_path = f'../../../Manuscripts/MultiCon/intermediate_plots/spatial_info'
chance_color = '#7d7d7d'
avg_color = '#287347'
subject_color = '#7d7d7d'
ce_colors = ['#7A22BC', '#378616']
ce_colors_dict = {'Two-context': '#378616', 'Multi-context': '#7A22BC'}
symbol_dict = {'Two-context': 'x', 'Multi-context': 'circle'}
symbols_list = ['x', 'circle']
context_colors = {'A': '#00802d', 'B': '#006c79', 'C': '#004da4', 'D': '#430073'}
mouse_colors = ['midnightblue', 'darkred', 'darkorchid', 'darkturquoise']
male_mice = ['mc44', 'mc46', 'mc54', 'mc55']
control_mice = ['mc46', 'mc49', 'mc52', 'mc54', 'mc59', 'mc60']
imaging5 = ['mc44', 'mc46', 'mc48', 'mc49', 'mc51', 'mc52']
session_list = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]
control_list = [f'A{x}' for x in np.arange(1, 16)] + [f'B{x}' for x in np.arange(1, 6)]
day_list = [f'Day {x}' for x in np.arange(1, 21)]
bin_size = 0.16 ## size of linear position bins equivalent to 2cm-wide bins
velocity_thresh = 10
centroid_distance = 4
data_of_interest = 'aligned_place_cells' ## one of behav, aligned_minian, aligned_place_cells, lin_behav
z_thresh = 1.96

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

xr.set_options(keep_attrs=True)

rs = RandomState(MT19937(SeedSequence(24601)))

### Spatial information shuffling testing.

In [None]:
data_type = 'S'
mouse = 'mc51'
session = '20'
correct_dir = True
only_running = True
minimum_firing_amount = 0.2
nshuffles = 10
data_path = f'../../{project_folder[0]}/{experiment_folders[0]}/output/aligned_place_cells/{mouse}/{data_type}/{mouse}_{data_type}_{session}.nc'
yra_path = f'../../{project_folder[0]}/{experiment_folders[0]}/output/aligned_minian/{mouse}/YrA/{mouse}_YrA_{session}.nc'

sdata = xr.open_dataset(data_path)[data_type]
yra = xr.open_dataset(yra_path)['YrA']
num_neurons = sdata.shape[0]

x_cm, y_cm = ctb.convert_to_cm(x=sdata['x'], y=sdata['y'])
velocity, running = pc.define_running_epochs(x_cm, y_cm, sdata['behav_t'], velocity_thresh=velocity_thresh)
neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, only_running=only_running, 
                                                                velocity_thresh=velocity_thresh)

# ar = neural_data.copy()
# nshuffles = 10
# shuffled_info_cell = np.zeros((nshuffles, 40))
# for shuffle in np.arange(0, nshuffles):
#     ## Shift position data forward by a random number of frames
#     shuffled_data = np.array(())
#     for trial in np.unique(neural_data['trials']):
#         position = ar['lin_position'][ar['trials'] == trial].values
#         random_shift = np.random.randint(60, position.shape[0]-1)  
#         rolled_position = np.roll(position, random_shift)
#         shuffled_data = np.concatenate((shuffled_data, rolled_position))
#     ## Calculate spatial coherence and spatial information with the reordered data
#     population_activity, occupancy, _ = pc.spatial_activity(ar, shuffled_data, bin_size=bin_size, binarized=binarize)
#     shuffled_tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
#     shuffled_info_cell[shuffle] = shuffled_tuning_curves[:, 8]
#     shuffled_info = pc.skaggs_information_content(shuffled_tuning_curves, occupancy)

## Calculate observed Skagg's information content and observed spatial coherence
population_activity, occupancy, _ = pc.spatial_activity(neural_data, position_data, bin_size=bin_size, binarized=binarize)
tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
avg_of_avg, spatial_coherence_values = pc.calculate_spatial_coherence(tuning_curves, ksize=8)
bits_per_event = pc.skaggs_information_content(tuning_curves, occupancy)
first_second = pc.first_second_half_stability(neural_data, bin_size=bin_size)
odd_even = pc.odd_even_stability(neural_data, bin_size=bin_size)

shuffled_si, shuffled_sc = pc.shuffle_spatial_metrics(neural_data, lin_pos_col='lin_position', bin_size=bin_size, nshuffles=nshuffles)
shuffled_first_second, shuffled_odd_even = pc.shuffle_stability_metrics(neural_data, bin_size=bin_size, nshuffles=nshuffles)

In [None]:
import plotly.express as px
color = px.colors.sequential.Blues
10 % len(color)

In [None]:
neuron = 0
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Calcium Event Rate', titles=[f'Example Neuron'], width=500, height=500)
color = px.colors.sequential.Blues

forward, _ = ctb.get_forward_reverse_trials(sdata)
for trial in forward[:-1]:
    idx = int(trial % len(color))
    neural_data = sdata[:, sdata['trials'] == trial].values
    position_data = sdata['lin_position'][sdata['trials'] == trial].values
    population_activity, occupancy, bins = pc.spatial_activity(neural_data, position_data, bin_size=bin_size, binarized=binarize)
    tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
    fig.add_trace(go.Scatter(x=np.arange(1, tuning_curves.shape[0]+1), y=tuning_curves[:, neuron], mode='lines', 
                             line_color=color[idx], opacity=0.8, showlegend=False))
fig.show()
fig.write_image(pjoin(fig_path, f'stability_example_neuron_{neuron}.png'))

In [None]:
neuron = 0
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Calcium Event Rate', titles=[f'Neuron {neuron+1}'], width=800, height=800)

forward, _ = ctb.get_forward_reverse_trials(sdata)
trial_ar = np.zeros((forward.shape[0], 40))
for trial in forward[:-1]:
    neural_data = sdata[:, sdata['trials'] == trial].values
    position_data = sdata['lin_position'][sdata['trials'] == trial].values
    population_activity, occupancy, bins = pc.spatial_activity(neural_data, position_data, bin_size=bin_size, binarized=binarize)
    tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
    trial_ar[int(trial)] = tuning_curves[:, neuron]
fig.add_trace(go.Heatmap(x=np.arange(1, tuning_curves.shape[0]+1), y=np.arange(0, forward.shape[0]), z=trial_ar, colorscale='viridis'))
fig.show()

In [None]:
trial

In [None]:
neuron = 7 # 7 is a non-place cell
xaxis = np.arange(1, population_activity.shape[0]+1)
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Event Rate', titles=['Non-Place Cell'])
fig.add_trace(go.Scatter(x=xaxis, y=tuning_curves[:, neuron]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_tuning_curves[:, neuron]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_info_cell[0]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_info_cell[1]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_info_cell[2]))
# fig.add_trace(go.Scatter(x=xaxis, y=shuffled_info_cell[3]))
fig.update_yaxes(range=[0, 0.035])
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_{session}_{neuron}_tc.png'))

In [None]:
num_events = np.sum(sdata > 0, axis=1)
np.where(num_events > 850)[0]

In [None]:
sdata['unit_id'].values[131]

In [None]:
scaling_factor = 10
neuron = 131
fig = pf.custom_graph_template(x_title='Time (s)', y_title='Activity (a.u.)', width=1000)
fig.add_trace(go.Scatter(x=yra['behav_t'], y=yra[neuron, :], mode='lines', line_color='darkgrey', opacity=0.8))
fig.add_trace(go.Scatter(x=sdata['behav_t'], y=sdata[neuron, :] * scaling_factor, mode='lines', line_color='red', opacity=0.5))
fig.show()

In [None]:
scaling_factor = 10
neuron = 6
fig = pf.custom_graph_template(x_title='Time (s)', y_title='Activity (a.u.)', width=1000)
fig.add_trace(go.Scatter(x=yra['behav_t'], y=yra[neuron, :][running], mode='lines', line_color='darkgrey', opacity=0.8))
fig.add_trace(go.Scatter(x=sdata['behav_t'], y=sdata[neuron, :][running] * scaling_factor, mode='lines', line_color='red', opacity=0.5))
fig.show()

In [None]:
neuron = 64
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Calcium Event Rate', width=600)
fig.add_trace(go.Scatter(x=np.arange(1, tuning_curves.shape[0]+1), y=tuning_curves[:, neuron], line_color='midnightblue', name='Tuning Curve'))
fig.add_trace(go.Scatter(x=np.arange(1, tuning_curves.shape[0]+1), y=avg_of_avg[:, neuron], line_color='darkgrey', name='Coherence Curve'))
fig.show()
print(pearsonr(tuning_curves[:, neuron], avg_of_avg[:, neuron]))
fig.write_image(pjoin(fig_path, 'tuning_coherence_curves.png'))

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information (bits/event)', y_title='Spatial Coherence')
fig.add_trace(go.Scatter(x=sdata['skaggs_info'], y=sdata['coherence'], mode='markers', marker_color='darkgrey'))
fig.show()
# fig.write_image(pjoin(fig_path, 'coherence_info_scatter.png'))

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information (bits/event)', y_title='Probability', titles=['All Cells'])
fig.add_trace(go.Histogram(x=sdata['skaggs_info'], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False))
fig.show()
fig.write_image(pjoin(fig_path, 'all_cells_si.png'))

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information', y_title='', rows=1, columns=3, width=1000,
                               titles=['All Cells', 'Place Cells', 'Non-Place Cells'], shared_x=True, shared_y=True)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=1)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'][skaggs_stability_z], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=2)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'][~skaggs_stability_z], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=3)
fig.update_yaxes(title='Probability', col=1)
fig.show()

In [None]:
np.where(sdata['skaggs_info'][~skaggs_z] > 5)

In [None]:
sdata['skaggs_info'][~skaggs_z][2]

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information', y_title='', rows=1, columns=3, width=1000,
                               titles=['All Cells', 'Place Cells', 'Non-Place Cells'], shared_x=True, shared_y=True)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=1)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'][skaggs_z], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=2)
fig.add_trace(go.Histogram(x=sdata['skaggs_info'][~skaggs_z], marker_color='darkgrey', 
                            histnorm='probability', showlegend=False), row=1, col=3)
fig.update_yaxes(title='Probability', col=1)
fig.show()
# fig.write_image(pjoin(fig_path, 'all_place_nonplace_dist_only_si.png'))

In [None]:
neuron = 14
fig = pf.custom_graph_template(x_title='Spatial Information (bits/event)', y_title='Probability')
fig.add_trace(go.Histogram(x=shuffled_si[:, neuron], marker_color='darkgrey', histnorm='probability'))
fig.add_vline(x=bits_per_event[neuron], line_width=1, line_color='red', opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'shuffled_si_{neuron}.png'))

In [None]:
sdata

In [None]:
fig = pf.custom_graph_template(x_title='Odd vs Even Trials Stability', y_title='First Half vs Second Half Stability')
fig.add_trace(go.Scatter(x=sdata['odd_even'], y=sdata['first_second'], mode='markers', marker_color='darkgrey'))
fig.show()
fig.write_image(pjoin(fig_path, 'first_odd_stability_scatter.png'))

In [None]:
## Example tuning curve for a neuron
neuron = 14
bins = np.arange(1, tuning_curves.shape[0]+1)
print(sdata.sel(unit_id=neuron)['skaggs_info'].values)
fig = pf.custom_graph_template(x_title='Spatial Bin', y_title='Calcium Event Rate')
fig.add_trace(go.Scatter(x=bins, y=tuning_curves[:, neuron], mode='lines', line_color=avg_color))
fig.show()
# fig.write_image(pjoin(fig_path, f'example_place_cell_{neuron}.png'))

In [None]:
fig = pf.custom_graph_template(x_title='Odd-Even Stability', y_title='First-Second Stability')
fig.add_trace(go.Scatter(x=sdata['odd_even_fisherz'], y=sdata['first_second_fisherz'], mode='markers', marker_color='darkgrey'))
fig.show()

In [None]:
fig = pf.custom_graph_template(x_title='Spatial Information', y_title='Spatial Coherence')
fig.add_trace(go.Scatter(x=sdata['skaggs_info'], y=sdata['coherence'], mode='markers', marker_color='darkgrey'))
fig.show()

In [None]:
fig = pf.custom_graph_template(x_title='First-Second Stability', y_title='Spatial Information')
fig.add_trace(go.Scatter(x=sdata['first_second'], y=sdata['skaggs_info'], mode='markers', marker_color='darkgrey'))
fig.show()

In [None]:
av = (sdata['first_second_fisherz'] + sdata['odd_even_fisherz']) / 2
fig = pf.custom_graph_template(x_title='', y_title='')
fig.add_trace(go.Histogram(x=av, marker_color='darkgrey'))
fig.show()

### Spatial information for each mouse for each session for place and non-place cells.

In [None]:
data_type = 'S'
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'spatial_info_all': [], 'spatial_info_place': [], 'spatial_info_nonplace': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = ctn.set_sex(mouse, male_mice)
            group = ctn.set_group(mouse, control_mice)
            for index, session in enumerate(natsorted(os.listdir(mpath))):
                index = ctn.mouse_indices(mouse, index)
                sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                sdata = sdata[sdata['minimum_activity_met'], :] ## remove neurons who didn't meet minimum activity requirement
                cell_dict['mouse'].append(mouse)
                cell_dict['group'].append(group)
                cell_dict['sex'].append(sex)
                cell_dict['session'].append(sdata.attrs['session_two'])
                cell_dict['day'].append(index+1)
                cell_dict['spatial_info_all'].append(np.mean(sdata['skaggs_info'].values))
                cell_dict['spatial_info_place'].append(np.mean(sdata['skaggs_info'].values[sdata['skaggs_place'].values]))
                cell_dict['spatial_info_nonplace'].append(np.mean(sdata['skaggs_info'].values[~sdata['skaggs_place'].values]))
spatial_info = pd.DataFrame(cell_dict)

In [None]:
## Plot of average spatial information for each mouse for each session.
avg = spatial_info.groupby(['group', 'day'], as_index=False).agg({'spatial_info_all': ['mean', 'sem'], 'spatial_info_place': ['mean', 'sem'],
                                                                  'spatial_info_nonplace': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', rows=1, columns=3, shared_x=True, shared_y=True,
                               titles=['All Cells', 'Place Cells', 'Non-place Cells'], width=1000)

for group in ['Two-context', 'Multi-context']:
    gdata = avg[avg['group'] == group]
    ## All cells
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['spatial_info_all']['mean'], mode='lines+markers', name=group, legendgroup=group,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_all']['sem'])), row=1, col=1)
    ## Place cells
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['spatial_info_place']['mean'], mode='lines+markers', name=group, legendgroup=group, showlegend=False,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_place']['sem'])), row=1, col=2)
    ## Non-place cells
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['spatial_info_nonplace']['mean'], mode='lines+markers', name=group, legendgroup=group, showlegend=False,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_nonplace']['sem'])), row=1, col=3)
for pos in [5.5, 10.5, 15.5, 20.5]:
    fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Spatial Information (bits/event)', col=1)
fig.show()
fig.write_image(pjoin(fig_path, 'all_cells_place_cells_si.png'), width=1000, height=500)

In [None]:
## Perform a linear regression to test for differences between groups for all cells. 
data = spatial_info[spatial_info['day'] == 16].reset_index(drop=True)
model = ols('spatial_info_all ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

### Stability for each session

In [None]:
data_type = 'S'
place_cells_only = 'place_cells'
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'unit_id': [], 'odd_even': [], 'first_second': [], 'stability': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = ctn.set_sex(mouse, male_mice)
            group = ctn.set_group(mouse, control_mice)
            for index, session in enumerate(natsorted(os.listdir(mpath))):
                index = ctn.mouse_indices(mouse, index)
                sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                if place_cells_only == 'place_cells':
                    sdata = sdata[sdata['skaggs_place'], :]
                elif place_cells_only == 'non_place_cells':
                    sdata = sdata[~sdata['skaggs_place'], :]
                elif place_cells_only == 'all_cells':
                    sdata = sdata
                stability = (sdata['odd_even'].values + sdata['first_second'].values) / 2 
                odd_even = sdata['odd_even'].values 
                first_second = sdata['first_second'].values 
                unit_ids = sdata['unit_id'].values

                for cell in np.arange(0, sdata.shape[0]):
                    cell_dict['mouse'].append(mouse)
                    cell_dict['group'].append(group)
                    cell_dict['sex'].append(sex)
                    cell_dict['session'].append(sdata.attrs['session_two'])
                    cell_dict['day'].append(index+1)
                    cell_dict['unit_id'].append(unit_ids[cell])
                    cell_dict['odd_even'].append(odd_even[cell])
                    cell_dict['first_second'].append(first_second[cell])
                    cell_dict['stability'].append(stability[cell])
stability_df = pd.DataFrame(cell_dict)
avg_mouse = stability_df.groupby(['group', 'mouse', 'day'], as_index=False).agg({'odd_even': 'mean', 'first_second': 'mean', 'stability': 'mean'})
avg_stab = avg_mouse.groupby(['group', 'day'], as_index=False).agg({'odd_even': ['mean', 'sem'], 'first_second': ['mean', 'sem'], 'stability': ['mean', 'sem']})

In [None]:
## Odd vs even stability
fig = pf.custom_graph_template(x_title='Day', y_title='Odd-Even Trial Stability')
for group in ['Two-context', 'Multi-context']:
    gdata = avg_stab[avg_stab['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['odd_even']['mean'], mode='lines+markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['odd_even']['sem'])))
for pos in [5.5, 10.5, 15.5, 20.5]:
    fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 1])
fig.show()
fig.write_image(pjoin(fig_path, f'odd_even_trial_stability_across_days_{place_cells_only}.png'), width=600, height=500)

In [None]:
## First half-second half trial stability
fig = pf.custom_graph_template(x_title='Day', y_title='First-Second Half Trial Stability')
for group in ['Two-context', 'Multi-context']:
    gdata = avg_stab[avg_stab['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['first_second']['mean'], mode='lines+markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['first_second']['sem'])))
for pos in [5.5, 10.5, 15.5, 20.5]:
    fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 1])
fig.show()
fig.write_image(pjoin(fig_path, f'first_second_stability_across_days_{place_cells_only}.png'), width=600, height=500)

In [None]:
## Stability metric
fig = pf.custom_graph_template(x_title='Day', y_title='Stability')
for group in ['Two-context', 'Multi-context']:
    gdata = avg_stab[avg_stab['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['stability']['mean'], mode='lines+markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['stability']['sem'])))
for pos in [5.5, 10.5, 15.5, 20.5]:
    fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 1])
fig.show()
fig.write_image(pjoin(fig_path, f'stability_metric_across_days_{place_cells_only}.png'), width=600, height=500)

### Plot percentage of place cells across days for both groups.

In [None]:
data_type = 'S'
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'place_skaggs_only': [], 'place_coherence_only': []}

for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = ctn.set_sex(mouse, male_mice)
            group = ctn.set_group(mouse, control_mice)
            for index, session in enumerate(natsorted(os.listdir(mpath))):
                index = ctn.mouse_indices(mouse, index)
                sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                num_cells = sdata.shape[0]

                cell_dict['mouse'].append(mouse)
                cell_dict['group'].append(group)
                cell_dict['sex'].append(sex)
                cell_dict['session'].append(sdata.attrs['session_two'])
                cell_dict['day'].append(index+1)
                cell_dict['place_skaggs_only'].append((np.sum(sdata['skaggs_place'].values) / num_cells) * 100)
                cell_dict['place_coherence_only'].append((np.sum(sdata['coherence_place']) / num_cells) * 100)
place_df = pd.DataFrame(cell_dict)

In [None]:
## Plot for percentage of place cells based on spatial information
avg_place = place_df.groupby(['group', 'day'], as_index=False).agg({'place_skaggs_only': ['mean', 'sem']})
avg_place = avg_place.replace({'Control': 'Two-Context', 'Experimental': 'Multi-Context'})
fig = pf.custom_graph_template(x_title='Day', y_title='Place Cells (%)', width=600)
for group in ['Two-context', 'Multi-context']:
    gdata = avg_place[avg_place['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['place_skaggs_only']['mean'], mode='lines+markers', name=group,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['place_skaggs_only']['sem'])))
for val in [5.5, 10.5, 15.5, 20.5]:
    fig.add_vline(x=val, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'percentage_place_skaggsonly.png'), width=600, height=500)

In [None]:
## Perform a linear regression to test for differences between groups. 
data = place_df[place_df['day'] == 16].reset_index(drop=True)
model = ols('place_skaggs_only ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

In [None]:
## Plot percentage of place cells on the first day in each context
first_days = avg_place[(avg_place['day'] == 1) | (avg_place['day'] == 6) | (avg_place['day'] == 11) | (avg_place['day'] == 16)]
fig = pf.custom_graph_template(x_title='Day', y_title='Place Cells (%)')
for group in ['Two-context', 'Multi-context']:
    gdata = first_days[first_days['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['place_skaggs_only']['mean'], mode='markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['place_skaggs_only']['sem'])))
fig.update_yaxes(range=[0, 100])
fig.show()

In [None]:
## Plot percentage of place cells on the final switch
switch = avg_place[(avg_place['day'] == 15) | (avg_place['day'] == 16)]
switch.loc[:, 'day'] = switch.loc[:, 'day'].astype(str)
fig = pf.custom_graph_template(x_title='Day', y_title='Place Cells (%)')
for group in ['Two-context', 'Multi-context']:
    gdata = switch[switch['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['place_skaggs_only']['mean'], mode='markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['place_skaggs_only']['sem'])))
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'percent_place_final_switch.png'), width=500, height=500)

In [None]:
yvars = ['place_skaggs_only', 'place_coherence_only', 'place_coherence_skagg', 'place_stability_only', 'place_skaggs_stability']
avg_place = place_df.groupby(['group', 'day'], as_index=False).agg({'place_skaggs_only': ['mean', 'sem'],
                                                                    'place_coherence_only': ['mean', 'sem'],
                                                                    'place_coherence_skagg': ['mean', 'sem'],
                                                                    'place_stability_only': ['mean', 'sem'],
                                                                    'place_skaggs_stability': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', 
                               titles=["SI", "Coherence", "Coherence and SI", 'Stability', 'Stability and SI'],
                               rows=1, columns=5, shared_y=True, shared_x=True, width=1200)
for i in np.arange(0, 5):
    y_var = yvars[i]
    for group in avg_place['group'].unique():
        gdata = avg_place[avg_place['group'] == group]
        fig.add_trace(go.Scatter(x=gdata['day'], y=gdata[y_var]['mean'], mode='lines+markers',
                                 marker_color=ce_colors_dict[group], showlegend=False, legendgroup=group, name=group,
                                 error_y=dict(type='data', array=gdata[y_var]['sem'])), row=1, col=i+1)
for value in [5.5, 10.5, 15.5, 20.5]:
    fig.add_vline(x=value, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Place Cells (%)', col=1)
fig.show()
# fig.write_image(pjoin(fig_path, 'all_place_criteria.png'))

In [None]:
yvars = ['place_skaggs_only', 'place_coherence_only']
avg_place = place_df.groupby(['group', 'day'], as_index=False).agg({'place_skaggs_only': ['mean', 'sem'], 'place_coherence_only': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', 
                               titles=['Skaggs Only', 'Coherence Only'],
                               rows=1, columns=2, shared_y=True, shared_x=True, width=800)
for i in np.arange(0, len(yvars)):
    y_var = yvars[i]
    for group in avg_place['group'].unique():
        gdata = avg_place[avg_place['group'] == group]
        fig.add_trace(go.Scatter(x=gdata['day'], y=gdata[y_var]['mean'], mode='lines+markers',
                                 marker_color=ce_colors_dict[group], showlegend=False, legendgroup=group, name=group,
                                 error_y=dict(type='data', array=gdata[y_var]['sem'])), row=1, col=i+1)
for value in [5.5, 10.5, 15.5, 20.5]:
    fig.add_vline(x=value, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(title='Place Cells (%)', col=1, range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'percent_place_si_coherence.png'), width=800, height=500)

In [None]:
## Plot heatmap of where the mouse was when a non-place cell fired
neurons = [1, 13]
correct_dir = True
only_running = True
velocity_thresh = 10
cscale = 'gray_r'

neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, only_running=only_running, velocity_thresh=velocity_thresh)
x_pos = neural_data['x'].values
y_pos = neural_data['y'].values
rw_one_x = np.mean(sdata['x'][(sdata['lick_port'] == sdata.attrs['reward_one'])])
rw_one_y = np.mean(sdata['y'][(sdata['lick_port'] == sdata.attrs['reward_one'])])
rw_two_x = np.mean(sdata['x'][(sdata['lick_port'] == sdata.attrs['reward_two'])])
rw_two_y = np.mean(sdata['y'][(sdata['lick_port'] == sdata.attrs['reward_two'])])

fig = pf.custom_graph_template(x_title='', y_title='', rows=1, columns=2, titles=[f'Neuron {neurons[0]}', f'Neuron {neurons[1]}'], width=1000)
for col, neuron in enumerate(neurons):
    activity = neural_data.values[neuron]
    H, x_edges, y_edges = np.histogram2d(x_pos, y_pos, bins=40, weights=(activity > 0))
    norm_vals = H / np.max(H)
    norm_vals[norm_vals == 0] = np.nan
    fig.add_trace(go.Heatmap(x=x_edges, y=y_edges, z=norm_vals, colorscale=cscale, coloraxis='coloraxis1'), row=1, col=col + 1)

    fig.add_trace(go.Scattergl(x=[rw_one_x], y=[rw_one_y], mode='markers', marker_color='red', 
                            showlegend=False, name='Reward 1', marker_size=8), row=1, col=col + 1)
    fig.add_trace(go.Scattergl(x=[rw_two_x], y=[rw_two_y], mode='markers', marker_color='red', 
                            showlegend=False, name='Reward 2', marker_size=8), row=1, col=col + 1)
    fig.add_annotation(x=np.mean(x_pos), y=np.mean(y_pos),
            text=f'SI: {np.round(sdata['skaggs_info'].values[neuron], 2)}',
            showarrow=False, font=dict(size=12), row=1, col=col + 1)
    fig.add_annotation(x=np.mean(x_pos), y=np.mean(y_pos) - 15,
            text=f'Shuffled SI: {np.round(sdata['shuffled_avg_si'].values[neuron], 2)}',
            showarrow=False, font=dict(size=12), row=1, col=col + 1)

fig.update_coloraxes(colorscale=cscale)
fig.show()
# fig.write_image(pjoin(fig_path, f'example_nonplace_cells.png'), width=1000, height=500)

### Percentage of place cells when only looking at minimum number of trials.

In [None]:
data_type = 'S'
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'place_skaggs_only': [], 'spatial_info_all': [],
             'spatial_info_place': [], 'spatial_info_nonplace': []}

for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/aligned_place_cells_min_trials/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = ctn.set_sex(mouse, male_mice)
            group = ctn.set_group(mouse, control_mice)
            for index, session in enumerate(natsorted(os.listdir(mpath))):
                index = ctn.mouse_indices(mouse, index)  
                sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                sdata = sdata[sdata['minimum_activity_met'], :] ## subset for cells that meet activity criterion
                num_cells = sdata.shape[0]

                cell_dict['mouse'].append(mouse)
                cell_dict['group'].append(group)
                cell_dict['sex'].append(sex)
                cell_dict['session'].append(sdata.attrs['session_two'])
                cell_dict['day'].append(session[-5:-3])
                cell_dict['place_skaggs_only'].append((np.sum(sdata['skaggs_place'].values) / num_cells) * 100)
                cell_dict['spatial_info_all'].append(np.mean(sdata['skaggs_info'].values))
                cell_dict['spatial_info_place'].append(np.mean(sdata['skaggs_info'].values[sdata['skaggs_place'].values]))
                cell_dict['spatial_info_nonplace'].append(np.mean(sdata['skaggs_info'].values[~sdata['skaggs_place'].values]))
place_df = pd.DataFrame(cell_dict)

In [None]:
## Plot percentage of place cells on the final switch
avg_place = place_df.groupby(['group', 'day'], as_index=False).agg({'place_skaggs_only': ['mean', 'sem']})
switch = avg_place.replace({1: '15', 2: '16'})
fig = pf.custom_graph_template(x_title='Day', y_title='Place Cells (%)')
for group in ['Two-context', 'Multi-context']:
    gdata = switch[switch['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['place_skaggs_only']['mean'], mode='markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['place_skaggs_only']['sem'])))
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'percentage_place_skaggs_only_min_trials.png'), width=500, height=500)

In [None]:
## Perform a linear regression to test for differences between groups. 
data = place_df[place_df['day'] == '16'].reset_index(drop=True)
model = ols('place_skaggs_only ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

In [None]:
## Plot of average spatial information for final switch
avg = place_df.groupby(['group', 'day'], as_index=False).agg({'spatial_info_all': ['mean', 'sem'], 'spatial_info_place': ['mean', 'sem'],
                                                                  'spatial_info_nonplace': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', rows=1, columns=3, shared_x=True, shared_y=True,
                               titles=['All Cells', 'Place Cells', 'Non-place Cells'], width=1000)

for group in ['Two-context', 'Multi-context']:
    gdata = avg[avg['group'] == group]
    ## All cells
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['spatial_info_all']['mean'], mode='lines+markers', name=group, legendgroup=group,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_all']['sem'])), row=1, col=1)
    ## Place cells
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['spatial_info_place']['mean'], mode='lines+markers', name=group, legendgroup=group, showlegend=False,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_place']['sem'])), row=1, col=2)
    ## Non-place cells
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['spatial_info_nonplace']['mean'], mode='lines+markers', name=group, legendgroup=group, showlegend=False,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_nonplace']['sem'])), row=1, col=3)
fig.update_yaxes(title='Spatial Information (bits/event)', col=1)
fig.update_yaxes(range=[0, 5])
fig.show()
fig.write_image(pjoin(fig_path, 'all_cells_place_cells_si_first27_trials.png'), width=1000, height=500)

In [None]:
## Perform a linear regression to test for differences between groups. 
data = place_df[place_df['day'] == '16'].reset_index(drop=True)
model = ols('spatial_info_nonplace ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

### Percentage of place cells with the last 27 trials.

In [None]:
data_type = 'S'
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'place_skaggs_only': [], 'spatial_info_all': [],
             'spatial_info_place': [], 'spatial_info_nonplace': []}

for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/aligned_place_cells_last27/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = ctn.set_sex(mouse, male_mice)
            group = ctn.set_group(mouse, control_mice)
            for index, session in enumerate(natsorted(os.listdir(mpath))):
                index = ctn.mouse_indices(mouse, index)
                sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                sdata = sdata[sdata['minimum_activity_met'], :] ## subset for activity criterion
                num_cells = sdata.shape[0]

                cell_dict['mouse'].append(mouse)
                cell_dict['group'].append(group)
                cell_dict['sex'].append(sex)
                cell_dict['session'].append(sdata.attrs['session_two'])
                cell_dict['day'].append(session[-5:-3])
                cell_dict['place_skaggs_only'].append((np.sum(sdata['skaggs_place'].values) / num_cells) * 100)
                cell_dict['spatial_info_all'].append(np.mean(sdata['skaggs_info'].values))
                cell_dict['spatial_info_place'].append(np.mean(sdata['skaggs_info'].values[sdata['skaggs_place'].values]))
                cell_dict['spatial_info_nonplace'].append(np.mean(sdata['skaggs_info'].values[~sdata['skaggs_place'].values]))
place_df = pd.DataFrame(cell_dict)

In [None]:
## Plot percentage of place cells on the final switch
avg_place = place_df.groupby(['group', 'day'], as_index=False).agg({'place_skaggs_only': ['mean', 'sem']})
switch = avg_place.replace({1: '15', 2: '16'})
fig = pf.custom_graph_template(x_title='Day', y_title='Place Cells (%)')
for group in ['Two-context', 'Multi-context']:
    gdata = switch[switch['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['place_skaggs_only']['mean'], mode='markers', line_color=ce_colors_dict[group],
                             name=group, error_y=dict(type='data', array=gdata['place_skaggs_only']['sem'])))
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'percentage_place_skaggs_last27.png'), width=500, height=500)

In [None]:
## Plot of average spatial information for final switch
avg = place_df.groupby(['group', 'day'], as_index=False).agg({'spatial_info_all': ['mean', 'sem'], 'spatial_info_place': ['mean', 'sem'],
                                                                  'spatial_info_nonplace': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', rows=1, columns=3, shared_x=True, shared_y=True,
                               titles=['All Cells', 'Place Cells', 'Non-place Cells'], width=1000)

for group in ['Two-context', 'Multi-context']:
    gdata = avg[avg['group'] == group]
    ## All cells
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['spatial_info_all']['mean'], mode='lines+markers', name=group, legendgroup=group,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_all']['sem'])), row=1, col=1)
    ## Place cells
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['spatial_info_place']['mean'], mode='lines+markers', name=group, legendgroup=group, showlegend=False,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_place']['sem'])), row=1, col=2)
    ## Non-place cells
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['spatial_info_nonplace']['mean'], mode='lines+markers', name=group, legendgroup=group, showlegend=False,
                             line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['spatial_info_nonplace']['sem'])), row=1, col=3)
fig.update_yaxes(title='Spatial Information (bits/event)', col=1)
fig.update_yaxes(range=[0, 5])
fig.show()
fig.write_image(pjoin(fig_path, 'all_cells_place_cells_si_last27_trials.png'), width=1000, height=500)

In [None]:
## Perform a linear regression to test for differences between groups. 
data = place_df[place_df['day'] == '16'].reset_index(drop=True)
model = ols('spatial_info_place ~ C(group)', data=data).fit()
anova_table = sm.stats.anova_lm(model, typ=3)
print(anova_table)

### Percentage of cells vs spatial information.

In [None]:
data_type = 'S'
spatial_info_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'decile': [], 'spatial_info': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/aligned_minian_place/')
        for mouse in os.listdir(exp_path):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Control' if mouse in control_mice else 'Experimental'
            for index, session in enumerate(os.listdir(mpath)):
                if '21' in session:
                    pass 
                else:
                    if (mouse == 'mc42') & (index > 14):
                        index += 1
                    elif (mouse == 'mc43') & (index > 11):
                        index += 1
                    elif (mouse == 'mc44') & (index > 7):
                        index += 1
                    elif (mouse == 'mc46') & (index > 9):
                        index += 1
                    elif (mouse == 'mc52') & (index > 2):
                        index += 1
                    elif (mouse == 'mc55') & (index > 2):
                        index += 1
                    
                    sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                    decile_array = np.arange(0, 101, 10)
                    deciles = np.percentile(sdata['skaggs_info'], decile_array)
                    for idx, decile_info_val in enumerate(deciles):
                        spatial_info_dict['mouse'].append(mouse)
                        spatial_info_dict['group'].append(group)
                        spatial_info_dict['sex'].append(sex)
                        spatial_info_dict['session'].append(sdata.attrs['session_two'])
                        spatial_info_dict['day'].append(index+1)
                        spatial_info_dict['decile'].append(decile_array[idx])
                        spatial_info_dict['spatial_info'].append(decile_info_val)
spatial_df = pd.DataFrame(spatial_info_dict)
avg_spatial = spatial_df.groupby(['group', 'day', 'decile'], as_index=False).agg({'spatial_info': ['mean', 'sem']})

In [None]:
fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, titles=day_list, height=1000, width=1000,
                               shared_y=True, shared_x=True)

for group in avg_spatial['group'].unique():
    gdata = avg_spatial[avg_spatial['group'] == group]
    for day in gdata['day'].unique():
        day = day - 1 ## adjust day to use as an index
        day_data = gdata[gdata['day'] == day + 1]
        if day < 5:
            row, col = 1, day + 1
        elif (day >= 5) & (day < 10):
            row, col = 2, day - 4
        elif (day >= 10) & (day < 15):
            row, col = 3, day - 9
        else:
            row, col = 4, day - 14
        
        fig.add_trace(go.Scatter(x=day_data['spatial_info']['mean'], y=day_data['decile'], mode='markers', showlegend=False,
                                marker_color=ce_colors_dict[day_data['group'].unique()[0]], opacity=0.6,
                                error_x=dict(type='data', array=day_data['spatial_info']['sem'])), row=row, col=col)
fig.update_yaxes(title='Percent of Cells', col=1)
fig.update_xaxes(title='Spatial Info', row=4)
fig.show()
fig.write_image(pjoin(fig_path, 'percent_of_cells_with_spatial_info.png'))

### Snake plots for a single mouse for each session.

In [None]:
data_type = 'S'
mouse = 'mc58'
experiment = 'MultiCon_Imaging6'
correct_dir = True 
only_running = True
normalize_plot = True
colorscale = 'cividis'

if mouse in control_mice:
    title_list = control_list 
else:
    title_list = session_list

fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, width=1100, height=1100,
                               titles=title_list, font_size=15)

exp_path = pjoin(dpath, f'{experiment}/output/aligned_minian/')
mpath = pjoin(exp_path, f'{mouse}/{data_type}')
for idx, session in enumerate(natsorted(os.listdir(mpath))):
    if '21' in session:
        pass 
    else:
        if (mouse == 'mc42') & (idx > 14):
            idx += 1
        elif (mouse == 'mc43') & (idx > 11):
            idx += 1
        elif (mouse == 'mc44') & (idx > 7):
            idx += 1
        elif (mouse == 'mc46') & (idx > 9):
            idx += 1
        elif (mouse == 'mc52') & (idx > 2):
            idx += 1
        elif (mouse == 'mc55') & (idx > 2):
            idx += 1
        
        if idx < 5:
            row, col = 1, idx + 1
        elif (idx >= 5) & (idx < 10):
            row, col = 2, idx - 4
        elif (idx >= 10) & (idx < 15):
            row, col = 3, idx - 9
        else:
            row, col = 4, idx - 14
        
        sdata = xr.open_dataset(pjoin(mpath, f'{session}'))[data_type]
        reward_one_pos = np.mean(sdata['lin_position'][(sdata['lick_port'] == sdata.attrs['reward_one']) & (sdata['water'])].values)
        reward_two_pos = np.mean(sdata['lin_position'][(sdata['lick_port'] == sdata.attrs['reward_two']) & (sdata['water'])].values)

        neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, only_running=only_running, velocity_thresh=velocity_thresh)
        population_activity, occupancy, bins = pc.spatial_activity(neural_data, position_data, bin_size=bin_size)
        tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
        resh_pop = tuning_curves.reshape((tuning_curves.shape[1], tuning_curves.shape[0]))
        centers_resh = np.argmax(resh_pop, axis=1)
        order_resh = np.argsort(centers_resh)

        if normalize_plot:
            max_vals = np.max(resh_pop, axis=1)
            for uid in np.arange(0, resh_pop.shape[0]):
                resh_pop[uid, :] = resh_pop[uid, :] / max_vals[uid]

        fig.add_trace(go.Heatmap(x=bins, y=np.arange(0, sdata.shape[0]), z=resh_pop[order_resh], coloraxis='coloraxis1'), row=row, col=col)
        for pos in [reward_one_pos, reward_two_pos]:
            fig.add_vline(x=pos, line_width=1, line_color='red', opacity=1, row=row, col=col)

fig.update_yaxes(autorange='reversed')
fig.update_yaxes(title='Neuron ID', col=1)
fig.update_xaxes(title='Lin Position (rad)', row=4)
fig.update_layout(coloraxis1={'colorscale': colorscale})
fig.show()
fig.write_image(pjoin(fig_path, f'snake_plots_all_sessions_{mouse}.png'), width=1100, height=1100)

### Snake plots for a single mouse for each context.

In [None]:
data_type = 'S'
session_list = [f'{x}' for x in np.arange(16, 21)]
ordering_session = '20'
context_str = 'fourth_five'
mouse = 'mc46'
experiment = 'MultiCon_Imaging5'
correct_dir = True 
only_running = True
smooth = False
normalize_plot = True
colorscale = 'viridis'

fig = pf.custom_graph_template(x_title='Position (rad)', y_title='', rows=1, columns=5, height=600, width=1200,
                               shared_y=True, shared_x=True)
exp_path = pjoin(dpath, f'{experiment}/output/aligned_minian/')
mpath = pjoin(exp_path, f'{mouse}/{data_type}')
crossreg_path = pjoin(dpath, f'{experiment}/output/cross_registration_results')
mappings = pd.read_pickle(pjoin(crossreg_path, f'circletrack_data/{mouse}/mappings_meta_{centroid_distance}_{context_str}.pkl'))
mappings.columns = mappings.columns.droplevel(0)

date_list = []
for session in session_list:
    try:
        sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
        date_list.append(sdata.attrs['date'])
    except:
        pass

shared_cells = mappings[date_list].dropna().reset_index(drop=True)

## Order cells based on specified session
sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{ordering_session}.nc'))[data_type]
sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
reward_one_pos = np.mean(sdata['lin_position'][sdata['lick_port'] == sdata.attrs['reward_one']].values)
reward_two_pos = np.mean(sdata['lin_position'][sdata['lick_port'] == sdata.attrs['reward_two']].values)

neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, lin_pos_col='lin_position',
                                                                only_running=only_running, velocity_thresh=velocity_thresh)
if smooth:
    smoothed_data = ctn.moving_average(neural_data, ksize=8) ## 264 ms
else:
    smoothed_data = neural_data.copy()
population_activity, occupancy, bins = pc.spatial_activity(smoothed_data, position_data, bin_size=bin_size)
tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
resh_pop = tuning_curves.reshape((tuning_curves.shape[1], tuning_curves.shape[0]))
centers_resh = np.argmax(resh_pop, axis=1)
order_resh = np.argsort(centers_resh)

cross_session_act = np.empty((len(session_list), shared_cells.shape[0], bins.shape[0]-1))
for idx, session in enumerate(session_list):
    if (mouse == 'mc56') & (context_str == 'first_five') & (idx == 0):
        pass 
    else:
        try:
            sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
        except:
            sdata = None
        
        if sdata is not None:
            sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
            neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, lin_pos_col='lin_position',
                                                                            only_running=only_running, velocity_thresh=velocity_thresh)
            if smooth:
                smoothed_data = ctn.moving_average(neural_data, ksize=8) ## 264 ms
            else:
                smoothed_data = neural_data.copy()
            population_activity, occupancy, bins = pc.spatial_activity(smoothed_data, position_data, bin_size=bin_size)
            tuning_curves = pc.get_tuning_curves(population_activity, occupancy)
            resh_pop = tuning_curves.reshape((tuning_curves.shape[1], tuning_curves.shape[0]))
            cross_session_act[idx, :, :] = resh_pop
            neur_array = np.arange(0, sdata.shape[0])

            if normalize_plot:
                max_vals = np.max(resh_pop, axis=1)
                for uid in np.arange(0, resh_pop.shape[0]):
                    resh_pop[uid, :] = resh_pop[uid, :] / max_vals[uid]

            fig.add_trace(go.Heatmap(x=bins, y=neur_array, z=resh_pop[order_resh], coloraxis='coloraxis1'), row=1, col=idx+1)
        else:
            pass
fig.update_yaxes(autorange='reversed')
fig.update_yaxes(title='Neuron ID', col=1)
fig.update_layout(coloraxis1={'colorscale': colorscale})
for pos in [reward_one_pos, reward_two_pos]:
    fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color='red', opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_snake_plots_{context_str}_{ordering_session}.png'))

In [None]:
position_data

In [None]:
data_type = 'S'
session_dict = {'first_five': [f'{x}' for x in np.arange(1, 6)], 'second_five': [f'{x}' for x in np.arange(6, 11)],
                'third_five': [f'{x}' for x in np.arange(11, 16)], 'fourth_five': [f'{x}' for x in np.arange(16, 21)]}
context_str_list = ['first_five', 'second_five', 'third_five', 'fourth_five']
experiment_list = ['MultiCon_Imaging5', 'MultiCon_Imaging6']
order_by = 'first' ## last or first
if order_by == 'last':
    mouse_list = ['mc44', 'mc46', 'mc48', 'mc49', 'mc51', 'mc52', 'mc54', 'mc56', 'mc58']
else:
    mouse_list = ['mc44', 'mc48', 'mc49', 'mc51', 'mc52', 'mc54', 'mc58']
correct_dir = True 
only_running = True
res_dict = {'mouse': [], 'group': [], 'sex': [], 'day': [], 'context_str': [], 'cell_idx': [], 'r': [], 'pval': []}

for experiment in experiment_list:
    exp_path = pjoin(dpath, f'{experiment}/output/aligned_minian/')
    for mouse in os.listdir(exp_path):
        if mouse in mouse_list:
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Control' if mouse in control_mice else 'Experimental'
            for context_str in context_str_list:
                crossreg_path = pjoin(dpath, f'{experiment}/output/cross_registration_results')
                mappings = pd.read_pickle(pjoin(crossreg_path, f'circletrack_data/{mouse}/mappings_meta_{centroid_distance}_{context_str}.pkl'))
                mappings.columns = mappings.columns.droplevel(0)

                date_list = []
                for session in session_dict[context_str]:
                    try:
                        sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
                        date_list.append(sdata.attrs['date'])
                    except:
                        pass

                shared_cells = mappings[date_list].dropna().reset_index(drop=True)

                ## Order cells based on specified session
                if order_by == 'last':
                    ordering_session = session_dict[context_str][-1]
                elif order_by == 'first':
                    ordering_session = session_dict[context_str][0]
                    
                sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{ordering_session}.nc'))[data_type]
                sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
                reward_one_pos = np.mean(sdata['lin_position'][sdata['lick_port'] == sdata.attrs['reward_one']].values)
                reward_two_pos = np.mean(sdata['lin_position'][sdata['lick_port'] == sdata.attrs['reward_two']].values)

                neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, only_running=only_running, velocity_thresh=velocity_thresh)
                population_activity, occupancy, bins = pc.average_spatial_activity(neural_data, position_data, bin_size=bin_size)
                resh_pop = population_activity.reshape((population_activity.shape[1], population_activity.shape[0]))
                centers_resh = np.argmax(resh_pop, axis=1)
                order_resh = np.argsort(centers_resh)

                cross_session_act = np.empty((len(session_dict[context_str]), shared_cells.shape[0], bins.shape[0]-1))
                cross_session_act[:] = np.nan
                for idx, session in enumerate(session_dict[context_str]):
                    if (mouse == 'mc56') & (context_str == 'first_five') & (idx == 0):
                        pass
                    else:
                        try:
                            sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
                        except:
                            sdata = None
                        
                        if sdata is not None:
                            sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
                            neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, 
                                                                                            only_running=only_running, velocity_thresh=velocity_thresh)
                            population_activity, occupancy, bins = pc.average_spatial_activity(neural_data, position_data, bin_size=bin_size)
                            resh_pop = population_activity.reshape((population_activity.shape[1], population_activity.shape[0]))
                            cross_session_act[idx, :, :] = resh_pop
                        else:
                            pass
                
                ref_idx = session_dict[context_str].index(ordering_session)
                ref_sess = cross_session_act[ref_idx]

                for day in np.arange(0, cross_session_act.shape[0]):
                    mat = cross_session_act[day]
                    if np.isnan(mat).any():
                        print(f'Nans found for {mouse} in the {context_str} on {day}!')
                        pass 
                    else:
                        for cell in np.arange(0, mat.shape[0]):
                            res = pearsonr(ref_sess[cell], mat[cell])

                            res_dict['mouse'].append(mouse)
                            res_dict['group'].append(group)
                            res_dict['sex'].append(sex)
                            res_dict['day'].append(day+1)
                            res_dict['context_str'].append(context_str)
                            res_dict['cell_idx'].append(cell)
                            res_dict['r'].append(res[0])
                            res_dict['pval'].append(res[1])
res_df = pd.DataFrame(res_dict)

In [None]:
avg_res = res_df.groupby(['group', 'context_str', 'day'], as_index=False).agg({'r': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', rows=1, columns=4, width=1000, shared_y=True, shared_x=True,
                               titles=['A', 'B', 'C', 'D'])
for group in ['Control', 'Experimental']:
    gdata = avg_res[avg_res['group'] == group]
    for context_str in avg_res['context_str'].unique():
        if context_str == 'first_five':
            col = 1
        elif context_str == 'second_five':
            col = 2
        elif context_str == 'third_five':
            col = 3
        else:
            col = 4
    
        cdata = gdata[gdata['context_str'] == context_str]
        fig.add_trace(go.Scatter(x=cdata['day'], y=cdata['r']['mean'], mode='markers', marker_color=ce_color_dict[group], opacity=0.6,
                                 name=group, error_y=dict(type='data', array=cdata['r']['sem']), showlegend=False), row=1, col=col)
fig.update_yaxes(title='PV Correlation', col=1)
fig.show()

### Reproducing Will's snake plots

In [None]:
data_type = 'S'
exp_path = '/media/caishuman/csstorage/phild/git/MazeProjects/output'
session_list = ['Training1', 'Training2', 'Training3', 'Training4', 'Reversal1']
ordering_session = 'Training4'
mouse = 'Atlas'
correct_dir = True 
only_running = True
normalize_plot = False

fig = pf.custom_graph_template(x_title='Position (rad)', y_title='', rows=1, columns=5, width=1000,
                               shared_y=True, shared_x=True)

mappings = pd.read_feather(pjoin(exp_path, f'registration/{mouse}.feat'))
shared_cells = mappings.dropna().reset_index(drop=True)

## Order cells based on specified session
sdata = xr.open_dataset(pjoin(exp_path, f'processed/{mouse}_{ordering_session}.nc'))[data_type]
sdata = sdata.sel(unit_id=shared_cells[ordering_session].values)
behav = pd.read_feather(pjoin(exp_path, f'behav/{mouse}_{ordering_session}.feat'))

neural_data = sdata.values
position_data = behav['lin_position'].values
sdata = sdata.assign_coords(trials=('frame', behav['trials']),
                            lin_position=('frame', behav['lin_position']),
                            x=('frame', behav['x']),
                            y=('frame', behav['y']))
neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, 
                                                                only_running=only_running, velocity_thresh=velocity_thresh)


population_activity, occupancy, bins = pc.average_spatial_activity(neural_data, position_data, bin_size=bin_size)
resh_pop = population_activity.reshape((population_activity.shape[1], population_activity.shape[0]))
centers_resh = np.argmax(resh_pop, axis=1)
order_resh = np.argsort(centers_resh)

# cross_session_act = np.empty((len(session_list), shared_cells.shape[0], bins.shape[0]-1))
# for idx, session in enumerate(session_list):
#     try:
#         sdata = xr.open_dataset(pjoin(exp_path, f'processed/{mouse}_{session}.nc'))[data_type]
#     except:
#         sdata = None

#     if sdata is not None:
#         sdata = sdata.sel(unit_id=shared_cells[session].values)
#         behav = pd.read_feather(pjoin(exp_path, f'behav/{mouse}_{session}.feat'))
#         neural_data = sdata.values
#         position_data = behav['lin_position'].values
#         # neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, 
#         #                                                                 only_running=only_running, velocity_thresh=velocity_thresh)
#         population_activity, occupancy, bins = pc.average_spatial_activity(neural_data, position_data, bin_size=bin_size)
#         resh_pop = population_activity.reshape((population_activity.shape[1], population_activity.shape[0]))
#         cross_session_act[idx, :, :] = resh_pop
#         neur_array = np.arange(0, sdata.shape[0])

#         if normalize_plot:
#             max_vals = np.max(resh_pop, axis=1)
#             for uid in np.arange(0, resh_pop.shape[0]):
#                 resh_pop[uid, :] = resh_pop[uid, :] / max_vals[uid]

#         fig.add_trace(go.Heatmap(x=bins, y=neur_array, z=resh_pop[order_resh], 
#                         coloraxis='coloraxis1'), row=1, col=idx+1)
#     else:
#         pass
# fig.update_yaxes(autorange='reversed')
# fig.update_yaxes(title='Neuron ID', col=1)
# fig.show()

### Distribution of spatial information on day 16

In [None]:
data_type = 'S'
fig = pf.custom_graph_template(x_title='Spatial Information (bits/event)', y_title='Probability', width=600)
mc_info = np.empty((1))
mc_info[:] = np.nan
tc_info = np.empty((1))
tc_info[:] = np.nan
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Two-context' if mouse in control_mice else 'Multi-context'
            for index, session in enumerate(natsorted(os.listdir(mpath))):
                if '16' not in session:
                    pass 
                else:
                    sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                    if group == 'Two-context':
                        tc_info = np.concatenate([tc_info, sdata['skaggs_info'].values])
                    else:
                        mc_info = np.concatenate([mc_info, sdata['skaggs_info'].values])

## Create figure
for group in ['Two-context', 'Multi-context']:
    if group == 'Two-context':
        plot_data = tc_info 
    else:
        plot_data = mc_info
    plot_data = plot_data[~np.isnan(plot_data)]
    fig.add_trace(go.Histogram(x=plot_data, marker_color=ce_colors_dict[group], histnorm='probability',
                               name=group, opacity=0.6))
fig.update_layout(barmode='overlay')
fig.show()
fig.write_image(pjoin(fig_path, 'spatial_info_distribution_all_mice.png'), width=600, height=500)                 

### Look at mutual information

In [None]:
## Loop through multiple bin sizes and measure mutual information
correct_dir = True 
only_running = True

mi_dict = {'bin_size': [], 'mi_avg': [], 'mi_sem': []}
for bin_size in np.arange(0.16, 1.16, 0.16):
    neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, only_running=only_running, 
                                                                    velocity_thresh=velocity_thresh)
    population_activity, occupancy, bins = pc.spatial_activity(neural_data, position_data, bin_size=bin_size, binarized=False)

    ## Calculate mutual information
    discrete_bins = np.arange(0, bins.shape[0]-1)
    mi = [mutual_info_score((population_activity[:, uid] * 100).astype(int), discrete_bins) for uid in np.arange(0, population_activity.shape[1])]

    mi_dict['bin_size'].append(bin_size)
    mi_dict['mi_avg'].append(np.mean(mi))
    mi_dict['mi_sem'].append(np.std(mi, ddof=1)/len(mi))
df = pd.DataFrame(mi_dict)

In [None]:
## Plot distribution of MI for a single mouse
fig = pf.custom_graph_template(x_title='Mutual Information (nats)', y_title='Probability', titles=[f'{sdata.attrs['animal']}'])
fig.add_trace(go.Histogram(x=mi, marker_color='darkgrey', histnorm='probability'))
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_mi_example.png'), width=500, height=500)

In [None]:
## Plot line graph of MI for different bin sizes
fig = pf.custom_graph_template(x_title='Bin Size (rad)', y_title='Mutual Information (nats)')
fig.add_trace(go.Scattergl(x=df['bin_size'], y=df['mi_avg'], mode='lines+markers', line_color='darkgrey', error_y=dict(type='data', array=df['mi_sem'])))
fig.show()
fig.write_image(pjoin(fig_path, 'mi_curve_w_different_bin_sizes.png'), width=500, height=500)

In [None]:
test_path = '/media/caishuman/csstorage3/Austin/CircleTrack/MultiCon_Imaging/MultiCon_Imaging5/output/aligned_mutual_info/mc44/S/mc44_S_1.nc'
t = xr.open_dataset(test_path)['S']
t['mutual_info_bool'].values

### Mutual info shuffle testing

In [None]:
project_dir = 'MultiCon_Imaging'
experiment_dir = 'MultiCon_Imaging5'
mouse_list = ['mc44']
dpath = f'../../{project_dir}/{experiment_dir}/output/aligned_place_cells/'
spath = f'../../{project_dir}/{experiment_dir}/output/aligned_mutual_info/'
data_type = 'S'
only_running = True
correct_dir = True
smooth = False
binarized = False
velocity_thresh = 10
bin_size = 0.16
nshuffles = 500
percentile = 95
min_event_amount = 0.2
min_trials = None ## lowest number of trials a mouse ran on day 16 is 27 trials, which was a two-context mouse

for mouse in tqdm(mouse_list):
    mpath = pjoin(dpath, f'{mouse}/{data_type}')
    for session in natsorted(os.listdir(mpath)):
        if '20' not in session:
            pass
        else:
            print(session)
            save_path = pjoin(spath, f'{mouse}/{data_type}')
            sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
            num_neurons = sdata.shape[0]
            minimum_act_bool = pc.minimum_activity_level(sdata, minimum_event_amount=min_event_amount, bin_size_seconds=60, fps=30, func=np.sum, binarized=0)

            if min_trials is not None:
                min_data = sdata[:, sdata['trials'] <= min_trials]
            else:
                min_data = sdata.copy()

            if smooth:
                smoothed_data = ctn.moving_average_xarray(min_data, ksize=8) 
            else:
                smoothed_data = min_data.copy()

            neural_data, position_data = ctn.subset_correct_dir_and_running(smoothed_data, correct_dir=correct_dir, only_running=only_running, 
                                                                            velocity_thresh=velocity_thresh)
            ## Calculate mutual information
            population_activity, occupancy, bins = pc.spatial_activity(neural_data, position_data, bin_size=bin_size, binarized=binarized)
            discrete_bins = np.arange(0, bins.shape[0]-1)
            mi = [mutual_info_score((population_activity[:, uid] * 100).astype(int), discrete_bins) for uid in np.arange(0, population_activity.shape[1])]

            ## Shuffle neural activity and calculate mutual info
            ar = neural_data.copy()
            shuffled_mutual = np.zeros((nshuffles, ar.shape[0]))
            for shuffle in np.arange(0, nshuffles):
                ## Shift position data forward by a random number of frames
                shuffled_data = np.array(())
                for trial in np.unique(ar['trials']):
                    position = ar['lin_position'][ar['trials'] == trial].values
                    random_shift = np.random.randint(0, position.shape[0])
                    rolled_position = np.roll(position, random_shift)
                    shuffled_data = np.concatenate((shuffled_data, rolled_position))
                ## Calculate trial stability of cells with the reordered positions
                ar_new = ar.copy()
                ar_new = ar_new.assign_coords(reordered_pos=('frame', shuffled_data))
                shuf_act, _, bins = pc.spatial_activity(ar_new.values, ar_new['reordered_pos'].values, bin_size=bin_size, binarized=False)
                ## Calculate mutual information
                discrete_bins = np.arange(0, bins.shape[0]-1)
                shuff_mi = [mutual_info_score((shuf_act[:, uid] * 100).astype(int), discrete_bins) for uid in np.arange(0, shuf_act.shape[1])]
                shuffled_mutual[shuffle, :] = shuff_mi
            mutual_cells = np.array([mi[neuron] > np.percentile(shuffled_mutual[:, neuron], percentile) for neuron in np.arange(0, num_neurons)])

In [None]:
## Plot distribution of a single neuron's shuffled MI
uid = 59
fig = pf.custom_graph_template(x_title='Mutual Information (nats)', y_title='Probability', titles=[f'Neuron {uid}'])
fig.add_trace(go.Histogram(x=shuffled_mutual[:, uid], marker_color='darkgrey', marker_line_width=2, marker_line_color='black', histnorm='probability'))
fig.add_vline(x=mi[uid], line_width=1, line_color='red', line_dash='dash', opacity=1)
fig.update_xaxes(range=[0, 4])
fig.show()
fig.write_image(pjoin(fig_path, f'shuffled_MI_neuron{uid}_{bin_size}rad.png'), width=500, height=500)

In [None]:
fig = pf.custom_graph_template(x_title='', y_title='Activity (a.u.)', titles=[f'Neuron {uid}', f'Neuron {uid}'], rows=1, columns=2, width=1200)
tc_orig = pc.get_tuning_curves(population_activity, occupancy)
tc_shuf = pc.get_tuning_curves(shuf_act, _)
fig.add_trace(go.Scattergl(x=bins[:-1], y=tc_orig[:, uid], name='Original'), row=1, col=1)
fig.add_trace(go.Scattergl(x=bins[:-1], y=tc_shuf[:, uid], name='Shuffled'), row=1, col=1)
fig.add_trace(go.Scattergl(x=neural_data['behav_t'], y=neural_data.values[uid], showlegend=False), row=1, col=2)
fig.update_xaxes(title='Spatial Bin (rad)', col=1)
fig.update_xaxes(title='Time (s)', col=2)
fig.show()
fig.write_image(pjoin(fig_path, f'orig_and_shuff_act_neuron{uid}_{bin_size}rad.png'), width=1200, height=500)

In [None]:
## Plot scatter plot of mutual information vs spatial information
fig = pf.custom_graph_template(x_title='Spatial information (bits/event)', y_title='Mutual Information (nats)')
fig.add_trace(go.Scattergl(x=neural_data['skaggs_info'].values, y=mi, mode='markers', marker_color='darkgrey',
                           marker=dict(line=dict(width=1)), opacity=0.7))
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_mi_vs_si_{bin_size}rad.png'), width=500, height=500)

### Mutual information for days 15 and 16

In [None]:
data_type = 'S'
correct_dir = True 
only_running = True

out_dict = {'mouse': [], 'group': [], 'sex': [], 'day': [], 'avg_mi': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Two-context' if mouse in control_mice else 'Multi-context'
            for idx, session in enumerate(natsorted(os.listdir(mpath))):
                if ('15' not in session) & ('16' not in session):
                    pass 
                else:
                    if (mouse == 'mc42') & (idx > 14):
                        idx += 1
                    elif (mouse == 'mc43') & (idx > 11):
                        idx += 1
                    elif (mouse == 'mc44') & (idx > 7):
                        idx += 1
                    elif (mouse == 'mc46') & (idx > 9):
                        idx += 1
                    elif (mouse == 'mc52') & (idx > 2):
                        idx += 1
                    elif (mouse == 'mc55') & (idx > 2):
                        idx += 1
                    elif (mouse == 'mc56') & (idx >= 0):
                        idx += 1
                    sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                    neural_data, position_data = ctn.subset_correct_dir_and_running(sdata, correct_dir=correct_dir, only_running=only_running, 
                                                                velocity_thresh=velocity_thresh)
                    population_activity, occupancy, bins = pc.spatial_activity(neural_data, position_data, bin_size=bin_size, binarized=False)
                    ## Calculate mutual information
                    discrete_bins = np.arange(0, bins.shape[0]-1)
                    ## Calculate mutual information in nats
                    mi = [mutual_info_score((population_activity[:, uid] * 100).astype(int), discrete_bins) for uid in np.arange(0, population_activity.shape[1])]
                    out_dict['mouse'].append(mouse)
                    out_dict['group'].append(group)
                    out_dict['sex'].append(sex)
                    out_dict['day'].append(idx+1)
                    out_dict['avg_mi'].append(np.mean(mi))
mi_df = pd.DataFrame(out_dict)

In [None]:
mi_df = mi_df.replace({'Two-context': '1', 'Multi-context': '2'})
avg_mi = mi_df.groupby(['group', 'day'], as_index=False).agg({'avg_mi': ['mean', 'sem']})
shifts = rs.uniform(-0.12, 0.12, mi_df['mouse'].unique().shape[0])

fig = pf.custom_graph_template(x_title='Day', y_title='Mutual Information (nats)')
for group in ['1', '2']:
    if group == '1':
        name = 'Two-context' 
    elif group == '2':
        name = 'Multi-context'
    gdata = avg_mi[avg_mi['group'] == group]
    fig.add_trace(go.Scattergl(x=gdata['day'], y=gdata['avg_mi']['mean'], name=name, mode='markers', marker_color=ce_colors_dict[name],
                               error_y=dict(type='data', array=gdata['avg_mi']['sem']), legendgroup=group))

for idx, mouse in enumerate(mi_df['mouse'].unique()):
    mdata = mi_df[mi_df['mouse'] == mouse]
    g = 'Two-context' if mdata['group'].unique()[0] == '1' else 'Multi-context'
    fig.add_trace(go.Scattergl(x=mdata['day'] + shifts[idx], y=mdata['avg_mi'], mode='markers', marker_color=ce_colors_dict[g],
                               marker=dict(line=dict(width=1, color='black')), showlegend=False, legendgroup=str(mdata['group'].unique()[0])))
fig.update_yaxes(range=[0, 2])
fig.show()