In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
from os.path import join as pjoin
from tqdm.notebook import tqdm
import plotly.graph_objects as go
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix
from scipy.stats import zscore

sys.path.append('../')
import circletrack_behavior as ctb
import circletrack_neural as ctn
import place_cells as pc
import plotting_functions as pf

def training_data_perturbation(sdata, test):
    """ 
    Args:
        sdata : xarray.DataArray
            Yra or C or S from Minian output
        test : str
            one of 'Z-Scored', 'Control', 'Roll'
    """
    if test == 'Z-Scored':
        ## Normalize data
        zdata = xr.apply_ufunc(
                zscore,
                sdata.chunk({'frame': -1, 'unit_id': -1}),
                input_core_dims=[['frame']],
                output_core_dims=[['frame']],
                kwargs={'axis': 1},
                dask='parallelized'
        ).compute()
    elif test == 'Control':
        zdata = sdata.copy()
    elif test == 'Roll':
        zdata = sdata.copy()
    return zdata

def testing_data_perturbation(sdata, test):
    """ 
    Args:
        sdata : xarray.DataArray
            Yra or C or S from Minian output
        test : str
            one of 'Z-Scored', 'Control', 'Roll'
    """
    if test == 'Z-Scored':
        ## Normalize data
        zdata = xr.apply_ufunc(
                zscore,
                sdata.chunk({'frame': -1, 'unit_id': 50}),
                input_core_dims=[['frame']],
                output_core_dims=[['frame']],
                kwargs={'axis': 1},
                dask='parallelized'
        ).compute()
    elif test == 'Control':
        zdata = sdata.copy()
    elif test == 'Roll':
        for neuron in np.arange(0, sdata.shape[0]):
            sdata[neuron, :] = np.roll(sdata[neuron, :], shift=np.random.randint(0, sdata.shape[1]))
        zdata = sdata.copy()
    return zdata

In [None]:
## Settings
project_folder = ['MultiCon_Imaging']
experiment_folders = ['MultiCon_Imaging5', 'MultiCon_Imaging6']
dpath = f'../../{project_folder[0]}'
fig_path = f'../../../Manuscripts/MultiCon/intermediate_plots/decoding'
chance_color = '#7d7d7d'
avg_color = '#287347'
subject_color = '#7d7d7d'
ce_colors = ['#7A22BC', '#378616']
ce_colors_dict = {'Two-context': '#378616', 'Multi-context': '#7A22BC'}
symbol_dict = {'Two-context': 'x', 'Multi-context': 'circle'}
symbols_list = ['x', 'circle']
context_colors = {'A': '#00802d', 'B': '#006c79', 'C': '#004da4', 'D': '#430073'}
mouse_colors = ['midnightblue', 'darkred', 'darkorchid', 'darkturquoise']
male_mice = ['mc44', 'mc46', 'mc54', 'mc55']
control_mice = ['mc46', 'mc49', 'mc52', 'mc54', 'mc59', 'mc60']
experimental_mice = ['mc44', 'mc51', 'mc55', 'mc56', 'mc58']
imaging5 = ['mc44', 'mc46', 'mc49', 'mc51', 'mc52']
session_list = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]
control_list = [f'A{x}' for x in np.arange(1, 16)] + [f'B{x}' for x in np.arange(1, 6)]
day_list = [f'Day {x}' for x in np.arange(1, 21)]
bin_size = 0.2 ## in seconds
velocity_thresh = 10
centroid_distance = 4
data_of_interest = 'aligned_minian' ## one of behav, aligned_minian, aligned_place_cells, lin_behav
z_thresh = 1.96
folds = 5

if not os.path.exists(fig_path):
    os.makedirs(fig_path)

xr.set_options(keep_attrs=True)

np.random.seed(24601)

In [None]:
data_type = 'C'
training_sessions_list = [['4', '9'], ['9', '14'], ['14', '19']]
testing_sessions_list = [['5', '10'], ['10', '15'], ['15', '20']]
context_str_list = ['four_five_ab', 'four_five_bc', 'four_five_cd']
correct_dir = True
only_running = True
output = {'mouse': [], 'comparison': [], 'test': [], 'accuracy': []}

