In [1]:
#To determine phase locking of all neurons in the PFC of an animal & report as population stats
import matplotlib.pyplot as plt
import numpy as np
from scipy.signal import hilbert, butter, sosfilt
from scipy.stats import circvar, circmean

from set_input import set_input_params, path_to_input_file, load_input_params_base
set_input_params()
[cell_types, Data_folder, plot_folder_vlad, analysis_folder_vlad, plot_folder_andrea, analysis_folder_andrea, 
     parallel, days_8arm, days_prob, days_andrea, rats_vlad, rats_andrea, speed_threshold, sample_rate_whl, 
     sample_rate_data, tab20_colors]=load_input_params_base()

print(days_andrea[-1]) #for just JC315

print(rats_andrea) #for just JC315

# print(load_input_params_base())

# print(speed_threshold)

#base input parameters used:
    #sample_rate_data
    #speed_threshold

['20240330', '20240401', '20240402', '20240403', '20240404', '20240406', '20240407', '20240408']
['JC258', 'JC274', 'JC283', 'JC315']


In [2]:
def get_pfc_clusters(des_filepath):
    file = open(des_filepath)
    des_raw = file.readlines()
    file.close()
    des = [i.strip('\n') for i in des_raw]
    pfc_clusters = [i + 2 for i, x in enumerate(des) if x[1] == 'p'] #selects all clusters in the pfc (when the 2nd value in the des line is 'p')
        #The value given in clu is described by des[value - 2] since cluster 1 is noise
        #This includes putative pyramidal (first value is 'p') and putative interneuron/basket (first value is 'b')
    return pfc_clusters

def get_hpc_clusters(des_filepath):
    file = open(des_filepath)
    des_raw = file.readlines()
    file.close()
    des = [i.strip('\n') for i in des_raw]
    pfc_clusters = [i + 2 for i, x in enumerate(des) if x[1] == '1'] #selects all clusters in the pfc (when the 2nd value in the des line is 'p')
        #The value given in clu is described by des[value - 2] since cluster 1 is noise
        #This includes putative pyramidal (first value is 'p') and putative interneuron/basket (first value is 'b')
    return pfc_clusters
# print(pfc_clusters(des_filepath))

In [3]:
'''Import spikes'''
#Import .res and .clu (code from Cell_map.ipynb)
def to_int_list(filename):
    '''
    Purpose: To import .res and .clu files
        .clu file is the cluster ID (putative neuron ID) for the corresponding .res file spikes
        .res file is the "frame" at sampling frequency 20kHz in which spikes occur
    Parameter: Full file path + ending of the .res or .clu file
    Return value (list of ints): Either the .res or .clu values
    '''
    file = open(filename)
    file_list = file.readlines()
    file.close()
    for i in range(len(file_list)):
        file_list[i] = int(file_list[i].strip('\n'))
    return file_list

In [4]:
def import_speed(speed_filepath):
    file = open(speed_filepath) #Different than to_int_list() above because these are float values
    speed = file.readlines()
    file.close()
    for i in range(len(speed)):
        speed[i] = float(speed[i].strip('\n'))
    return speed

