In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
from os.path import join as pjoin
from tqdm.notebook import tqdm
import plotly.graph_objects as go
import pingouin as pg
from scipy.stats import pearsonr, spearmanr, zscore
import itertools

sys.path.append('/home/austinbaggetta/csstorage3/CircleTrack/CircleTrackAnalysis')
import circletrack_behavior as ctb
import circletrack_neural as ctn
import place_cells as pc
import plotting_functions as pf

In [None]:
## Settings
project_folder = ['MultiCon_Imaging']
experiment_folders = ['MultiCon_Imaging5', 'MultiCon_Imaging6', 'MultiCon_Imaging7']
dpath = f'../../{project_folder[0]}'
fig_path = f'../../../Manuscripts/MultiCon/intermediate_plots'
chance_color = '#7d7d7d'
avg_color = '#287347'
subject_color = '#7d7d7d'
ce_colors = ['#7A22BC', '#378616']
ce_colors_dict = {'Two-context': '#378616', 'Multi-context': '#7A22BC'}
symbol_dict = {'Two-context': 'x', 'Multi-context': 'circle'}
symbols_list = ['x', 'circle']
context_colors = {'A': '#00802d', 'B': '#006c79', 'C': '#004da4', 'D': '#430073'}
mouse_colors = ['midnightblue', 'darkred', 'darkorchid', 'darkturquoise']
male_mice = ['mc44', 'mc46', 'mc54', 'mc55', 'mc64', 'mc65']
control_mice = ['mc46', 'mc49', 'mc52', 'mc54', 'mc59', 'mc60', 'mc61', 'mc64']
session_list = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]
control_list = [f'A{x}' for x in np.arange(1, 16)] + [f'B{x}' for x in np.arange(1, 6)]
bin_size = 0.1
velocity_thresh = 10
data_of_interest = 'aligned_minian' ## one of behav, aligned_minian, lin_behav

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

xr.set_options(keep_attrs=True)

### Plot number of cells across days.