for test in ['Control', 'Z-Scored', 'Roll']:
    for mouse in experimental_mice:
        for idx, training_sessions in enumerate(training_sessions_list):
            if mouse == 'mc44':
                pass
            else:
                testing_sessions = testing_sessions_list[idx]
                if mouse in imaging5:
                    mpath = pjoin(dpath, f'{experiment_folders[0]}/output/aligned_minian/{mouse}/{data_type}')
                    crossreg_path = pjoin(dpath, f'{experiment_folders[0]}/output/cross_registration_results')
                    file_str = f'mappings_meta_{centroid_distance}_{context_str_list[idx]}.pkl'
                else:
                    mpath = pjoin(dpath, f'{experiment_folders[1]}/output/aligned_minian/{mouse}/{data_type}')
                    crossreg_path = pjoin(dpath, f'{experiment_folders[1]}/output/cross_registration_results')
                    file_str = f'mappings_meta_{centroid_distance}_{context_str_list[idx]}.pkl'
                mappings = pd.read_pickle(pjoin(crossreg_path, f'circletrack_data/{mouse}/{file_str}'))
                mappings.columns = mappings.columns.droplevel(0)

                date_list = []
                for session in training_sessions + testing_sessions:
                    sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
                    date_list.append(sdata.attrs['date'])
                shared_cells = mappings[date_list].dropna().reset_index(drop=True)

                
                sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{training_sessions[0]}.nc'))[data_type]
                sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
                zdata = training_data_perturbation(sdata, test=test)
                ndata, _ = ctn.subset_correct_dir_and_running(zdata, correct_dir=correct_dir, 
                                                              only_running=only_running, velocity_thresh=velocity_thresh)
                sub_data = ctn.bin_activity(ndata, bin_size_seconds=bin_size, func=np.mean)
                neural_data = sub_data
                first_half = sub_data.shape[1]
                for session in training_sessions[1:]:
                    sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
                    sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
                    zdata = training_data_perturbation(sdata, test=test)
                    ndata, _ = ctn.subset_correct_dir_and_running(zdata, correct_dir=correct_dir, 
                                                                  only_running=only_running, velocity_thresh=velocity_thresh)
                    sub_data = ctn.bin_activity(ndata, bin_size_seconds=bin_size, func=np.mean)
                    neural_data = np.concatenate([neural_data, sub_data], axis=1)
                    second_half = sub_data.shape[1]
                neural_data = neural_data.T

                if ('4' in training_sessions) & ('9' in training_sessions):
                    conditions = np.concatenate([np.repeat(['A'], repeats=first_half), np.repeat(['B'], repeats=second_half)])
                    comparison = 'A or B'
                elif ('9' in training_sessions) & ('14' in training_sessions):
                    conditions = np.concatenate([np.repeat(['B'], repeats=first_half), np.repeat(['C'], repeats=second_half)])
                    comparison = 'B or C'
                elif ('14' in training_sessions) & ('19' in training_sessions):
                    conditions = np.concatenate([np.repeat(['C'], repeats=first_half), np.repeat(['D'], repeats=second_half)])
                    comparison = 'C or D'
                ## Train decoder using training data from day 4 in specified contexts
                clf = RandomForestClassifier(n_estimators=100, max_depth=None, random_state=24601)
                X_train = neural_data
                y_train = conditions
                clf.fit(X_train, y_train)

                ## Get testing sessions (day 5 of specified contexts)
                sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{testing_sessions[0]}.nc'))[data_type]
                sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
                zdata = testing_data_perturbation(sdata, test=test)
                ndata, _ = ctn.subset_correct_dir_and_running(zdata, correct_dir=correct_dir, 
                                                              only_running=only_running, velocity_thresh=velocity_thresh)
                sub_data = ctn.bin_activity(ndata, bin_size_seconds=bin_size, func=np.mean)
                X_test = sub_data
                first_half = sub_data.shape[1]
                for session in testing_sessions[1:]:
                    sdata = xr.open_dataset(pjoin(mpath, f'{mouse}_{data_type}_{session}.nc'))[data_type]
                    sdata = sdata.sel(unit_id=shared_cells[sdata.attrs['date']].values)
                    zdata = testing_data_perturbation(sdata, test=test)
                    ndata, position_data = ctn.subset_correct_dir_and_running(zdata, correct_dir=correct_dir, 
                                                                              only_running=only_running, velocity_thresh=velocity_thresh)
                    sub_data = ctn.bin_activity(ndata, bin_size_seconds=bin_size, func=np.mean)
                    X_test = np.concatenate([X_test, sub_data], axis=1)
                    second_half = sub_data.shape[1]
                X_test = X_test.T
                if ('4' in training_sessions) & ('9' in training_sessions):
                    y_test = np.concatenate([np.repeat(['A'], repeats=first_half), np.repeat(['B'], repeats=second_half)])
                elif ('9' in training_sessions) & ('14' in training_sessions):
                    y_test = np.concatenate([np.repeat(['B'], repeats=first_half), np.repeat(['C'], repeats=second_half)])
                elif ('14' in training_sessions) & ('19' in training_sessions):
                    y_test = np.concatenate([np.repeat(['C'], repeats=first_half), np.repeat(['D'], repeats=second_half)])

                ## Create predictions using trained random forest
                preds = clf.predict(X_test)

                ## Accuracy
                overall_acc = np.mean(preds == y_test)
                chance = 1 / len(np.unique(conditions))

                ## Save results
                output['mouse'].append(mouse)
                output['comparison'].append(comparison)
                output['test'].append(test)
                output['accuracy'].append(overall_acc)
