In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
from os.path import join as pjoin
from tqdm.notebook import tqdm
import plotly.graph_objects as go
from scipy.stats import pearsonr, spearmanr, zscore, skew 
from natsort import natsorted
import statsmodels.api as sm
from statsmodels.formula.api import ols
from numpy.random import RandomState, SeedSequence, MT19937

sys.path.append('../')
import circletrack_behavior as ctb
import circletrack_neural as ctn
import place_cells as pc
import plotting_functions as pf

In [None]:
## Settings
project_folder = ['MultiCon_Imaging']
experiment_folders = ['MultiCon_Imaging5', 'MultiCon_Imaging6', 'MultiCon_Imaging7']
dpath = f'../../{project_folder[0]}'
fig_path = f'../../../Manuscripts/MultiCon/intermediate_plots/firing_rates'
chance_color = '#7d7d7d'
avg_color = '#287347'
subject_color = '#7d7d7d'
ce_colors = ['#7A22BC', '#378616']
ce_colors_dict = {'Two-context': '#378616', 'Multi-context': '#7A22BC'}
symbol_dict = {'Two-context': 'x', 'Multi-context': 'circle'}
symbols_list = ['x', 'circle']
context_colors = {'A': '#00802d', 'B': '#006c79', 'C': '#004da4', 'D': '#430073'}
mouse_colors = ['midnightblue', 'darkred', 'darkorchid', 'darkturquoise']
male_mice = ['mc44', 'mc46', 'mc54', 'mc55', 'mc64', 'mc65']
control_mice = ['mc46', 'mc49', 'mc52', 'mc54', 'mc59', 'mc60', 'mc61', 'mc64']
session_list = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]
control_list = [f'A{x}' for x in np.arange(1, 16)] + [f'B{x}' for x in np.arange(1, 6)]
day_list = [f'Day {x}' for x in np.arange(1, 21)]
bin_size = 1 ## in seconds
velocity_thresh = 10
centroid_distance = 4
data_of_interest = 'aligned_place_cells' ## one of behav, aligned_minian, aligned_place_cells, lin_behav
data_type = 'S'

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

### Plot histograms of firing rates for each day for a mouse.

In [None]:
experiment = 'MultiCon_Imaging6'
mouse = 'mc56'
fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, titles=day_list, height=1000, width=1000,
                               shared_y=True, shared_x=True)

exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
mpath = pjoin(exp_path, f'{mouse}/{data_type}')
for idx, session in enumerate(natsorted(os.listdir(mpath))):
    if '21' in session:
        pass 
    else:
        idx = ctn.mouse_indices(mouse, idx)
        if mouse == 'mc56':
            idx += 1

        if idx < 5:
            row, col = 1, idx + 1
        elif (idx >= 5) & (idx < 10):
            row, col = 2, idx - 4
        elif (idx >= 10) & (idx < 15):
            row, col = 3, idx - 9
        else:
            row, col = 4, idx - 14
        
        sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
        act_bin = ctn.bin_activity(sdata.values, bin_size_seconds=bin_size, func=np.mean)
        avg_act = np.mean(act_bin, axis=1) / bin_size ## convert to Hz for any bin size
        fig.add_trace(go.Histogram(x=avg_act, histnorm='probability', marker_color='darkgrey', showlegend=False,
                                    marker_line_color='black', marker_line_width=1), row=row, col=col)
fig.update_xaxes(title='Events (Hz)', row=4)
fig.update_yaxes(title='Probability', col=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_eventrate_histograms.png'), width=1000, height=1000)

In [None]:
## Firing rate distribution for a single mouse and session
act_bin = ctn.bin_activity(sdata.values, bin_size_seconds=1, func=np.mean)
avg_act = np.mean(act_bin, axis=1) / bin_size
fig = pf.custom_graph_template(x_title='Event Rate (events/s)', y_title='Probability', titles=[f'{sdata.attrs['animal']} : {sdata.attrs['session_two']}'])
fig.add_trace(go.Histogram(x=avg_act, histnorm='probability', marker_color='darkgrey', marker_line_color='black',
                           marker_line_width=2))
fig.show()
fig.write_image(pjoin(fig_path, f'{sdata.attrs['animal']}_{sdata.attrs['session_two']}_example_firing_rates.png'), width=500, height=500)

### Calculate average firing rate and skew of the firing rate distribution for every mouse and every session.

In [None]:
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'act_mean': [], 'act_skew': []} 
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = ctn.set_sex(mouse, male_mice)
            group = ctn.set_group(mouse, control_mice)
            for idx, session in enumerate(natsorted(os.listdir(mpath))):
                idx = ctn.mouse_indices(mouse, idx)
                
                S = xr.open_dataset(pjoin(mpath, session))[data_type]
                sdata = ctn.qc_matrix(S, threshold=True)
                act_bin = ctn.bin_activity(sdata.values, bin_size_seconds=1, func=np.mean)
                avg_act = np.mean(act_bin, axis=1) / bin_size ## convert to Hz for any bin size

                cell_dict['mouse'].append(mouse)
                cell_dict['group'].append(group)
                cell_dict['sex'].append(sex)
                cell_dict['session'].append(sdata.attrs['session_two'])
                cell_dict['day'].append(idx + 1)
                cell_dict['act_mean'].append(np.mean(avg_act))
                cell_dict['act_skew'].append(skew(avg_act))