In [None]:
data_type = 'S'
cell_dict = {'mouse': [], 'group': [], 'group_two': [], 'session': [], 'day': [], 'num_cells': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            if mouse == 'mc47':
                pass 
            else:
                mpath = pjoin(exp_path, f'{mouse}/{data_type}')
                sex = ctn.set_sex(mouse, male_mice)
                group = ctn.set_group(mouse, control_mice)
                for index, session in enumerate(os.listdir(mpath)):
                    index = ctn.mouse_indices(mouse, index)
                    data = xr.open_dataset(pjoin(mpath, session))[data_type] 
                    cell_dict['mouse'].append(mouse)
                    cell_dict['group'].append(group)
                    cell_dict['group_two'].append(sex)
                    cell_dict['session'].append(data.attrs['session_two'])
                    cell_dict['day'].append(index + 1)
                    cell_dict['num_cells'].append(data.shape[0])
cell_df = pd.DataFrame(cell_dict)

In [None]:
## Plot average number of cells across days
avg_cells = cell_df.groupby(['day'], as_index=False).agg({'num_cells': ['mean', 'sem']})
avg_cells = avg_cells[avg_cells['day'] < 21]
fig = pf.custom_graph_template(x_title='Day', y_title='Number of Cells')
fig.add_trace(go.Scatter(x=avg_cells['day'], y=avg_cells['num_cells']['mean'], mode='lines+markers', showlegend=False,
                         line_color=avg_color, error_y=dict(type='data', array=avg_cells['num_cells']['sem'])))
for mouse in cell_df['mouse'].unique():
    mdata = cell_df[cell_df['mouse'] == mouse]
    fig.add_trace(go.Scatter(x=mdata['day'], y=mdata['num_cells'], mode='lines', line_color=avg_color,
                             opacity=0.5, line_width=1, name=mouse, showlegend=False))
fig.update_yaxes(range=[0, 900])
fig.show()
fig.write_image(pjoin(fig_path, 'number_of_cells.png'))

In [None]:
## Plot average number of cells across days for control and experimental
avg_cells = cell_df.groupby(['day', 'group'], as_index=False).agg({'num_cells': ['mean', 'sem']})
avg_cells = avg_cells[avg_cells['day'] < 21]
fig = pf.custom_graph_template(x_title='Day', y_title='Number of Cells', width=600)
for group in ['Two-context', 'Multi-context']:
    gdata = avg_cells[avg_cells['group'] == group]
    fig.add_trace(go.Scatter(x=gdata['day'], y=gdata['num_cells']['mean'], mode='lines+markers', name=group,
                            line_color=ce_colors_dict[group], error_y=dict(type='data', array=gdata['num_cells']['sem'])))
for mouse in cell_df['mouse'].unique():
    mdata = cell_df[cell_df['mouse'] == mouse]
    fig.add_trace(go.Scatter(x=mdata['day'], y=mdata['num_cells'], mode='lines', line_color=ce_colors_dict[mdata['group'].unique()[0]],
                             opacity=0.7, line_width=1, name=mouse, showlegend=False))
fig.update_yaxes(range=[0, 900])
fig.show()
fig.write_image(pjoin(fig_path, 'number_of_cells_control_experimental.png'))
cell_df.mixed_anova(dv='num_cells', between='group', within='day', subject='mouse')

### Testing...

In [None]:
df = pd.read_feather('/media/caishuman/csstorage3/Austin/CircleTrack/MultiCon_Imaging/MultiCon_Imaging5/output/behav/mc51/mc51_18.feat')
df.head()

In [None]:
plot_all_trials = False
plot_forward_trials = False
plot_correct_direction = False
shift = 5
correct_dir = ctb.get_correct_direction(df['a_pos'], shift=shift)
trials = ctb.get_trials(df, shift_factor=0.8, angle_type='radians', counterclockwise=True)
df['new_trials'] = trials

fig = pf.custom_graph_template(x_title='Time (s)', y_title='Angular Position (deg)', width=700)
fig.add_trace(go.Scatter(x=df['t'], y=df['a_pos'], mode='lines', line_color='darkgrey', showlegend=False))
if plot_correct_direction:
    fig.add_trace(go.Scatter(x=df['t'][correct_dir], y=df['a_pos'][correct_dir], mode='markers', 
                             marker_color='midnightblue', showlegend=False, marker=dict(line=dict(width=0.3))))

if plot_all_trials:
    for trial in np.unique(trials):
        tdata = df[df['new_trials'] == trial]
        fig.add_trace(go.Scatter(x=tdata['t'], y=tdata['a_pos'], showlegend=False))


if plot_forward_trials:
    forward, reverse = ctb.forward_reverse_trials(df, df['new_trials'])
    forward = np.asarray(forward)
    test = []
    for val in df['new_trials'].values:
        test.append(val in forward)
    
    for trial in forward:
        d = df[df['new_trials'] == trial]
        fig.add_trace(go.Scatter(x=d['t'], y=d['a_pos'], line_color='red', showlegend=False))
fig.show()

In [None]:
## Save place_cell_df
place_cell_df.to_csv(pjoin(fig_path, 'intermediate_data/percent_place_cells_nonshuffled_correct_direction.csv'))

In [None]:
## Plot percentage of place cells across days
fig = pf.custom_graph_template(x_title='Day', y_title='Percent Place Cells (%)')
place_cell_df['percentage_place'] = place_cell_df['percentage_place'].astype(np.float64)
avg_place = place_cell_df.groupby(['group', 'day'], as_index=False).agg({'percentage_place': ['mean', 'sem']})
fig = pf.plot_behavior_across_days(place_cell_df, x_var='day', y_var='percentage_place', groupby_var=['day', 'group'], plot_transitions=[5.5, 10.5, 15.5],
                                   marker_color=ce_colors, avg_color=avg_color, expert_line=False, chance=False, transition_color=['darkgrey', 'darkgrey', 'darkgrey'],
                                   plot_datapoints=False, x_title='Day', y_title='Percent Place Cells (%)', titles=[''], height=500, width=500)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'percent_place_cells_only_correct_direction.png'))
place_cell_df.mixed_anova(within='day', between='group', dv='percentage_place', subject='mouse')

