In [None]:
%load_ext autoreload
%autoreload 2
import os
import sys
import numpy as np
import pandas as pd
import xarray as xr
from os.path import join as pjoin
from tqdm.notebook import tqdm
import plotly.graph_objects as go
from scipy.stats import pearsonr, spearmanr, zscore
import itertools

sys.path.append('/home/austinbaggetta/csstorage3/CircleTrack/CircleTrackAnalysis')
import circletrack_behavior as ctb
import circletrack_neural as ctn
import place_cells as pc
import plotting_functions as pf
import cell_overlap as co

In [None]:
## Settings
project_dir = 'MultiCon_Imaging'
experiment_dir = 'MultiCon_Imaging5'
minian_path = f'../../../{project_dir}/{experiment_dir}/minian_results'
crossreg_path = f'../../../{project_dir}/{experiment_dir}/output/cross_registration_results'
fig_path = f'../../../{project_dir}/{experiment_dir}/intermediate_figures/'
bin_size = 0.1
velocity_thresh = 14
session_list = [f'A{x}' for x in np.arange(1, 6)] + [f'B{x}' for x in np.arange(1, 6)] + [f'C{x}' for x in np.arange(1, 6)] + [f'D{x}' for x in np.arange(1, 6)]
actual_dates = [f'2024_08_{x}' for x in np.arange(24, 32)] + [f'2024_09_{x}' for x in np.arange(1, 10)] + [f'2024_09_{x}' for x in np.arange(10, 13)]
centroid_distance = 5
config = {'scrollZoom': True}
opacity = 0.6
colors = ['viridis', 'thermal', 'temps', 'rainbow', 'turbid', 'tropic', 'twilight', 'turbo']
colors = ['blues', 'solar']

### Use a slider to look at cross-registered cells between sessions.

In [None]:
## Create plot to look at cross-registered cells between any number of sessions
np.random.seed(1)
num_cells = 25 ## either 'all' or some number
mouse = 'mc48'
date_list = ['2024_09_11', '2024_09_12']
fig = pf.custom_graph_template(x_title='Width', y_title='Height', width=600, height=600, titles=[''],
                               shared_x=True, shared_y=True)

mpath = pjoin(minian_path, mouse)
mappings = pd.read_pickle(pjoin(crossreg_path, f'circletrack_data/{mouse}/mappings_{centroid_distance}.pkl'))
mappings.columns = mappings.columns.droplevel(0)
shared_cells = mappings[date_list].dropna().reset_index(drop=True)
A_shifted = xr.open_dataset(pjoin(crossreg_path, f'circletrack_data/{mouse}/A_shifted.nc'))
shiftds = xr.open_dataset(pjoin(crossreg_path, f'circletrack_data/{mouse}/shiftds_{centroid_distance}.nc'))
max_proj = shiftds['temps_shifted'].max(dim='session')

if type(num_cells) == str:
    cells_of_interest = shared_cells.copy()
else:
    neurons = np.random.choice(shared_cells[date_list[0]], num_cells)
    cells_of_interest = pd.DataFrame()
    for cell in neurons:
        ar = shared_cells[shared_cells[date_list[0]] == cell]
        cells_of_interest = pd.concat([cells_of_interest, ar], ignore_index=True)

for idx, session in enumerate(date_list):
    a = A_shifted.sel(session=session)
    cells = cells_of_interest[session]
    sub_a = a.sel(unit_id=cells.to_numpy())

    fig.add_trace(go.Heatmap(z=max_proj, colorscale='gray', showscale=False, visible=False, name=f'{session}'))
    fig.add_trace(go.Heatmap(z=sub_a['A_shifted'].mean(dim='unit_id').values, colorscale='thermal', showscale=False, 
                             opacity=opacity, showlegend=False))

fig.data[0].visible = True
steps = []
for i in np.arange(len(fig.data)):
    if i % 2 == 0:
        step = dict(
            method='update',
            args=[{'visible': [False] * len(fig.data)},
                {'title': 'Switched to: ' + date_list[int(i/2)]}],
        )
        step['args'][0]['visible'][i] = True
        step['args'][0]['visible'][i+1] = True
        steps.append(step)

sliders = [dict(
    active=0,
    steps=steps
)]

fig.update_layout(sliders=sliders)
fig['layout']['sliders'][0]['pad'] = dict(t=50)
fig.show(config=config)
fig.write_html(pjoin(fig_path, f'slider_verification_{mouse}_{date_list[0]}_{date_list[1]}.html'))

In [None]:
## Cell overlap across days
## Distance of x pixels between cell centers
centroid_distance = 5
mouse = 'mc44'
fig = pf.custom_graph_template(x_title='Day', y_title='', width=600, 
                               shared_y=True, titles=[mouse])

mappings = pd.read_pickle(pjoin(crossreg_path, f'circletrack_data/{mouse}/mappings_{centroid_distance}.pkl'))
overlap = co.calculate_overlap(mappings)
overlap = co.dates_to_days(overlap, '2024_08_24', days=20)
matrix = overlap.pivot_table(index='session_id1', columns='session_id2', values='overlap')
fig.add_trace(go.Heatmap(z=matrix.values, x=matrix.index, y=matrix.columns, coloraxis='coloraxis'))
boundaries = [5.5, 10.5, 15.5]
for boundary in boundaries:
        fig.add_vline(x=boundary, line_width=1.5, line_color='red', opacity=1)
        fig.add_hline(y=boundary, line_width=1.5, line_color='red', opacity=1) 
fig.update_yaxes(title='Day', col=1)
fig.update_layout(coloraxis_colorbar={'title': 'Overlap (%)'})
fig.show()
fig.write_image(pjoin(fig_path, f'{mouse}_overlap_heatmap.png'))