In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
from os.path import join as pjoin
from tqdm.notebook import tqdm
import plotly.graph_objects as go
from scipy.stats import pearsonr, spearmanr, zscore
import itertools

sys.path.append('/home/austinbaggetta/csstorage3/CircleTrack/CircleTrackAnalysis')
import circletrack_behavior as ctb
import circletrack_neural as ctn
import place_cells as pc
import plotting_functions as pf

In [None]:
## Settings
dpath = '../../../MultiCon_Imaging/MultiCon_Imaging2/output/'
fig_path = '../../../MultiCon_Imaging/MultiCon_Imaging2/intermediate_figures/'
bin_size = 0.3
velocity_thresh = 7
session_list = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]

In [None]:
## Single mouse correlation of activity across spatial bins
mouse = 'mc26'
session_num = '1'
correct_direction = True
data_type = 'S' 

avg_corr_dict = {'session': [], 'avg_corr': []}
mpath = pjoin(dpath, f'aligned_minian/{mouse}/{data_type}')
fig = pf.custom_graph_template(x_title='Linear Position', y_title='Linear Position', height=500, width=550, titles=['A1'])
spike_data = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session_num}.nc'))[data_type] ## select S matrix
x_cm, y_cm = ctb.convert_to_cm(x=spike_data['x'].values, y=spike_data['y'].values)
velocity, running = pc.define_running_epochs(x_cm, 
                                             y_cm, 
                                             spike_data['behav_t'].values, 
                                             velocity_thresh=velocity_thresh)
if correct_direction:
    forward, reverse = ctb.forward_reverse_trials(spike_data, spike_data['trials'])
    correct_trials = np.zeros(spike_data.shape[1])
    for trial in forward:
        loop_trial = (spike_data['trials'] == trial).astype(int)
        correct_trials = correct_trials + loop_trial
    correct_trials = correct_trials.values
else:
    correct_trials = np.ones(spike_data.shape[1])

running_data = spike_data[:, np.logical_and(running, correct_trials)]
bins = ctb.calculate_bins(x=spike_data['lin_position'].values, bin_size=bin_size)
population_activity = np.zeros((len(bins)-1, spike_data['unit_id'].shape[0]))
for idx, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
    binned_data = running_data.values[:, (running_data['lin_position'] >= start) & (running_data['lin_position'] < end)]
    mean_activity = binned_data.mean(axis=1)
    population_activity[idx, :] = mean_activity
pvc = np.corrcoef(population_activity)
np.fill_diagonal(pvc, val=np.nan)
off_diagonal = np.nanmean(pvc)

avg_corr_dict['session'].append(spike_data.attrs['session_two'])
avg_corr_dict['avg_corr'].append(off_diagonal)
fig.add_trace(go.Heatmap(x=bins, y=bins, z=pvc))
fig.data[0].colorbar.title = "Pearson's r"
fig.update_layout(coloraxis=dict(colorscale='Viridis'))
fig.show()
# fig.write_image(pjoin(fig_path, f'{mouse}_A1_spatialbins.png'))

In [None]:
## Correlation of cell activity across spatial bins for each session for a mouse
mouse = 'mc26'
correct_direction = True
data_type = 'S' 

avg_corr_dict = {'session': [], 'avg_corr': []}
mpath = pjoin(dpath, f'aligned_minian/{mouse}/{data_type}')
fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, height=1000, width=1000, titles=session_list)
for index, session in tqdm(enumerate(os.listdir(mpath))):
    if (mouse == 'mc26') & (index > 10) & (index < 17):
        index += 1
    elif (mouse == 'mc26') & (index == 17):
        index += 2
    elif (mouse == 'mc23') & (index > 14):
        index += 1
    spike_data = xr.open_dataset(pjoin(mpath, session))[data_type] ## select S matrix
    # reward_one_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_one']].values)
    # reward_two_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_two']].values)
    ## PV correlations of position data
    x_cm, y_cm = ctb.convert_to_cm(x=spike_data['x'].values, y=spike_data['y'].values)
    velocity, running = pc.define_running_epochs(x_cm, 
                                                y_cm, 
                                                spike_data['behav_t'].values, 
                                                velocity_thresh=velocity_thresh)
    if correct_direction:
        forward, reverse = ctb.forward_reverse_trials(spike_data, spike_data['trials'])
        correct_trials = np.zeros(spike_data.shape[1])
        for trial in forward:
            loop_trial = (spike_data['trials'] == trial).astype(int)
            correct_trials = correct_trials + loop_trial
            correct_trials = correct_trials.values
    else:
        correct_trials = np.ones(spike_data.shape[1])

    running_data = spike_data[:, np.logical_and(running, correct_trials)]
    bins = ctb.calculate_bins(x=spike_data['lin_position'].values, bin_size=0.4)
    population_activity = np.zeros((len(bins)-1, spike_data['unit_id'].values.shape[0]))
    for idx, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
        binned_data = running_data.values[:, (running_data['lin_position'] >= start) & (running_data['lin_position'] < end)]
        mean_activity = binned_data.mean(axis=1)
        population_activity[idx, :] = mean_activity
    pvc = np.corrcoef(population_activity)
    np.fill_diagonal(pvc, val=np.nan)
    off_diagonal = np.nanmean(pvc)

    avg_corr_dict['session'].append(spike_data.attrs['session_two'])
    avg_corr_dict['avg_corr'].append(off_diagonal)
    ## Plot figure
    if index < 5:
        fig.add_trace(go.Heatmap(x=bins, y=bins, z=pvc, coloraxis='coloraxis1'), row=1, col=index+1)
    elif (index >= 5) & (index < 10):
        fig.add_trace(go.Heatmap(x=bins, y=bins, z=pvc, coloraxis='coloraxis1'), row=2, col=index-4)
    elif (index >= 10) & (index < 15):
        fig.add_trace(go.Heatmap(x=bins, y=bins, z=pvc, coloraxis='coloraxis1'), row=3, col=index-9)
    elif index >= 15:
        fig.add_trace(go.Heatmap(x=bins, y=bins, z=pvc, coloraxis='coloraxis1'), row=4, col=index-14)
