In [None]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from pathlib import Path
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pynapple as nap
from spatial_manifolds.toroidal import *
from spatial_manifolds.behaviour_plots import *
from matplotlib.colors import TwoSlopeNorm
from scipy.spatial import distance
from spatial_manifolds.circular_decoder import circular_decoder, cross_validate_decoder, cross_validate_decoder_time, circular_nanmean

from spatial_manifolds.data.curation import curate_clusters
from scipy.stats import zscore
from spatial_manifolds.util import gaussian_filter_nan
from spatial_manifolds.predictive_grid import compute_travel_projected, wrap_list
from spatial_manifolds.behaviour_plots import *
from spatial_manifolds.behaviour_plots import trial_cat_priority

import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from IPython.display import HTML

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2
%matplotlib ipympl

In [None]:
mouse = 25
day =  25

session = 'OF1'
of1_folder = f'/Users/harryclark/Downloads/COHORT12_nwb/M{mouse}/D{day:02}/{session}/'
grid_path = of1_folder + "tuning_scores/grid_score.parquet"
shifted_grid_path = of1_folder + "tuning_scores/shifted_grid_score.parquet"
spatial_path = of1_folder + "tuning_scores/shifted_spatial_information.parquet"
spikes_path = of1_folder + f"sub-{mouse}_day-{day:02}_ses-{session}_srt-kilosort4_clusters.npz"
beh_path = of1_folder + f"sub-{mouse}_day-{day:02}_ses-{session}_beh.nwb"
active_projects_path = Path("/Volumes/cmvm/sbms/groups/CDBS_SIDB_storage/NolanLab/ActiveProjects/")
anatomy_path = active_projects_path / "Chris/Cohort12/derivatives/labels/anatomy/cluster_annotations.csv"
cluster_locations = pd.read_csv(anatomy_path)
beh_OF = nap.load_file(beh_path)
clusters_OF = nap.load_file(spikes_path)
shifted_grid_scores_of1 = pd.read_parquet(shifted_grid_path)
spatial_information_score_of1 = pd.read_parquet(spatial_path)

shifted_grid_scores_of1 = shifted_grid_scores_of1.query('travel >= 0')
spatial_information_score_of1 = spatial_information_score_of1.query('travel >= 0')

In [None]:
cluster_ids_values = shifted_grid_scores_of1.query('travel == 0').cluster_id

non_grid_cells = pd.DataFrame()
grid_cells = pd.DataFrame()
non_spatial_cells = pd.DataFrame()
cells = pd.DataFrame()

for index in cluster_ids_values:

    cluster_spatial_information_of1 = spatial_information_score_of1[spatial_information_score_of1.cluster_id==index]
    cluster_shifted_grid_scores_of1 = shifted_grid_scores_of1[shifted_grid_scores_of1.cluster_id==index]

    percentile99_grid_score_of1 = np.nanpercentile(cluster_shifted_grid_scores_of1.null_grid_score.iloc[0], 95)
    percentile99_spatial_information_of1 = np.nanpercentile(cluster_spatial_information_of1.null_spatial_information.iloc[0], 95)

    field_spacing = cluster_shifted_grid_scores_of1.field_spacing.values[np.nanargmax(cluster_shifted_grid_scores_of1.grid_score)]
    orientation = cluster_shifted_grid_scores_of1.orientation.values[np.nanargmax(cluster_shifted_grid_scores_of1.grid_score)]
    
    max_grid_score_of1 = cluster_shifted_grid_scores_of1.grid_score.values[np.nanargmax(cluster_shifted_grid_scores_of1.grid_score)]
    spatial_info = cluster_spatial_information_of1.spatial_information.values[np.nanargmax(cluster_shifted_grid_scores_of1.grid_score)]
    spatial_info_no_lag = cluster_spatial_information_of1.spatial_information.iloc[0]

    cell = shifted_grid_scores_of1[shifted_grid_scores_of1.grid_score==max_grid_score_of1]

    if (max_grid_score_of1 > percentile99_grid_score_of1) and (spatial_info > percentile99_spatial_information_of1) and (max_grid_score_of1>0.4):
        grid_cells = pd.concat([grid_cells, cell], ignore_index=True)
    elif (spatial_info_no_lag > percentile99_spatial_information_of1):
        non_grid_cells = pd.concat([non_grid_cells, cell], ignore_index=True)
    else:
        non_spatial_cells = pd.concat([non_spatial_cells, cell], ignore_index=True)
    cells = pd.concat([cells, cell], ignore_index=True)
    
all_cells = cells.copy()
grid_cells = grid_cells.sort_values(by=['field_spacing'])
non_grid_cells = non_grid_cells.sort_values(by=['field_spacing'])
non_spatial_cells = non_spatial_cells.sort_values(by=['field_spacing'])
non_grid_and_non_spatial_cells = pd.concat([non_grid_cells, non_spatial_cells], ignore_index=True)

print(f'there are {len(non_grid_and_non_spatial_cells)} non_grid and non_spatial_cells')
print(f'there are {len(grid_cells)} grid_cells')
print(f'there are {len(non_grid_cells)} non grid spatial cells')
print(f'there are {len(non_spatial_cells)} non spatial cells')
print(f'there are {len(all_cells)} cells')


In [None]:
plt.scatter(grid_cells['field_spacing'], grid_cells['orientation'], color='tab:red')
plt.scatter(non_grid_cells['field_spacing'], non_grid_cells['orientation'], color='tab:cyan')
plt.scatter(non_spatial_cells['field_spacing'], non_spatial_cells['orientation'], color='tab:grey', alpha=0.5)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import hdbscan
from sklearn.preprocessing import StandardScaler
import matplotlib.cm as cm

samples = np.stack([np.array(grid_cells['field_spacing']),
                    np.cos((np.array(grid_cells['orientation'])/60) * 2 * np.pi),
                    np.sin((np.array(grid_cells['orientation'])/60) * 2 * np.pi)]).T

samples2d = np.stack([np.array(grid_cells['field_spacing']),
                    np.array(grid_cells['orientation'])]).T

# Standardize the data
scaler = StandardScaler()
samples_scaled = scaler.fit_transform(samples)
samples_scaled[:, 1] /= np.sqrt(2)
samples_scaled[:, 2] /= np.sqrt(2)

samples_scaled = samples

# Apply HDBSCAN
clusterer = hdbscan.HDBSCAN(min_cluster_size=5, cluster_selection_epsilon=3)
module_labels = clusterer.fit_predict(samples_scaled)

# Plot the results
plt.figure(figsize=(3, 3))
label_colors = {label: cm.get_cmap('viridis', len(np.unique(module_labels)))(i) for i, label in enumerate(np.unique(module_labels))}
for mi in np.unique(module_labels):
    mask = module_labels == mi
    print(f'for mi{mi}, there are {np.sum(mask)} points')
    plt.scatter(samples2d[:, 0][mask], samples2d[:, 1][mask], c=label_colors[mi], s=20, cmap='viridis', label='Clustered Points')
# Highlight unassigned points (label -1)
unassigned = samples2d[module_labels == -1]
plt.scatter(unassigned[:, 0], unassigned[:, 1], s=21, color='red', label='Unassigned Points')
plt.scatter(all_cells['field_spacing'], all_cells['orientation'], s=20, color='tab:grey', alpha=0.5,zorder=-1)

#plt.legend()
plt.xlabel('Grid Spacing (cm)')
plt.ylabel('Grid Orientation ($^\circ$)')
plt.ylim(0,60)
plt.title(f'HDBSCAN M{mouse}D{day}')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/SUPP_Grid_module_classification/HDBSCAN_M{mouse}D{day}.pdf')
plt.show()

if np.unique(module_labels).size == 1 and np.unique(module_labels)[0] == -1:
    module_labels[:] = 0  # Assign all points to a single cluster if no clusters were found
    label_colors[0] = label_colors[-1]

In [None]:
# put cluster ids into modules then rearange from smallest spacing to larger
grid_module_cluster_ids = []
grid_module_ids = []
avg_spacings = []
for mi, module_label in enumerate(np.unique(module_labels[module_labels != -1])):
    grid_ids = np.array(grid_cells['cluster_id'])
    cells = grid_cells[np.isin(grid_cells['cluster_id'], grid_ids[module_labels == module_label])]
    avg_spacings.append(np.nanmean(cells.field_spacing.values))
    grid_module_cluster_ids.append(cells['cluster_id'].tolist())
    grid_module_ids.append(mi)
    print(f'for module {mi}, there are {len(cells)} cells with average spacing {np.nanmean(cells.field_spacing.values)}')
grid_module_cluster_ids = [x for _, x in sorted(zip(avg_spacings, grid_module_cluster_ids))]
grid_module_ids = [x for _, x in sorted(zip(avg_spacings, grid_module_ids))]

In [None]:
ncols = 10
rows_per_module = {mi: int(np.ceil(len(module) / ncols)) for mi, module in zip(grid_module_ids, grid_module_cluster_ids)}
nrows = sum(rows_per_module.values())+len(grid_module_cluster_ids)-1
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(10, 1*nrows), squeeze=False)
row_counter = 0
for mi, module_ids in zip(grid_module_ids, grid_module_cluster_ids):
    cells = grid_cells[grid_cells['cluster_id'].isin(module_ids)]
    print(f'for module {mi}, there are {len(cells)} cells')
    counter = 0
    for j in range(rows_per_module[mi]):
        for i in range(ncols):
            if counter < len(cells):
                index = cells['cluster_id'].values[counter]
                score = cells['grid_score'].values[counter]
                cluster_shifted_grid_scores = shifted_grid_scores_of1[shifted_grid_scores_of1.cluster_id==index]
                travel = cluster_shifted_grid_scores.travel.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
                max_score = cluster_shifted_grid_scores.grid_score.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
                field_spacing = cluster_shifted_grid_scores.field_spacing.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
                
                tcs = {}    
                position = np.stack([beh_OF['P_x'], beh_OF['P_y']], axis=1)
                beh_lag = compute_travel_projected(["P_x", "P_y"], position, position, travel)
                position_lagged = np.stack([beh_lag['P_x'], beh_lag['P_y']], axis=1)
                for cell in cells['cluster_id'].values:
                    tc = nap.compute_2d_tuning_curves(nap.TsGroup([clusters_OF[cell]]), position_lagged, nb_bins=(40,40))[0]
                    tc = gaussian_filter_nan(tc[0], sigma=(2.5,2.5))
                    tcs[cell] = tc
                #ax[j, i].text(0,-2, f'id: {index}, mgs: {np.round(max_score, decimals=1)}', size=7)
                #ax[j, i].text(0,44, f'fs:{int(field_spacing)}', size=7)
                ax[row_counter, i].imshow(tcs[index], cmap='jet')
                counter+=1
        row_counter += 1
    row_counter += 1