In [None]:
## Shuffle neural data first prior to calculating the percent of place cells
data_type = 'S'
alpha = 0.01
nbins = 20
cell_dict = {'mouse': [], 'group': [], 'group_two': [], 'session': [], 'day': [], 'percentage_place': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in os.listdir(exp_path):
            if mouse == 'mc47':
                pass 
            else:
                mpath = pjoin(exp_path, f'{mouse}/{data_type}')
                sex = 'Male' if mouse in male_mice else 'Female'
                group = 'Control' if mouse in control_mice else 'Experimental'
                for index, session in enumerate(os.listdir(mpath)):
                    if (mouse == 'mc43') & (index > 11):
                        index += 1
                    elif (mouse == 'mc42') & (index > 14):
                        index += 1
                    elif (mouse == 'mc44') & (index > 7):
                        index += 1
                    elif (mouse == 'mc46') & (index > 9):
                        index += 1
                    elif (mouse == 'mc52') & (index > 2):
                        index += 1
                    
                    data = xr.open_dataset(pjoin(mpath, session))[data_type]
                    shifted_data = np.roll(data.values, 300) ## shift by 5 seconds

                    x_cm, y_cm = ctb.convert_to_cm(x=data['x'].values, y=data['y'].values)
                    pf_reference = pc.PlaceFields(x=x_cm,
                                  y=y_cm,
                                  t=data['behav_t'].values,
                                  neural_data=shifted_data,
                                  circular=True,
                                  shuffle_test=True,
                                  nbins=nbins,
                                  velocity_threshold=velocity_thresh)
                    data = data.assign_coords(place_cell_r = ('unit_id', pf_reference.data['spatial_info_pvals'] < alpha))
                    percent_place = (np.sum(data['place_cell_r']) / data.shape[0]) * 100

                    cell_dict['mouse'].append(mouse)
                    cell_dict['group'].append(group)
                    cell_dict['group_two'].append(sex)
                    cell_dict['session'].append(data.attrs['session_two'])
                    cell_dict['day'].append(index + 1)
                    cell_dict['percentage_place'].append(percent_place.values)
place_cell_shuffled_df = pd.DataFrame(cell_dict)

In [None]:
## Saved shuffled data (neural data rolled by 300 frames, or 5 seconds)
place_cell_shuffled_df.to_csv(pjoin(fig_path, 'intermediate_data/percent_place_cells_shuffled.csv'))

In [None]:
## Combine shuffled and non-shuffled
# place_cell_df = pd.read_csv()
# place_cell_shuffled_df = pd.read_csv()
place_cell_df['type'] = 'Nonshuffled' 
place_cell_shuffled_df['type'] = 'Shuffled'
combined = pd.concat([place_cell_df, place_cell_shuffled_df])
combined['percentage_place'] = combined['percentage_place'].astype(np.float64)

## Create plot
fig = pf.custom_graph_template(x_title='Day', y_title='Percent Place Cells (%)', width=600)
avg_combined = combined.groupby(['group', 'day', 'type'], as_index=False).agg({'percentage_place': ['mean', 'sem']})
for group in avg_combined['group'].unique():
    gdata = avg_combined[avg_combined['group'] == group]
    for type in gdata['type'].unique():
        plot_data = gdata[gdata['type'] == type]

        if type == 'Nonshuffled':
            linetype = 'solid'
            name = f'Nonshuffled {group}'
        else:
            linetype = 'dash'
            name = f'Shuffled {group}'

        fig.add_trace(go.Scatter(x=plot_data['day'], y=plot_data['percentage_place']['mean'], mode='lines+markers',
                                 error_y=dict(type='data', array=plot_data['percentage_place']['sem']),
                                 line=dict(dash=linetype), line_color=ce_color_dict[group], name=name))
for value in [5.5, 10.5, 15.5]:
    fig.add_vline(x=value, line_width=1, line_dash='dash', line_color=chance_color, opacity=1)
fig.update_yaxes(range=[0, 100])
fig.show()
fig.write_image(pjoin(fig_path, 'percent_place_cells_shuffled_nonshuffled.png'))

### Check whether speed threshold is selecting running.

In [None]:
## Single mouse correlation of activity across spatial bins
mouse_list = ['mc51']
experiment_list = ['MultiCon_Imaging5']
session_num = '2'
data_type = 'YrA'

for experiment in experiment_list:
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in mouse_list:
            mpath = pjoin(exp_path, mouse)
            avg_corr_dict = {'session': [], 'avg_corr': []}
            rpath = pjoin(mpath, f'{data_type}')
            fig = pf.custom_graph_template(x_title='Linear Position', y_title='Linear Position', height=500, width=550, titles=[f'Day {session_num}'])
            spike_data = xr.open_dataset(pjoin(rpath, f'{mouse}_{data_type}_{session_num}.nc'))[data_type]
            index_start = pc.adjust_behavior_start(x=spike_data['x'].values, y=spike_data['y'].values, t=spike_data['behav_t'].values)
            if index_start.size > 0:
                spike_data = spike_data[:, int(index_start[0]+1):]
            
            reward_one_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_one']].values)
            reward_two_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_two']].values)
            x_cm, y_cm = ctb.convert_to_cm(x=spike_data['x'].values, y=spike_data['y'].values)
            velocity, running = pc.define_running_epochs(x_cm, 
                                                        y_cm, 
                                                        spike_data['behav_t'].values, 
                                                        velocity_thresh=velocity_thresh)

In [None]:
## Look at histograms of linear position overlayed with and without the velocity threshold
fig = pf.custom_graph_template(x_title='Linearized Position (rad)', y_title='Probability', width=700)
fig.add_trace(go.Histogram(x=spike_data['lin_position'], histnorm='probability', name='No running threshold', marker_color='midnightblue'))
fig.add_trace(go.Histogram(x=spike_data['lin_position'][running], histnorm='probability', name='Running threshold', marker_color='red'))
fig.add_vline(x=reward_one_pos, line_width=1, line_dash='dash', opacity=1, line_color='red')
fig.add_vline(x=reward_two_pos, line_width=1, line_dash='dash', opacity=1, line_color='red')
fig.show()
fig.write_image(pjoin(fig_path, f'running_threshold_{velocity_thresh}.png'))

In [None]:
## Difference in x position
delta_x = np.diff(x_cm)
fig = pf.custom_graph_template(x_title='Delta X', y_title='Count')
fig.add_trace(go.Histogram(x=delta_x, histnorm='probability'))
fig.add_vline(x=np.max(delta_x), line_width=1, line_dash='dash', opacity=1, line_color='red')
fig.add_vline(x=np.min(delta_x), line_width=1, line_dash='dash', opacity=1, line_color='red')
fig.show()

In [None]:
## Difference in time
delta_t = np.diff(spike_data['behav_t'])
fig = pf.custom_graph_template(x_title='Delta T', y_title='Count')
fig.add_trace(go.Histogram(x=delta_t, histnorm='probability'))
fig.add_vline(x=np.max(delta_t), line_width=1, line_dash='dash', opacity=1, line_color='red')
fig.show()

In [None]:
fig = pf.custom_graph_template(x_title='Velocity (cm/s)', y_title='')
fig.add_trace(go.Histogram(x=velocity, histnorm='probability'))
fig.add_vline(x=7, line_width=1, line_dash='dash', opacity=1, line_color='red')
fig.add_vline(x=14, line_width=1, line_dash='dash', opacity=1, line_color='darkorchid')
fig.show()

In [None]:
fig = pf.custom_graph_template(x_title='Lin Position (rad)', y_title='Velocity (cm/s)')
fig.add_trace(go.Scatter(x=spike_data['lin_position'].values, 
                         y=velocity, 
                         mode='markers', line_color='darkgrey'))
fig.add_vline(x=reward_one_pos, line_width=1, line_dash='dash', opacity=1, line_color='red')
fig.add_vline(x=reward_two_pos, line_width=1, line_dash='dash', opacity=1, line_color='red')
fig.show()

### Calculate spatial bin correlations for one session of one mouse.

In [None]:
## Single mouse correlation of activity across spatial bins
mouse_list = ['mc51']
experiment_list = ['MultiCon_Imaging5']
session_num = '15'
correct_direction = True
data_type = 'YrA'
test = 'spearman'
normalize = False

for experiment in experiment_list:
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')

        for mouse in mouse_list:
            mpath = pjoin(exp_path, mouse)

            avg_corr_dict = {'session': [], 'avg_corr': []}
            rpath = pjoin(mpath, f'{data_type}')
            fig = pf.custom_graph_template(x_title='Linear Position', y_title='Linear Position', height=500, width=550, titles=[f'Day {session_num}'])
            spike_data = xr.open_dataset(pjoin(rpath, f'{mouse}_{data_type}_{session_num}.nc'))[data_type]

            index_start = pc.adjust_behavior_start(x=spike_data['x'].values, y=spike_data['y'].values, t=spike_data['behav_t'].values)
            if index_start.size > 0:
                spike_data = spike_data[:, int(index_start[0]+1):]

            if normalize:
                spike_data = xr.apply_ufunc(
                zscore,
                spike_data.chunk({'frame': -1, 'unit_id': 50}),
                input_core_dims=[['frame']],
                output_core_dims=[['frame']],
                kwargs={'axis': 1},
                dask='parallelized'
                    ).compute()

            reward_one_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_one']].values)
            reward_two_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_two']].values)
            x_cm, y_cm = ctb.convert_to_cm(x=spike_data['x'].values, y=spike_data['y'].values)
            velocity, running = pc.define_running_epochs(x_cm, 
                                                         y_cm, 
                                                         spike_data['behav_t'].values, 
                                                         velocity_thresh=velocity_thresh)
            if correct_direction:
                forward, reverse = ctb.forward_reverse_trials(spike_data, spike_data['trials'])
                correct_trials = np.zeros(spike_data.shape[1])
                for trial in forward:
                    loop_trial = (spike_data['trials'] == trial).astype(int)
                    correct_trials = correct_trials + loop_trial
                correct_trials = correct_trials.values
            else:
                correct_trials = np.ones(spike_data.shape[1])

            running_data = spike_data[:, np.logical_and(running, correct_trials)]
            bins = ctb.calculate_bins(x=spike_data['lin_position'].values, bin_size=bin_size)
            population_activity = np.zeros((len(bins)-1, spike_data['unit_id'].shape[0]))
            for idx, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
                binned_data = running_data.values[:, (running_data['lin_position'] >= start) & (running_data['lin_position'] < end)]
                mean_activity = binned_data.mean(axis=1)
                mean_activity = mean_activity / (np.sum((running_data['lin_position'] >= start) & (running_data['lin_position'] < end)) * 1/30).values
                population_activity[idx, :] = mean_activity
            
            if test == 'pearson':
                pvc = np.corrcoef(population_activity)
                np.fill_diagonal(pvc, val=np.nan)
            elif test == 'spearman':
                cor = spearmanr(population_activity, axis=1)
                pvc = cor.correlation
                np.fill_diagonal(pvc, val=np.nan)
            off_diagonal = np.nanmean(pvc)

            avg_corr_dict['session'].append(spike_data.attrs['session_two'])
            avg_corr_dict['avg_corr'].append(off_diagonal)
            fig.add_trace(go.Heatmap(x=bins, y=bins, z=pvc))
            for pos in [reward_one_pos, reward_two_pos]:
                fig.add_vline(x=pos, line_width=1, line_dash='dash', line_color='red', opacity=1)
                fig.add_hline(y=pos, line_width=1, line_dash='dash', line_color='red', opacity=1)
            
            if test == 'pearson':
                fig.data[0].colorbar.title = "Pearson's r"
            elif test == 'spearman':
                fig.data[0].colorbar.title = "Spearman's r"
            fig.update_layout(coloraxis=dict(colorscale='Viridis'))
fig.show()

### Calculate spatial bin correlations for one mouse across all sessions.

In [None]:
## Correlation of cell activity across spatial bins for each session for a mouse
mouse_list = ['mc44']
experiment_list = ['MultiCon_Imaging5']
correct_direction = True
data_type = 'YrA' 
test = 'spearman'
normalize = False
bin_size = 0.1

for experiment in experiment_list:
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')

        for mouse in mouse_list:
            mpath = pjoin(exp_path, mouse)

            if mouse in control_mice:
                titles = control_list 
            else:
                titles = session_list

            avg_corr_dict = {'session': [], 'avg_corr': []}
            data_path = pjoin(mpath, f'{data_type}')
            fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, height=1000, width=1000, titles=titles)
            for index, session in tqdm(enumerate(os.listdir(data_path))):
                if (mouse == 'mc43') & (index > 11):
                    index += 1
                elif (mouse == 'mc42') & (index > 14):
                    index += 1
                elif (mouse == 'mc44') & (index > 7):
                    index += 1
                elif (mouse == 'mc46') & (index > 9):
                    index += 1
                elif (mouse == 'mc52') & (index > 2):
                    index += 1
                spike_data = xr.open_dataset(pjoin(data_path, session))[data_type] 

                if normalize:
                    spike_data = xr.apply_ufunc(
                    zscore,
                    spike_data.chunk({'frame': -1, 'unit_id': 50}),
                    input_core_dims=[['frame']],
                    output_core_dims=[['frame']],
                    kwargs={'axis': 1},
                    dask='parallelized'
                        ).compute()

                reward_one_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_one']].values)
                reward_two_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_two']].values)
                ## PV correlations of position data
                x_cm, y_cm = ctb.convert_to_cm(x=spike_data['x'].values, y=spike_data['y'].values)
                velocity, running = pc.define_running_epochs(x_cm, 
                                                            y_cm, 
                                                            spike_data['behav_t'].values, 
                                                            velocity_thresh=velocity_thresh)
                if correct_direction:
                    forward, reverse = ctb.forward_reverse_trials(spike_data, spike_data['trials'])
                    correct_trials = np.zeros(spike_data.shape[1])
                    for trial in forward:
                        loop_trial = (spike_data['trials'] == trial).astype(int)
                        correct_trials = correct_trials + loop_trial
                        correct_trials = correct_trials.values
                else:
                    correct_trials = np.ones(spike_data.shape[1])

                running_data = spike_data[:, np.logical_and(running, correct_trials)]
                bins = ctb.calculate_bins(x=spike_data['lin_position'].values, bin_size=bin_size)
                population_activity = np.zeros((len(bins)-1, spike_data['unit_id'].values.shape[0]))
                for idx, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
                    binned_data = running_data.values[:, (running_data['lin_position'] >= start) & (running_data['lin_position'] < end)]
                    mean_activity = binned_data.mean(axis=1)
                    mean_activity = mean_activity / (np.sum((running_data['lin_position'] >= start) & (running_data['lin_position'] < end)) * 1/30).values
                    population_activity[idx, :] = mean_activity

                if test == 'pearson':
                    pvc = np.corrcoef(population_activity)
                    np.fill_diagonal(pvc, val=np.nan)
                elif test == 'spearman':
                    cor = spearmanr(population_activity, axis=1)
                    pvc = cor.correlation
                    np.fill_diagonal(pvc, val=np.nan)
                off_diagonal = np.nanmean(pvc)

                avg_corr_dict['session'].append(spike_data.attrs['session_two'])
                avg_corr_dict['avg_corr'].append(off_diagonal)
                ## Plot figure
                if index < 5:
                    row, col = 1, index + 1
                elif (index >= 5) & (index < 10):
                    row, col = 2, index - 4
                elif (index >= 10) & (index < 15):
                    row, col = 3, index - 9
                elif index >= 15:
                    row, col = 4, index - 14

                fig.add_trace(go.Heatmap(x=bins, y=bins, z=pvc, coloraxis='coloraxis1'), row=row, col=col)
                # for pos in [reward_one_pos, reward_two_pos]:
                #     fig.add_vline(x=pos, line_width=0.5, line_dash='dash', line_color='red', opacity=1, row=row, col=col)
                #     fig.add_hline(y=pos, line_width=0.5, line_dash='dash', line_color='red', opacity=1, row=row, col=col)

fig.update_layout(coloraxis=dict(colorscale='Viridis'))
fig.update_yaxes(title='Lin Position', col=1)
fig.update_xaxes(title='Lin Position', row=4)

if test == 'pearson':
    fig.data[0].colorbar.title = "Pearson's r"
elif test == 'spearman':
    fig.data[0].colorbar.title = "Spearman's r"
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_spatial_bin_correlations_{data_type}_{bin_size}.png'))