In [None]:
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from pathlib import Path
import seaborn as sns
from sklearn.metrics import confusion_matrix
import pynapple as nap
from spatial_manifolds.toroidal import *
from spatial_manifolds.behaviour_plots import *
from matplotlib.colors import TwoSlopeNorm
from scipy.spatial import distance
from spatial_manifolds.circular_decoder import circular_decoder, cross_validate_decoder, cross_validate_decoder_time, circular_nanmean

from spatial_manifolds.data.curation import curate_clusters
from scipy.stats import zscore
from spatial_manifolds.util import gaussian_filter_nan
from spatial_manifolds.predictive_grid import compute_travel_projected, wrap_list
from spatial_manifolds.behaviour_plots import *
from spatial_manifolds.behaviour_plots import trial_cat_priority
from spatial_manifolds.detect_grids import cell_classification_of1, HDBSCAN_grid_modules, plot_grid_modules_rate_maps, compute_vr_tcs
import numpy as np
import matplotlib.pyplot as plt
import hdbscan
from sklearn.preprocessing import StandardScaler
import matplotlib.cm as cm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from IPython.display import HTML

import warnings
warnings.filterwarnings('ignore')
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:

def compute_correlation_matrix(spectrogram1, spectrogram2, window_size=1):
    """
    Compute a correlation matrix across sliding windows between two spectrograms.
    
    Parameters:
    spectrogram1 (2D np.array): First spectrogram
    spectrogram2 (2D np.array): Second spectrogram
    window_size (int): Size of the sliding window
    
    Returns:
    2D np.array: Correlation matrix
    """
    # Ensure the spectrograms have the same shape
    assert spectrogram1.shape == spectrogram2.shape, "Spectrograms must have the same shape"
    
    num_windows = spectrogram1.shape[1] - window_size + 1
    correlation_matrix = np.zeros((num_windows, num_windows))
    
    for i in range(num_windows):
        window1 = spectrogram1[:, i:i+window_size]
        for j in range(num_windows):
            window2 = spectrogram2[:, j:j+window_size]
            correlation_matrix[i, j] = np.corrcoef(window1.flatten(), window2.flatten())[0, 1]
    
    return correlation_matrix


In [None]:
'''tcs, _, _ = compute_vr_tcs(mouse, day)
for cluster_id_group in cluster_ids_by_group:

    for index in cluster_id_group:
        print(clusters_OF.coord_probe_y[index])

    ncols = 10
    nrows = int(np.ceil(len(cluster_id_group)/ncols))
    fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(10, 2*nrows), squeeze=False)
    counter = 0
    for j in range(nrows):
        for i in range(ncols):
            if counter<len(cluster_id_group):
                index = cluster_id_group[counter]
                plot_firing_rate_map(ax[j, i], 
                                    zscore(tcs[index]),
                                    bs=2,
                                    tl=200,
                                    p=95)
            ax[j, i].set_title(f'{index}')
            ax[j, i].set_xticks([])
            ax[j, i].set_yticks([])
            ax[j, i].xaxis.set_tick_params(labelbottom=False)
            ax[j, i].yaxis.set_tick_params(labelleft=False)
            counter+=1
    plt.tight_layout()
    plt.show()'''

In [None]:

