In [None]:
%load_ext autoreload
%autoreload 2
import os
import re 
import sys
import math
import numpy as np
import pandas as pd
import xarray as xr
from os.path import join as pjoin
from tqdm.notebook import tqdm
import plotly.graph_objects as go
from scipy.stats import zscore, skew

sys.path.append('../')
import circletrack_behavior as ctb
import circletrack_neural as ctn
import place_cells as pc
import plotting_functions as pf

In [None]:
## Settings
project_folder = ['MultiCon_Imaging']
experiment_folders = ['MultiCon_Imaging5', 'MultiCon_Imaging6']
dpath = f'../../{project_folder[0]}'
fig_path = f'../../../Manuscripts/MultiCon/intermediate_plots/pop_bursts'
chance_color = 'darkgrey'
sem_color = 'rgba(125, 125, 125, 0.5)'
ce_colors_dict = {'Two-Context': 'midnightblue', 'Multi-Context': '#287347'}
group_colors = {'TC M': 'midnightblue', 'TC F': 'darkgreen', 'MC M': 'blue', 'MC F': 'green'}
symbol_dict = {'Two-Context': 'x', 'Multi-Context': 'circle'}
symbols_list = ['x', 'circle']
context_colors = {'A': '#00802d', 'B': '#006c79', 'C': '#004da4', 'D': '#430073'}
mouse_colors = ['midnightblue', 'darkred', 'darkorchid', 'darkturquoise']
male_mice = ['mc44', 'mc46', 'mc54', 'mc55']
control_mice = ['mc46', 'mc49', 'mc52', 'mc54', 'mc59', 'mc60']
experimental_mice = ['mc44', 'mc48', 'mc51', 'mc55', 'mc56', 'mc58']
imaging5 = ['mc44', 'mc46', 'mc48', 'mc49', 'mc51', 'mc52']
session_list = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]
control_list = [f'A{x}' for x in np.arange(1, 16)] + [f'B{x}' for x in np.arange(1, 6)]
day_list = [f'Day {x}' for x in np.arange(1, 21)]
bin_size = 0.11 ## size of linear position bins equivalent to 4cm-wide bins
velocity_thresh = 10
centroid_distance = 5
data_of_interest = 'aligned_minian' ## one of behav, aligned_minian, aligned_place_cells, lin_behav
zthresh = 2
min_len = 3
nshuffles = 500
fps = 30 ## frames per second for calcium data

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

xr.set_options(keep_attrs=True)

## Set seed for randomization
np.random.seed(24601)

### Example mouse.

In [None]:
data_type = 'S'
mouse = 'mc51' ## mc56, session 20 has a massive population event
session = '1'
data_path = f'../../{project_folder[0]}/{experiment_folders[0]}/output/{data_of_interest}/{mouse}/{data_type}/{mouse}_{data_type}_{session}.nc'
sdata = xr.open_dataset(data_path)[data_type]
pop_act = zscore(np.nanmean(zscore(sdata, axis=1), axis=0))
bursts_og = pop_act > zthresh
num_bursts_og = np.sum(bursts_og)
xaxis = sdata['behav_t'].values

## New burst definition, where anything above the threshold with a minimum length of min_len are combined into one burst
burst_start, burst_end = ctn.define_population_bursts(sdata, min_len=min_len, zthresh=zthresh, first_zscore=True, second_zscore=True)
burst_mid = np.round((burst_start + burst_end) / 2).astype(int)

In [None]:
## Plot bursts below linearized position and include raster of cell firing
fig = pf.custom_graph_template(x_title='', y_title='', rows=3, columns=1, shared_x=True, font_size=22, titles=[f'{mouse}'],
                               height=900, width=1000, row_heights=[0.2, 0.2, 0.6], vertical_spacing=0.04)