fr_df = pd.DataFrame(cell_dict)                

In [None]:
## Plot line graph of average firing rates across days within each context for Multi-context mice
yval = 'act_mean'
avg = fr_df.groupby(['group', 'session'], as_index=False).agg({'act_mean': ['mean', 'sem'], 'act_skew': ['mean', 'sem']})
mc_avg = avg[avg['group'] == 'Multi-context']
context_day = [1, 2, 3, 4, 5]

fig = pf.custom_graph_template(x_title='Day', y_title='', titles=['Multi-context'])
for session in ['A', 'B', 'C', 'D']:
    sess_bool = [session in mc_avg['session'].values[x] for x in np.arange(0, mc_avg['session'].values.shape[0])]
    plot_data = mc_avg[sess_bool]

    fig.add_trace(go.Scattergl(x=context_day, y=plot_data[yval]['mean'], mode='lines+markers', line_color=context_colors[session],
                               error_y=dict(type='data', array=plot_data[yval]['sem']), name=session))
if yval == 'act_mean':
    fig.update_yaxes(title='Event Rate (Hz)', range=[0, 0.01])
elif yval == 'act_skew':
    fig.update_yaxes(title='Event Rate Skew', range=[0, 5])
fig.show()
fig.write_image(pjoin(fig_path, f'{yval}_multi_context_context_days.png'), width=500, height=500)

In [None]:
## Plot line graph of average firing rates across days within each context for Multi-context mice split between sexes
yval = 'act_mean'
mc_fr = fr_df[fr_df['group'] == 'Multi-context']
avg_sex = mc_fr.groupby(['group', 'sex', 'session'], as_index=False).agg({'act_mean': ['mean', 'sem'], 'act_skew': ['mean', 'sem']})
context_day = [1, 2, 3, 4, 5]

fig = pf.custom_graph_template(x_title='Day', y_title='', rows=1, columns=2, titles=['Male', 'Female'], 
                               width=1000, shared_y=True, shared_x=True)
for idx, sex in enumerate(['Male', 'Female']):
    for session in ['A', 'B', 'C', 'D']:
        sess_bool = [session in avg_sex['session'].values[x] for x in np.arange(0, avg_sex['session'].values.shape[0])]
        plot_data = avg_sex[sess_bool]
        plot_data = plot_data[plot_data['sex'] == sex]

        fig.add_trace(go.Scattergl(x=context_day, y=plot_data[yval]['mean'], mode='lines+markers', line_color=context_colors[session],
                                error_y=dict(type='data', array=plot_data[yval]['sem']), name=session,
                                legendgroup=session, showlegend=False), row=1, col=idx+1)
if yval == 'act_mean':
    fig.update_yaxes(title='Event Rate (Hz)', range=[0, 0.01], col=1)
elif yval == 'act_skew':
    fig.update_yaxes(title='Event Rate Skew', range=[0, 6], col=1)
for val in [0, 1, 2, 3]:
    fig.data[val]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, f'{yval}_multi_context_context_days_malefemale.png'), width=1000, height=500)

### Look at two-context mice.

In [None]:
## Plot firing rate for two-context mice across both contexts
yval = 'act_mean'

