In [None]:
# To be able to make edits to repo without having to restart notebook
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import pearsonr
import seaborn as sns
import tkinter as tk
from scipy import signal
from scipy import ndimage
from math import ceil
import cv2
import ot as pot
import itertools
import matplotlib.patches as patches

PROJECT_PATH = os.path.dirname(os.path.abspath(os.getcwd()))
# PROJECT_PATH = os.getcwd()
sys.path.append(os.path.dirname(PROJECT_PATH))

from _prototypes.cell_remapping.src.remapping import pot_sliced_wasserstein
from _prototypes.cell_remapping.src.wasserstein_distance import _get_ratemap_bucket_midpoints, single_point_wasserstein

unit_matcher_path = os.getcwd()
prototype_path = os.path.abspath(os.path.join(unit_matcher_path, os.pardir))
project_path = os.path.abspath(os.path.join(prototype_path, os.pardir))
lab_path = os.path.abspath(os.path.join(project_path, os.pardir))
sys.path.append(project_path)
os.chdir(project_path)
print(project_path)

In [None]:
from _prototypes.cell_remapping.main import main
from _prototypes.cell_remapping.src.settings import settings_dict
from x_io.rw.axona.batch_read import make_study
from library.study_space import Animal

settings_dict['useMatchedCut'] = False
settings_dict['ppm'] = 711

studies = []
failed = []
path = r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\all"
animal_pos_dict = {}
for folder in os.listdir(path):
    newp = os.path.join(path,folder)
    if 'sub' not in folder:
        # and 'Session3' in folder:
        study = make_study(newp,settings_dict=settings_dict)
        study.make_animals()
        for animal in study.animals:
            pos_fp = newp + r"\position.txt"
            pos_data = pd.read_csv(pos_fp, sep='\t', header=None)
            # print(len(pos_data))
            # split each line at the comma
            pos_data = pos_data[0].str.split(';', expand=True)  
            # drop wherever there is the word 'POSITION'
            pos_data = pos_data[~pos_data[0].str.contains('POSITION')]
            
            # drop(pos_data[pos_data[0] == 'POSITION'].index)
            # pos_data = pos_data[1:]
            # rename columns
            pos_data.columns = ['time', 'x', 'y']
            # convert to numeric
            pos_data = pos_data.apply(pd.to_numeric)
            animal_pos_dict[animal.animal_id] = pos_data

        studies.append(study)
    else:
        failed.append(newp)

In [None]:
spikes = []
spike_dict = {}
ses_cts = []
ses_names = {}
ses_ct = 1
for study in studies:
    for animal in study.animals:
        cells = animal.sessions['session_1'].get_cell_data()['cell_ensemble'].cells
        spike_dict[animal.animal_id] = cells


        for cell in cells:
            spikes.append(cell.event_times)
            ses_cts.append(ses_ct)
            ses_names[ses_ct] = animal.animal_id

    ses_ct += 1

In [None]:
unq_names = []
for ky in animal_pos_dict.keys():
    nme = ky.split('_tet')[0]
    if nme not in unq_names:
        unq_names.append(nme)

fig = plt.figure(figsize=(12, 40))
ct = 1
for ky in unq_names:
    dta = animal_pos_dict[ky + '_tet1']
    ax1 = fig.add_subplot(len(unq_names), 2, ct)
    ax1.plot(dta['time'], dta['x'])
    ax1.set_title(ky + ' - X vs Time')
    ax1.set_xlabel('Time (s)')
    ax1.set_ylabel('X')

    ax2 = fig.add_subplot(len(unq_names), 2, ct + 1)
    ax2.plot(dta['x'])
    ax2.set_title(ky + ' - X vs index')
    ax2.set_xlabel('Index')
    ax2.set_ylabel('X')

    ct += 2

fig.tight_layout()
plt.show()


In [None]:
valid_chunks = {}

unq_names = []
for ky in animal_pos_dict.keys():
    nme = ky.split('_tet')[0]
    if nme not in unq_names:
        unq_names.append(nme)

fig = plt.figure(figsize=(20, 40))
ct = 1