## Linearized position
fig.add_trace(go.Scattergl(x=xaxis, y=sdata['lin_position'], mode='lines', line_color='darkgrey', showlegend=False), row=1, col=1)
fig.add_trace(go.Scattergl(x=xaxis[sdata['water']], y=sdata['lin_position'][sdata['water']], mode='markers', marker_color='red', showlegend=False), row=1, col=1)
## Population bursts
fig.add_trace(go.Scattergl(x=xaxis, y=pop_act, mode='lines', line_color='black', showlegend=False), row=2, col=1)
fig.add_trace(go.Scattergl(x=xaxis[burst_mid], y=pop_act[burst_mid], mode='markers', marker_color='red', showlegend=False), row=2, col=1)
fig.add_hline(y=zthresh, line_width=1, line_color='red', line_dash='dash', opacity=1, row=2, col=1)
## Raster plot
fig.add_trace(go.Heatmap(x=xaxis, y=np.arange(0, sdata.shape[0]), z=(sdata > 0).astype(int), colorscale='gray_r', showscale=False), row=3, col=1)
fig.update_yaxes(title='Position', row=1)
fig.update_yaxes(title='Z-Scored Activity', row=2)
fig.update_yaxes(title='Neuron', autorange='reversed', row=3)
fig.update_xaxes(showticklabels=False, ticks='', row=1)
fig.update_xaxes(showticklabels=False, ticks='', row=2)
fig.update_xaxes(title='Time (s)', showticklabels=True, row=3)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_raster_position_pop_act.png'), width=1000, height=900)
# fig.write_html(pjoin(fig_path, f'{mouse}_raster_position_pop_act.html'))

In [None]:
## Old and new burst definition plot
## Plot old definition
fig = pf.custom_graph_template(x_title='', y_title='Z-Scored Activity', rows=2, columns=1, shared_x=True, 
                               titles=['Original Definition', 'New Definition'], height=600, width=1000)
fig.add_trace(go.Scatter(x=xaxis, y=pop_act, mode='lines', line_color='black', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=xaxis[bursts_og], y=pop_act[bursts_og], mode='markers', line_color='red', showlegend=False), row=1, col=1)
## New definition
fig.add_trace(go.Scatter(x=xaxis, y=pop_act, mode='lines', line_color='black', showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=xaxis[burst_mid], y=pop_act[burst_mid], mode='markers', line_color='red', showlegend=False), row=2, col=1)
fig.add_hline(y=zthresh, opacity=1, line_width=1, line_dash='dash', line_color='red')
fig.update_xaxes(title='Time (s)', row=2)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_old_new_bursts_popact.png'))
fig.write_html(pjoin(fig_path, f'{mouse}_old_new_bursts_popact.html'))

In [None]:
## Shuffle activity example
shuffled_act = np.zeros((sdata.shape[0], sdata.shape[1]))
for uid in np.arange(0, sdata.shape[0]):
    rand_shift = np.random.randint(0, sdata.shape[1])
    shuffled_act[uid] = np.roll(sdata[uid].values, rand_shift)
shuffled_pop_act = zscore(np.mean(zscore(shuffled_act, axis=1), axis=0))
shuff_burst_start, shuff_burst_end = ctn.define_population_bursts(shuffled_act, min_len=min_len, zthresh=zthresh)
shuff_burst_mid = np.round((shuff_burst_start + shuff_burst_end) / 2).astype(int)
bursts_og = pop_act > zthresh
## Plot original and shuffled data
fig = pf.custom_graph_template(x_title='', y_title='Z-Scored Activity', rows=2, columns=1, shared_x=True, 
                               titles=['Original', 'Shuffled'], height=600, width=1000)
fig.add_trace(go.Scatter(x=xaxis, y=pop_act, mode='lines', line_color='black', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=xaxis[burst_mid], y=pop_act[burst_mid], mode='markers', line_color='red', showlegend=False), row=1, col=1)
## New definition
fig.add_trace(go.Scatter(x=xaxis, y=shuffled_pop_act, mode='lines', line_color='black', showlegend=False), row=2, col=1)
fig.add_trace(go.Scatter(x=xaxis[shuff_burst_mid], y=shuffled_pop_act[shuff_burst_mid], mode='markers', line_color='red', showlegend=False), row=2, col=1)
fig.add_hline(y=zthresh, opacity=1, line_width=1, line_dash='dash', line_color='red')
fig.update_xaxes(title='Time (s)', row=2)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_original_shuffled_popact.png'))
fig.write_html(pjoin(fig_path, f'{mouse}_original_shuffled_popact.html'))

