In [1]:
import numpy as np
import scipy.io as sio
import pandas as pd
from tqdm.notebook import tqdm
import os
import pickle

import scipy.signal
import matplotlib.pyplot as plt

In [2]:
mice = ['ym212','ym213','ym214','ym215','ym218','ym219','ym220','ym222','ym223','ym224','ym226','ym227']  # 12 in total
sessions = ['5FC','7FC']  # recent and remote
data_dir = r'/Users/david/Projects/Yuichi/MemoryAge_WinLen_2.56_Step_0.1_FREEZING_ONLY/Data'

sample_rate = 1600
window_len = int(sample_rate*2.56)
step = int(sample_rate*0.1)        
Z_SCORE = True

PRINT = True

# 1. Define functions for the feature extraction in Attention analysis
function 'calculate_features' will process each small segment (3 * 256) of a LFP (3 * 4096) from a mouse in one session, and return 24 features.

In [3]:
# SET PARAMS
sample_rate = 1600  # Hz

# freq bands
freq_bands = {
    'theta': (6, 12),
    'beta': (20, 30),
    'sGamma': (30, 50),
    'fGamma': (60, 90)
}

chunk_size = 256

In [4]:
# Function to filter the signal
def bandpass_filter(signal, band, sample_rate):
    '''
    signal: shape (n_channels, n_times), e.g., (3, 256)
    band: low and high cutoff frequencies, e.g., (6, 12)
    sample_rate: int, e.g., 1600Hz

    output: each row is a filtered signal
    '''
    nyquist = 0.5 * sample_rate
    low, high = band

    # Normalize the frequency to the Nyquist
    low = low / nyquist
    high = high / nyquist
    # build a bandpass Butterworth filter with order 4
    b, a = scipy.signal.butter(4, [low, high], btype='band')

    # Apply the filter to the signal
    # important: axis=1: filter along the columns
    return scipy.signal.filtfilt(b, a, signal, axis=1)


# Function to calculate the envelop of the signal
def calculate_envelop(signal):
    '''
    signal: filtered signal with shape (n_channels, n_times), e.g., (3, 256)
    axis: = 1, calculate envelop along the columns
    '''
    analytic_signal = scipy.signal.hilbert(signal, axis=1)
    return np.abs(analytic_signal)


# Function to calculate the PSD
def calculate_psd(signal, sample_rate):
    freqs, psd = scipy.signal.welch(signal, sample_rate, axis=1)
    return freqs, psd

In [5]:
def calculate_features(LFP, freq_bands, sample_rate):

    # ACC_theta, CA1_theta, BLA_theta, ACC_beta, CA1_beta, BLA_beta, ACC_sGamma, CA1_sGamma, BLA_sGamma, ACC_fGamma, CA1_fGamma, BLA_fGamma
    psd_features = []

    # ACC-CA1-theta_corr, ACC-BLA-theta_corr, CA1-BLA-theta_corr, ACC-CA1-beta_corr, ACC-BLA-beta_corr, CA1-BLA-beta_corr, 
    # ACC-CA1-sGamma_corr, ACC-BLA-sGamma_corr, CA1-BLA-sGamma_corr, ACC-CA1-fGamma_corr, ACC-BLA-fGamma_corr, CA1-BLA-fGamma_corr
    corr_features = []

    for band, (low, high) in freq_bands.items():   # band = 'theta', 'beta', 'sGamma', 'fGamma'
        # ----- Filtering each channel
        filtered_signal = bandpass_filter(LFP, (low, high), sample_rate)  # filtered_signal.shape = (3, 256)
        # ----- Calculating the envelop of each filtered signal of each channel
        envelop = calculate_envelop(filtered_signal)                      # envelop.shape = (3, 256)

        # exclude 15 time points in the beginning and the end of the filtered signal and its envelop
        # becasue the envelops have distortions at the beginning and the end
        exclude_edge = 15
        
        # ----- Calculate mean PSD in the freq range
        freqs, psd = calculate_psd(filtered_signal, sample_rate)
        # freqs.shape = (129,), psd.shape = (3, 129)
        for ch_i in range(psd.shape[0]):
            # select the psd in the freq range (low, high)
            idx = np.logical_and(freqs >= low, freqs <= high)
            # get the mean psd in the freq range
            mean_psd = np.mean(psd[ch_i, idx])
            psd_features.append(mean_psd)

        # Calculating correlation
        # we need to exclude the edge of the envelops because they have distortions
        # corrcoef will calculate the correlation between each pair of channels
        corr_matrix = np.corrcoef(envelop[:,exclude_edge:-exclude_edge])
        # append corr_matrix[0,1], corr_matrix[0,2], corr_matrix[1,2]
        corr_features.extend(corr_matrix[np.triu_indices(3, k=1)])

    features = np.array(psd_features + corr_features)

    return features