tc_fr = fr_df[fr_df['group'] == 'Two-context']
avg = tc_fr.groupby(['group', 'session', 'day'], as_index=False).agg({'act_mean': ['mean', 'sem'], 'act_skew': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', titles=['Two-context'], width=500)

for session in ['A', 'B']:
    sess_bool = [session in avg['session'].values[x] for x in np.arange(0, avg['session'].values.shape[0])]
    plot_data = avg[sess_bool]
    if session == 'A':
        plot_data = plot_data.loc[[0, 7, 8, 9, 10, 11, 12, 13, 14, 1, 2, 3, 4, 5, 6], :] ## reorder index to be correct
    if session == 'B':
        plot_data.loc[:, 'day'] = [1, 2, 3, 4, 5]

    fig.add_trace(go.Scattergl(x=plot_data['day'], y=plot_data[yval]['mean'], mode='lines+markers', line_color=context_colors[session],
                               error_y=dict(type='data', array=plot_data[yval]['sem']), name=session))
if yval == 'act_mean':
    fig.update_yaxes(title='Event Rate (Hz)', range=[0, 0.01])
elif yval == 'act_skew':
    fig.update_yaxes(title='Event Rate Skew', range=[0, 5])
fig.show()
fig.write_image(pjoin(fig_path, f'{yval}_two_context_days.png'), width=500, height=500)

In [None]:
## Separate by male and female for two-context mice
yval = 'act_mean'

tc_fr = fr_df[fr_df['group'] == 'Two-context']
avg_sex = tc_fr.groupby(['group', 'sex', 'session', 'day'], as_index=False).agg({'act_mean': ['mean', 'sem'], 'act_skew': ['mean', 'sem']})

fig = pf.custom_graph_template(x_title='Day', y_title='', titles=['Male', 'Female'], width=1000, rows=1, columns=2,
                               shared_y=True, shared_x=True)

for idx, sex in enumerate(['Male', 'Female']):
    for session in ['A', 'B']:
        sess_bool = [session in avg_sex['session'].values[x] for x in np.arange(0, avg_sex['session'].values.shape[0])]
        plot_data = avg_sex[sess_bool]
        plot_data = plot_data[plot_data['sex'] == sex].reset_index(drop=True)
        if session == 'A':
            plot_data = plot_data.loc[[0, 7, 8, 9, 10, 11, 12, 13, 14, 1, 2, 3, 4, 5, 6], :] ## reorder index to be correct
        if session == 'B':
            plot_data.loc[:, 'day'] = [1, 2, 3, 4, 5]

        fig.add_trace(go.Scattergl(x=plot_data['day'], y=plot_data[yval]['mean'], mode='lines+markers', line_color=context_colors[session],
                                error_y=dict(type='data', array=plot_data[yval]['sem']), name=session, showlegend=False), row=1, col=idx + 1)
if yval == 'act_mean':
    fig.update_yaxes(title='Event Rate (Hz)', range=[0, 0.01])
elif yval == 'act_skew':
    fig.update_yaxes(title='Event Rate Skew', range=[0, 5])
for val in [0, 1]:
    fig.data[val]['showlegend'] = True
fig.show()
fig.write_image(pjoin(fig_path, f'{yval}_two_context_context_days_malefemale.png'), width=1000, height=500)

### Calculate firing rates for place cells vs non place cells.

In [None]:
cell_dict = {'mouse': [], 'group': [], 'sex': [], 'session': [], 'day': [], 'act_mean_place': [], 'act_skew_place': [], 'act_mean_nonplace': [], 'act_skew_nonplace': []} 
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = ctn.set_sex(mouse, male_mice)
            group = ctn.set_group(mouse, control_mice)
            for idx, session in enumerate(natsorted(os.listdir(mpath))):
                if '21' in session:
                    pass 
                else:
                    idx = ctn.mouse_indices(mouse, idx)
                    
                    sdata = xr.open_dataset(pjoin(mpath, session))[data_type]
                    act_bin = ctn.bin_activity(sdata.values, bin_size_seconds=1, func=np.mean)
                    avg_act = np.mean(act_bin, axis=1) / bin_size
                    pcs = avg_act[sdata['skaggs_place'].values]
                    npcs = avg_act[~sdata['skaggs_place'].values]

                    cell_dict['mouse'].append(mouse)
                    cell_dict['group'].append(group)
                    cell_dict['sex'].append(sex)
                    cell_dict['session'].append(sdata.attrs['session_two'])
                    cell_dict['day'].append(idx + 1)
                    cell_dict['act_mean_place'].append(np.mean(pcs))
                    cell_dict['act_skew_place'].append(skew(pcs))
                    cell_dict['act_mean_nonplace'].append(np.mean(npcs))
                    cell_dict['act_skew_nonplace'].append(skew(npcs))
fr_df = pd.DataFrame(cell_dict)

In [None]:
group = 'Multi-context'
avg = fr_df.groupby(['group', 'day'], as_index=False).agg({'act_mean_place': ['mean', 'sem'], 'act_mean_nonplace': ['mean', 'sem']})
data = avg[avg['group'] == group]
x = ['P', 'NP']

fig = pf.custom_graph_template(x_title='', y_title='', rows=4, columns=5, shared_x=True, shared_y=True,
                               width=1000, height=1000, titles=day_list)
for day in data['day']:
    idx = day - 1
    if idx < 5:
        row, col = 1, idx + 1
    elif (idx >= 5) & (idx < 10):
        row, col = 2, idx - 4
    elif (idx >= 10) & (idx < 15):
        row, col = 3, idx - 9
    else:
        row, col = 4, idx - 14
    day_data = data[data['day'] == day]

    fig.add_trace(go.Bar(x=x, y=[day_data['act_mean_place']['mean'].values[0], day_data['act_mean_nonplace']['mean'].values[0]], marker_color=['darkgrey', 'midnightblue'], 
                         showlegend=False, marker_line_color='black', marker_line_width=1,
                         error_y=dict(type='data', array=[day_data['act_mean_place']['sem'].values[0], day_data['act_mean_nonplace']['sem'].values[0]])), row=row, col=col)
fig.update_yaxes(title='Event Rate (Hz)', col=1)
fig.show()
fig.write_image(pjoin(fig_path, f'{group}_place_nonplace_frs.png'), width=1000, height=1000)

### Overlap vs non-overlap for days 15 and 16.

In [None]:
days = [15, 16]
output_dict = {'mouse': [], 'group': [], 'sex': [], 'sess_one_uid': [], 'sess_two_uid': [], 'place_in_both': []}
for experiment in os.listdir(dpath):
    if experiment not in experiment_folders:
        pass 
    else:
        exp_path = pjoin(dpath, f'{experiment}/output/{data_of_interest}/')
        for mouse in tqdm(os.listdir(exp_path)):
            mpath = pjoin(exp_path, f'{mouse}/{data_type}')
            sex = set_sex(mouse, male_mice)
            group = set_group(mouse, control_mice)
            
            if mouse in imaging5:
                crossreg_path = pjoin(dpath, f'{experiment_folders[0]}/output/cross_registration_results')
                act_path = pjoin(dpath, f'{experiment_folders[0]}/output/aligned_place_cells/{mouse}/{data_type}')
                behav_path = pjoin(dpath, f'{experiment_folders[0]}/output/behav/{mouse}')
            else:
                crossreg_path = pjoin(dpath, f'{experiment_folders[1]}/output/cross_registration_results')
                act_path = pjoin(dpath, f'{experiment_folders[1]}/output/aligned_place_cells/{mouse}/{data_type}')
                behav_path = pjoin(dpath, f'{experiment_folders[1]}/output/behav/{mouse}')

            file_str = f'mappings_{centroid_distance}_None.pkl' ## cross registration file with all days
                
            mappings = pd.read_pickle(pjoin(crossreg_path, f'circletrack_data/{mouse}/{file_str}'))
            if mouse not in imaging5:
                mappings = mappings.drop('2025_02_07', axis=1, level=1).reset_index(drop=True)
            mappings.columns = mappings.columns.droplevel(0)

            act_one = xr.open_dataset(pjoin(act_path, f'{mouse}_{data_type}_{days[0]}.nc'))[data_type]
            act_two = xr.open_dataset(pjoin(act_path, f'{mouse}_{data_type}_{days[1]}.nc'))[data_type]

            ## Subset by selecting the shared neurons
            shared_cells = mappings[[act_one.attrs['date'], act_two.attrs['date']]].dropna().reset_index(drop=True)
            act_one = act_one.sel(unit_id=shared_cells.iloc[:, 0].values)
            act_two = act_two.sel(unit_id=shared_cells.iloc[:, 1].values)

            for idx, uid in enumerate(act_one['unit_id'].values):
                if act_one['skaggs_place'][act_one['unit_id'] == uid].values[0] & act_two['skaggs_place'][idx].values:
                    both_place_cells = True 
                else:
                    both_place_cells = False
                output_dict['mouse'].append(mouse)
                output_dict['group'].append(group)
                output_dict['sex'].append(sex)
                output_dict['sess_one_uid'].append(uid)
                output_dict['sess_two_uid'].append(act_two['unit_id'][idx].values)
                output_dict['place_in_both'].append(both_place_cells)
place_both_df = pd.DataFrame(output_dict)