for ky in unq_names:
    dta = animal_pos_dict[ky + '_tet1']
    total_positions = max(dta['x']) 
    chunk_size = 128  # Chunk size in units of position
    
    # Calculate the number of chunks
    num_chunks = int(total_positions / chunk_size) + 1
    
    combined_chunk_x = []
    combined_chunk_time = []
    combined_chunk_index = []
    
    for chunk_index in range(num_chunks):
        start_pos = chunk_index * chunk_size
        end_pos = (chunk_index + 1) * chunk_size
        # print(start_pos, end_pos)
        
        # Find corresponding indices for start and end positions
        # start_idx = next((idx for idx, pos in enumerate(dta['x']) if pos >= start_pos), None)
        # end_idx = next((idx for idx, pos in enumerate(dta['x']) if pos > end_pos), None)  # Exclusive end
        start_idx = np.where(dta['x'] >= start_pos)[0][0]
        end_idx = np.where(dta['x'] > end_pos)[0]
        if len(end_idx) > 0:
            end_idx = end_idx[0]
        else:
            end_idx = None
        # end_idx = end_idx if end_idx is not None else len(dta['x'])
        
        chunk_time = dta['time'][start_idx:end_idx].to_numpy()
        chunk_x = dta['x'][start_idx:end_idx].to_numpy()
        if len(chunk_x) > 0:
            chunk_x = chunk_x - chunk_x[0]

            if np.min(chunk_x) >= 0 and np.max(chunk_x) <= 128 and chunk_x[-1] - chunk_x[0] > 50:

                valid_chunks[ky + '_chunk_' + str(chunk_index)] = {}
                valid_chunks[ky + '_chunk_' + str(chunk_index)]['chunk_x'] = chunk_x
                valid_chunks[ky + '_chunk_' + str(chunk_index)]['chunk_time'] = chunk_time
                

        combined_chunk_x.extend(chunk_x)
        combined_chunk_time.extend(chunk_time)
        # stp = (end_pos - start_pos) / len(chunk_x)
        # combined_chunk_index.extend(np.arange(start_pos, end_pos, stp))
    
    ax1 = fig.add_subplot(len(unq_names), 2, ct)
    ax1.plot(combined_chunk_time, combined_chunk_x)
    ax1.set_title(f"{ky} - X vs Time")
    ax1.set_xlabel('Time (s)')
    ax1.set_ylabel('X')
    # ax1.set_ylim([0, 128])

    ax2 = fig.add_subplot(len(unq_names), 2, ct + 1)
    ax2.plot(range(len(combined_chunk_x)), combined_chunk_x)
    ax2.set_title(f"{ky} - X vs Index")
    ax2.set_xlabel('Index')
    ax2.set_ylabel('X')
    # ax2.set_ylim([0, 128])

    ct += 2

fig.tight_layout()
plt.show()


In [None]:
import numpy as np