for axi in ax.flatten():
    axi.axis('off')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/SUPP_Grid_module_classification/GC_rate_maps_modules_M{mouse}D{day}.pdf')
plt.show()

In [None]:
'''ncols = 10
nrows = int(np.ceil(len(non_grid_and_non_spatial_cells) / ncols))

fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(10, 1*nrows), squeeze=False)
cells = non_grid_and_non_spatial_cells
counter = 0
for j in range(nrows):
    for i in range(ncols):
        if counter < len(cells):
            index = cells['cluster_id'].values[counter]
            score = cells['grid_score'].values[counter]
            cluster_shifted_grid_scores = shifted_grid_scores_of1[shifted_grid_scores_of1.cluster_id==index]
            travel = cluster_shifted_grid_scores.travel.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
            max_score = cluster_shifted_grid_scores.grid_score.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
            field_spacing = cluster_shifted_grid_scores.field_spacing.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
            
            tcs = {}    
            position = np.stack([beh_OF['P_x'], beh_OF['P_y']], axis=1)
            beh_lag = compute_travel_projected(["P_x", "P_y"], position, position, travel)
            position_lagged = np.stack([beh_lag['P_x'], beh_lag['P_y']], axis=1)
            for cell in cells['cluster_id'].values:
                tc = nap.compute_2d_tuning_curves(nap.TsGroup([clusters_OF[cell]]), position_lagged, nb_bins=(40,40))[0]
                tc = gaussian_filter_nan(tc[0], sigma=(2.5,2.5))
                tcs[cell] = tc
            #ax[j, i].text(0,-2, f'id: {index}, mgs: {np.round(max_score, decimals=1)}', size=7)
            #ax[j, i].text(0,44, f'fs:{int(field_spacing)}', size=7)
            ax[j, i].imshow(tcs[index], cmap='jet')
            counter+=1

for axi in ax.flatten():
    axi.axis('off')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/SUPP_Grid_module_classification/NGS_NS_rate_maps_module{mi}_{mouse}D{day}_lagged.pdf')
plt.show()'''

In [None]:
'''ncols = 10
nrows = int(np.ceil(len(non_grid_and_non_spatial_cells) / ncols))

fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(10, 1*nrows), squeeze=False)
cells = non_grid_and_non_spatial_cells
counter = 0
for j in range(nrows):
    for i in range(ncols):
        if counter < len(cells):
            index = cells['cluster_id'].values[counter]
            score = cells['grid_score'].values[counter]
            cluster_shifted_grid_scores = shifted_grid_scores_of1[shifted_grid_scores_of1.cluster_id==index]
            travel = cluster_shifted_grid_scores.travel.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
            max_score = cluster_shifted_grid_scores.grid_score.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
            field_spacing = cluster_shifted_grid_scores.field_spacing.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
            
            tcs = {}    
            position = np.stack([beh_OF['P_x'], beh_OF['P_y']], axis=1)
            beh_lag = compute_travel_projected(["P_x", "P_y"], position, position, travel=0)
            position_lagged = np.stack([beh_lag['P_x'], beh_lag['P_y']], axis=1)
            for cell in cells['cluster_id'].values:
                tc = nap.compute_2d_tuning_curves(nap.TsGroup([clusters_OF[cell]]), position_lagged, nb_bins=(40,40))[0]
                tc = gaussian_filter_nan(tc[0], sigma=(2.5,2.5))
                tcs[cell] = tc
            #ax[j, i].text(0,-2, f'id: {index}, mgs: {np.round(max_score, decimals=1)}', size=7)
            #ax[j, i].text(0,44, f'fs:{int(field_spacing)}', size=7)
            ax[j, i].imshow(tcs[index], cmap='jet')
            counter+=1

for axi in ax.flatten():
    axi.axis('off')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/SUPP_Grid_module_classification/NGS_NS_rate_maps_module{mi}_{mouse}D{day}_not_lagged.pdf')
plt.show()'''

we now have seperated cells into grid modules and non grid cells

In [None]:
vr_folder = f'/Users/harryclark/Downloads/COHORT12_nwb/M{mouse}/D{day:02}/VR/'
spikes_path = vr_folder + f"sub-{mouse}_day-{day:02}_ses-VR_srt-kilosort4_clusters.npz"
beh_path = vr_folder + f"sub-{mouse}_day-{day:02}_ses-VR_beh.nwb"
beh = nap.load_file(beh_path)
clusters = nap.load_file(spikes_path)
clusters = curate_clusters(clusters)