In [None]:
## Shuffle activity and plot histogram of observed vs null for new definition
shuff_num_og = np.zeros(nshuffles)
num_bursts_shuff = np.zeros(nshuffles)
for shuff in np.arange(0, nshuffles):
    shuffled_act = np.zeros((sdata.shape[0], sdata.shape[1]))
    for uid in np.arange(0, sdata.shape[0]):
        rand_shift = np.random.randint(0, sdata.shape[1])
        shuffled_act[uid] = np.roll(sdata[uid].values, rand_shift)
    ## New definition
    shuff_burst_start, shuff_burst_end = ctn.define_population_bursts(shuffled_act, min_len=min_len, zthresh=zthresh)
    shuff_burst_mid = np.round((shuff_burst_start + shuff_burst_end) / 2).astype(int)
    num_bursts_shuff[shuff] = shuff_burst_mid.shape[0]
    ## Old definition
    shuff_bursts_og = zscore(np.mean(zscore(shuffled_act, axis=1), axis=0)) > zthresh
    shuff_num_og[shuff] = np.sum(shuff_bursts_og)

In [None]:
## Plot histogram of number of burst's shuffled distribution
fig = pf.custom_graph_template(x_title='Number of Bursts', y_title='Probability')
fig.add_trace(go.Histogram(x=num_bursts_shuff, marker_color='darkgrey', marker_line_width=2, 
                           marker_line_color='black', showlegend=False, histnorm='probability'))
