In [1]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import numpy as np
import itertools
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from tqdm import tqdm
from natsort import natsort_keygen
import os
import datetime
import sys
from scipy.stats import zscore

In [2]:
sys.path.append('/media/caishuman/csstorage3/Austin/CircleTrack/CircleTrackAnalysis')
import cell_overlap as co
import plotting_functions as pf
import circletrack_neural as ctn

In [3]:
## Set parameters
## Mouse list
mouse_list = ['mc_EEG1_01', 'mc_EEG1_02']
## Load mappings of a specific param_distance from cross_registration results 
param_distance = 5
## Specify session type; one of 'pre', 'circletrack', 'post'
session_type = 'circletrack'
## Set key file
key_file = '{}_data_keys.yml'.format(session_type)
## Set path
path = '/media/caishuman/csstorage3/Austin/CircleTrack/MultiCon_AfterHours/MultiCon_EEG1/circletrack_data'
## Create session list
session_list_one = ['A1', 'A2', 'A3', 'A4', 'A5', 'B1', 'B2', 'B3', 'B4', 'B5', 'C1', 'C2', 'C3', 'C4', 'C5', 'D1', 'D2', 'D3', 'D4', 'D5']
session_list_two = ['A1', 'A2', 'A3', 'A4', 'A5', 'B1', 'B2', 'B3', 'B4', 'B5', 'C1', 'C2', 'C3', 'C4', 'C5', 'A2_1', 'A2_2', 'A2_3', 'A2_4', 'A2_5']

In [4]:
## Plot cells across sessions
## Loop through mouse list
mouse_neural = {}
for mouse in mouse_list:
    sessions = ctn.import_mouse_neural_data(path, mouse, key_file = key_file, session = '20min', neural_type = 'spikes')
    cell_dict = {}
    for session in sessions:
        cell_dict[session] = len(sessions[session].unit_id)
    mouse_neural[mouse] = cell_dict
fig = pf.custom_graph_template(title = 'Cells per Session', x_title = 'Session', y_title = 'Number of Cells')
for mouse in mouse_neural:
    plot_data = pd.DataFrame(mouse_neural[mouse], index = [0])
    if mouse == 'mc03':
        plot_data.insert(0, 'Reversal2', np.nan)
        plot_data.insert(1, 'Reversal3', np.nan)
    plot_data = plot_data[session_list]
    fig.add_trace(go.Scatter(x = session_list, y = plot_data.iloc[0, :], mode = 'lines', line_color = 'grey', line_width = 0.5, opacity = 0.5, name = mouse, showlegend = False))
fig.show()

1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
  xr.open_zarr(pjoin(dpath, d))
1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.
  xr.open_zarr(pjoin(dpath, d))
1. Consolidating metadata in this existing store with zarr.consolidate_metadata().
2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or
3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.


KeyError: "not all values found in index 'frame'"

In [None]:
## Plot pairwise pre_session heatmaps
## Create subplots
fig = make_subplots(rows = 1, cols = 2, subplot_titles = (mouse_list[0], mouse_list[1]))
## Loop through each mouse and add plots to fig
for mouse in tqdm(mouse_list):
    mappings = pd.read_pickle('../../cross_registration_results/{}_data/{}/mappings_{}.pkl'.format(session_type, mouse, param_distance))
    overlap = co.calculate_overlap(mappings)
    if mouse == mouse_list[0]:
        overlap = co.dates_to_days(overlap, '2022_10_19', days = 20)
        #between_context_one = co.between_context_overlap(mouse = mouse_list[0], overlap = overlap, session_ids = ['A5', 'B5', 'C5', 'D5', 'R_A5'], session_type = 'pre')
        ## Create heatmap matrix
        matrix = overlap.pivot_table(index = 'session_id1', columns = 'session_id2', values = 'overlap')
        matrix = matrix.sort_values(by = 'session_id1')
        matrix = matrix.sort_values(by = 'session_id2', axis = 1)
        fig.add_trace(go.Heatmap(z = matrix.values, x = matrix.index, y = matrix.columns, coloraxis = 'coloraxis'), row = 1, col = 1)
    elif mouse == mouse_list[1]:
        overlap = co.dates_to_days(overlap, '2022_10_21', days = 20)
        #between_context_two = co.between_context_overlap(mouse = mouse_list[1], overlap = overlap, session_ids = ['A5', 'B5', 'C5', 'D5', 'R_A5'], session_type = 'pre')
        ## Create heatmap matrix
        matrix = overlap.pivot_table(index = 'session_id1', columns = 'session_id2', values = 'overlap')
        matrix = matrix.sort_values(by = 'session_id1')
        matrix = matrix.sort_values(by = 'session_id2', axis = 1)
        fig.add_trace(go.Heatmap(z = matrix.values, x = matrix.index, y = matrix.columns, coloraxis = 'coloraxis'), row = 1, col = 2)
## Figure alterations
fig.update_layout(template = 'simple_white', width = 1600, height = 800, coloraxis = {'colorscale': 'Viridis'})
fig.update_xaxes(title_text = 'Day', row = 1, col = 1)
fig.update_xaxes(title_text = 'Day', row = 1, col = 2)
fig.update_yaxes(title_text = 'Day', row = 1, col = 1)
fig.update_yaxes(title_text = 'Day', row = 1, col = 2)
fig.update_layout(title = {'text': 'Cell Overlap Between Days', 'xanchor': 'center', 'y': 0.95, 'x': 0.5})
fig.update_layout(coloraxis_colorbar = {'title': 'Percent Overlap'})
## Set boundaries for lines
boundaries = [5, 10, 15]
for boundary in boundaries:
        fig.add_vline(x=boundary+0.5, line_width=2, line_color='red', opacity=1)
        fig.add_hline(y=boundary+0.5, line_width=2, line_color='red', opacity=1) 