tl=200 # cm
bs=2 # cm
time_bs = 100 # ms
tns = beh['trial_number']
dt = beh['travel']-((tns[0]-1)*tl)
n_bins = int(int(((np.ceil(np.nanmax(dt))//tl)+1)*tl)/bs)
max_bound = int(((np.ceil(np.nanmax(dt))//tl)+1)*tl)
min_bound = 0
dt_bins =np.arange(0,max_bound,bs)
plot_stops(beh, tl=200, sort=False, return_fig=False, 
           savepath=f'/Users/harryclark/Documents/figs/toroidal/M{mouse}D{day}_stops.pdf')
plt.close()
plot_stops(beh, tl=200, sort=True, return_fig=False, 
           savepath=f'/Users/harryclark/Documents/figs/toroidal/M{mouse}D{day}_stops_sorted.pdf')
plt.close()

# trick to clip the tc to around the end of the ephys recording
# take the cell with the highest firing rate, and find the last bin with a spike
# then work backwards and clip at the end of the last appropriate trials
tc = nap.compute_1d_tuning_curves(nap.TsGroup([clusters[clusters.index[np.nanargmax(clusters.firing_rate)]]]), 
                                      dt, 
                                      nb_bins=n_bins, 
                                      minmax=[min_bound, max_bound],
                                      ep=beh["moving"])[0]
mask = np.isnan(tc)
tc = gaussian_filter(np.nan_to_num(tc).astype(np.float64), sigma=2.5)
last_ephys_bin = int(np.nonzero(tc)[0][-1] + (tl/bs) - np.nonzero(tc)[0][-1]%(tl/bs))
last_ephys_time_bin = clusters[clusters.index[0]].count(bin_size=time_bs, time_units = 'ms').index[-1]

# time binned variables for later
ep = nap.IntervalSet(start=0, end=last_ephys_time_bin, time_units = 's')
speed_in_time = beh['S'].bin_average(bin_size=time_bs, time_units = 'ms', ep=ep)
dt_in_time = beh['travel'].bin_average(bin_size=time_bs, time_units = 'ms', ep=ep)-((tns[0]-1)*tl)
pos_in_time = dt_in_time%tl
trial_number_in_time = (dt_in_time//tl)+tns[0]
tcs = {}
tcs_time = {}
autocorrs = {}
for cell in clusters.index:
    tc = nap.compute_1d_tuning_curves(nap.TsGroup([clusters[cell]]), 
                                      dt, 
                                      nb_bins=n_bins, 
                                      minmax=[min_bound, max_bound],
                                      ep=beh["moving"])[0]
    mask = np.isnan(tc)
    tc = gaussian_filter(np.nan_to_num(tc).astype(np.float64), sigma=2.5)
    tc = zscore(tc)
    tc = tc[:last_ephys_bin] # only want bins with ephys data in it
    tcs[cell] = tc
    
    tc_time = clusters[cell].count(bin_size=time_bs, time_units = 'ms', ep=ep)
    tc_time = gaussian_filter(np.nan_to_num(tc_time).astype(np.float64), sigma=2.5) # 
    tc_time = zscore(tc_time)
    tcs_time[cell] = tc_time

    lags = np.arange(0, 200, 1) # were looking at 10 timesteps back and 10 forward
    autocorr = []
    for lag in lags:
        if lag < 0:
            tc_offset = np.roll(tc, lag)
            tc_offset[lag:] = 0
        elif lag > 0:
            tc_offset = np.roll(tc, lag)
            tc_offset[:lag] = 0
        else:
            tc_offset = tc
        corr = stats.pearsonr(tc, tc_offset)[0]
        autocorr.append(corr)
    autocorr = np.array(autocorr)
    autocorrs[cell] = autocorr

# drop beh trials from after last ephys bin
beh_trials = beh['trials']
beh_trials = beh_trials[:int(last_ephys_bin/(tl/bs))]


In [None]:
from scipy.signal import find_peaks
for mi, module_ids in zip(grid_module_ids, grid_module_cluster_ids):
    matrix = np.array(list(autocorrs.values()))
    matrix_cluster_ids = np.array(list(autocorrs.keys()))
    cluster_id_of_interest = module_ids
    matrix = matrix[np.isin(matrix_cluster_ids, cluster_id_of_interest)]
    matrix_cluster_ids = matrix_cluster_ids[np.isin(matrix_cluster_ids, cluster_id_of_interest)]
    peaks = []
    for array in matrix:
        if len(find_peaks(array)[0])>0:
            peak = find_peaks(array)[0][0]
        else:
            peak = np.nan
        peaks.append(peak)
    peaks = np.array(peaks)*bs
    median_peak = np.nanmedian(peaks)
    fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(2,2), squeeze=False)
    if median_peak < 200:
        max_r = 200
    else:
        max_r = 400
    ax[0,0].hist(peaks, bins=25, range=(0, max_r), color=label_colors[mi])
    ax[0,0].axvline(median_peak-15, color='grey', linestyle='--')
    ax[0,0].axvline(median_peak+15, color='grey', linestyle='--')
    plt.savefig(f'/Users/harryclark/Documents/figs/SUPP_Grid_module_classification/GC_peaks_{mi}_{mouse}D{day}.pdf')
    plt.show()

    for peak, cluster_id in zip(peaks, matrix_cluster_ids):
        if not np.abs(peak-median_peak)<20: # 20cm tolerance
            module_ids.remove(cluster_id)
    grid_module_cluster_ids[grid_module_ids.index(mi)] = module_ids


In [None]:
ncols = 10
rows_per_module = {mi: int(np.ceil(len(module) / ncols)) for mi, module in zip(grid_module_ids, grid_module_cluster_ids)}
nrows = sum(rows_per_module.values())+len(grid_module_cluster_ids)-1
fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(10, 1*nrows), squeeze=False)
row_counter = 0
for mi, module_ids in zip(grid_module_ids, grid_module_cluster_ids):
    cells = grid_cells[grid_cells['cluster_id'].isin(module_ids)]
    print(f'for module {mi}, there are {len(cells)} cells')
    counter = 0
    for j in range(rows_per_module[mi]):
        for i in range(ncols):
            if counter < len(cells):
                index = cells['cluster_id'].values[counter]
                score = cells['grid_score'].values[counter]
                cluster_shifted_grid_scores = shifted_grid_scores_of1[shifted_grid_scores_of1.cluster_id==index]
                travel = cluster_shifted_grid_scores.travel.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
                max_score = cluster_shifted_grid_scores.grid_score.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
                field_spacing = cluster_shifted_grid_scores.field_spacing.values[np.nanargmax(cluster_shifted_grid_scores.grid_score)]
                
                tcs_post_curation = {}    
                position = np.stack([beh_OF['P_x'], beh_OF['P_y']], axis=1)
                beh_lag = compute_travel_projected(["P_x", "P_y"], position, position, travel)
                position_lagged = np.stack([beh_lag['P_x'], beh_lag['P_y']], axis=1)
                for cell in cells['cluster_id'].values:
                    tc = nap.compute_2d_tuning_curves(nap.TsGroup([clusters_OF[cell]]), position_lagged, nb_bins=(40,40))[0]
                    tc = gaussian_filter_nan(tc[0], sigma=(2.5,2.5))
                    tcs_post_curation[cell] = tc
                #ax[j, i].text(0,-2, f'id: {index}, mgs: {np.round(max_score, decimals=1)}', size=7)
                #ax[j, i].text(0,44, f'fs:{int(field_spacing)}', size=7)
                ax[row_counter, i].imshow(tcs_post_curation[index], cmap='jet')
                counter+=1
        row_counter += 1
    row_counter += 1

for axi in ax.flatten():
    axi.axis('off')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/SUPP_Grid_module_classification/GC_rate_maps_modules_M{mouse}D{day}_post_curated.pdf')
plt.show()

In [None]:
from scipy.signal import find_peaks
for mi, module_ids in zip(grid_module_ids, grid_module_cluster_ids):
    matrix = np.array(list(autocorrs.values()))
    matrix_cluster_ids = np.array(list(autocorrs.keys()))
    cluster_id_of_interest = module_ids
    matrix = matrix[np.isin(matrix_cluster_ids, cluster_id_of_interest)]
    matrix_cluster_ids = matrix_cluster_ids[np.isin(matrix_cluster_ids, cluster_id_of_interest)]
    peaks = []
    for array in matrix:
        if len(find_peaks(array)[0])>0:
            peak = find_peaks(array)[0][0]
        else:
            peak = np.nan
        peaks.append(peak)
    peaks = np.array(peaks)*bs
    median_peak = np.nanmedian(peaks)
    fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(2,2), squeeze=False)
    if median_peak < 200:
        max_r = 200
    else:
        max_r = 400
    ax[0,0].hist(peaks, bins=25, range=(0, max_r), color=label_colors[mi])
    ax[0,0].axvline(median_peak-15, color='grey', linestyle='--')
    ax[0,0].axvline(median_peak+15, color='grey', linestyle='--')
    plt.savefig(f'/Users/harryclark/Documents/figs/SUPP_Grid_module_classification/GC_peaks_{mi}_{mouse}D{day}_post_curated.pdf')
    plt.show()


In [None]:
# we now have cluster ids classified into modules, non grid spatial cells and non spatial cells 
# as defined by activity in the open field
grid_module_cluster_ids = sorted(grid_module_cluster_ids, key=len, reverse=True) 

cluster_ids_by_group = []
cluster_ids_by_group.extend(grid_module_cluster_ids)
cluster_ids_by_group.append(non_grid_cells.cluster_id.values.tolist())
cluster_ids_by_group.append(non_spatial_cells.cluster_id.values.tolist())

for cluster_ids in cluster_ids_by_group:
    print(cluster_ids)

In [None]:
# This is a hack to remove very low rates 
new_cluster_ids_by_group = []
for cluster_id_group in cluster_ids_by_group:
    new_group = []
    for cell in cluster_id_group:
        rate = nap.TsGroup([clusters[cell]]).rates[0]
        if rate < 1:
            print(f'{cell} removed for having a very low rate')
        else:
            new_group.append(cell)
    new_cluster_ids_by_group.append(new_group)
cluster_ids_by_group = new_cluster_ids_by_group.copy()

for cluster_ids in cluster_ids_by_group:
    print(cluster_ids)

In [None]:
for m, cluster_id_group in enumerate(cluster_ids_by_group):
    ncols = 10
    nrows = int(np.ceil(len(cluster_id_group)/ncols))
    fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(10, 1.4*nrows), squeeze=False)
    counter = 0
    for j in range(nrows):
        for i in range(ncols):
            if counter<len(cluster_id_group):
                index = cluster_id_group[counter]
                plot_firing_rate_map(ax[j, i], 
                                    zscore(tcs[index]),
                                    bs=bs,
                                    tl=tl,
                                    p=95)
            else:
                ax[j, i].axis('off')
            counter+=1
            ax[j, i].set_xticks([])
            ax[j, i].set_yticks([])
            ax[j, i].xaxis.set_tick_params(labelbottom=False)
            ax[j, i].yaxis.set_tick_params(labelleft=False)
    plt.tight_layout()
    plt.savefig(f'/Users/harryclark/Documents/figs/SUPP_Grid_module_classification/VR_rate_maps_{m}_M{mouse}D{day}.pdf')
    plt.show()


In [None]:
tcs_to_use = {cluster_id: tcs[cluster_id] for cluster_id in cluster_ids_by_group[0] if cluster_id in tcs}
tcs_time_to_use =  {cluster_id: tcs_time[cluster_id] for cluster_id in cluster_ids_by_group[0] if cluster_id in tcs_time}
zmaps_time = np.array(list(tcs_time_to_use.values()))

N = len(tcs_to_use)
zmaps = np.array(list(tcs_to_use.values()))
results = spectral_analysis(tcs_to_use, tl, bs=bs)
f_modules =              results[0]
phi_modules =            results[1]
grid_cell_idxs_modules = results[2]
spectrograms =           results[3]
trial_starts =           results[6]
L = tl

In [None]:
# Plot PSDs
plt.figure(figsize=(3,4))
nongrid_idxs = np.setdiff1d(np.arange(N), np.concatenate(grid_cell_idxs_modules))
fmax = 8/L
count = 0
Ps = []
for j in range(len(grid_cell_idxs_modules)):
    grid_cell_idxs = grid_cell_idxs_modules[j]
    for gi in grid_cell_idxs:
        mp = gaussian_filter1d(zmaps[gi].ravel(), 3)
        f, Pxx = welch(mp,nperseg=4000,noverlap=3000)
        # Ps.append(Pxx[f<fmax])
        Ps.append(Pxx[f<fmax]/(Pxx[f<fmax]).sum())
    count += len(grid_cell_idxs)
    plt.axhline(count, c='grey',linestyle='dashed',linewidth=0.4)
for ngi in nongrid_idxs:
    mp = gaussian_filter1d(zmaps[ngi].ravel(), 3)
    f, Pxx = welch(mp,nperseg=4000,noverlap=3000)
    Ps.append(Pxx[f<fmax]/(Pxx[f<fmax]).sum())
Ps = np.stack(Ps)

plt.pcolormesh(100*f[f<fmax]/2,np.arange(len(Ps)),np.stack(Ps),vmax=0.04)
plt.xlabel(f'Frequency (m-1)')
plt.xlim([0,2])
plt.ylabel('Neuron')
plt.tight_layout()
plt.title('PSDs')
plt.show()

In [None]:
# Module spectrograms
plt.figure(figsize=(20,3))
mi = 2
for i, grid_cell_idxs in enumerate(grid_cell_idxs_modules):
    plt.subplot(1,5,i+1)
    Ng = len(grid_cell_idxs)
    S = spectrograms[grid_cell_idxs].mean(0)
    plt.imshow(S,origin='lower',aspect='auto',vmax=0.25,cmap='magma')
    plt.yticks([0, len(S)/2, len(S)], [0, 1, 2])
    
    if i==0:
        plt.ylabel(f'Frequency (m-1)')

    for i in range(1,8):
        plt.axhline(100*i/L/2,linewidth=1,c='grey',alpha=0.5)
        
    for ts in trial_starts[1:-1]:
        plt.axvline(ts,linewidth=1,c='grey',alpha=0.5)
    plt.xlabel('Trials')
    plt.show()

In [None]:
# Plot trajectories on the neural sheet
grid_cell_idxs = grid_cell_idxs_modules[0]
phi = phi_modules[0]
Ng = len(grid_cell_idxs)

maps = gaussian_filter1d(zmaps[grid_cell_idxs].reshape(Ng, -1), 2, axis=1)
maps_time = gaussian_filter1d(zmaps_time[grid_cell_idxs], 2)

# maps = zmaps[grid_cell_idxs].reshape(Ng,-1)
angles = np.arctan2(np.cos(phi)@maps, np.sin(phi)@maps)
angles_time = np.arctan2(np.cos(phi)@maps_time, np.sin(phi)@maps_time)

fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(10, 5), squeeze=False)

ax[0,0].set_title(r'$\theta_1$')
ax[0,0].imshow(angles[0].reshape(-1,int(L/bs)),cmap='hsv')
for ts in trial_starts[1:-1]:
    ax[0,0].axhline(ts,color='k',linestyle='dashed',linewidth=1)
ax[0,0].set_xlabel('Pos. (cm)')
ax[0,0].set_ylabel('Trial')

ax[0,1].imshow(angles[1].reshape(-1,int(L/bs)),cmap='hsv')
ax[0,1].set_title(r'$\theta_2$')
for ts in trial_starts[1:-1]:
    ax[0,1].axhline(ts,color='k',linestyle='dashed',linewidth=1)
ax[0,1].set_ylabel('Trial')
ax[0,1].set_xlabel('Pos. (cm)')

ax[0,2].imshow(angles[2].reshape(-1,int(L/bs)),cmap='hsv')
ax[0,2].set_title(r'$\theta_3$')
for ts in trial_starts[1:-1]:
    ax[0,2].axhline(ts,color='k',linestyle='dashed',linewidth=1)
ax[0,2].set_ylabel('Trial')
ax[0,2].set_xlabel('Pos. (cm)')

plt.tight_layout(w_pad=0.2)

In [None]:
def rolling_pearson_r(x, y, window):
    r_values = np.full(len(x), np.nan)
    for i in range(window - 1, len(x)):
        x_window = x[i - window + 1:i + 1]
        y_window = y[i - window + 1:i + 1]
        if len(x_window) == window and len(y_window) == window:
            r = np.corrcoef(x_window, y_window)[0, 1]
            r_values[i] = r
    return r_values

In [None]:
# Compute angular differences with wrap-around
dtheta1 = np.diff(angles_time[0])
dtheta2 = np.diff(angles_time[1])

# Wrap-around correction for circular dimensions
dtheta1 = np.mod(dtheta1 + np.pi, 2 * np.pi) - np.pi
dtheta2 = np.mod(dtheta2 + np.pi, 2 * np.pi) - np.pi

# Compute Euclidean distance in angular space
torus_distances = np.sqrt(dtheta1**2 + dtheta2**2)

# Compute instantaneous speed
torus_speed = torus_distances / (time_bs/1000)

# smooth and zscore
torus_speed = gaussian_filter1d(torus_speed, 1)
speed_in_time = gaussian_filter1d(speed_in_time, 1)

ztorus_speed = zscore(torus_speed)
zspeed = zscore(speed_in_time)

In [None]:
trial_colors_in_time = []
trial_group_in_time = []
trial_type_in_time = []
for tn in trial_number_in_time:
    trial = beh['trials'][beh['trials']['number']==tn]
    group=(trial['context'][0], 
           trial['type'][0],
           trial['performance'][0])
    c = get_color_for_group(group)
    trial_group_in_time.append(group)
    trial_colors_in_time.append(c)
    trial_type_in_time.append(trial['type'][0])
trial_colors_in_time = np.array(trial_colors_in_time)
trial_group_in_time = np.array(trial_group_in_time)
trial_type_in_time = np.array(trial_type_in_time)

In [None]:
# decoding in time
grid_ids_from_modules = [item for sublist in grid_module_cluster_ids for item in sublist]
grid_tcs_time = {cluster_id: tcs_time[cluster_id] for cluster_id in grid_ids_from_modules if cluster_id in tcs_time}

speed_in_time = np.array(speed_in_time)
pos_in_time = np.array(pos_in_time) 
trial_number_in_time = np.array(trial_number_in_time)
dt_in_time = np.array(dt_in_time)
tns_to_decode_with = np.array(beh['trials']['number'])
tns_to_decode_with = tns_to_decode_with[tns_to_decode_with<=np.nanmax(trial_number_in_time)]
trial_types = np.array(beh['trials']['type'])

tns_to_decode = np.array(beh['trials']['number']) # decode all trials to visualise
tns_to_train = np.array(beh['trials']['number'][np.isin(beh['trials']['type'], np.array(['b','nb']))]) 
tns_to_decode = tns_to_decode[tns_to_decode<=np.nanmax(trial_number_in_time)] # handles last ephys trials
tns_to_train = tns_to_train[tns_to_train<=np.nanmax(trial_number_in_time)] # handles last ephys trials

predictions_in_time, errors_in_time = cross_validate_decoder_time(grid_tcs_time, 
                                                true_position=pos_in_time, 
                                                trial_numbers=trial_number_in_time, 
                                                tns_to_decode=tns_to_decode, 
                                                tns_to_train=tns_to_train, 
                                                tl=tl, bs=bs, train=0.9, n=10, verbose=False)

avg_predictions_in_time = [np.mean(np.stack(preds_n), axis=0) for preds_n in predictions_in_time]
avg_predictions_in_time = np.concatenate(avg_predictions_in_time).ravel()

In [None]:
reg_30seconds_speed = rolling_pearson_r(zspeed, ztorus_speed, window=300)
reg_20seconds_speed = rolling_pearson_r(zspeed, ztorus_speed, window=200)
reg_10seconds_speed = rolling_pearson_r(zspeed, ztorus_speed, window=100)
reg_5seconds_speed = rolling_pearson_r(zspeed, ztorus_speed, window=50)
reg_1seconds_speed = rolling_pearson_r(zspeed, ztorus_speed, window=10)

# Plot the speed over time
fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(10, 2), sharex=True)
start=3200
stop=6200
print(f'showing window from {trial_number_in_time[start]-trial_number_in_time[0]} to {trial_number_in_time[stop]-trial_number_in_time[0]}')
ax[0].axhline(0, color='grey', linestyle='--', linewidth=0.5)
ax[0].set_ylim([-0.5, 1])
ax[0].plot(np.arange(0, len(ztorus_speed)*(time_bs/1000), (time_bs/1000))[start:stop], reg_20seconds_speed[start:stop], label='20 seconds', color='tab:grey')
ax[1].plot(np.arange(0, len(ztorus_speed)*(time_bs/1000), (time_bs/1000))[start:stop], ztorus_speed[start:stop], color='tab:purple', label='Torus Speed')
ax[1].plot(np.arange(0, len(zspeed)*(time_bs/1000), (time_bs/1000))[start:stop], zspeed[start:stop], color='black', label='Track Speed')
ax[1].set_xlabel('Time (seconds)')
ax[1].set_ylabel('Z-scored\nSpeed')
ax[0].set_ylabel('R')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/FIGURE2/GC_torus_snippet1_{mouse}D{day}.pdf')
plt.show()


# Plot the speed over time
fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(10, 2), sharex=True)
start=10000
stop=13000
print(f'showing window from {trial_number_in_time[start]-trial_number_in_time[0]} to {trial_number_in_time[stop]-trial_number_in_time[0]}')
ax[0].axhline(0, color='grey', linestyle='--', linewidth=0.5)
ax[0].set_ylim([-0.5, 1])
ax[0].plot(np.arange(0, len(ztorus_speed)*(time_bs/1000), (time_bs/1000))[start:stop], reg_20seconds_speed[start:stop], label='20 seconds', color='tab:grey')
ax[1].plot(np.arange(0, len(ztorus_speed)*(time_bs/1000), (time_bs/1000))[start:stop], ztorus_speed[start:stop], color='tab:purple', label='Torus Speed')
ax[1].plot(np.arange(0, len(zspeed)*(time_bs/1000), (time_bs/1000))[start:stop], zspeed[start:stop], color='black', label='Track Speed')
ax[1].set_xlabel('Time (seconds)')
ax[1].set_ylabel('Z-scored\nSpeed')
ax[0].set_ylabel('R')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/FIGURE2/GC_torus_snippet2_{mouse}D{day}.pdf')
plt.show()


# Plot the speed over time
fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(10, 2), sharex=True)
start=1400
stop=17000
print(f'showing window from {trial_number_in_time[start]-trial_number_in_time[0]} to {trial_number_in_time[stop]-trial_number_in_time[0]}')
ax[0].axhline(0, color='grey', linestyle='--', linewidth=0.5)
ax[0].set_ylim([-0.5, 1])
ax[0].plot(np.arange(0, len(ztorus_speed)*(time_bs/1000), (time_bs/1000))[start:stop], reg_20seconds_speed[start:stop], label='20 seconds', color='tab:grey')
ax[1].plot(np.arange(0, len(ztorus_speed)*(time_bs/1000), (time_bs/1000))[start:stop], ztorus_speed[start:stop], color='tab:purple', label='Torus Speed')
ax[1].plot(np.arange(0, len(zspeed)*(time_bs/1000), (time_bs/1000))[start:stop], zspeed[start:stop], color='black', label='Track Speed')
ax[1].set_xlabel('Time (seconds)')
ax[1].set_ylabel('Z-scored\nSpeed')
ax[0].set_ylabel('R')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/FIGURE2/GC_torus_snippet3_{mouse}D{day}.pdf')
plt.show()

