In [1]:
import spikeinterface.full as si
from probeinterface.plotting import plot_probe, plot_probegroup
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import gridspec
from scipy import stats
import os
import matplotlib.ticker as ticker
from P2_PostProcess.VirtualReality.spatial_information import calculate_spatial_information

In [2]:
def add_trial_statistics(spike_data_vr, processed_position_data):
    trial_numbers = np.asarray(processed_position_data["trial_number"])
    trial_types = np.asarray(processed_position_data["trial_type"])
    hit_miss_tries = np.asarray(processed_position_data["hit_miss_try"])

    trial_numbers_list = []
    trial_types_list = []
    hit_miss_tries_list = []

    for i in range(len(spike_data_vr)):
        trial_numbers_list.append(trial_numbers)
        trial_types_list.append(trial_types)
        hit_miss_tries_list.append(hit_miss_tries)

    spike_data_vr["trial_trial_numbers"] = trial_numbers_list
    spike_data_vr["trial_trial_types"] = trial_types_list
    spike_data_vr["trial_hit_miss_try"] = hit_miss_tries_list
    return spike_data_vr

In [3]:

def plot_temporal_autocorrelogram_fr(spike_data, save_path=None, track_length=200, suffix=""):

    for cluster_index, cluster_id in enumerate(spike_data.cluster_id):
        cluster_spike_data = spike_data[spike_data["cluster_id"] == cluster_id]
        firing_times_cluster = np.array(cluster_spike_data["firing_times_vr"].iloc[0])

        if len(firing_times_cluster)>1:
            firing_rate_list = cluster_spike_data['fr_time_binned_smoothed'].iloc[0]
            firing_rates = []
            for firing in firing_rate_list:
                firing_rates.extend(firing)
            fr = np.array(firing_rates)
            fr[np.isnan(fr)] = 0; 
            fr[np.isinf(fr)] = 0
            autocorr_window_size = track_length*10
            lags = np.arange(0, 50, 1)
            autocorrelogram = []
            for i in range(len(lags)):
                fr_lagged = fr[i:]
                corr = stats.pearsonr(fr_lagged, fr[:len(fr_lagged)])[0]
                autocorrelogram.append(corr)
            autocorrelogram= np.array(autocorrelogram)
            fig = plt.figure(figsize=(5,2.5))
            ax = fig.add_subplot(1, 1, 1)  # specify (nrows, ncols, axnum)
            ax.axhline(y=0, color="black", linewidth=2,linestyle="dashed")
            ax.plot(lags, autocorrelogram, color="black", linewidth=3)
            plt.ylabel('Autocorr', fontsize=25, labelpad = 10)
            plt.xlabel('Lag ()', fontsize=25, labelpad = 10)
            ax.set_xlim(left=0)
            ax.yaxis.set_ticks_position('left')
            ax.xaxis.set_ticks_position('bottom')
            ax.set_ylim([np.floor(min(autocorrelogram[5:])*10)/10,np.ceil(max(autocorrelogram[5:])*10)/10])
            if np.floor(min(autocorrelogram[5:])*10)/10 < 0:
                ax.set_yticks([np.floor(min(autocorrelogram[5:])*10)/10, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
            else:
                ax.set_yticks([-0.1, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
            tick_spacing = track_length
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.set_title(f'cluster_id: {cluster_id}')
            plt.xticks(fontsize=20)
            plt.yticks(fontsize=20)
            fig.tight_layout(pad=2.0)
            plt.subplots_adjust(hspace = .35, wspace = .35,  bottom = 0.2, left = 0.3, right = 0.87, top = 0.92)
            plt.show()
            plt.close()

In [4]:

def plot_spatial_autocorrelogram_fr(spike_data, save_path=None, track_length=200, suffix=""):

    for cluster_index, cluster_id in enumerate(spike_data.cluster_id):
        cluster_spike_data = spike_data[spike_data["cluster_id"] == cluster_id]
        firing_rates = np.array(cluster_spike_data['fr_binned_in_space_smoothed'].iloc[0])
        firing_times_cluster = np.array(cluster_spike_data["firing_times_vr"].iloc[0])

        if len(firing_times_cluster)>1:
            fr = firing_rates.flatten()
            fr[np.isnan(fr)] = 0; 
            fr[np.isinf(fr)] = 0
            autocorr_window_size = track_length*10
            lags = np.arange(0, autocorr_window_size, 1) # were looking at 10 timesteps back and 10 forward
            autocorrelogram = []
            for i in range(len(lags)):
                fr_lagged = fr[i:]
                corr = stats.pearsonr(fr_lagged, fr[:len(fr_lagged)])[0]
                autocorrelogram.append(corr)
            autocorrelogram= np.array(autocorrelogram)
            fig = plt.figure(figsize=(5,2.5))
            ax = fig.add_subplot(1, 1, 1)  # specify (nrows, ncols, axnum)
            for f in range(1,11):
                ax.axvline(x=track_length*f, color="gray", linewidth=2,linestyle="solid", alpha=0.5)
            ax.axhline(y=0, color="black", linewidth=2,linestyle="dashed")
            ax.plot(lags, autocorrelogram, color="black", linewidth=3)
            plt.ylabel('Spatial Autocorr', fontsize=25, labelpad = 10)
            plt.xlabel('Lag (cm)', fontsize=25, labelpad = 10)
            plt.xlim(0,(track_length*4)+3)
            ax.yaxis.set_ticks_position('left')
            ax.xaxis.set_ticks_position('bottom')
            ax.set_ylim([np.floor(min(autocorrelogram[5:])*10)/10,np.ceil(max(autocorrelogram[5:])*10)/10])
            if np.floor(min(autocorrelogram[5:])*10)/10 < 0:
                ax.set_yticks([np.floor(min(autocorrelogram[5:])*10)/10, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
            else:
                ax.set_yticks([-0.1, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
            tick_spacing = track_length
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.set_title(f'cluster_id: {cluster_id}')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
            plt.xticks(fontsize=20)
            plt.yticks(fontsize=20)
            fig.tight_layout(pad=2.0)
            plt.subplots_adjust(hspace = .35, wspace = .35,  bottom = 0.2, left = 0.3, right = 0.87, top = 0.92)
            plt.show()
            #plt.savefig(save_path + '/spatial_autocorrelogram_' + spike_data.session_id.iloc[cluster_index] + '_' + str(int(cluster_id)) + suffix + '.png', dpi=200)
            plt.close()




In [5]:

def plot_spatial_autocorrelogram_fr(spike_data, save_path=None, track_length=200, suffix=""):

    for cluster_index, cluster_id in enumerate(spike_data.cluster_id):
        cluster_spike_data = spike_data[spike_data["cluster_id"] == cluster_id]
        firing_rates = np.array(cluster_spike_data['fr_binned_in_space_smoothed'].iloc[0])
        firing_times_cluster = np.array(cluster_spike_data["firing_times_vr"].iloc[0])

        if len(firing_times_cluster)>1:
            fr = firing_rates.flatten()
            fr[np.isnan(fr)] = 0; 
            fr[np.isinf(fr)] = 0
            autocorr_window_size = track_length*10
            lags = np.arange(0, autocorr_window_size, 1) # were looking at 10 timesteps back and 10 forward
            autocorrelogram = []
            for i in range(len(lags)):
                fr_lagged = fr[i:]
                corr = stats.pearsonr(fr_lagged, fr[:len(fr_lagged)])[0]
                autocorrelogram.append(corr)
            autocorrelogram= np.array(autocorrelogram)
            fig = plt.figure(figsize=(5,2.5))
            ax = fig.add_subplot(1, 1, 1)  # specify (nrows, ncols, axnum)
            for f in range(1,11):
                ax.axvline(x=track_length*f, color="gray", linewidth=2,linestyle="solid", alpha=0.5)
            ax.axhline(y=0, color="black", linewidth=2,linestyle="dashed")
            ax.plot(lags, autocorrelogram, color="black", linewidth=3)
            plt.ylabel('Spatial Autocorr', fontsize=25, labelpad = 10)
            plt.xlabel('Lag (cm)', fontsize=25, labelpad = 10)
            plt.xlim(0,(track_length*4)+3)
            ax.yaxis.set_ticks_position('left')
            ax.xaxis.set_ticks_position('bottom')
            ax.set_ylim([np.floor(min(autocorrelogram[5:])*10)/10,np.ceil(max(autocorrelogram[5:])*10)/10])
            if np.floor(min(autocorrelogram[5:])*10)/10 < 0:
                ax.set_yticks([np.floor(min(autocorrelogram[5:])*10)/10, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
            else:
                ax.set_yticks([-0.1, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
            tick_spacing = track_length
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.set_title(f'cluster_id: {cluster_id}')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
            plt.xticks(fontsize=20)
            plt.yticks(fontsize=20)
            fig.tight_layout(pad=2.0)
            plt.subplots_adjust(hspace = .35, wspace = .35,  bottom = 0.2, left = 0.3, right = 0.87, top = 0.92)
            plt.show()
            #plt.savefig(save_path + '/spatial_autocorrelogram_' + spike_data.session_id.iloc[cluster_index] + '_' + str(int(cluster_id)) + suffix + '.png', dpi=200)
            plt.close()


def plot_spatial_autocorrelogram_fr_subset2(spike_data, track_length=200, tt=None, hmt=None, c=None): 
    for cluster_index, cluster_id in enumerate(spike_data.cluster_id):
        cluster_spike_data = spike_data[spike_data["cluster_id"] == cluster_id]
        firing_rates = np.array(cluster_spike_data['fr_binned_in_space_smoothed'].iloc[0])
        firing_times_cluster = np.array(cluster_spike_data["firing_times_vr"].iloc[0])

        if len(firing_times_cluster)>1:
            trial_trial_numbers = np.asarray(cluster_spike_data["trial_trial_numbers"].iloc[0])
            trial_trial_types = np.asarray(cluster_spike_data["trial_trial_types"].iloc[0])
            trial_hit_miss_try = np.asarray(cluster_spike_data["trial_hit_miss_try"].iloc[0])

            trial_numbers = trial_trial_numbers[(trial_trial_types==tt) & (trial_hit_miss_try==hmt)]
            trial_numbers_to_remove = np.arange(1,len(firing_rates)+1)
            trial_numbers_to_remove = np.setdiff1d(trial_numbers_to_remove, trial_numbers)
            
            fr = firing_rates.flatten()
            fr[np.isnan(fr)] = 0 
            fr[np.isinf(fr)] = 0
            autocorr_window_size = track_length*10
            lags = np.arange(0, autocorr_window_size, 1) # were looking at 10 timesteps back and 10 forward
            autocorrelogram = []
            for i in range(len(lags)):
                fr_lagged = fr[i:]
                corr = stats.pearsonr(fr_lagged, fr[:len(fr_lagged)])[0]
                autocorrelogram.append(corr)
            autocorrelogram= np.array(autocorrelogram)
            fig = plt.figure(figsize=(5,2.5))
            ax = fig.add_subplot(1, 1, 1)  # specify (nrows, ncols, axnum)
            for f in range(1,11):
                ax.axvline(x=track_length*f, color="gray", linewidth=2,linestyle="solid", alpha=0.5)
            ax.axhline(y=0, color="black", linewidth=2,linestyle="dashed")
            ax.plot(lags, autocorrelogram, color="black", linewidth=3)
            plt.ylabel('Spatial Autocorr', fontsize=25, labelpad = 10)
            plt.xlabel('Lag (cm)', fontsize=25, labelpad = 10)
            plt.xlim(0,(track_length*4)+3)
            ax.yaxis.set_ticks_position('left')
            ax.xaxis.set_ticks_position('bottom')
            ax.set_ylim([np.floor(min(autocorrelogram[5:])*10)/10,np.ceil(max(autocorrelogram[5:])*10)/10])
            if np.floor(min(autocorrelogram[5:])*10)/10 < 0:
                ax.set_yticks([np.floor(min(autocorrelogram[5:])*10)/10, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
            else:
                ax.set_yticks([-0.1, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
            tick_spacing = track_length
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.set_title(f'cluster_id: {cluster_id}')
            ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
            plt.xticks(fontsize=20)
            plt.yticks(fontsize=20)
            fig.tight_layout(pad=2.0)
            plt.subplots_adjust(hspace = .35, wspace = .35,  bottom = 0.2, left = 0.3, right = 0.87, top = 0.92)
            plt.show()
            #plt.savefig(save_path + '/spatial_autocorrelogram_' + spike_data.session_id.iloc[cluster_index] + '_' + str(int(cluster_id)) + suffix + '.png', dpi=200)
            plt.close()
   


In [6]:
def plot_of_autocorrelogram(spike_data, save_path=None):
    for cluster_index, cluster_id in enumerate(spike_data.cluster_id):
        cluster_df = spike_data[(spike_data.cluster_id == cluster_id)]  # dataframe for that cluster
        rate_map_autocorr_fig = plt.figure()
        rate_map_autocorr_fig.set_size_inches(5, 5, forward=True)
        ax = rate_map_autocorr_fig.add_subplot(1, 1, 1)  # specify (nrows, ncols, axnum)
        rate_map_autocorr = cluster_df['rate_map_autocorrelogram'].iloc[0]
        if rate_map_autocorr.size:
            ax = plt.subplot(1, 1, 1)
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['left'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            plt.tick_params(axis='both', which='both',bottom=False,
                            top=False, right=False,left=False,
                            labelleft=False,labelbottom=False)
            ax.set_aspect('equal')
            autocorr_img = ax.imshow(rate_map_autocorr, cmap='jet', interpolation='nearest')
            rate_map_autocorr_fig.colorbar(autocorr_img)
            plt.tight_layout()
            plt.title('Autocorrelogram \n grid score: ' + str(round(cluster_df['grid_score'].iloc[0], 2)), fontsize=24)
            plt.show()
            #plt.savefig(save_path + '/open_field_rate_map_autocorrelogram_' + spike_data.session_id.iloc[cluster_index] + '_' + str(int(cluster_id)) + '.png', dpi=300)
        plt.close()

In [7]:
project_path = "/mnt/datastore/Harry/Cohort11_april2024/derivatives/"
# get sorting analyzer and unit locations
of_session_base_names = ["M21_D26_2024-05-28_16-35-31_OF1"]
vr_session_base_names = ["M21_D26_2024-05-28_17-04-41_VR1"]

master_data = pd.DataFrame()
for vr_name, of_name in zip(vr_session_base_names, of_session_base_names):
    mouse = vr_name.split("_")[0]
    day = vr_name.split("_")[1]
    sorting_analyzer_path = f"{project_path}{mouse}/{day}/ephys/sorting_analyzer"
    vr_path = f"{project_path}{mouse}/{day}/vr/{vr_name}/processed/kilosort4/spikes.pkl"
    ramp_path = f"{project_path}{mouse}/{day}/vr/{vr_name}/processed/kilosort4/ramp_classifications.pkl"
    of_path = f"{project_path}{mouse}/{day}/of/{of_name}/processed/kilosort4/spikes.pkl"
        
    position_path = f"{project_path}{mouse}/{day}/vr/{vr_name}/processed/position_data.csv"
    processed_position_path = f"{project_path}{mouse}/{day}/vr/{vr_name}/processed/processed_position_data.pkl"
    position_data = pd.read_csv(position_path)
    processed_position_data = pd.read_pickle(processed_position_path)
    
    spike_data_vr = pd.read_pickle(vr_path); spike_data_vr["firing_times_vr"] = spike_data_vr["firing_times"]
    spike_data_of = pd.read_pickle(of_path)
    spike_data_vr = add_trial_statistics(spike_data_vr, processed_position_data)
    spike_data_vr = calculate_spatial_information(spike_data_vr, position_data, track_length=200)
    spike_data = pd.merge(spike_data_vr, spike_data_of, on="cluster_id")
    master_data = pd.concat([master_data, spike_data])
    print(f"added {day}")


added D26


In [9]:
track_length = 200

master_data = master_data.sort_values(by=["spatial_information_score_Isec"], ascending=False) 

for cluster_id in master_data.cluster_id:
    fig = plt.figure(figsize=(5,2.5))
    ax = fig.add_subplot(1, 1, 1)  # specify (nrows, ncols, axnum)
    for f in range(1,11):
        ax.axvline(x=track_length*f, color="gray", linewidth=2,linestyle="solid", alpha=0.5)
 
    cluster_spike_data = spike_data[(spike_data.cluster_id == cluster_id)]  # dataframe for that cluster
    firing_rates = np.array(cluster_spike_data['fr_binned_in_space_smoothed'].iloc[0])
    firing_times_cluster = np.array(cluster_spike_data["firing_times_vr"].iloc[0])
    trial_trial_numbers = np.asarray(cluster_spike_data["trial_trial_numbers"].iloc[0])
    trial_trial_types = np.asarray(cluster_spike_data["trial_trial_types"].iloc[0])
    trial_hit_miss_try = np.asarray(cluster_spike_data["trial_hit_miss_try"].iloc[0])

    for tt, c in zip([0,1], ["black", "red"]): 
        for hmt, linestyle in zip(["hit", "run"], ["solid", "dotted"]): 

            trial_numbers = trial_trial_numbers[(trial_trial_types==tt) & (trial_hit_miss_try==hmt)]
            trial_numbers_to_remove = np.arange(1,len(firing_rates)+1)
            trial_numbers_to_remove = np.setdiff1d(trial_numbers_to_remove, trial_numbers)

            #print(f'using trials {trial_numbers} for tt {tt} and hmt {hmt}')
            fr = firing_rates.copy()
            fr[np.isnan(fr)] = 0 
            fr[np.isinf(fr)] = 0
            fr[trial_numbers_to_remove-1,:] = np.nan
            fr = fr.flatten()

            autocorr_window_size = track_length*10
            lags = np.arange(0, autocorr_window_size, 1) # were looking at 10 timesteps back and 10 forward
            autocorrelogram = []
            for i in range(len(lags)):
                fr_lagged = fr[i:]
                a = fr_lagged
                b = fr[:len(fr_lagged)]
                nans = np.ma.mask_or(np.isnan(a), np.isnan(b))
                #print(f" len(a[~nans]) = {len(a[~nans])}, len(b[~nans]) = {len(b[~nans])}")
                try:
                    corr = stats.pearsonr(a[~nans],b[~nans])[0]
                except:
                    corr = np.nan
                autocorrelogram.append(corr)
            autocorrelogram= np.array(autocorrelogram)
            ax.plot(lags[3:], autocorrelogram[3:], color=c, linestyle=linestyle, linewidth=2)

    ax.axhline(y=0, color="black", linewidth=2,linestyle="dashed")
    plt.ylabel('Spatial Autocorr', fontsize=25, labelpad = 10)
    plt.xlabel('Lag (cm)', fontsize=25, labelpad = 10)
    plt.xlim(0,(track_length*4)+3)
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    #ax.set_ylim([np.floor(min(autocorrelogram[5:])*10)/10,np.ceil(max(autocorrelogram[5:])*10)/10])
    #if np.floor(min(autocorrelogram[5:])*10)/10 < 0:
    #    ax.set_yticks([np.floor(min(autocorrelogram[5:])*10)/10, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
    #else:
    #    ax.set_yticks([-0.1, 0, np.ceil(max(autocorrelogram[5:])*10)/10])
    #ax.set_ylim([0, 0.5])
    #ax.set_yticks([-0.1, 0, 0.5]) 

    tick_spacing = track_length
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.set_title(f'cluster_id: {cluster_id}')
    ax.xaxis.set_major_locator(ticker.MultipleLocator(tick_spacing))
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    fig.tight_layout(pad=2.0)
    plt.subplots_adjust(hspace = .35, wspace = .35,  bottom = 0.2, left = 0.3, right = 0.87, top = 0.92)
    plt.show()


KeyError: 'spatial_information_score_Isec'