def filter_for_speed(speed, cell_number, clu, res, speed_threshold = speed_threshold):
    '''
    Purpose: To determine which spikes occured when the animal was in motion = abs(speed) >= speed_threshold cm/s
        Will also return 'NaN' for cells which have an average firing rate less than 0.25 spikes/second, as with Nardin et al., 2023
            (see "Dataset details" in the methods section)
    Parameters: 
        speed_filepath (str): full filepath of the .speed file for a specific training session
        cell_number (int): identity of the putative neuron, as with in the .clu file
        clu (list): imported clu file, with first (noise) cluster removed
        res (list): imported res file
    Return value (list): recording "frames" (at 20 kHz) in which the animal was in motion and the specified cell spiked
    '''
    
    spike_index = [i for i, x in enumerate(clu) if x == cell_number] #since clu and res have corresponding indices
    all_spikes = [res[i] for i in spike_index] #in frame number of the spike recording, in 20kHz
    spike_frame = [i for i in all_spikes if abs(speed[i // 512]) >= speed_threshold] 
        #if the speed during a spike is 5cm/s or above, the spike is kept

    # print('total time:', (1/20_000 * res[-1]), 'sec')
    # print('avg firing rate:', len(all_spikes) / (1/20_000 * res[-1]), 'sp/sec')
    if len(all_spikes) / (1/sample_rate_data * res[-1]) < 0.25:
        # print('cell ' + str(cell_number) + ' has a sub-threshold avg firing rate')
        return 'NaN'
        
    return spike_frame



In [5]:
'''Import eegh, determine spike phases'''
def calculate_eeg_phase(eegh_filepath, tetrode_number, eeg_fs = 5000):
    '''
    Purpose: Bandpass (3rd-order Butterworth) filter the eegh data for the theta range (5-12 Hz),
        extract phase using the Hilbert transform, and return the theta phase during each spike in spike_frame
    Parameter(s): 
        eegh_filepath (str):
        tetrode_number (int): 
        spike_frame (list of ints): the output of filter_for_speed()
    Return value (list): spike_phase, a list of equal length to spike_frame
    
    -Assumes that there are always 32 channels in the eegh recording
    '''
    eeg = np.fromfile(eegh_filepath, dtype = np.uint16)
    one_channel = eeg[int(tetrode_number*len(eeg)/32) : int((tetrode_number+1)*len(eeg)/32)]
    
    
    #calculate inst_phase for the entire eegh recording
    sos = butter(3, [5, 12], btype = 'bandpass', output = 'sos', analog = False, fs = eeg_fs)
    filtered = sosfilt(sos, one_channel)
    
    # t = np.arange(len(subsample)) / fs  #Creating time axis if you want to plot signal vs. time
    
    analytic_signal = hilbert(filtered) #Where filtered is a ~1s portion of the theta band-pass filtered EEG
    amplitude_envelope = np.abs(analytic_signal)
    inst_phase = np.unwrap(np.angle(analytic_signal))
    # inst_freq = (np.diff(inst_phase) / 
    #              (2.0 * np.pi) * eeg_fs)
    return inst_phase


def calculate_spike_phase(inst_phase, spike_frame, eeg_fs = 5000, sample_rate_data = sample_rate_data):
    conversion = sample_rate_data / eeg_fs
    spike_phase = [inst_phase[round(i / conversion)] for i in spike_frame] #inst theta phase of the eeg during each spike
    # spike_time = [i / 20_000 for i in spike_frame] #time [s] of each spike, only necessary to plot polar w/ radius = time
    return spike_phase

In [12]:
def stats(spike_phase):
    #From https://www.mathworks.com/matlabcentral/fileexchange/10676-circular-statistics-toolbox-directional-statistics
    mean_direction = circmean(spike_phase) # in rads
    n = len(spike_phase)
    r = 1 - circvar(spike_phase) #The resultant vector length
    
    R = n*r #Rayleigh's R
    z = R**2 / n #Rayleigh's z

    # compute p value using approxation in Zar, p. 617
    # pval = exp(sqrt(1+4*n+4*(n^2-R^2))-(1+2*n));
    # math.exp(math.sqrt(1 + 4 * n + 4 * (n**2 - R**2)) - (1 + 2 * n))
    pval = np.exp(np.sqrt(1 + 4 * n + 4 * (n**2 - R**2)) - (1 + 2 * n))
    return mean_direction, r, R, z, pval

#Could also convert mean_direction to degrees

In [10]:
tetrode_number = 17 #Can check in google sheets which tetrodes work based on training day, usually 17-32 is fine
    #For JC315, tetrode 17 is in the HPC pyrimidale every day
    #JC283 varies a lot more and doesn't always have pyramidale
        #20230930 [25,27, 30, 21] are pyramidale; 
        #20231001 and 02 [26, 30] are oriens/sup. pyramidale; 
        #20231003 [17, 19, 25] in pyramidale

clu_res_filepath = 'JC315-20240408/JC315-20240408_training1'
eegh_filepath = 'JC315-eegh/20240408/JC315-20240408_02.eegh'
speed_filepath = 'eight_arm_fig_data_andrea/analysis/JC315-20240408/JC315_20240408_training1.speed'
des_filepath = 'JC315-20240408/JC315-20240408.des'

clu = to_int_list(clu_res_filepath + '.clu')[1:]#Removing the first (noise) cluster
res = to_int_list(clu_res_filepath + '.res')
pfc_clusters = get_pfc_clusters(des_filepath)
hpc_clusters = get_hpc_clusters(des_filepath)

speed = import_speed(speed_filepath)
inst_phase = calculate_eeg_phase(eegh_filepath, tetrode_number)




def get_selective_cells(cell_list):
    cells = []
    mean_directions = []
    pvals = []
    for cell_number in cell_list:
        spike_frame = filter_for_speed(speed, cell_number, clu, res)
        if spike_frame == 'NaN':
            # cells.append('low')
            # mean_directions.append('NaN')
            # pvals.append('NaN')
            pass
        else:
            spike_phase = calculate_spike_phase(inst_phase, spike_frame)
            mean_direction, resultant_vector_length, _, _, pval = stats(spike_phase)
            cells.append(cell_number)
            mean_directions.append(mean_direction)
            pvals.append(pval)
    selective_indices = [i for i, x in enumerate(pvals) if x < 0.05]
    for i in selective_indices:
        print('Cell ' + str(cells[i]) + ' is selective for ' + str(mean_directions[i] * (180 / np.pi)) + ' degrees')

    print('Proportion of cells that are selective:', len(selective_indices) / len(cells))
    return cells, mean_directions, pvals

print('PFC cells')
pfc_cells, pfc_mean_directions, pfc_pvals = get_selective_cells(pfc_clusters)

print('\nHPC cells')
hpc_cells, hpc_mean_directions, hpc_pvals = get_selective_cells(hpc_clusters)


# print(cells)
# print(pvals)

PFC cells
Cell 43 is selective for 346.4287503059515 degrees
Proportion of cells that are selective: 0.01639344262295082

HPC cells
Cell 102 is selective for 42.08760285148507 degrees
Cell 109 is selective for 10.561791252902951 degrees
Cell 115 is selective for 273.22026242369986 degrees
Cell 128 is selective for 264.6051715003737 degrees
Cell 132 is selective for 287.173810453618 degrees
Cell 136 is selective for 242.8103627043702 degrees
Cell 144 is selective for 313.7923727318179 degrees
Cell 156 is selective for 344.00316903590743 degrees
Cell 162 is selective for 140.86768088746456 degrees
Proportion of cells that are selective: 0.1956521739130435


In [11]:
print(len(pfc_cells))

61