In [None]:
reg_30seconds_pos = rolling_pearson_r(pos_in_time, avg_predictions_in_time, window=300)
reg_20seconds_pos = rolling_pearson_r(pos_in_time, avg_predictions_in_time, window=200)
reg_10seconds_pos = rolling_pearson_r(pos_in_time, avg_predictions_in_time, window=100)
reg_5seconds_pos = rolling_pearson_r(pos_in_time, avg_predictions_in_time, window=50)
reg_1seconds_pos = rolling_pearson_r(pos_in_time, avg_predictions_in_time, window=10)


# Plot the speed over time
fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(10, 2), sharex=True)
start=3200
stop=6200
print(f'showing window from {trial_number_in_time[start]-trial_number_in_time[0]} to {trial_number_in_time[stop]-trial_number_in_time[0]}')
ax[0].axhline(0, color='grey', linestyle='--', linewidth=0.5)
ax[0].set_ylim([-0.5, 1])
ax[0].plot(np.arange(0, len(ztorus_speed)*(time_bs/1000), (time_bs/1000))[start:stop], reg_20seconds_pos[start:stop], label='20 seconds', color='tab:grey')
ax[1].plot(np.arange(0, len(avg_predictions_in_time)*(time_bs/1000), (time_bs/1000))[start:stop], avg_predictions_in_time[start:stop], color='red', label='decoded from grid')
ax[1].plot(np.arange(0, len(avg_predictions_in_time)*(time_bs/1000), (time_bs/1000))[start:stop], pos_in_time[start:stop], color='black',label='true pos')
ax[1].set_xlabel('Time (seconds)')
ax[1].set_ylabel('Pos (cm)')
ax[0].set_ylabel('R')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/FIGURE2/GC_decode_snippet1_{mouse}D{day}.pdf')
plt.show()


# Plot the speed over time
fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(10, 2), sharex=True)
start=10000
stop=13000
print(f'showing window from {trial_number_in_time[start]-trial_number_in_time[0]} to {trial_number_in_time[stop]-trial_number_in_time[0]}')
ax[0].axhline(0, color='grey', linestyle='--', linewidth=0.5)
ax[0].set_ylim([-0.5, 1])
ax[0].plot(np.arange(0, len(ztorus_speed)*(time_bs/1000), (time_bs/1000))[start:stop], reg_20seconds_pos[start:stop], label='20 seconds', color='tab:grey')
ax[1].plot(np.arange(0, len(avg_predictions_in_time)*(time_bs/1000), (time_bs/1000))[start:stop], avg_predictions_in_time[start:stop], color='red', label='decoded from grid')
ax[1].plot(np.arange(0, len(avg_predictions_in_time)*(time_bs/1000), (time_bs/1000))[start:stop], pos_in_time[start:stop], color='black',label='true pos')
ax[1].set_xlabel('Time (seconds)')
ax[1].set_ylabel('Pos (cm)')
ax[0].set_ylabel('R')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/FIGURE2/GC_decode_snippet2_{mouse}D{day}.pdf')
plt.show()