acc_df = pd.DataFrame(output)
avg_acc = acc_df.groupby(['comparison', 'test'], as_index=False).agg({'accuracy': ['mean', 'sem']})

In [None]:
## Plot accuracies for all mice using non-normalized cell data
add_mice = True
normal = avg_acc[avg_acc['test'] == 'Control']
fig = pf.custom_graph_template(x_title='', y_title='Decoding Accuracy')
fig.add_trace(go.Scatter(x=normal['comparison'], y=normal['accuracy']['mean'], mode='markers', marker_color='darkgrey',
                         error_y=dict(type='data', array=normal['accuracy']['sem']), showlegend=False))
if add_mice:
    for mouse in experimental_mice:
        mdata = acc_df[(acc_df['mouse'] == mouse) & (acc_df['test'] == 'Control')].reset_index(drop=True)
        if mouse in male_mice:
            color = 'midnightblue'
        else:
            color = 'darkorchid'
        fig.add_trace(go.Scatter(x=mdata['comparison'], y=mdata['accuracy'], mode='lines', line_color=color, 
                                 line_width=1, opacity=0.7, name=mouse, showlegend=False))
fig.add_hline(y=chance, line_width=1, line_color=chance_color, line_dash='dash', opacity=1)
fig.update_yaxes(range=[0, 1.05])
fig.show()
# fig.write_image(pjoin(fig_path, 'decision_tree_aorb_borc_cord.png'))

In [None]:
## Plot accuracies for all mice with the different tests
add_mice = True
fig = pf.custom_graph_template(x_title='', y_title='', titles=['A or B', 'B or C', 'C or D'], 
                               rows=1, columns=3, shared_y=True, width=1200)
for idx, cmp in enumerate(['A or B', 'B or C', 'C or D']):
    data = avg_acc[avg_acc['comparison'] == cmp]
    fig.add_trace(go.Scatter(x=data['test'], y=data['accuracy']['mean'], mode='markers', marker_color=avg_color,
                             error_y=dict(type='data', array=data['accuracy']['sem']), showlegend=False), row=1, col=idx+1)
if add_mice:
    for mouse in experimental_mice:
        for idx, cmp in enumerate(['A or B', 'B or C', 'C or D']):
            mdata = acc_df[(acc_df['mouse'] == mouse) & (acc_df['comparison'] == cmp)].reset_index(drop=True)
            mdata = mdata.reindex(index=[0, 2, 1])
            fig.add_trace(go.Scatter(x=mdata['test'], y=mdata['accuracy'], mode='lines', line_color=avg_color, 
                                    line_width=1, opacity=0.7, name=mouse, showlegend=False), row=1, col=idx+1)
fig.update_yaxes(title='Decoding Accuracy', col=1)
fig.update_yaxes(range=[0, 1.05])
fig.add_hline(y=chance, line_width=1, line_color=chance_color, line_dash='dash', opacity=1)
fig.show()