mice = [25, 21, 21, 22, 22, 25, 26, 26, 26, 26, 27, 28]
days = [24, 19, 22, 40, 41, 23, 12, 13, 16, 18, 26, 11]
mice = [25, 26, 27]
days = [24, 18, 26,]
for mouse, day in zip(mice,days):
    gcs, ngs, ns, sc, ngs_ns, all = cell_classification_of1(mouse, day, percentile_threshold=95) # subset 
    grid_module_ids, grid_module_cluster_ids = HDBSCAN_grid_modules(gcs, all, mouse, day) # create grid modules using HDBSCAN    
    
    plot_grid_modules_rate_maps(gcs, grid_module_ids, grid_module_cluster_ids, mouse, day)

    # we now have cluster ids classified into modules, non grid spatial cells and non spatial cells 
    # as defined by activity in the open field
    grid_module_cluster_ids = sorted(grid_module_cluster_ids, key=len, reverse=True) 

    cluster_ids_by_group = []
    cluster_ids_by_group.extend(grid_module_cluster_ids)
    cluster_ids_by_group.append(ngs.cluster_id.values.tolist())
    cluster_ids_by_group.append(ns.cluster_id.values.tolist())
    cluster_ids_by_group.append(gcs.cluster_id.values.tolist())
    cluster_ids_by_group.append(sc.cluster_id.values.tolist())

    for cluster_ids in cluster_ids_by_group:
        print(cluster_ids)
        
        
    active_projects_path = Path("/Volumes/cmvm/sbms/groups/CDBS_SIDB_storage/NolanLab/ActiveProjects/")
    anatomy_path = active_projects_path / "Chris/Cohort12/derivatives/labels/anatomy/cluster_annotations.csv"
    cluster_locations = pd.read_csv(anatomy_path)
    session='OF1'
    of1_folder = f'/Users/harryclark/Downloads/COHORT12_nwb/M{mouse}/D{day:02}/{session}/'
    spikes_path = of1_folder + f"sub-{mouse}_day-{day:02}_ses-{session}_srt-kilosort4_clusters.npz"
    clusters_OF = nap.load_file(spikes_path)

    tcs, _, _ = compute_vr_tcs(mouse, day)
    spects = []
    for cluster_ids_to_use in cluster_ids_by_group:
        tcs_to_use = {cluster_id: tcs[cluster_id] for cluster_id in cluster_ids_to_use if cluster_id in tcs}
        tl=200
        bs=2
        N = len(tcs_to_use)
        zmaps = np.array(list(tcs_to_use.values()))
        results = spectral_analysis(tcs_to_use, tl=tl, bs=bs)
        f_modules =              results[0]
        phi_modules =            results[1]
        grid_cell_idxs_modules = results[2]
        spectrograms =           results[3]
        trial_starts =           results[6]
        L = tl

        # Module spectrograms
        plt.figure(figsize=(3,3))
        S = spectrograms[grid_cell_idxs_modules[0]].mean(0)
        plt.imshow(S,origin='lower',aspect='auto',vmax=0.25,cmap='magma')
        plt.yticks([0, len(S)/2, len(S)], [0, 1, 2])
        plt.ylabel(f'Frequency (m-1)')
        plt.xlabel('Trials')
        plt.show()
        spects.append(S)

    tcs, _, _ = compute_vr_tcs(mouse, day)
    for cluster_id_group in cluster_ids_by_group:
        ncols = 10
        nrows = int(np.ceil(len(cluster_id_group)/ncols))
        fig, ax = plt.subplots(ncols=ncols, nrows=nrows, figsize=(10, 2*nrows), squeeze=False)
        counter = 0
        for j in range(nrows):
            for i in range(ncols):
                if counter<len(cluster_id_group):
                    index = cluster_id_group[counter]
                    plot_firing_rate_map(ax[j, i], 
                                        zscore(tcs[index]),
                                        bs=2,
                                        tl=200,
                                        p=95)
                ax[j, i].set_title(f'{index}')
                ax[j, i].set_xticks([])
                ax[j, i].set_yticks([])
                ax[j, i].xaxis.set_tick_params(labelbottom=False)
                ax[j, i].yaxis.set_tick_params(labelleft=False)
                counter+=1
        plt.tight_layout()
        plt.show()

    # make a correlation matrix across windows in spects [-1] grids and [-3] non grid spatial
    gc_ngs_correlation = compute_correlation_matrix(spects[-2], spects[-4], window_size=1)
    plt.figure(figsize=(3,3))
    plt.title('grids vs nongrid spatial')
    plt.imshow(gc_ngs_correlation,aspect='auto',cmap='viridis')
    plt.show()

    gc_ngs_correlation = compute_correlation_matrix(spects[-2], spects[-1], window_size=1)
    plt.figure(figsize=(3,3))
    plt.title('grids vs speed')
    plt.imshow(gc_ngs_correlation,aspect='auto',cmap='viridis')
    plt.show()

    gc_ngs_correlation = compute_correlation_matrix(spects[-1], spects[-4], window_size=1)
    plt.figure(figsize=(3,3))
    plt.title('speed vs nongrid spatial')
    plt.imshow(gc_ngs_correlation,aspect='auto',cmap='viridis')
    plt.show()