fig.add_vline(x=burst_mid.shape[0], line_color='red', line_dash='dash', line_width=1, opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_shuffled_dist_new_definition.png'))
if burst_mid.shape[0] < np.percentile(num_bursts_shuff, q=5):
    print('Significantly fewer bursts than chance.')
elif burst_mid.shape[0] > np.percentile(num_bursts_shuff, q=95):
    print('Significantly more bursts than chance.')
else:
    print('Bursts are at chance.')

In [None]:
## Plot histogram of number of burst's shuffled distribution for the original burst definition
fig = pf.custom_graph_template(x_title='Number of Bursts', y_title='Probability')
fig.add_trace(go.Histogram(x=shuff_num_og, marker_color='darkgrey', marker_line_width=2, 
                           marker_line_color='black', showlegend=False, histnorm='probability'))
fig.add_vline(x=num_bursts_og, line_color='red', line_dash='dash', line_width=1, opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_shuffled_dist_og_definition.png'))
if np.sum(num_bursts_og) < np.percentile(shuff_num_og, q=5):
    print('Significantly fewer bursts than chance.')
elif np.sum(num_bursts_og) > np.percentile(shuff_num_og, q=95):
    print('Significantly more bursts than chance.')
else:
    print('Bursts are at chance.')

### Sweep through different values of min_len and zthresh for all mice to see changes in the number of bursts.

In [None]:
## Plot number of bursts for different min_len and different zthresh for all mice
data_type = 'S' 
session = '20'
zlist = [2, 2.5, 3, 3.5, 4]
len_list = [1, 2, 3, 4, 5, 6]
output_dict = {'mouse': [], 'group': [], 'group_two': [], 'sex': [], 'min_len': [], 'zthresh': [], 'num_bursts': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            if mouse == 'mc48': 
                pass 
            else:
                mpath = pjoin(exp_path, f'{mouse}/{data_type}')
                sex = 'Male' if mouse in male_mice else 'Female'
                group = 'Two-Context' if mouse in control_mice else 'Multi-Context'
                group_two = 'MC M' if mouse in experimental_mice and mouse in male_mice else 'MC F' if mouse in experimental_mice and mouse not in male_mice else 'TC M' if mouse not in experimental_mice and mouse in male_mice else 'TC F'
                sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
                for zthresh in zlist:
                    for min_len in len_list:
                        burst_start, burst_end = ctn.define_population_bursts(sdata, min_len=min_len, zthresh=zthresh)
                        burst_mid = np.round((burst_start + burst_end) / 2).astype(int)
                        output_dict['mouse'].append(mouse)
                        output_dict['group'].append(group)
                        output_dict['group_two'].append(group_two)
                        output_dict['sex'].append(sex)
                        output_dict['min_len'].append(min_len)
                        output_dict['zthresh'].append(zthresh)
                        output_dict['num_bursts'].append(burst_mid.shape[0])
burst_df = pd.DataFrame(output_dict)

In [None]:
## For real data
avg = burst_df.groupby(['group_two', 'min_len', 'zthresh'], as_index=False).agg({'num_bursts': ['mean', 'sem']})
yvar = 'num_bursts'
fig = pf.custom_graph_template(x_title='Z-Threshold', y_title='Number of Bursts', titles=['Original Data'], width=700)
for group in ['TC M', 'TC F', 'MC M', 'MC F']:
    gdata = avg[avg['group_two'] == group]
    for min_len in gdata['min_len'].unique():
        plot_data = gdata[gdata['min_len'] == min_len]
        fig.add_trace(go.Scatter(x=plot_data['zthresh'], y=plot_data[yvar]['mean'], name=f'Minimum frames: {min_len}', mode='lines+markers',
                                 line_color=group_colors[group], error_y=dict(type='data', array=plot_data[yvar]['sem']), 
                                 showlegend=False, legendgroup=str(min_len)))
for min_len in gdata['min_len'].unique():
    fig.data[int(min_len)-1]['showlegend'] = True
fig.show()
fig.write_html(pjoin(fig_path, 'real_data_zthresh_minlen_paramtersweep.html'))

In [None]:
## Plot number of bursts for different min_len and different zthresh for all mice
## but shuffle data going into burst detection first
output_dict = {'mouse': [], 'group': [], 'group_two': [], 'sex': [], 'min_len': [], 'zthresh': [], 'num_bursts': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            if mouse == 'mc48': 
                pass 
            else:
                mpath = pjoin(exp_path, f'{mouse}/{data_type}')
                sex = 'Male' if mouse in male_mice else 'Female'
                group = 'Two-Context' if mouse in control_mice else 'Multi-Context'
                group_two = 'MC M' if mouse in experimental_mice and mouse in male_mice else 'MC F' if mouse in experimental_mice and mouse not in male_mice else 'TC M' if mouse not in experimental_mice and mouse in male_mice else 'TC F'
                sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
                shuffled_act = np.zeros((sdata.shape[0], sdata.shape[1]))
                for uid in np.arange(0, sdata.shape[0]):
                    rand_shift = np.random.randint(0, sdata.shape[1])
                    shuffled_act[uid] = np.roll(sdata[uid].values, rand_shift)
                for zthresh in zlist:
                    for min_len in len_list:
                        burst_start, burst_end = ctn.define_population_bursts(shuffled_act, min_len=min_len, zthresh=zthresh)
                        burst_mid = np.round((burst_start + burst_end) / 2).astype(int)
                        output_dict['mouse'].append(mouse)
                        output_dict['group'].append(group)
                        output_dict['group_two'].append(group_two)
                        output_dict['sex'].append(sex)
                        output_dict['min_len'].append(min_len)
                        output_dict['zthresh'].append(zthresh)
                        output_dict['num_bursts'].append(burst_mid.shape[0])
shuffled_burst_df = pd.DataFrame(output_dict)

In [None]:
## For shuffled data
avg = shuffled_burst_df.groupby(['group_two', 'min_len', 'zthresh'], as_index=False).agg({'num_bursts': ['mean', 'sem']})
yvar = 'num_bursts'
fig = pf.custom_graph_template(x_title='Z-Threshold', y_title='Number of Bursts', titles=['Shuffled Data'], width=700)
for group in ['TC M', 'TC F', 'MC M', 'MC F']:
    gdata = avg[avg['group_two'] == group]
    for min_len in gdata['min_len'].unique():
        plot_data = gdata[gdata['min_len'] == min_len]
        fig.add_trace(go.Scatter(x=plot_data['zthresh'], y=plot_data[yvar]['mean'], name=f'Minimum frames: {min_len}', mode='lines+markers',
                                 line_color=group_colors[group], error_y=dict(type='data', array=plot_data[yvar]['sem']), 
                                 showlegend=False, legendgroup=str(min_len)))
for min_len in gdata['min_len'].unique():
    fig.data[int(min_len)-1]['showlegend'] = True
fig.show()

### Shuffle activity, but use a threshold determined by the population mean and std of the original data. 

Calculate number, average length, and skew of bursts.

In [None]:
## Shuffle activity but use a threshold from the original data
## Don't use the second z-score when calculating pop_act
pop_act = np.mean(zscore(sdata, axis=1), axis=0)
pop_act_mean = pop_act.mean()
pop_act_sd = pop_act.std()
pop_thresh = pop_act_mean + zthresh*pop_act_sd
## Get indexes where bursts start and end
burst_start, burst_end = ctn.define_population_bursts(sdata, min_len=min_len, zthresh=pop_thresh, first_zscore=True, second_zscore=False)
burst_mid = np.round((burst_start + burst_end) / 2).astype(int)
## Calculate the length of bursts
burst_lengths = ((burst_end - burst_start) / fps) * 1000 ## in milliseconds
## Get burst amplitudes
amp = []
for bs, be in zip(burst_start, burst_end):
    amp.append(np.max(pop_act[bs:be+1])) ## not inclusive, so add 1

## Shuffle activity but use the population threshold defined above
num_bursts_shuff = np.zeros(nshuffles)
avglen_bursts_shuff = np.zeros(nshuffles)
skew_bursts_shuff = np.zeros(nshuffles)
burst_amplitude_shuff = np.zeros(nshuffles)
for shuff in np.arange(0, nshuffles):
    shuffled_act = np.zeros((sdata.shape[0], sdata.shape[1]))
    for uid in np.arange(0, sdata.shape[0]):
        rand_shift = np.random.randint(0, sdata.shape[1])
        shuffled_act[uid] = np.roll(sdata[uid].values, rand_shift)
    ## New definition
    shuff_burst_start, shuff_burst_end = ctn.define_population_bursts(shuffled_act, min_len=min_len, zthresh=pop_thresh, first_zscore=True, second_zscore=False)
    ## Get average shuffled amplitude
    pop_shuffled = np.mean(zscore(shuffled_act, axis=1), axis=0)
    shuff_amp = []
    for bs, be in zip(shuff_burst_start, shuff_burst_end):
        shuff_amp.append(np.max(pop_shuffled[bs:be+1])) ## not inclusive, so add 1
    burst_amplitude_shuff[shuff] = np.mean(np.array(shuff_amp))
    ## Get the number of bursts for each shuffle
    shuff_burst_mid = np.round((shuff_burst_start + shuff_burst_end) / 2).astype(int)
    num_bursts_shuff[shuff] = shuff_burst_mid.shape[0]
    ## Length of bursts and skew
    burst_lengths_shuff = ((shuff_burst_end - shuff_burst_start) / fps) * 1000 ## in milliseconds
    avglen_bursts_shuff[shuff] = np.mean(burst_lengths_shuff)
    skew_bursts_shuff[shuff] = skew(burst_lengths_shuff)

In [None]:
## Plot histogram of number of burst's shuffled distribution
fig = pf.custom_graph_template(x_title='Number of Bursts', y_title='Probability', titles=[f'{mouse}'])
fig.add_trace(go.Histogram(x=num_bursts_shuff, marker_color=chance_color, marker_line_width=2, 
                           marker_line_color='black', showlegend=False, histnorm='probability'))
fig.add_vline(x=burst_mid.shape[0], line_color='red', line_dash='dash', line_width=1, opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_shuffled_dist_number_bursts_{pop_thresh}_{min_len}_using_popthresh.png'))
if burst_mid.shape[0] < np.percentile(num_bursts_shuff, q=5):
    print('Significantly fewer bursts than chance.')
elif burst_mid.shape[0] > np.percentile(num_bursts_shuff, q=95):
    print('Significantly more bursts than chance.')
else:
    print('Bursts are at chance.')

In [None]:
## Plot histogram of average length of burst's shuffled distribution
fig = pf.custom_graph_template(x_title='Average Burst Length (ms)', y_title='Probability', titles=[f'{mouse}'])
fig.add_trace(go.Histogram(x=avglen_bursts_shuff, marker_color=chance_color, marker_line_width=2, 
                           marker_line_color='black', showlegend=False, histnorm='probability'))
fig.add_vline(x=np.mean(burst_lengths), line_color='red', line_dash='dash', line_width=1, opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_shuffled_dist_burst_len_{pop_thresh}_{min_len}_using_popthresh.png'))
if np.mean(burst_lengths) < np.percentile(avglen_bursts_shuff, q=5):
    print('Significantly shorter bursts than chance.')
elif np.mean(burst_lengths) > np.percentile(avglen_bursts_shuff, q=95):
    print('Significantly longer bursts than chance.')
else:
    print('Burst length is at chance.')

In [None]:
## Plot histogram of average burst amplitude as a shuffled distribution
fig = pf.custom_graph_template(x_title='Average Burst Amplitude', y_title='Probability', titles=[f'{mouse}'])
fig.add_trace(go.Histogram(x=burst_amplitude_shuff, marker_color=chance_color, marker_line_width=2, 
                           marker_line_color='black', showlegend=False, histnorm='probability'))
fig.add_vline(x=np.mean(np.array(amp)), line_color='red', line_dash='dash', line_width=1, opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_shuffled_dist_burst_amplitude_{pop_thresh}_{min_len}_using_popthresh.png'))
if np.mean(np.array(amp)) < np.percentile(burst_amplitude_shuff, q=5):
    print('Significantly smaller bursts than chance.')
elif np.mean(np.array(amp)) > np.percentile(burst_amplitude_shuff, q=95):
    print('Significantly larger bursts than chance.')
else:
    print('Burst amplitude is at chance.')

In [None]:
## Plot histogram of skew of burst length as a shuffled distribution
fig = pf.custom_graph_template(x_title='Burst Length Skew', y_title='Probability', titles=[f'{mouse}'])
fig.add_trace(go.Histogram(x=skew_bursts_shuff, marker_color=chance_color, marker_line_width=2, 
                           marker_line_color='black', showlegend=False, histnorm='probability'))
fig.add_vline(x=skew(burst_lengths), line_color='red', line_dash='dash', line_width=1, opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_shuffled_dist_skew_burst_len_{pop_thresh}_{min_len}_using_popthresh.png'))

### Align bursts around locomotion for a single mouse.

In [None]:
## mc48, session 19, gives inf in some frames of velocity calculation. Will need to investigate.
window_size = 2 ## in seconds
x_pos, y_pos, lin_pos = ctb.smooth_over_trials(sdata, filter_width=2)
x_cm, y_cm = ctb.convert_to_cm(x=x_pos, y=y_pos)
velocity, running = pc.define_running_epochs(x_cm, y_cm, sdata['behav_t'].values, velocity_thresh=velocity_thresh)

loco = np.empty((burst_mid.shape[0], window_size * window_size * fps + 1))
loco.fill(np.nan)
for idx, b in enumerate(burst_mid):
    d = ctn.extract_windowed_data_by_index(velocity, window_val=b, window_size=window_size, fps=fps)
    if d.size == loco.shape[1]:
        loco[idx] = d

avg_act = np.nanmean(loco, axis=0)
sem = np.nanstd(loco, axis=0, ddof=1) / loco.shape[0]

In [None]:
time = np.arange(-(window_size * fps), (window_size * fps) + 1) / fps
fig = pf.custom_graph_template(x_title='Time around Burst Center (s)', y_title='Velocity (cm/s)', titles=[f'{mouse}'])
fig.add_trace(go.Scatter(x=time, y=avg_act, mode='lines', line_color='midnightblue', showlegend=False))
fig.add_trace(go.Scatter(x=time, y=avg_act + sem, mode='lines', 
                         name='Upper Bound', line=dict(width=0), showlegend=False, line_color=sem_color))
fig.add_trace(go.Scatter(x=time, y=avg_act - sem, mode='lines', line_color=sem_color,
                         name='Lower Bound', line=dict(width=0), showlegend=False, fillcolor=sem_color, fill='tonexty'))
fig.add_vline(x=0, line_width=1, line_color='red', line_dash='dash', opacity=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_velocity_around_burst_center.png'))

### Where are mice on the circle track when bursts occur?

In [None]:
## Find every frame that is part of a burst to weight a 2d histogram by those values
pop_act = np.mean(zscore(sdata, axis=1), axis=0)
pop_act_mean = pop_act.mean()
pop_act_sd = pop_act.std()
pop_thresh = pop_act_mean + zthresh*pop_act_sd
above_thresh = pop_act > pop_thresh
## Get indexes where bursts start and end
burst_start, burst_end = ctn.define_population_bursts(sdata, min_len=min_len, zthresh=pop_thresh, first_zscore=True, second_zscore=False)
burst_mid = np.round((burst_start + burst_end) / 2).astype(int)
## Create frames where bursts are true
burst_frames = np.zeros((sdata.shape[1]))
burst_frames[burst_mid] = 1

H, x_edges, y_edges = np.histogram2d(x_pos, y_pos, bins=40, weights=burst_frames)
rw_one_x = np.mean(sdata['x'][(sdata['water']) & (sdata['lick_port'] == sdata.attrs['reward_one'])])
rw_one_y = np.mean(sdata['y'][(sdata['water']) & (sdata['lick_port'] == sdata.attrs['reward_one'])])

rw_two_x = np.mean(sdata['x'][(sdata['water']) & (sdata['lick_port'] == sdata.attrs['reward_two'])])
rw_two_y = np.mean(sdata['y'][(sdata['water']) & (sdata['lick_port'] == sdata.attrs['reward_two'])])

In [None]:
fig = pf.custom_graph_template(x_title='X Position', y_title='Y Position', titles=[f'{mouse}'])
fig.add_trace(go.Heatmap(x=x_edges, y=y_edges, z=H, colorscale='gray_r'))
fig.add_trace(go.Scatter(x=[rw_one_x], y=[rw_one_y], mode='markers', marker_color='red', showlegend=False))
fig.add_trace(go.Scatter(x=[rw_two_x], y=[rw_two_y], mode='markers', marker_color='red', showlegend=False))
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_burst_occupancy.png'))

### Number of bursts across days.

In [None]:
output_dict = {'mouse': [], 'day': [], 'context_day': [], 'session': [], 'group': [], 'group_two': [], 'sex': [], 'num_bursts': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = 'Male' if mouse in male_mice else 'Female'
            group = 'Two-Context' if mouse in control_mice else 'Multi-Context'
            group_two = 'MC M' if mouse in experimental_mice and mouse in male_mice else 'MC F' if mouse in experimental_mice and mouse not in male_mice else 'TC M' if mouse not in experimental_mice and mouse in male_mice else 'TC F'
            for index, session in enumerate(os.listdir(mpath)):
                if '21' in session:
                    pass 
                else:
                    if (mouse == 'mc42') & (index > 14):
                        index += 1
                    elif (mouse == 'mc43') & (index > 11):
                        index += 1
                    elif (mouse == 'mc44') & (index > 7):
                        index += 1
                    elif (mouse == 'mc46') & (index > 9):
                        index += 1
                    elif (mouse == 'mc52') & (index > 2):
                        index += 1
                    elif (mouse == 'mc55') & (index > 2):
                        index += 1

                    sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                    ## Get indexes where bursts start and end
                    pop_act = np.nanmean(zscore(sdata, axis=1), axis=0)
                    pop_act_mean = pop_act.mean()
                    pop_act_sd = pop_act.std()
                    pop_thresh = pop_act_mean + zthresh*pop_act_sd
                    burst_start, burst_end = ctn.define_population_bursts(sdata, min_len=min_len, zthresh=pop_thresh, first_zscore=True, second_zscore=False)
                    burst_mid = np.round((burst_start + burst_end) / 2).astype(int)

                    output_dict['mouse'].append(mouse)
                    output_dict['group'].append(group)
                    output_dict['group_two'].append(group_two)
                    output_dict['sex'].append(sex)
                    output_dict['session'].append(sdata.attrs['session'])
                    output_dict['day'].append(index+1)
                    output_dict['num_bursts'].append(burst_mid.shape[0])
                    output_dict['context_day'].append(int(re.search(pattern='[0-9]+', string=sdata.attrs['session_two'])[0]))
output_df = pd.DataFrame(output_dict)
avg_data = output_df.groupby(['group_two', 'session', 'context_day'], as_index=False).agg({'num_bursts': ['mean', 'sem']})

In [None]:
## Plot number of bursts for multi-context mice
fig = pf.custom_graph_template(x_title='Day in Context', y_title='', rows=1, columns=2, shared_y=True, shared_x=True,
                               titles=['Multi-Context Male', 'Multi-Context Female'], width=800)
for idx, group in enumerate(['MC M', 'MC F']):
    gdata = avg_data[avg_data['group_two'] == group].reset_index(drop=True)
    col = idx + 1
    for session in gdata['session'].unique():
        plot_data = gdata[gdata['session'] == session].reset_index(drop=True)
        fig.add_trace(go.Scatter(x=plot_data['context_day'], y=plot_data['num_bursts']['mean'], mode='lines+markers', name=session, legendgroup=session,
                                 error_y=dict(type='data', array=plot_data['num_bursts']['sem']), showlegend=False,
                                 line_color=context_colors[session]), row=1, col=col)
fig.update_yaxes(title='Number of Bursts', col=1)
for idx in np.arange(0, 4):
    fig.data[idx]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'number_of_bursts_multicontext.png'))

In [None]:
## Number of Bursts across days for two-context mice
fig = pf.custom_graph_template(x_title='Day in Context', y_title='', rows=1, columns=2, shared_y=True, shared_x=True,
                               titles=['Two-Context Male', 'Two-Context Female'], width=800)
for idx, group in enumerate(['TC M', 'TC F']):
    gdata = avg_data[avg_data['group_two'] == group].reset_index(drop=True)
    col = idx + 1
    for session in gdata['session'].unique():
        plot_data = gdata[gdata['session'] == session].reset_index(drop=True)
        fig.add_trace(go.Scatter(x=plot_data['context_day'], y=plot_data['num_bursts']['mean'], mode='lines+markers', name=session, legendgroup=session,
                                 error_y=dict(type='data', array=plot_data['num_bursts']['sem']), showlegend=False,
                                 line_color=context_colors[session]), row=1, col=col)
fig.update_yaxes(title='Number of Bursts', col=1)
fig.update_yaxes(range=[160, 300])
for idx in np.arange(0, 2):
    fig.data[idx]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, 'number_of_bursts_twocontext.png'))

In [None]:
## Plot number of bursts for two-context mice
fig = pf.custom_graph_template(x_title='Day in Context', y_title='', rows=1, columns=2, shared_y=True, shared_x=True,
                               titles=['Multi-Context Male', 'Multi-Context Female'], width=800)
for idx, group in enumerate(['MC M', 'MC F']):
    gdata = avg_data[avg_data['group_two'] == group].reset_index(drop=True)
    col = idx + 1
    for session in gdata['session'].unique():
        plot_data = gdata[gdata['session'] == session].reset_index(drop=True)
        fig.add_trace(go.Scatter(x=plot_data['context_day'], y=plot_data['num_bursts']['mean'], mode='lines+markers', name=session, legendgroup=session,
                                 error_y=dict(type='data', array=plot_data['num_bursts']['sem']), showlegend=False,
                                 line_color=context_colors[session]), row=1, col=col)

### Creating new code for shuffling - work in progress.

In [None]:
def shuffle_data(ar, num_cells):
    shuff_ar = np.zeros((num_cells, ar.shape[-1]))
    for cell in np.arange(0, num_cells):
        random_shift = np.random.randint(0, ar.shape[-1])
        shuff_ar[cell] = np.roll(ar[cell], random_shift)
    return shuff_ar

In [None]:
%%time
## Testing new shuffling code
# ar = sdata.expand_dims(dim={'shuffle': 5})
ar = sdata.copy()
t = xr.apply_ufunc(
    shuffle_data,
    ar,
    dask='parallelized',
    kwargs={'num_cells': ar.shape[0]}
)
t