# 2. The original code for Freezing segments creation

In [6]:
tot_num_seg_recent = 0
tot_num_seg_remote = 0

for mouse in tqdm(mice):
    for session in sessions:
        # ---------------------------------------------------
        # load LFP and the corresponding TS
        lfp_name = data_dir + '/' + mouse + '_' + session + '_' + 'LFP.mat'
        lfp_and_ts = sio.loadmat(lfp_name)
        lfp = lfp_and_ts['LFP_3_regions'].transpose()   # e.g., shape=(3, 987136)
        ts = lfp_and_ts['LFP_ts_usec'].squeeze()        # e.g., shape=(987136)
        del lfp_and_ts
        
        # load (fB, fE)
        fre_ts_name = data_dir + '/' + mouse + '_' + session + '_' + 'Freeze_Ts.csv'
        fre_B_E = pd.read_csv(fre_ts_name, header=None)
        fre_B_E = fre_B_E.rename(columns={0:'fB',1:'fE'})

        # check whether 'the 1st start timestamp of freeze behavior' happens later than 'LFP timestamp start'
        if fre_B_E.iloc[0,0] < ts[0]:
            print(f'Attention: Mouse{mouse}, Session{session}.')
        # check whether 'the last end timestamp of freeze behavior' happens earlier than 'LFP timestamp end'
        if fre_B_E.iloc[-1,1] > ts[-1]:
            print(f'Attention: Mouse{mouse}, Session{session}.')
        # After check, we can conclude that all the timestamps in fre_B_E is in the range of ts

        # Now, we have
        # - lfp           in millivolts
        # - ts            in micro second, the lfp is the downsampling signal with a sample rate=1600Hz
        #                 any two adjacent ts have a diff 625us  
        #                 e.g., np.diff(ts[:,0]) 
        # - fre_B_E   in micro second

        # ---------------------------------------------------
        #   Z-SCORE normalization
        # ---------------------------------------------------
        if Z_SCORE:
            freeze_lfp_mean_std = pd.read_csv(r'./check_freeze_lfp_mean_std_figs/freeze_lfp_mean_std.csv')
            freeze_lfp_mean_std = freeze_lfp_mean_std.set_index(['Mouse-Session'])

            # the mean and std of the current session freezing lfp
            lfp_mean = freeze_lfp_mean_std.loc[mouse + '_' + session, 'Mean']
            lfp_std = freeze_lfp_mean_std.loc[mouse + '_' + session, 'Std']

            lfp = (lfp - lfp_mean)/lfp_std

            # Now, we have
            # - lfp        with freezing period z-score normalized
            #              if you only get the freezing segments, then the mean=0, std=1
            # TODO: How to deal with the outliers?

        # ---------------------------------------------------
        # Convert the fB fE from machine time to the corresponding idx on lfp array
        # e.g., fB machine time -> lfp machine time -> lfp array idx

        # we make an identical df as fre_B_E
        # this df is used to save each fB or fE's idx on lfp data array
        fre_B_E_idx = fre_B_E.copy()
        fre_B_E_idx['fB'] = np.nan
        fre_B_E_idx['fE'] = np.nan

        ts_starting_point = 0
        # note that too many loops is not efficient, but consider the small data scale, it is fine.
        # the time complicity is O(ts.shape[0])
        for row_i in range(len(fre_B_E)):    
            for col_j in range(2):
                behav_ts = fre_B_E.iloc[row_i, col_j]

                for ts_idx in range(ts_starting_point, ts.shape[0]):
                    if behav_ts <= ts[ts_idx]:
                        fre_B_E_idx.iloc[row_i, col_j] = ts_idx
                        ts_starting_point = ts_idx + 1
                        break  # 'break' apply to the innermost loop.

        # make sure the indice are int type
        fre_B_E_idx = fre_B_E_idx.astype(int)
        # fE - fB to get the length of each period
        fre_periods = fre_B_E_idx.iloc[:,1] - fre_B_E_idx.iloc[:,0]
        # remove all the periods smaller than 'window_len' defined above
        periods_keep = window_len <= fre_periods  # for example, remove all the periods smaller than 2.56s
        fre_B_E_idx = fre_B_E_idx.loc[periods_keep, :].reset_index(drop=True)

        del row_i, col_j, behav_ts, ts_idx, ts_starting_point, fre_periods, periods_keep
        # Now, we have
        # - fre_B_E_idx
        
        # ---------------------------------------------------
        # now we make each 'window_len' lfp using freezing timestamps

        segment_all = []
        feature_all = []

        # for each Begin End pair
        for B_E_pair_idx in range(len(fre_B_E_idx)):
            # get the Begin End index
            B_tmp = fre_B_E_idx.iloc[B_E_pair_idx, 0]
            E_tmp = fre_B_E_idx.iloc[B_E_pair_idx, 1]

            for segment_start in range(B_tmp, E_tmp - window_len + 2, step):
            # why do I set the ending point to E_tmp - window_len + 2?
            # because even if you get the last value, which is E_tmp - window_len + 1, it is still ok.
            # because from E_tmp - window_len + 1 to E_tmp (inclusive), there are window_len points! 

                # cut lfp
                segment = lfp[:, segment_start: segment_start + window_len]
                segment_all.append(segment)

                # ----------- UPDATE for the attention analysis -----------
                feature_1_sample = []
                # for each small segment with the len of 256 (transformer's chunk size), calculate the features
                for chunk_i in range(0, window_len, chunk_size):
                    # get the small segment
                    segment_small = segment[:, chunk_i:chunk_i+chunk_size]
                    # calculate the features
                    feature_1_small_segment = calculate_features(segment_small, freq_bands, sample_rate)  # shape=(24,)
                    feature_1_sample.append(feature_1_small_segment)
                
                feature_all.append(np.stack(feature_1_sample))  # each time append a (16,24) array for one sample
                # -------------------------------------------------------------------
        
        if session == '5FC':  # recent
            tot_num_seg_recent += len(segment_all)
            if PRINT:
                print(f'Mouse:{mouse} Session:{session}| There are {len(segment_all)} segments.')
                recent_seg_num_this_mouse = len(segment_all)

            # make segment_all has the shape of segments*channels*time
            segment_all = np.stack(segment_all)
            # save as pickle
            file_name = mouse + '_Recent'
            with open(os.path.join(data_dir, 'pickle_for_transformer', file_name), 'wb') as f:
                pickle.dump(segment_all, f)
            
            with open(os.path.join(data_dir, 'features_for_attention_analysis', file_name), 'wb') as f:
                pickle.dump(feature_all, f) 

        else:                 # remote
            tot_num_seg_remote += len(segment_all)
            if PRINT:
                print(f'            Session:{session}| There are {len(segment_all)} segments.')
                remote_seg_num_this_mouse = len(segment_all)
                print(f'            Total:{recent_seg_num_this_mouse+remote_seg_num_this_mouse}')

            # make segment_all has the shape of segments*channels*time
            segment_all = np.stack(segment_all)
            # save as pickle
            file_name = mouse + '_Remote'
            with open(os.path.join(data_dir, 'pickle_for_transformer', file_name), 'wb') as f:
                pickle.dump(segment_all, f)
            
            with open(os.path.join(data_dir, 'features_for_attention_analysis', file_name), 'wb') as f:
                pickle.dump(feature_all, f)

print(f'There are {tot_num_seg_recent} segments in recent.')
print(f'There are {tot_num_seg_remote} segments in remote.')

A Jupyter Widget

Mouse:ym212 Session:5FC| There are 1076 segments.
            Session:7FC| There are 452 segments.
            Total:1528
Mouse:ym213 Session:5FC| There are 274 segments.
            Session:7FC| There are 443 segments.
            Total:717
Mouse:ym214 Session:5FC| There are 142 segments.
            Session:7FC| There are 344 segments.
            Total:486
Mouse:ym215 Session:5FC| There are 1608 segments.
            Session:7FC| There are 442 segments.
            Total:2050
Mouse:ym218 Session:5FC| There are 1601 segments.
            Session:7FC| There are 1075 segments.
            Total:2676
Mouse:ym219 Session:5FC| There are 288 segments.
            Session:7FC| There are 153 segments.
            Total:441
Mouse:ym220 Session:5FC| There are 1597 segments.
            Session:7FC| There are 745 segments.
            Total:2342
Mouse:ym222 Session:5FC| There are 587 segments.
            Session:7FC| There are 1936 segments.
            Total:2523
Mouse:ym223 Session:5FC| Ther