fig.update_layout(coloraxis=dict(colorscale='Viridis'))
fig.update_yaxes(title='Lin Position', col=1)
fig.update_xaxes(title='Lin Position', row=4)
fig.show()
# fig.write_image(pjoin(fig_path, f'{mouse}_spatial_bin_correlations.png'))

In [None]:
## Plot the average correlation value across all spatial bins
mouse_list = ['mc23', 'mc26']
corr_dict = {'mouse': [], 'session': [], 'avg_corr': []}
for mouse in mouse_list:
    mpath = pjoin(dpath, f'aligned_minian/{mouse}/S')
    for session in tqdm(os.listdir(mpath)):
        spike_data = xr.open_dataset(pjoin(mpath, session))['S'] ## select S matrix
        # reward_one_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_one']].values)
        # reward_two_pos = np.mean(spike_data['lin_position'][spike_data['lick_port'] == spike_data.attrs['reward_two']].values)
        ## PV correlations of position data
        velocity, running = pc.define_running_epochs(spike_data['x'].values, spike_data['y'].values, spike_data['behav_t'].values, velocity_thresh=velocity_thresh)
        running_data = spike_data[:, running]
        bins = pc.calculate_bins(x=spike_data['lin_position'].values, bin_size=0.4)
        population_activity = np.zeros((len(bins)-1, spike_data['unit_id'].values.shape[0]))
        for idx, (start, end) in enumerate(zip(bins[:-1], bins[1:])):
            binned_data = running_data.values[:, (running_data['lin_position'] >= start) & (running_data['lin_position'] < end)]
            mean_activity = binned_data.mean(axis=1)
            population_activity[idx, :] = mean_activity
        pvc = np.corrcoef(population_activity)
        np.fill_diagonal(pvc, val=np.nan)
        off_diagonal = np.nanmean(pvc)
        corr_dict['mouse'].append(mouse)
        corr_dict['session'].append(spike_data.attrs['session_two'])
        corr_dict['avg_corr'].append(off_diagonal)
corr_df = pd.DataFrame(corr_dict)
avg_corr = corr_df.groupby(['session'], as_index=False).agg({'avg_corr': ['mean', 'sem']})

In [None]:
## Plot data
context_colors = []
A_data = avg_corr.loc[0:4, :]
B_data = avg_corr.loc[5:9, :]
C_data = avg_corr.loc[10:14, :]
D_data = avg_corr.loc[15:, :]
x_data = np.arange(1, 6) ## days
fig = pf.custom_graph_template(x_title='Day', y_title='Pearson Correlation')
fig.add_trace(go.Scatter(x=x_data, y=A_data['avg_corr']['mean'], mode='lines+markers', line_color='darkorchid', name='A', error_y = dict(type='data', array=A_data['avg_corr']['sem'])))
fig.add_trace(go.Scatter(x=x_data, y=B_data['avg_corr']['mean'], mode='lines+markers', line_color='darkgrey', name='B', error_y = dict(type='data', array=B_data['avg_corr']['sem'])))
fig.add_trace(go.Scatter(x=x_data, y=C_data['avg_corr']['mean'], mode='lines+markers', line_color='turquoise', name='C', error_y = dict(type='data', array=C_data['avg_corr']['sem'])))
fig.add_trace(go.Scatter(x=x_data, y=D_data['avg_corr']['mean'], mode='lines+markers', line_color='green', name='D', error_y = dict(type='data', array=D_data['avg_corr']['sem'])))
fig.show()