# Plot the speed over time
fig, ax = plt.subplots(ncols=1, nrows=2, figsize=(10, 2), sharex=True)
start=14000
stop=17000
print(f'showing window from {trial_number_in_time[start]-trial_number_in_time[0]} to {trial_number_in_time[stop]-trial_number_in_time[0]}')
ax[0].axhline(0, color='grey', linestyle='--', linewidth=0.5)
ax[0].set_ylim([-0.5, 1])
ax[0].plot(np.arange(0, len(ztorus_speed)*(time_bs/1000), (time_bs/1000))[start:stop], reg_20seconds_pos[start:stop], label='20 seconds', color='tab:grey')
ax[1].plot(np.arange(0, len(avg_predictions_in_time)*(time_bs/1000), (time_bs/1000))[start:stop], avg_predictions_in_time[start:stop], color='red', label='decoded from grid')
ax[1].plot(np.arange(0, len(avg_predictions_in_time)*(time_bs/1000), (time_bs/1000))[start:stop], pos_in_time[start:stop], color='black',label='true pos')
ax[1].set_xlabel('Time (seconds)')
ax[1].set_ylabel('Pos (cm)')
ax[0].set_ylabel('R')
plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/FIGURE2/GC_decode_snippet3_{mouse}D{day}.pdf')
plt.show()

In [None]:
fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(4, 2), sharex=True, sharey=True)
ax[0].scatter(reg_20seconds_pos[trial_type_in_time=='b'], 
              reg_20seconds_speed[trial_type_in_time=='b'], s=0.3, 
              color=trial_colors_in_time[trial_type_in_time=='b'], alpha=0.1)
ax[1].scatter(reg_20seconds_pos[trial_type_in_time=='nb'], 
              reg_20seconds_speed[trial_type_in_time=='nb'], s=0.3, 
              color=trial_colors_in_time[trial_type_in_time=='nb'], alpha=0.1)
ax[0].set_ylabel('R speed')
ax[0].set_xlabel('R pos')
ax[1].set_xlabel('R pos')
plt.tight_layout()
plt.show()

In [None]:
from scipy.stats import linregress
import statsmodels.formula.api as smf

fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(4, 2), sharex=True, sharey=True)

# For 'b' trial type
nan_mask_b = np.isfinite(reg_20seconds_pos[trial_type_in_time == 'b']) & np.isfinite(reg_20seconds_speed[trial_type_in_time == 'b'])
x_b = reg_20seconds_pos[trial_type_in_time == 'b'][nan_mask_b]
y_b = reg_20seconds_speed[trial_type_in_time == 'b'][nan_mask_b]
c_b = trial_colors_in_time[trial_type_in_time == 'b'][nan_mask_b]
ax[0].scatter(x_b, y_b, s=0.3, color=c_b, alpha=0.1, rasterized=True)

df_b = pd.DataFrame({'pos': x_b, 
                     'speed': y_b, 
                     'tn': trial_type_in_time[trial_type_in_time == 'b'][nan_mask_b],
                     'time': np.arange(len(x_b))})

# Fit the linear mixed-effects model
model = smf.mixedlm("speed ~ pos", data=df_b, groups=df_b["tn"], re_formula="~time")
result = model.fit()
# Extract p-values for fixed effects
print("P-values for fixed effects:")
print(result.pvalues)


# Linear regression for 'b'
slope_b, intercept_b, r_value_b, p_value_b, _ = linregress(x_b, y_b)
ax[0].plot(x_b, slope_b * x_b + intercept_b, color='black', linewidth=1)
#ax[0].text(0.05, 0.95, f"$R^2$: {r_value_b**2:.2f}\n$p$: {p_value_b:.2e}", 
#           transform=ax[0].transAxes, fontsize=8, verticalalignment='top')

# For 'nb' trial type
nan_mask_nb = np.isfinite(reg_20seconds_pos[trial_type_in_time == 'nb']) & np.isfinite(reg_20seconds_speed[trial_type_in_time == 'nb'])
x_nb = reg_20seconds_pos[trial_type_in_time == 'nb'][nan_mask_nb]
y_nb = reg_20seconds_speed[trial_type_in_time == 'nb'][nan_mask_nb]
c_nb = trial_colors_in_time[trial_type_in_time == 'nb'][nan_mask_nb]
ax[1].scatter(x_nb, y_nb, s=0.3, color=c_nb, alpha=0.1, rasterized=True)

# Linear regression for 'nb'
slope_nb, intercept_nb, r_value_nb, p_value_nb, _ = linregress(x_nb, y_nb)
ax[1].plot(x_nb, slope_nb * x_nb + intercept_nb, color='black', linewidth=1)
#ax[1].text(0.05, 0.95, f"$R^2$: {r_value_nb**2:.2f}\n$p$: {p_value_nb:.2e}", 
#           transform=ax[1].transAxes, fontsize=8, verticalalignment='top')

ax[0].set_ylabel('R speed')
ax[0].set_xlabel('R pos')
ax[1].set_xlabel('R pos')

plt.tight_layout()
plt.savefig(f'/Users/harryclark/Documents/figs/FIGURE2/GC_regression_{mouse}D{day}.pdf')
plt.show()

In [None]:
sorted_cats = beh['trials'][:int(last_ephys_bin/(tl/bs))].groupby(by=['context','type','performance'])
sorted_cats = sort_dict_by_priority(sorted_cats, trial_cat_priority)

sorted_trial_indices = []
sorted_trial_colors = []
sorted_block_sizes = []
for group, cat_indices in zip(sorted_cats.keys(), sorted_cats.values()):
    c = get_color_for_group(group)
    sorted_trial_colors.extend(np.repeat(c, len(cat_indices)).tolist())
    sorted_trial_indices.extend(cat_indices.tolist())
    sorted_block_sizes.append(len(cat_indices))
sorted_trial_colors = np.array(sorted_trial_colors)
sorted_trial_indices = np.array(sorted_trial_indices)

trial_colors = []
trial_groups = []
for trial in beh['trials'][:int(last_ephys_bin/(tl/bs))]:
    group=(trial['context'][0], 
           trial['type'][0],
           trial['performance'][0])
    c = get_color_for_group(group)
    trial_colors.append(c)
    trial_groups.append(group)
trial_colors = np.array(trial_colors)
trial_groups = np.array(trial_groups)

In [None]:
# decoding 
grid_tcs = {cluster_id: tcs[cluster_id] for cluster_id in grid_cells.cluster_id if cluster_id in tcs}

x_true_dt = dt_bins[:last_ephys_bin]
true_position = x_true_dt%tl
trial_numbers = (x_true_dt//tl)+beh['trials']['number'][0]
tns_to_decode_with = np.array(beh['trials']['number'])
tns_to_decode_with = tns_to_decode_with[tns_to_decode_with<=np.nanmax(trial_numbers)]
trial_types = np.array(beh['trials']['type'])
trial_types[np.argsort(trial_types)]

tns_to_decode = np.array(beh['trials']['number']) # decode all trials to visualise
tns_to_train = np.array(beh['trials']['number'][np.isin(beh['trials']['type'], np.array(['b','nb']))]) 
tns_to_decode = tns_to_decode[tns_to_decode<=np.nanmax(trial_numbers)] # handles last ephys trials
tns_to_train = tns_to_train[tns_to_train<=np.nanmax(trial_numbers)] # handles last ephys trials

predictions, errors = cross_validate_decoder(grid_tcs, true_position, trial_numbers, tns_to_decode, tns_to_train, tl, bs, train=0.9, n=10, verbose=False)
avg_predictions = circular_nanmean(predictions, tl, axis=2)

fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(0.8, 2), width_ratios=[1,0.05], sharey=True)
x = np.arange(1, len(avg_predictions)+1)
y = np.arange(0, len(avg_predictions[0])*bs, bs)
X, Y = np.meshgrid(x, y)
heatmap = ax[0].pcolormesh(Y, X, avg_predictions.T, shading='auto', cmap='hsv')
heatmap.set_rasterized(True)
ax[0].set_xlabel('Pos. (cm)')
ax[1].axis('off')
ax[1].scatter(np.ones(len(trial_colors)), 
                np.arange(0,len(trial_colors)), 
                c = trial_colors,
                marker='s')