def _speed1D(x, t, window_sizes=np.arange(1, 100, 2)):
    """Calculates smoothed speed using multiple window sizes and averages the results"""
    N = len(x)
    v_smoothed = np.zeros(N)

    for index in range(N):
        speeds = []
        for window_size in window_sizes:
            start_index = max(0, index - window_size // 2)
            end_index = min(N - 1, index + window_size // 2)

            dx = x[end_index] - x[start_index]
            dt = t[end_index] - t[start_index]

            if dt != 0:
                speed = np.abs(dx / dt)
                speeds.append(speed)

        if speeds:
            v_smoothed[index] = np.mean(speeds)
        else:
            v_smoothed[index] = 0  # Set speed to 0 if no valid speeds are computed

    return v_smoothed



def _gkern(kernlen: int, std: int) -> np.ndarray:

    '''
        Returns a 2D Gaussian kernel array.

        Params:
            kernlen, std (int):
                Kernel length and standard deviation

        Returns:
            np.ndarray:
                gkern2d
    '''

    gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1)
    gkern2d = np.outer(gkern1d, gkern1d)
    return gkern2d



In [None]:

animal_ids = []
for ky in valid_chunks.keys():
    animal_ids.append(ky.split('_chunk')[0])
animal_ids = np.unique(animal_ids)

avg_chunks = {}

for unq_animal in animal_ids:
    
    fig = plt.figure(figsize=(12, 3))
    ax = plt.subplot(1,1,1)
    axt = ax.twinx()
    ky_c = {}
    # chunks
    for ky in valid_chunks.keys():
        if unq_animal in ky:
            spk_keys = list(spike_dict.keys())
            spk_keys = [x for x in spk_keys if ky.split('_chunk')[0] in x]
            
            if unq_animal not in avg_chunks:
                avg_chunks[unq_animal] = {}

            ky_c[ky] = 1
            # tetrodes
            for spk_ky in spk_keys:
                spk_data = spike_dict[spk_ky]
                # cells 
                for cell in spk_data:
                    tet_cell_id = str(spk_ky.split('_tet')[-1]) + '_' + str(cell.cluster.cluster_label)
                    spks_to_plot = cell.event_times
                    spks_to_plot = spks_to_plot[spks_to_plot >= valid_chunks[ky]['chunk_time'][0]]
                    spks_to_plot = spks_to_plot[spks_to_plot <= valid_chunks[ky]['chunk_time'][-1]]

                    position_bins = np.arange(0, 129, 129/64)
                    pos_bin_ct = np.zeros((64))
                    for spk in spks_to_plot:
                        closest_chunk_time = valid_chunks[ky]['chunk_time'][np.argmin(np.abs(valid_chunks[ky]['chunk_time'] - spk))]
                        closest_chunk_idx = np.where(valid_chunks[ky]['chunk_time'] == closest_chunk_time)[0][0]
                        closest_chunk_position = valid_chunks[ky]['chunk_x'][closest_chunk_idx]
                        closest_position_bin = position_bins[np.argmin(np.abs(position_bins - closest_chunk_position))]
                        closest_position_bin_idx = np.argmin(np.abs(position_bins - closest_chunk_position))
                        pos_bin_ct[closest_position_bin_idx] += 1

                    column_values = np.linspace(0,129,64)
                    occ_map_raw = np.zeros((64))
                    for i in range(0,len(valid_chunks[ky]['chunk_time'])):
                        column_index = np.abs(column_values - valid_chunks[ky]['chunk_x'][i]).argmin()
                        occ_map_raw[column_index] += valid_chunks[ky]['chunk_time'][i] - valid_chunks[ky]['chunk_time'][i-1]
                        # kernlen = int(1*8)
                        # std = int(0.2*kernlen)
                        # occ_map_normalized = occ_map_raw / pos_t[-1]
                        # occ_map_smoothed = cv2.filter2D(occ_map_normalized,-1,_gkern(kernlen,std))
                        # kernel = np.ones((2,2))
                        # occ_map_smoothed = occ_map_smoothed/max(occ_map_smoothed.flatten())

                    # if len(spks_to_plot) != 0:
                    #     # fr_map = pos_bin_ct / np.max(spks_to_plot)
                    #     kernlen = int(3 * 8)
                    #     std = int(0.2 * kernlen)
                    #     fr_map = cv2.filter2D(fr_map, -1, _gkern(kernlen, std))
                    #     fr_map = fr_map / np.max(fr_map)
                    #     # fr_map = ndimage.gaussian_filter(fr_map, sigma=1)
                    fr_map_raw = pos_bin_ct 
                    # / np.max(pos_bin_ct)

                    # kernlen = int(1 * 8)
                    # std = int(0.2 * kernlen)
                    # pos_bin_ct = cv2.filter2D(pos_bin_ct, -1, _gkern(kernlen, std))
                    # fr_map_smoothed = pos_bin_ct / np.max(pos_bin_ct)

                    
                    if tet_cell_id in avg_chunks[unq_animal].keys():
                        avg_chunks[unq_animal][tet_cell_id]['collection'] += fr_map_raw
                        avg_chunks[unq_animal][tet_cell_id]['occ_map'] += occ_map_raw
                        avg_chunks[unq_animal][tet_cell_id]['count'] += 1
                        if valid_chunks[ky]['chunk_time'][-1] > avg_chunks[unq_animal][tet_cell_id]['T']:
                            avg_chunks[unq_animal][tet_cell_id]['T'] = valid_chunks[ky]['chunk_time'][-1]
                    else:
                        avg_chunks[unq_animal][tet_cell_id] = {'collection': [], 'count': 0, 'occ_map': [], 'T':0}
                        avg_chunks[unq_animal][tet_cell_id]['occ_map'] = occ_map_raw
                        avg_chunks[unq_animal][tet_cell_id]['collection'] = fr_map_raw
                        avg_chunks[unq_animal][tet_cell_id]['count'] += 1
                        avg_chunks[unq_animal][tet_cell_id]['T'] = valid_chunks[ky]['chunk_time'][-1]
                    

                    # axt.plot(valid_chunks[ky]['chunk_time'], valid_chunks[ky]['chunk_x'], color='red', lw=0.5)
                    t = valid_chunks[ky]['chunk_time']
                    x = valid_chunks[ky]['chunk_x']
                    speed = _speed1D(x, t)
                    # convert_to_nan = np.where(speed > 20)[0]
                    # convert_to_nan2 = np.where(speed < 10)[0]
                    # convert_to_nan = np.concatenate((convert_to_nan, convert_to_nan2))
                    # # convert_to_nan = []
                    # # print(convert_to_nan)
                    # # t[convert_to_nan] = np.nan
                    # # t = t[(speed <30) & (speed > 0)]
                    # # valid_spike_times = spks_to_plot[]
                    # # speed = speed[(speed < 30) & (speed > 0)]
                    # speed[convert_to_nan] = np.nan                 
                    # smoothed_speed = ndimage.gaussian_filter(speed, sigma=100)
                    
                    if unq_animal == "Mouse3_session11_230530_151933_2":
                        # t[speed > 20] = 0
                        speed[speed > 7] = 0

                    axt.plot(t, speed, color='red', lw=0.5, alpha=0.5)
                    ax.plot(spks_to_plot, np.ones(len(spks_to_plot)) * ky_c[ky], 'ko', markersize=1)

                    ky_c[ky] += 1

    ax.set_title(unq_animal)
    print(unq_animal)
    ax.set_ylabel('Cell #')
    axt.set_ylabel('Position')
    ax.set_xlabel('Time (s)')
    fig.tight_layout()
    plt.show()



In [None]:
stop()

In [None]:
fig = plt.figure(figsize=(4, 4))
ax = plt.subplot(1,1,1)
avg_map = []
order = []
for ky in avg_chunks.keys():
    # if 'mouse3' in ky.lower():
    for unit in avg_chunks[ky].keys():
        avg = avg_chunks[ky][unit]['collection']
        avg_occ_map = avg_chunks[ky][unit]['occ_map']
        rate_map_raw = avg / avg_occ_map
        T = avg_chunks[ky][unit]['T']
        # order.append(np.sum(avg))
        #  / avg_chunks[ky][unit]['count']
        # avg = avg / np.sum(avg)
        kernlen = int(1 * 8)
        std = int(0.2 * kernlen)
        avg = cv2.filter2D(avg, -1, _gkern(kernlen, std))
        avg_occ_map = avg_occ_map / T
        avg_occ_map = cv2.filter2D(avg_occ_map,-1,_gkern(kernlen,std))
        avg = avg / np.max(avg)
        avg_occ_map = avg_occ_map / np.max(avg_occ_map)
        order.append(np.mean(avg))
        rate_map = np.where(avg_occ_map<0.0001, 0, avg/avg_occ_map)
        rate_map = rate_map/max(rate_map.flatten())
        avg_map.append(rate_map)

# order_idx = np.argsort(np.array(order))
# avg_map = np.array(avg_map).squeeze()[order_idx,:]
avg_map = np.asarray(avg_map).squeeze().T
print(avg_map.shape)
# map with colorbar
img = ax.imshow(avg_map, aspect='auto', cmap='jet')
ax.invert_yaxis()
ax.set_yticklabels(np.arange(0, 80, 10))
cbar = fig.colorbar(img, ax=ax)
cbar.set_label('Firing Rate Norm')
ax.set_xlabel('Cell #')
ax.set_ylabel('Linear Track Length (cm)')
# ax.set_title('Mouse 3')
fig.tight_layout()
plt.show()

fig = plt.figure(figsize=(3,3))
ax = plt.subplot(1,1,1)
avg_map = []
for ky in avg_chunks.keys():
    for unit in avg_chunks[ky].keys():
        avg = avg_chunks[ky][unit]['collection'] 
        # avg = avg / np.max(avg)
        avg_map.append(avg)

# map with colorbar
img = ax.imshow(avg_map, aspect='auto', cmap='jet')
ax.invert_yaxis()
# change yaxis to be from 0 to 79.9 instead of current
ax.set_yticklabels(np.arange(0, 80, 10))
cbar = fig.colorbar(img, ax=ax)
cbar.set_label('Firing Rate Norm')
ax.set_xlabel('Cell #')
ax.set_ylabel('Linear Track Length (cm)')
fig.tight_layout()
plt.show()



In [None]:

animal_ids = []
for ky in valid_chunks.keys():
    animal_ids.append(ky.split('_chunk')[0])
animal_ids = np.unique(animal_ids)

avg_chunks = {}
unq_cell_plot_dict = {}

for unq_animal in animal_ids:
    

    ky_c = {}
    # chunks
    for ky in valid_chunks.keys():
        if unq_animal in ky:
            spk_keys = list(spike_dict.keys())
            spk_keys = [x for x in spk_keys if ky.split('_chunk')[0] in x]
            
            if unq_animal not in avg_chunks:
                avg_chunks[unq_animal] = {}

            ky_c[ky] = 1
            # tetrodes
            for spk_ky in spk_keys:
                spk_data = spike_dict[spk_ky]
                # cells 
                for cell in spk_data:
                    tet_cell_id = str(spk_ky.split('_tet')[-1]) + '_' + str(cell.cluster.cluster_label)
                    spks_to_plot = cell.event_times
                    spks_to_plot = spks_to_plot[spks_to_plot >= valid_chunks[ky]['chunk_time'][0]]
                    spks_to_plot = spks_to_plot[spks_to_plot <= valid_chunks[ky]['chunk_time'][-1]]

                    position_bins = np.arange(0, 129, 129/64)
                    pos_bin_ct = np.zeros((64))
                    for spk in spks_to_plot:
                        closest_chunk_time = valid_chunks[ky]['chunk_time'][np.argmin(np.abs(valid_chunks[ky]['chunk_time'] - spk))]
                        closest_chunk_idx = np.where(valid_chunks[ky]['chunk_time'] == closest_chunk_time)[0][0]
                        closest_chunk_position = valid_chunks[ky]['chunk_x'][closest_chunk_idx]
                        closest_position_bin = position_bins[np.argmin(np.abs(position_bins - closest_chunk_position))]
                        closest_position_bin_idx = np.argmin(np.abs(position_bins - closest_chunk_position))
                        pos_bin_ct[closest_position_bin_idx] += 1

                    column_values = np.linspace(0,129,64)
                    occ_map_raw = np.zeros((64))
                    for i in range(0,len(valid_chunks[ky]['chunk_time'])):
                        column_index = np.abs(column_values - valid_chunks[ky]['chunk_x'][i]).argmin()
                        occ_map_raw[column_index] += valid_chunks[ky]['chunk_time'][i] - valid_chunks[ky]['chunk_time'][i-1]
                        # kernlen = int(1*8)
                        # std = int(0.2*kernlen)
                        # occ_map_normalized = occ_map_raw / pos_t[-1]
                        # occ_map_smoothed = cv2.filter2D(occ_map_normalized,-1,_gkern(kernlen,std))
                        # kernel = np.ones((2,2))
                        # occ_map_smoothed = occ_map_smoothed/max(occ_map_smoothed.flatten())

                    # if len(spks_to_plot) != 0:
                    #     # fr_map = pos_bin_ct / np.max(spks_to_plot)
                    #     kernlen = int(3 * 8)
                    #     std = int(0.2 * kernlen)
                    #     fr_map = cv2.filter2D(fr_map, -1, _gkern(kernlen, std))
                    #     fr_map = fr_map / np.max(fr_map)
                    #     # fr_map = ndimage.gaussian_filter(fr_map, sigma=1)
                    fr_map_raw = pos_bin_ct 
                    # / np.max(pos_bin_ct)

                    # kernlen = int(1 * 8)
                    # std = int(0.2 * kernlen)
                    # pos_bin_ct = cv2.filter2D(pos_bin_ct, -1, _gkern(kernlen, std))
                    # fr_map_smoothed = pos_bin_ct / np.max(pos_bin_ct)

                    
                    if tet_cell_id in avg_chunks[unq_animal].keys():
                        avg_chunks[unq_animal][tet_cell_id]['collection'] += fr_map_raw
                        avg_chunks[unq_animal][tet_cell_id]['occ_map'] += occ_map_raw
                        avg_chunks[unq_animal][tet_cell_id]['count'] += 1
                        if valid_chunks[ky]['chunk_time'][-1] > avg_chunks[unq_animal][tet_cell_id]['T']:
                            avg_chunks[unq_animal][tet_cell_id]['T'] = valid_chunks[ky]['chunk_time'][-1]
                    else:
                        avg_chunks[unq_animal][tet_cell_id] = {'collection': [], 'count': 0, 'occ_map': [], 'T':0}
                        avg_chunks[unq_animal][tet_cell_id]['occ_map'] = occ_map_raw
                        avg_chunks[unq_animal][tet_cell_id]['collection'] = fr_map_raw
                        avg_chunks[unq_animal][tet_cell_id]['count'] += 1
                        avg_chunks[unq_animal][tet_cell_id]['T'] = valid_chunks[ky]['chunk_time'][-1]
                    

                    # axt.plot(valid_chunks[ky]['chunk_time'], valid_chunks[ky]['chunk_x'], color='red', lw=0.5)
                    t = valid_chunks[ky]['chunk_time']
                    x = valid_chunks[ky]['chunk_x']
                    speed = _speed1D(x, t)
                    # speed = ndimage.gaussian_filter(speed, sigma=100)
                    # speed = smooth_signal(speed, [2,5,10,25,50,100,200,250,400,500,1000,2000])

                    # convert_to_nan = np.where(speed > 20)[0]
                    # convert_to_nan2 = np.where(speed < 10)[0]
                    # convert_to_nan = np.concatenate((convert_to_nan, convert_to_nan2))
                    # # convert_to_nan = []
                    # # print(convert_to_nan)
                    # # t[convert_to_nan] = np.nan
                    # # t = t[(speed <30) & (speed > 0)]
                    # # valid_spike_times = spks_to_plot[]
                    # # speed = speed[(speed < 30) & (speed > 0)]
                    # speed[convert_to_nan] = np.nan                 
                    # smoothed_speed = ndimage.gaussian_filter(speed, sigma=100)
                    
                    if unq_animal == "Mouse3_session11_230530_151933_2":
                        # t[speed > 20] = 0
                        speed[speed > 7] = 0

                    binned_spikes = np.zeros((len(t)))
                    for spk in spks_to_plot:
                        closest_chunk_time = t[np.argmin(np.abs(t - spk))]
                        closest_chunk_idx = np.where(t == closest_chunk_time)[0][0]
                        binned_spikes[closest_chunk_idx] += 1

                    unq_cell_plot_dict_key = unq_animal + '_' + str(spk_ky.split('_tet')[-1]) + '_' + str(cell.cluster.cluster_label)
                    if unq_cell_plot_dict_key not in unq_cell_plot_dict.keys():
                        fig = plt.figure(figsize=(12, 3))
                        ax = plt.subplot(1,1,1)
                        axt = ax.twinx()
                        axblank = ax.twinx()
                        ttle = unq_animal + ' tetrode ' + str(spk_ky.split('_tet')[-1]) + ' cell ' + str(cell.cluster.cluster_label)
                        unq_cell_plot_dict[unq_cell_plot_dict_key] = {'fig': fig, 'ax': ax, 'axt': axt, 'axblank': axblank, 'ttle': ttle}

                        unq_cell_plot_dict[unq_cell_plot_dict_key]['binned_spikes'] = binned_spikes
                        unq_cell_plot_dict[unq_cell_plot_dict_key]['t'] = t
                        unq_cell_plot_dict[unq_cell_plot_dict_key]['speed'] = speed

                    else:
                        fig = unq_cell_plot_dict[unq_cell_plot_dict_key]['fig']
                        ax = unq_cell_plot_dict[unq_cell_plot_dict_key]['ax']
                        axt = unq_cell_plot_dict[unq_cell_plot_dict_key]['axt']
                        axblank = unq_cell_plot_dict[unq_cell_plot_dict_key]['axblank']
                        ttle = unq_cell_plot_dict[unq_cell_plot_dict_key]['ttle']

                        unq_cell_plot_dict[unq_cell_plot_dict_key]['binned_spikes']= np.hstack((unq_cell_plot_dict[unq_cell_plot_dict_key]['binned_spikes'], binned_spikes))
                        unq_cell_plot_dict[unq_cell_plot_dict_key]['t'] = np.hstack((unq_cell_plot_dict[unq_cell_plot_dict_key]['t'], t))
                        unq_cell_plot_dict[unq_cell_plot_dict_key]['speed'] = np.hstack((unq_cell_plot_dict[unq_cell_plot_dict_key]['speed'], speed))

                    # speed_smoothed = ndimage.gaussian_filter(speed, sigma=200)
                    l1 = axt.plot(t, speed, color='red', lw=2, alpha=0.7, label='Speed')
                    unq_cell_plot_dict[unq_cell_plot_dict_key]['l1'] = l1
                    # axblank.plot(spks_to_plot, np.ones(len(spks_to_plot)) * ky_c[ky], 'ko', markersize=1)
                    # hide y-axis
                    axblank.get_yaxis().set_visible(False)


                    ky_c[ky] += 1

for ky in unq_cell_plot_dict:
    ax = unq_cell_plot_dict[ky]['ax']
    axt = unq_cell_plot_dict[ky]['axt']
    fig = unq_cell_plot_dict[ky]['fig']
    ttle = unq_cell_plot_dict[ky]['ttle']
    binned_spikes = unq_cell_plot_dict[ky]['binned_spikes']
    t = unq_cell_plot_dict[ky]['t']
    speed = unq_cell_plot_dict[ky]['speed']
    l1 = unq_cell_plot_dict[ky]['l1']

    smoothed = ndimage.gaussian_filter(binned_spikes, sigma=100)
    smoothed = smoothed / (t[1] - t[0])
    correlation = pearsonr(speed, smoothed)
    # correlation = pearsonr(speed_smoothed, smoothed)
    l2 = ax.plot(t, smoothed, color='black', lw=2, alpha=1, label='Firing Rate')

    lns = l2 + l1
    labs = [l.get_label() for l in lns]
    ax.legend(lns, labs, loc='upper left')

    ax.set_title(ttle + ' - Correlation: ' + str(correlation[0]))
    ax.set_xlabel('Time (s)')
    axt.set_ylabel('Speed (cm/s)')
    ax.set_ylabel('Firing Rate (Hz)')
    fig.tight_layout()
    
plt.show()



In [None]:
def smooth_signal(signal, window_sizes):
    smoothed_signal = np.zeros_like(signal, dtype=np.float64)
    for window_size in window_sizes:
        half_window = window_size // 2
        for i in range(len(signal)):
            start = max(0, i - half_window)
            end = min(len(signal), i + half_window + 1)
            mn = np.nanmean(signal[start:end])
            smoothed_signal[i] += mn
        smoothed_signal /= len(window_sizes)
    return smoothed_signal


In [None]:
plt.plot(smooth_signal(speed, np.arange(1, 1000, 10)))

In [None]:

fig = plt.figure(figsize=(8,3))

# make cmap based on value in ses_cts
cmp = plt.get_cmap('jet')
norm = plt.Normalize(vmin=0, vmax=max(ses_cts))

c = 0
for spike in spikes:
    plt.scatter(spike,np.ones(len(spike))*c,s=1, color=cmp(norm(ses_cts[c])))
    c += 1

plt.ylabel('Cell #')
plt.xlim(0,180)
plt.xlabel('Time (s)')

plt.title('Raster of all cells in all sessions')
plt.show()
fig.tight_layout()

In [None]:


# make cmap based on value in ses_cts
cmp = plt.get_cmap('jet')
norm = plt.Normalize(vmin=0, vmax=max(ses_cts))
fig = plt.figure(figsize=(8,1))

c = 0
prev = None
for spike in spikes:
    if ses_cts[c] != prev and prev is not None:
        fig = plt.figure(figsize=(8,1))

    plt.scatter(spike,np.ones(len(spike))*c,s=1, color=cmp(norm(ses_cts[c])))
    if c != len(ses_cts)-1:
        prev = ses_cts[c]

        if ses_cts[c+1] != prev:
            # plt.title('Session {}'.format(prev))
            plt.title(ses_names[prev])
            plt.ylabel('Cell #')
            plt.xlabel('Time (s)')
            plt.show()
            fig.tight_layout()
        c += 1

# plt.title('Session {}'.format(prev))
plt.title(ses_names[prev])
plt.ylabel('Cell #')
plt.xlabel('Time (s)')
plt.show()
fig.tight_layout()

In [None]:
lap1_path = r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit_test_data\VR_test_data\10-2-23_lap\lap1\position.txt"
lap1_data = pd.read_csv(lap1_path, sep='\t', header=None)
lap1_data = lap1_data[0].str.split(';', expand=True)  
lap1_data = lap1_data[~lap1_data[0].str.contains('POSITION')]
lap1_data.columns = ['time', 'x', 'y']
lap1_data = lap1_data.apply(pd.to_numeric)

lap2_path = r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit_test_data\VR_test_data\10-2-23_lap\lap2\position.txt"
lap2_data = pd.read_csv(lap2_path, sep='\t', header=None)
lap2_data = lap2_data[0].str.split(';', expand=True)
lap2_data = lap2_data[~lap2_data[0].str.contains('POSITION')]
lap2_data.columns = ['time', 'x', 'y']
lap2_data = lap2_data.apply(pd.to_numeric)

lap3_path = r"C:\Users\aaoun\OneDrive - cumc.columbia.edu\Desktop\HussainiLab\neuroscikit_test_data\VR_test_data\10-2-23_lap\lap3\position.txt"
lap3_data = pd.read_csv(lap3_path, sep='\t', header=None)
lap3_data = lap3_data[0].str.split(';', expand=True)
lap3_data = lap3_data[~lap3_data[0].str.contains('POSITION')]
lap3_data.columns = ['time', 'x', 'y']
lap3_data = lap3_data.apply(pd.to_numeric)

fig = plt.figure(figsize=(8,3))

ax1 = plt.subplot(1,3,1)
ax1.plot(lap1_data['time'], lap1_data['x'])
ax2 = plt.subplot(1,3,2)
ax2.plot(lap2_data['time'], lap2_data['x'])
ax3 = plt.subplot(1,3,3)
ax3.plot(lap3_data['time'], lap3_data['x'])

ax1.set_title('Lap 1')
ax2.set_title('Lap 2')
ax3.set_title('Lap 3')

ax1.set_xlabel('Time (s)')
ax2.set_xlabel('Time (s)')
ax3.set_xlabel('Time (s)')

ax1.set_ylabel('X')
ax2.set_ylabel('X')
ax3.set_ylabel('X')

fig.suptitle('X vs Time')
fig.tight_layout()
plt.show()

fig = plt.figure(figsize=(8,3))

ax1 = plt.subplot(1,3,1)
ax1.plot(lap1_data['x'])
ax2 = plt.subplot(1,3,2)
ax2.plot(lap2_data['x'])
ax3 = plt.subplot(1,3,3)
ax3.plot(lap3_data['x'])

ax1.set_title('Lap 1')
ax2.set_title('Lap 2')
ax3.set_title('Lap 3')

ax1.set_xlabel('Index')
ax2.set_xlabel('Index')
ax3.set_xlabel('Index')

ax1.set_ylabel('X')
ax2.set_ylabel('X')
ax3.set_ylabel('X')

fig.suptitle('X vs Index')
fig.tight_layout()
plt.show()



In [None]:
for lap_data in [lap1_data, lap2_data, lap3_data]:

    # track_lap = np.zeros(lap_data['x'].shape)
    # track_lap[lap_data['x'] >= 124] = 1

    # check where change in time is greater than 1 second
    lap_time_diff = np.diff(lap_data['time'])
    idx_t = np.where(lap_time_diff <= -1)[0]
    idx_t = idx_t[-1] + 1

    # # # look for change from 0 to 1 and 1 to 0
    lap_diff = np.diff(lap_data['x'])
    idx = np.where(lap_diff <= -20 )[0]
    if len(idx) > 0:
        idx = idx[-1] + 1

        idx_use = np.max([idx, idx_t])
    else:
        idx_use = idx_t

    to_plot_x = lap_data['x'][idx_use:].to_numpy()
    # to_plot_x -= to_plot_x[0]
    to_plot_t = lap_data['time'][idx_use:].to_numpy()
    # to_plot_t -= to_plot_t[0]

    fig = plt.figure(figsize=(8,3))
    plt.plot(to_plot_t, to_plot_x)
    plt.title('X vs Time')
    plt.xlabel('Time (s)')
    plt.ylabel('X position')
    plt.show()

In [None]:
len(lap1_data['x'] >= 124)