fig.update_xaxes(dtick=1)
fig.update_yaxes(dtick=1)
fig.show()

In [None]:
## Load cell activity for each session
## Set path and mouse
path = '/media/caishuman/csstorage3/Austin/CircleTrack/MultiCon_AfterHours/MultiCon_EEG1/{}_data'.format(session_type)
neural_dictionary = {}
for mouse in tqdm(mouse_list):
    neural_dictionary[mouse] = ctn.import_mouse_neural_data(path, mouse, key_file = key_file, session = '20min', plot_frame_usage = False)

In [None]:
## Calculate spearman correlations between the average activity in each session
param_distance = 5
activity_dictionary = {}
for mouse in tqdm(mouse_list):
    activity_summary = []
    neural_data = neural_dictionary[mouse]
    mappings = pd.read_pickle('../../cross_registration_results/{}_data/{}/mappings_{}.pkl'.format(session_type, mouse, param_distance))
    for d1,d2 in itertools.combinations_with_replacement(neural_data, r = 2):
        session_one = neural_data[d1]
        session_two = neural_data[d2]
        if d1 == d2:
            cell_ids = mappings.session[[neural_data[d1].session.values.tolist(), neural_data[d2].session.values.tolist()]].dropna(how = 'any').drop_duplicates().reset_index(drop = True)
        else:
            cell_ids = mappings.session[[neural_data[d1].session.values.tolist(), neural_data[d2].session.values.tolist()]].dropna(how = 'any').reset_index(drop = True)
        ## Select unit_ids based on cell_ids
        first_session = session_one.sel(unit_id = np.array(cell_ids.iloc[:, 0]))
        second_session = session_two.sel(unit_id = np.array(cell_ids.iloc[:, 1]))
        ## Calculate correlation activity
        res = ctn.calculate_activity_correlation(first_session, second_session, test = 'spearman')
        ## Add to list
        tmp = [d1, d2, res[0], res[1]]
        tmp2 = [d2, d1, res[0], res[1]]
        activity_summary.append(tmp)
        activity_summary.append(tmp2)
    activity_summary = pd.DataFrame(activity_summary, columns = ['session_id1', 'session_id2', 'statistic', 'pvalue'])
    activity_dictionary[mouse] = activity_summary

In [None]:
## Plot average population vector correlations between sessions
fig = make_subplots(rows = 1, cols = 2, subplot_titles = (mouse_list[0], mouse_list[1]))
for mouse in tqdm(mouse_list):
    data = activity_dictionary[mouse]
    matrix = data.pivot_table(index = 'session_id1', columns = 'session_id2', values = 'statistic')
    matrix = matrix.sort_values(by = 'session_id1')
    matrix = matrix.sort_values(by = 'session_id2', axis = 1)
    if mouse == mouse_list[0]:
        fig.add_trace(go.Heatmap(z = matrix.values, x = matrix.index, y = matrix.columns, coloraxis = 'coloraxis'), row = 1, col = 1)
    elif mouse == mouse_list[1]:
        fig.add_trace(go.Heatmap(z = matrix.values, x = matrix.index, y = matrix.columns, coloraxis = 'coloraxis'), row = 1, col = 2)
fig.update_layout(template = 'simple_white', width = 1600, height = 800, coloraxis = {'colorscale': 'Viridis'})
fig.update_xaxes(title_text = 'Day', row = 1, col = 1)
fig.update_xaxes(title_text = 'Day', row = 1, col = 2)
fig.update_yaxes(title_text = 'Day', row = 1, col = 1)
fig.update_yaxes(title_text = 'Day', row = 1, col = 2)
fig.update_layout(title = {'text': 'PVC Between Sessions', 'xanchor': 'center', 'y': 0.95, 'x': 0.5})
fig.update_layout(coloraxis_colorbar = {'title': 'r value'})
fig.show()   

In [None]:
## Play around with cell dynamics:
## Interested in seeing the average number of spikes per cell distribution
neural_data = neural_dictionary['mc_EEG1_01']['A2']
neural_data

In [None]:
## Plot individual cell activity across the session
time_vec = np.arange(0, neural_data.shape[1])
fig = go.Figure()
fig.add_trace(go.Scatter(x = time_vec, y = neural_data.values[17]))
fig.update_layout(template = 'simple_white')

In [None]:
## zscore data
zdata = zscore(neural_data.values, axis = 1)
fig = go.Figure()
fig.add_trace(go.Scatter(x = time_vec, y = zdata[4]))
fig.update_layout(template = 'simple_white')

In [None]:
zdata[1]

In [None]:
z_thresh = 2
test = zdata > z_thresh
ans = zdata[0][test[0]]
len(ans)

In [None]:
## Loop through every cell and get the number of spikes across the session for each cell
z_thresh = 2
boolean = zdata > z_thresh ## binarize for spike vs no spike
spikes = []
for i in np.arange(0, zdata.shape[0]):
    ans = zdata[i][test[i]]
    spikes.append(len(ans))

In [None]:
fig = go.Figure()
fig.add_trace(go.Histogram(x = spikes))
fig.update_layout(template = 'simple_white', width = 500, height = 500)
fig.show()