ax[0].set_xlim(0,tl)
ax[0].set_ylim(0,len(avg_predictions))
ax[0].invert_yaxis()
fig.savefig(f'/Users/harryclark/Documents/figs/decoding/GC_M{mouse}D{day}.pdf', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
sorted_predictions = predictions[sorted_trial_indices]
sorted_errors = errors[sorted_trial_indices]

avg_sorted_predictions = circular_nanmean(sorted_predictions, tl, axis=2)
avg_sorted_errors = np.nanmean(sorted_errors, axis=2)

b_error = np.arange(1,tl,bs) - circular_nanmean(avg_sorted_predictions[:len(trial_types[trial_types=='b'])], tl=tl, axis=0)
nb_error = np.arange(1,tl,bs) - circular_nanmean(avg_sorted_predictions[len(trial_types[trial_types=='b']):], tl=tl, axis=0)

plt.hist(b_error, color='tab:blue', bins=100,alpha=0.4)
plt.hist(nb_error, color='tab:orange', bins=100,alpha=0.4)
plt.title('errors before circular correction')
plt.show()
b_error[b_error>(tl*0.75)] = tl-b_error[b_error>(tl*0.75)]
b_error[b_error<(-tl*0.75)] = tl+b_error[b_error<(-tl*0.75)]
nb_error[nb_error>(tl*0.75)] = tl-nb_error[nb_error>(tl*0.75)]
nb_error[nb_error<(-tl*0.75)] = tl+nb_error[nb_error<(-tl*0.75)]

plt.hist(b_error, color='tab:blue', bins=100,alpha=0.4)
plt.hist(nb_error, color='tab:orange', bins=100,alpha=0.4)
plt.title('errors after circular correction')
plt.show()

norm = TwoSlopeNorm(vmin=-35,vcenter=0, vmax=35)
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(2.5, 2), width_ratios=[1,1], height_ratios=[0.3,1], sharex=True)
x = np.arange(1, len(avg_sorted_predictions)+1)
y = np.arange(0, len(avg_sorted_predictions[0])*bs, bs)
X, Y = np.meshgrid(x, y)
heatmap1 = ax[1,0].pcolormesh(Y, X, avg_sorted_predictions.T, shading='auto', cmap='hsv')
heatmap1.set_rasterized(True)
ax[1,0].set_xlabel('Pos. (cm)')
ax[1,0].set_xlim(0,tl)
ax[1,0].set_ylim(0,len(avg_sorted_predictions))
ax[1,0].invert_yaxis()
heatmap = ax[1,1].pcolormesh(Y, X, avg_sorted_errors.T, shading='auto', norm=norm, cmap='bwr')
heatmap.set_rasterized(True)
ax[1,1].set_xlabel('Pos. (cm)')
ax[1,1].set_xlim(0,tl)
ax[1,1].set_ylim(0,len(avg_sorted_errors))
ax[1,1].invert_yaxis()
ax[0,0].plot(y,y, color='black', linestyle='dashed')
ax[0,0].plot(y, circular_nanmean(avg_sorted_predictions[:len(trial_types[trial_types=='b'])], tl=tl, axis=0), color='tab:blue')
ax[0,0].plot(y, circular_nanmean(avg_sorted_predictions[len(trial_types[trial_types=='b']):], tl=tl, axis=0), color='tab:orange')
ax[0,1].plot(np.arange(0,200,2), b_error, color='tab:blue')
ax[0,1].plot(np.arange(0,200,2), nb_error, color='tab:orange')
ax[1,0].set_xlabel(f'Pos (cm)')
ax[1,1].set_xlabel(f'Pos (cm)')
ax[1,1].set_yticklabels([])
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.5, hspace=None)
plt.savefig(f'/Users/harryclark/Documents/figs/decoding/GC_M{mouse}D{day}_sorted.pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
'''# stops versus decoded position plot
trial_numbers = np.array(beh['trial_number'])
position = np.array(beh['P'])
trial_types = np.array(beh['trial_type'])
speed = np.array(beh['S'])
stop_mask = speed<3
fig, ax = plt.subplots(ncols=1, nrows=1, figsize=(2, 2))

sorted_cats = beh['trials'].groupby(by=['context','type','performance'])
sorted_trial_indices = []
sorted_trial_colors = []
for group, cat_indices in zip(sorted_cats.keys(), sorted_cats.values()):
    c = get_color_for_group(group)
    sorted_trial_colors.extend(np.repeat(c, len(cat_indices)).tolist())
    sorted_trial_indices.extend(cat_indices.tolist())
sorted_trial_colors = np.array(sorted_trial_colors)
sorted_trial_indices = np.array(sorted_trial_indices)

for i, sti in enumerate(sorted_trial_indices):
    tn_mask = trial_numbers==beh['trials'][sti]['number'].iloc[0]
    stops = position[(stop_mask & tn_mask)]
    decoded_pos = avg_predictions[sti]
    argmax_pos = np.argmax(decoded_pos>90)
    projected_stop = decoded_pos[argmax_pos]
    print(projected_stop)
    if len(stops)>0:
        error = projected_stop - stops[0]
    else:
        error = np.nan
    ax.plot([0, error], [i, i], color=sorted_trial_colors[i], alpha=0.5, linewidth=0.5)
plt.show()
'''

In [None]:
# decoding 
ng_tcs = {cluster_id: tcs[cluster_id] for cluster_id in non_grid_cells.cluster_id if cluster_id in tcs}

x_true_dt = dt_bins[:last_ephys_bin]
true_position = x_true_dt%tl
trial_numbers = (x_true_dt//tl)+beh['trials']['number'][0]
tns_to_decode_with = np.array(beh['trials']['number'])
tns_to_decode_with = tns_to_decode_with[tns_to_decode_with<=np.nanmax(trial_numbers)]
trial_types = np.array(beh['trials']['type'])
trial_types[np.argsort(trial_types)]

tns_to_decode = np.array(beh['trials']['number']) # decode all trials to visualise
tns_to_train = np.array(beh['trials']['number'][np.isin(beh['trials']['type'], np.array(['b','nb']))]) 
tns_to_decode = tns_to_decode[tns_to_decode<=np.nanmax(trial_numbers)] # handles last ephys trials
tns_to_train = tns_to_train[tns_to_train<=np.nanmax(trial_numbers)] # handles last ephys trials

predictions, errors = cross_validate_decoder(ng_tcs, true_position, trial_numbers, tns_to_decode, tns_to_train, tl, bs, train=0.9, n=10, verbose=False)
avg_predictions = circular_nanmean(predictions, tl, axis=2)

fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(0.8, 2), width_ratios=[1,0.05], sharey=True)
x = np.arange(1, len(avg_predictions)+1)
y = np.arange(0, len(avg_predictions[0])*bs, bs)
X, Y = np.meshgrid(x, y)
heatmap = ax[0].pcolormesh(Y, X, avg_predictions.T, shading='auto', cmap='hsv')
heatmap.set_rasterized(True)
ax[0].set_xlabel('Pos. (cm)')
ax[1].axis('off')
ax[1].scatter(np.ones(len(trial_colors)), 
                np.arange(0,len(trial_colors)), 
                c = trial_colors,
                marker='s')
ax[0].set_xlim(0,tl)
ax[0].set_ylim(0,len(avg_predictions))
ax[0].invert_yaxis()
fig.savefig(f'/Users/harryclark/Documents/figs/decoding/NG_M{mouse}D{day}.pdf', dpi=300, bbox_inches='tight')
plt.show()

sorted_predictions = predictions[sorted_trial_indices]
sorted_errors = errors[sorted_trial_indices]

avg_sorted_predictions = circular_nanmean(sorted_predictions, tl, axis=2)
avg_sorted_errors = np.nanmean(sorted_errors, axis=2)

b_error = np.arange(1,tl,bs) - circular_nanmean(avg_sorted_predictions[:len(trial_types[trial_types=='b'])], tl=tl, axis=0)
nb_error = np.arange(1,tl,bs) - circular_nanmean(avg_sorted_predictions[len(trial_types[trial_types=='b']):], tl=tl, axis=0)

plt.hist(b_error, color='tab:blue', bins=100,alpha=0.4)
plt.hist(nb_error, color='tab:orange', bins=100,alpha=0.4)
plt.title('errors before circular correction')
plt.show()
b_error[b_error>(tl*0.75)] = tl-b_error[b_error>(tl*0.75)]
b_error[b_error<(-tl*0.75)] = tl+b_error[b_error<(-tl*0.75)]
nb_error[nb_error>(tl*0.75)] = tl-nb_error[nb_error>(tl*0.75)]
nb_error[nb_error<(-tl*0.75)] = tl+nb_error[nb_error<(-tl*0.75)]

plt.hist(b_error, color='tab:blue', bins=100,alpha=0.4)
plt.hist(nb_error, color='tab:orange', bins=100,alpha=0.4)
plt.title('errors after circular correction')
plt.show()

norm = TwoSlopeNorm(vmin=-35,vcenter=0, vmax=35)
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(2.5, 2), width_ratios=[1,1], height_ratios=[0.3,1], sharex=True)
x = np.arange(1, len(avg_sorted_predictions)+1)
y = np.arange(0, len(avg_sorted_predictions[0])*bs, bs)
X, Y = np.meshgrid(x, y)
heatmap1 = ax[1,0].pcolormesh(Y, X, avg_sorted_predictions.T, shading='auto', cmap='hsv')
heatmap1.set_rasterized(True)
ax[1,0].set_xlabel('Pos. (cm)')
ax[1,0].set_xlim(0,tl)
ax[1,0].set_ylim(0,len(avg_sorted_predictions))
ax[1,0].invert_yaxis()
heatmap = ax[1,1].pcolormesh(Y, X, avg_sorted_errors.T, shading='auto', norm=norm, cmap='bwr')
heatmap.set_rasterized(True)
ax[1,1].set_xlabel('Pos. (cm)')
ax[1,1].set_xlim(0,tl)
ax[1,1].set_ylim(0,len(avg_sorted_errors))
ax[1,1].invert_yaxis()
ax[0,0].plot(y,y, color='black', linestyle='dashed')
ax[0,0].plot(y, circular_nanmean(avg_sorted_predictions[:len(trial_types[trial_types=='b'])], tl=tl, axis=0), color='tab:blue')
ax[0,0].plot(y, circular_nanmean(avg_sorted_predictions[len(trial_types[trial_types=='b']):], tl=tl, axis=0), color='tab:orange')
ax[0,1].plot(np.arange(0,200,2), b_error, color='tab:blue')
ax[0,1].plot(np.arange(0,200,2), nb_error, color='tab:orange')
ax[1,0].set_xlabel(f'Pos (cm)')
ax[1,1].set_xlabel(f'Pos (cm)')
ax[1,1].set_yticklabels([])
plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=0.5, hspace=None)
plt.savefig(f'/Users/harryclark/Documents/figs/decoding/NG_M{mouse}D{day}_sorted.pdf', dpi=300, bbox_inches='tight')
plt.show()




In [None]:
angles1 = angles[0].reshape(-1,int(L/bs))
angles2 = angles[1].reshape(-1,int(L/bs))
angles3 = angles[2].reshape(-1,int(L/bs))

for i, angles0 in  enumerate([angles1, angles2, angles3]):
    fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(0.8, 2), width_ratios=[1,0.05], sharey=True)

    x = np.arange(1, len(angles0)+1)
    y = np.arange(0, len(angles0[0])*bs, bs)
    X, Y = np.meshgrid(x, y)
    heatmap = ax[0].pcolormesh(Y, X, angles0.T, shading='auto', cmap='hsv')
    heatmap.set_rasterized(True)
    ax[0].set_xlabel('Pos. (cm)')
    ax[1].axis('off')
    ax[1].scatter(np.ones(len(trial_colors)), 
                  np.arange(0,len(trial_colors)), 
                  c = trial_colors,
                  marker='s')
    ax[0].set_xlim(0,tl)
    ax[0].set_ylim(0,len(angles0))
    ax[0].invert_yaxis()
    fig.savefig(f'/Users/harryclark/Documents/figs/toroidal/M{mouse}D{day}A{i}.pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
def average_correlation_per_matrix(matrix):
    corrs = np.zeros((len(matrix), len(matrix)))
    for i in range(len(matrix)):
        for j in range(len(matrix)):
            corrs[i,j] = stats.pearsonr(matrix[i], matrix[j])[0]
    return np.nanmean(corrs)

corr1 = average_correlation_per_matrix(angles1)
corr2 = average_correlation_per_matrix(angles2)
corr3 = average_correlation_per_matrix(angles3)

print("corr angles1 :\n", corr1)
print("corr angles2 :\n", corr2)
print("corr angles3 :\n", corr3)

best_angle1 = angles1
best_angle2 = angles2

if corr3 > corr1:
    best_angle1 = angles3
elif corr3 > corr2:
    best_angle2 = angles3


In [None]:
angles1_sorted = angles[0].reshape(-1,int(L/bs))[sorted_trial_indices]
angles2_sorted = angles[1].reshape(-1,int(L/bs))[sorted_trial_indices]
angles3_sorted = angles[2].reshape(-1,int(L/bs))[sorted_trial_indices]

for i, angles0_sorted in enumerate([angles1_sorted, angles2_sorted, angles3_sorted]):
    fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(0.8, 2), width_ratios=[1,0.05], sharey=True)

    x = np.arange(1, len(angles0_sorted)+1)
    y = np.arange(0, len(angles0_sorted[0])*bs, bs)
    X, Y = np.meshgrid(x, y)
    ax[0].pcolormesh(Y, X, angles0_sorted.T, shading='auto', cmap='hsv')
    ax[0].set_xlabel('Pos. (cm)')
    ax[1].axis('off')
    ax[1].scatter(np.ones(len(sorted_trial_colors)), 
                  np.arange(0,len(sorted_trial_colors)), 
                  c = sorted_trial_colors,
                  marker='s')
    ax[0].set_xlim(0,tl)
    ax[0].set_ylim(0,len(angles0_sorted))
    ax[0].invert_yaxis()
    fig.savefig(f'/Users/harryclark/Documents/figs/toroidal/M{mouse}D{day}A{i}_sorted.pdf', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
fig, ax = plt.subplots(ncols=3, nrows=1, figsize=(6,2), sharey=True, sharex=True)
b_group = ('rz1', 'b', 'hit')
nb_group = ('rz1', 'nb', 'hit')

b_mask = np.all(trial_groups == b_group, axis=1)
nb_mask = np.all(trial_groups == nb_group, axis=1)
hit_mask = np.logical_or(b_mask, nb_mask)

heatmap, xedges, yedges = np.histogram2d(best_angle1[hit_mask].flatten(), best_angle2[hit_mask].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
heatmap1, xedges, yedges = np.histogram2d(best_angle1[b_mask].flatten(), best_angle2[b_mask].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
heatmap2, xedges, yedges = np.histogram2d(best_angle1[nb_mask].flatten(), best_angle2[nb_mask].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)

heatmap= gaussian_filter_nan(heatmap, sigma=(3,3))
heatmap1= gaussian_filter_nan(heatmap1, sigma=(3,3))
heatmap2= gaussian_filter_nan(heatmap2, sigma=(3,3))

if np.any(np.isnan(heatmap)):
    heatmap[:] = 1
    heatmap1[:] = 1
    heatmap2[:] = 1

ax[0].pcolormesh(xedges, yedges, heatmap.T, cmap='jet')
ax[1].pcolormesh(xedges, yedges, heatmap1.T, cmap='jet')
ax[2].pcolormesh(xedges, yedges, heatmap2.T, cmap='jet')




In [None]:
time = np.append(np.diff(np.array(beh['S'].index)), 0)
position = np.array(beh['P'])
travel = np.array(beh['travel'])
speed = np.array(beh['S'])
time_spent_in_each_bin = angles0.copy()

dt_npy = np.array(beh['travel']-((tns[0]-1)*tl))
n_bins = int(int(((np.ceil(np.nanmax(dt_npy))//tl)+1)*tl)/bs)
max_bound = int(((np.ceil(np.nanmax(dt_npy))//tl)+1)*tl)
min_bound = 0
bins_visited, _ = np.histogram(position, bins=int(tl/bs), range=[0,tl])
time_spent_in_bins, _ = np.histogram(position, weights=time, bins=int(tl/bs), range=[0,tl])
time_spent_in_bins_per_trial = time_spent_in_bins/int(np.max(dt_npy)/tl)
speed_in_bin_per_trial = bs/time_spent_in_bins_per_trial
speed_in_bin_per_trial = gaussian_filter1d(speed_in_bin_per_trial, sigma=2)
time_spent_in_bins_per_trial = bs/speed_in_bin_per_trial
time_at_bin_centre = np.cumsum(time_spent_in_bins_per_trial)

In [None]:
# interpolate 2d
def interpolate_along_time_axis(time_at_bin_centre, heatmap):
    import numpy as np
    from scipy.interpolate import RegularGridInterpolator

    # Example data
    # Replace these with your actual matrices and time values
    time_values = time_at_bin_centre

    # Define the new time grid (every 100 milliseconds)
    new_time_grid = np.arange(time_values.min(), time_values.max(), 0.1)

    # Define the spatial grid
    x = np.arange(heatmap.shape[1])  # 100
    y = np.arange(heatmap.shape[2])  # 100

    # Create the interpolator
    interpolator = RegularGridInterpolator((time_values, x, y), heatmap)

    # Generate the new grid points for interpolation
    T, X, Y = np.meshgrid(new_time_grid, x, y, indexing='ij')
    points = np.stack([T.ravel(), X.ravel(), Y.ravel()], axis=-1)

    # Perform the interpolation
    interpolated_values = interpolator(points)

    # Reshape to (num_new_times, x, y)
    interpolated_matrices = interpolated_values.reshape(len(new_time_grid), len(x), len(y))

    return interpolated_matrices

def interpolate_bin_centres_in_time(time_at_bin_centre, bin_centres):
    from scipy.interpolate import interp1d

    # Example input data
    # Replace these with your actual time and position values
    time_values = time_at_bin_centre # Time at bin centers (in seconds)
    positions = bin_centres       # Position at each bin center

    # Define the new time grid (every 100 milliseconds)
    new_time_grid = np.arange(time_values.min(), time_values.max(), 0.1)

    # Create the interpolator
    interpolator = interp1d(time_values, positions, kind='linear')

    # Interpolate positions at the new time points
    interpolated_positions = interpolator(new_time_grid)
    return interpolated_positions

In [None]:
def tile_smooth_crop_heatmap(heatmap, sigma=(3,3)):
    # handles circulurity
    rows, cols = heatmap.shape
    tiled_heatmap = np.tile(heatmap, (3, 3))
    tiled_heatmap= gaussian_filter_nan(tiled_heatmap, sigma=sigma)
    cropped_heatmap = tiled_heatmap[rows:2*rows, cols:2*cols]
    return cropped_heatmap


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
norm = TwoSlopeNorm(vmin=-0.2,vcenter=0,vmax=0.2)

# Generate sample heatmap data (replace this with your actual data)
heatmap_data_all = []
for i in range(0,100):
    heatmap, xedges, yedges = np.histogram2d(best_angle1[hit_mask][:,i*1:i*1+1].flatten(), best_angle2[hit_mask][:,i*1:i*1+1].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap= tile_smooth_crop_heatmap(heatmap, sigma=(3,3))
    heatmap_data_all.append(heatmap)
heatmap_data_all=np.array(heatmap_data_all)

heatmap_data_b = []
for i in range(0,100):
    heatmap, xedges, yedges = np.histogram2d(best_angle1[b_mask][:,i*1:i*1+1].flatten(), best_angle2[b_mask][:,i*1:i*1+1].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap= tile_smooth_crop_heatmap(heatmap, sigma=(3,3))
    heatmap_data_b.append(heatmap)
heatmap_data_b=np.array(heatmap_data_b)

heatmap_data_nb = []
for i in range(0,100):
    heatmap, xedges, yedges = np.histogram2d(best_angle1[nb_mask][:,i*1:i*1+1].flatten(), best_angle2[nb_mask][:,i*1:i*1+1].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap= tile_smooth_crop_heatmap(heatmap, sigma=(3,3))
    heatmap_data_nb.append(heatmap)
heatmap_data_nb=np.array(heatmap_data_nb)

heatmap_data_all = interpolate_along_time_axis(time_at_bin_centre, heatmap_data_all)
heatmap_data_b = interpolate_along_time_axis(time_at_bin_centre, heatmap_data_b)
heatmap_data_nb = interpolate_along_time_axis(time_at_bin_centre, heatmap_data_nb)
bin_centres = interpolate_bin_centres_in_time(time_at_bin_centre, np.arange(0,tl,bs))

# Create a figure and axis
fig, axs = plt.subplots(ncols=4, nrows=2, figsize=(6,2.5), sharey=True, sharex=True, height_ratios=[0.1,1])
ax0 = axs[1,0].imshow(heatmap_data_all[0], cmap='jet', vmin=0, vmax=0.4)
ax1 = axs[1,1].imshow(heatmap_data_b[0], cmap='jet', vmin=0, vmax=0.4)
ax2 = axs[1,2].imshow(heatmap_data_nb[0], cmap='jet', vmin=0, vmax=0.4)
ax3 = axs[1,3].imshow(heatmap_data_b[0]-heatmap_data_nb[0], cmap='bwr', norm=norm)

axs[1,0].set_title('all hits')
axs[1,1].set_title('cued hits')
axs[1,2].set_title('uncued hits')
axs[1,3].set_title('delta')
axs[1,3].set_ylim(0,50)

scat = axs[0,0].scatter(0, -50, color='black')
text = axs[0,1].text(0,0,f'{0}')
scalar=6
axs[0,0].axvspan(0,200/scalar,
        alpha=0.2,
        zorder=-10,
        edgecolor='none',
        facecolor='grey',
    )
#axs[0,0].axvspan(90/scalar,110/scalar,
#        alpha=1,
#        zorder=-10,
#        edgecolor='none',
#        facecolor='lightgreen',
#    )

axs[0,0].axvline(90/scalar, color='black', linestyle='dotted', linewidth=0.5)
axs[0,0].axvline(110/scalar, color='black', linestyle='dotted', linewidth=0.5)


axs[0,0].axvspan(0,30/scalar,
        alpha=0.5,
        zorder=-10,
        edgecolor='none',
        facecolor='darkgrey',
    )
axs[0,0].axvspan(170/scalar,200/scalar,
        alpha=0.5,
        zorder=-10,
        edgecolor='none',
        facecolor='darkgrey',
    )

x_data = bin_centres/scalar
y_data = np.ones(len(x_data))

for ax in axs.flatten():
    ax.axis('off')

# Update function for animation
def update(frame):
    text.set_text(f'Trial time: {np.round(frame*frame_interval/1000, decimals=1)} seconds')
    scat.set_offsets(np.c_[x_data[frame], y_data[frame]])
    ax0.set_array(heatmap_data_all[frame])
    ax1.set_array(heatmap_data_b[frame])
    ax2.set_array(heatmap_data_nb[frame])
    ax3.set_array(heatmap_data_b[frame]-heatmap_data_nb[frame])

    return ax,

# Create animation
frame_interval = 100
ani = animation.FuncAnimation(fig, update, frames=len(heatmap_data_b), blit=True, interval=frame_interval, repeat_delay=0)
plt.tight_layout()
ani.save(f'/Users/harryclark/Downloads/M{mouse}D{day}_activity_bump2.gif', writer='pillow')
# Display the animation
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


# Example heatmap data (replace this with your own 2D NumPy array)
heatmap = heatmap_data_nb[0]

# Torus parameters
R = 0.5  # Major radius
r = 0.2  # Minor radius

# Create a meshgrid for the torus
u = np.linspace(0, 2 * np.pi, heatmap.shape[0])
v = np.linspace(0, 2 * np.pi, heatmap.shape[1])
u, v = np.meshgrid(u, v)

# Parametric equations for the torus
x = (R + r * np.cos(v)) * np.cos(u)
y = (R + r * np.cos(v)) * np.sin(u)
z = r * np.sin(v)

for elev, azim, view_label in zip([90,0,45], [0,0,45], ['top', 'side', 'iso']):
    # Create the figure and 3D axis
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(111, projection='3d')
    ax.set_axis_off()

    ax.view_init(elev=elev, azim=azim)

    # Map the heatmap values to the torus surface as colors
    surf = ax.plot_surface(x, y, z, facecolors=plt.cm.jet(heatmap), 
                    rstride=1, cstride=1, antialiased=True,alpha=.4, edgecolor='none', vmin=0, vmax=0.2)
    ax.set_zlim3d(-0.5, 0.5)

    # Update function for animation
    def update(frame):
        global surf
        surf.remove()
        surf = ax.plot_surface(x, y, z, facecolors=plt.cm.jet(heatmap_data_nb[frame]*2), 
                    rstride=1, cstride=1, antialiased=True, alpha=.7, edgecolor='none', vmin=0, vmax=0.2)
        return surf,

    # Create animation
    frame_interval = 100
    ani = animation.FuncAnimation(fig, update, frames=len(heatmap_data_b), blit=True, interval=frame_interval, repeat_delay=0)
    plt.tight_layout()
    ani.save(f'/Users/harryclark/Downloads/M{mouse}D{day}_activity_bump_torus_{view_label}.gif', writer='pillow')
    # Display the animation
    plt.show()


In [None]:

fig, ax = plt.subplots(ncols=7, nrows=7, figsize=(10,10), sharey=True, sharex=True)
norm = TwoSlopeNorm(vmin=-0.2,vcenter=0,vmax=0.2)
for i, group in enumerate([('rz1', 'b', 'hit'), ('rz1', 'b', 'try'), ('rz1', 'b', 'run'), ('rz1', 'nb', 'hit'), ('rz1', 'nb', 'try'), ('rz1', 'nb', 'run')]):
    t_mask = np.all(trial_groups == group, axis=1)
    
    c = get_color_for_group(group)

    heatmap, xedges, yedges = np.histogram2d(best_angle1[t_mask].flatten(), best_angle2[t_mask].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap1, xedges, yedges = np.histogram2d(best_angle1[t_mask][:,0:20].flatten(), best_angle2[t_mask][:,0:20].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap2, xedges, yedges = np.histogram2d(best_angle1[t_mask][:,20:40].flatten(), best_angle2[t_mask][:,20:40].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap3, xedges, yedges = np.histogram2d(best_angle1[t_mask][:,40:60].flatten(), best_angle2[t_mask][:,40:60].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap4, xedges, yedges = np.histogram2d(best_angle1[t_mask][:,60:80].flatten(), best_angle2[t_mask][:,60:80].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap5, xedges, yedges = np.histogram2d(best_angle1[t_mask][:,80:100].flatten(), best_angle2[t_mask][:,80:100].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)

    heatmap=  tile_smooth_crop_heatmap(heatmap, sigma=(3,3))
    heatmap1= tile_smooth_crop_heatmap(heatmap1, sigma=(3,3))
    heatmap2= tile_smooth_crop_heatmap(heatmap2, sigma=(3,3))
    heatmap3= tile_smooth_crop_heatmap(heatmap3, sigma=(3,3))
    heatmap4= tile_smooth_crop_heatmap(heatmap4, sigma=(3,3))
    heatmap5= tile_smooth_crop_heatmap(heatmap5, sigma=(3,3))

    if np.any(np.isnan(heatmap)):
        heatmap[:] = 1
        heatmap1[:] = 1
        heatmap2[:] = 1
        heatmap3[:] = 1
        heatmap4[:] = 1
        heatmap5[:] = 1

    ax[0,i].pcolormesh(xedges, yedges, heatmap.T, cmap='jet',  vmin=0, vmax=0.4)
    ax[1,i].pcolormesh(xedges, yedges, heatmap1.T, cmap='jet', vmin=0, vmax=0.4)
    ax[2,i].pcolormesh(xedges, yedges, heatmap2.T, cmap='jet', vmin=0, vmax=0.4)
    ax[3,i].pcolormesh(xedges, yedges, heatmap3.T, cmap='jet', vmin=0, vmax=0.4)
    ax[4,i].pcolormesh(xedges, yedges, heatmap4.T, cmap='jet', vmin=0, vmax=0.4)
    ax[5,i].pcolormesh(xedges, yedges, heatmap5.T, cmap='jet', vmin=0, vmax=0.4)

h_b_mask = np.all(trial_groups == ('rz1', 'b', 'hit'), axis=1)
h_nb_mask = np.all(trial_groups == ('rz1', 'nb', 'hit'), axis=1)
segments = [0,20,40,60,80,100,120]
for i, segment in enumerate(segments[:-1]):
    heatmap_h_b, xedges, yedges = np.histogram2d(best_angle1[h_b_mask][:, segments[i]:segments[i+1]].flatten(), best_angle2[h_b_mask][:, segments[i]:segments[i+1]].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap_h_nb, xedges, yedges = np.histogram2d(best_angle1[h_nb_mask][:, segments[i]:segments[i+1]].flatten(), best_angle2[h_nb_mask][:, segments[i]:segments[i+1]].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
    heatmap_h_b= tile_smooth_crop_heatmap(heatmap_h_b, sigma=(3,3))
    heatmap_h_nb= tile_smooth_crop_heatmap(heatmap_h_nb, sigma=(3,3))
    diff_h= heatmap_h_b-heatmap_h_nb
    ax[i+1,6].pcolormesh(xedges, yedges, diff_h.T, cmap='bwr', norm=norm)
heatmap_h_b, xedges, yedges = np.histogram2d(best_angle1[h_b_mask].flatten(), best_angle2[h_b_mask].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
heatmap_h_nb, xedges, yedges = np.histogram2d(best_angle1[h_nb_mask].flatten(), best_angle2[h_nb_mask].flatten(), bins=50, range=[[-np.pi, np.pi], [-np.pi, np.pi]], density=True)
heatmap_h_b= tile_smooth_crop_heatmap(heatmap_h_b, sigma=(3,3))
heatmap_h_nb= tile_smooth_crop_heatmap(heatmap_h_nb, sigma=(3,3))

diff_h= heatmap_h_b-heatmap_h_nb
ax[0,6].pcolormesh(xedges, yedges, diff_h.T, cmap='bwr', norm=norm)
plt.show()

In [None]:
for cell in cluster_ids_by_group[0]:
    tc = nap.compute_1d_tuning_curves(nap.TsGroup([clusters[cell]]), 
                                      dt, 
                                      nb_bins=n_bins, 
                                      minmax=[min_bound, max_bound],
                                      ep=beh["moving"])[0]
    mask = np.isnan(tc)
    tc = gaussian_filter(np.nan_to_num(tc).astype(np.float64), sigma=2.5)
    tc = zscore(tc)
    tc = tc[:last_ephys_bin] # only want bins with ephys data in it

    fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(0.8, 2), width_ratios=[1,0.05], sharey=True)
    plot_firing_rate_map(ax[0], tc, bs=bs, tl=tl,p=95, sort_indices=None)
    ax[1].axis('off')
    ax[1].scatter(np.ones(len(trial_colors)), 
                  np.arange(0,len(trial_colors)), 
                  c = trial_colors,
                  marker='s')
    ax[0].set_xlabel('Pos (cm)')

    fig.savefig(f'/Users/harryclark/Documents/figs/rate_map_examples/M{mouse}D{day}GC{cell}.pdf', dpi=300, bbox_inches='tight')
    plt.show()

In [None]:
for cell in cluster_ids_by_group[-2]:
    tc = nap.compute_1d_tuning_curves(nap.TsGroup([clusters[cell]]), 
                                      dt, 
                                      nb_bins=n_bins, 
                                      minmax=[min_bound, max_bound],
                                      ep=beh["moving"])[0]
    mask = np.isnan(tc)
    tc = gaussian_filter(np.nan_to_num(tc).astype(np.float64), sigma=2.5)
    tc = zscore(tc)
    tc = tc[:last_ephys_bin] # only want bins with ephys data in it

    fig, ax = plt.subplots(ncols=2, nrows=1, figsize=(0.8, 2), width_ratios=[1,0.05], sharey=True)
    plot_firing_rate_map(ax[0], tc, bs=bs, tl=tl,p=95, sort_indices=None)
    ax[1].axis('off')
    ax[1].scatter(np.ones(len(trial_colors)), 
                  np.arange(0,len(trial_colors)), 
                  c = trial_colors,
                  marker='s')
    ax[0].set_xlabel('Pos (cm)')

    fig.savefig(f'/Users/harryclark/Documents/figs/rate_map_examples/M{mouse}D{day}NGS{cell}.pdf', dpi=300, bbox_inches='tight')
    plt.show()
