# Imports

In [None]:
import os
import glob
import datetime
import random
import copy
import json
import sys
import shelve

import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
import ptitprince as pt
import seaborn as sns
from statannotations.Annotator import Annotator
from scipy import signal, stats, interpolate
from scipy.io import loadmat
from scipy.fft import *
from sklearn import preprocessing
import sklearn as skl
from tqdm.notebook import tqdm, trange

%matplotlib qt

In [None]:
def transposeStringArray(stringArray):
    outputVector = []
    for i in range(len(stringArray[0])):
        outputVector.append(stringArray[0][i]+stringArray[1][i]+stringArray[2][i])
    return  np.array(outputVector)

def morlet_transform(array, sampling_freq, scale_max):
    scales = np.arange(1, scale_max)  ##Return numbers spaced evenly on a log scale (a geometric progression).
    array = [(i - np.min(array)) / (np.max(array) - np.min(array)) for i in array]  # NORMALIZE - important
    # array = skl.preprocessing.MinMaxScaler().fit_transform(array.reshape(1,-1))
    # wavelet transform
    coef, freqs = pywt.cwt(array, scales, "cmor1.5-1.0", sampling_period = 1/sampling_freq)
    power = abs(coef)
    average_power = np.mean(power, axis=1)
    return coef, freqs, power, average_power, scales


def surrogate(df, col_index, sampling_freq, scale_max, n_surrogate=10):
    """Randomly shuffled datapoint-datasets analysis to contrast peaks in ori data (if any)"""

    arr2shuffle = copy.deepcopy(df)
    arr2shuffle = list(arr2shuffle.iloc[:, col_index])
    shuffled = []
    for i in range(n_surrogate + 1):
        shuffled.append(random.sample(arr2shuffle, k=len(arr2shuffle)))

    coefs = []
    freqs = []
    power = []
    average_power = []
    for arr in shuffled:
        coefs.append(morlet_transform(array=arr, sampling_freq=sampling_freq, scale_max=scale_max)[0])
        freqs.append(morlet_transform(array=arr, sampling_freq=sampling_freq, scale_max=scale_max)[1])
        power.append(morlet_transform(array=arr, sampling_freq=sampling_freq, scale_max=scale_max)[2])
        average_power.append(morlet_transform(array=arr, sampling_freq=sampling_freq, scale_max=scale_max)[3])
        ssd = np.std(average_power, axis=0)
    coefs, freqs, power, average_power = [np.mean(i, axis=0) for i in [coefs, freqs, power, average_power]]

    return coefs, freqs, power, average_power, ssd, shuffled

def processEventFile(eventDict):
    eventTimes = eventDict['time']
    eventDurs = eventTimes[:,-1] - eventTimes[:,0]
    eventMask = eventDurs <= 3

    maskedEventDict = {}
    maskedEventDict['time'] = eventDict['time'][eventMask,:]
    maskedEventDict['cond'] = eventDict['cond'][eventMask]
    maskedEventDict['sar'] = eventDict['sar'][eventMask,:]
    
    return maskedEventDict, eventMask

def find_nearest(array, values):
    # make sure array is a numpy array
    array = np.array(array)

    # get insert positions
    idxs = np.searchsorted(array, values, side="left")
    
    # find indexes where previous index is closer
    prev_idx_is_less = ((idxs == len(array))|(np.fabs(values - array[np.maximum(idxs-1, 0)]) < np.fabs(values - array[np.minimum(idxs, len(array)-1)])))
    idxs[prev_idx_is_less] -= 1
    
    return array[idxs], idxs


def annotateTFPlot(axe, eventDict, eventFocus = None, takeAllEvents = True, eventConds = None, onlyTicks = False):
    
    eventLabels = ['D-On', 'T-L', 'S-On', 'S-Off', 'M-On', 'M-Off', 'FB']
    
    if takeAllEvents:
        eventTimes = eventDict['time']
    else:
        eventTimes = eventDict['time'][eventConds,:]
    
            
    meanET = np.mean(eventTimes, 0)
    meanET -= meanET[0]
    meanET -= meanET[eventFocus]

    if onlyTicks:
        return meanET
    
    else:    
        [ymin, ymax] = axe.get_ylim()
        
        axe.vlines(meanET, ymin = ymin, ymax = ymax, ls = 'dashed', colors = 'seashell')
        
        for lInd, label in enumerate(eventLabels):
            if lInd == 1:
                continue
            elif lInd == 3:
                xoffset = -0.1
            else:
                xoffset = 0
            axe.text(x = meanET[lInd] + xoffset, y = ymax+1, s = label, c = 'k', rotation = 45)


def genDataDicts(fileLoc):
    
    folders = os.listdir(fileLoc)
    
    lfpDict = {}
    eventDict = {}
    prosEventDict = {}
    
    for folder in folders:
        
        folderLoc = fileLoc + '/' + folder
        files = os.listdir(folderLoc)
    
        if (folder + "_LFP.mat" in files) & ('evt.mat' in files) & (len(folder) == 8):

            data = loadmat(folderLoc + '/' + folder + "_LFP.mat", squeeze_me = True, simplify_cells = True)
            eDict = loadmat(folderLoc + '/' + "evt.mat", squeeze_me = True, simplify_cells = True)
            
            lfpDict[folder] = {'LFP': data['LFP'] ,'sr' : data['LFP_SR'],
                            'timestamps' : data['timestamps'],
                            'blocs' : transposeStringArray(data['blocs']),
                            'blocTimes' : data['blocRanges']}
            
            eventDict[folder] = eDict['evt']        
            eDict2, eMask = processEventFile(eDict['evt'])
            prosEventDict[folder] = {'processedEvents': eDict2, 'eMask': eMask}
            
        
    return lfpDict, eventDict, prosEventDict


class NumpyEncoder(json.JSONEncoder):
    """ Special json encoder for numpy types """
    def default(self, obj):
        if isinstance(obj, (np.int_, np.intc, np.intp, np.int8,
                            np.int16, np.int32, np.int64, np.uint8,
                            np.uint16, np.uint32, np.uint64)):
            return int(obj)
        elif isinstance(obj, (np.float_, np.float16, np.float32,
                              np.float64)):
            return float(obj)
        elif isinstance(obj, (np.ndarray,)):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)


def genDataDicts_Spike(fileLoc):
    
    folders = os.listdir(fileLoc)
    
    neuronDict = {}
    eventDict = {}
    prosEvents = {}
    prosEventDict = {}
    
    for folder in folders:
        
        folderLoc = fileLoc + '/' + folder
        files = os.listdir(folderLoc)

        
        # print(folder)
        
        if ('spk_lfp_evt_eog.mat' in files) & ('evt.mat' in files) & (len(folder) == 8):
        
            matSpikes = loadmat(folderLoc + '/' + 'spk_lfp_evt_eog.mat', simplify_cells = True)
            matEvents = loadmat(folderLoc + '/' + 'evt.mat', simplify_cells = True)


            if type(matSpikes['nexFile']['neurons']) == list:
                
                for neuron in matSpikes['nexFile']['neurons']:
                
                    nID = neuron['name']
                    nID = nID.replace('Elec_', 'E')
                    nID = nID.replace('_Neuron_', 'N')
                    rDate = datetime.datetime.strptime(folder, '%d%m%Y').strftime('%d%m%y')
                    neuronDict[rDate + nID] = {'timestamps': neuron['timestamps'], 'date': rDate,
                                            'Electrode': nID[-3], 'Number': nID[-1],
                                            'xPos': neuron['xPos'], 'yPos': neuron['yPos']}
            
            elif type(matSpikes['nexFile']['neurons']) == dict:
                
                neuron = matSpikes['nexFile']['neurons']
                nID = neuron['name']
                nID = nID.replace('Elec_', 'E')
                nID = nID.replace('_Neuron_', 'N')
                rDate = datetime.datetime.strptime(folder, '%d%m%Y').strftime('%d%m%y')
                neuronDict[rDate + nID] = {'timestamps': neuron['timestamps'], 'date': rDate,
                                            'Electrode': nID[-3], 'Number': nID[-1],
                                            'xPos': neuron['xPos'], 'yPos': neuron['yPos']}
                
            
    return neuronDict

def formatData(data, time, timeLocks, binSize, maxLead, maxLag, sr=781.25, isComplex=False):
    
    if len(data) != len(time):
        raise Exception('Data and time must have equal length')

    num_trials = len(timeLocks)
    
    rangeInds = np.round(np.array([maxLead*sr, maxLag*sr], dtype = np.int64))
    
    nearestTimes, nearestInds = find_nearest(time, timeLocks)
    beginInds = nearestInds - rangeInds[0]
    endInds = nearestInds + rangeInds[-1]
    
    fData = np.zeros((np.abs(rangeInds).sum(), data.shape[1], num_trials))

    if isComplex:
        fData = fData.astype(complex)
    
    for trialInd in range(num_trials):
        
        trial_data = data[beginInds[trialInd] : endInds[trialInd],:]
        trial_length = trial_data.shape[0]
        
        if trial_length >= fData.shape[0]:
            fData[:, :, trialInd] = trial_data
        else:
            fData[:trial_length, :, trialInd] = trial_data
    
    return fData

def formatData_Spike(data, time, timeLocks, binSize, maxLead, maxLag, method):
    
    if len(data) != len(time):
        raise Exception('Data and time must have equal length')

    binBounds = np.linspace(-np.round(maxLead/binSize, 1),np.round(maxLag/binSize, 1), int((maxLead+maxLag)/binSize)+1)*binSize
    timeBounds = np.array(list(zip(binBounds[0:-1],binBounds[1:])))
    
    nTrials = len(timeLocks)
    nBins = len(timeBounds)

    fData = np.zeros((nBins, nTrials))

    for iTrial in range(nTrials):

        for iBin in range(nBins):
            if method == 'count':
                fData[iBin, iTrial] = np.sum(data[(time >= (timeLocks[iTrial] + timeBounds[iBin][0])) & (time < (timeLocks[iTrial] + timeBounds[iBin][1]))])
            elif method == 'rate':
                spkCount = np.sum(data[(time >= (timeLocks[iTrial] + timeBounds[iBin][0])) & (time < (timeLocks[iTrial] + timeBounds[iBin][1]))])     
                fData[iBin, iTrial] = spkCount/np.diff(timeBounds[iBin,:])
                
    return fData

In [None]:
def spike_to_rate(spiketimes, nbins = 100, remove_tails = False, axis = -1):

    spiketimes = spiketimes[np.array(spiketimes)!=None]

    binned_spikes, _ = np.histogram(spiketimes, bins = nbins, range = (0,1000))

    binned_fr = smooth_rates(binned_spikes, nbins = nbins, remove_tails = remove_tails, axis = axis)
    
    return binned_fr

def smooth_rates(firing_rate, nbins = 100, remove_tails = False, axis = -1, order = 5, lp_savgol = 3, lp_filtfilt = 2):
    
    nneigh = 70
    
    if remove_tails:
        lowpass = signal.butter(order, lp_savgol, 'lp', fs=nbins, output='sos')
        firing_rate = signal.savgol_filter(firing_rate, nneigh, order, mode = 'mirror', axis = axis)
    else:
        lowpass = signal.butter(order, lp_filtfilt, 'lp', fs=nbins, output='sos')

    firing_rate = signal.sosfiltfilt(lowpass, firing_rate, axis = axis)
    
    return firing_rate

In [None]:
def plot_raster(raster, ax, axis = -1, offset = 0, marker = '.', color = 'b', alpha = 1, markersize = 1):

    raster = np.array(raster)
    raster[raster > 1] = 1    
    
    for rInd in range(raster.shape[axis]):
        
        t_spikes = raster[:,rInd]
        sinds = np.where(t_spikes != 0)[0]
        ax.scatter(sinds, t_spikes[sinds]*rInd + offset, marker = marker, color = color, alpha = alpha, s = markersize)

In [None]:
def plot_all_ERPs(fig, ax, data, color='k', lw=1, alpha=0.5):

    max_lead = data.attrs['MaxLead']*1000
    max_lag = data.attrs['MaxLag']*1000
    
    time_vec = np.linspace(-max_lead, max_lag, data.shape[0])    

    ax.plot(time_vec, data.to_numpy().reshape(time_vec.shape[0], -1), c=color, lw=lw, alpha=alpha)

    ymin, ymax = data.min().round(1), data.max().round(1)
    
    ax.set_xticks([-max_lead, max_lag])    
    ax.set_xticklabels([str(int(-max_lead)), 'FB'])
    ax.set_ylabel(r'Amplitude (mV)')
    ax.set_xlabel('Time (ms)')
    ax.set_aspect('auto', adjustable='box')
    fig.show()

In [None]:
def generate_colors(n):
    """
    Generates a list of n distinct colors.
    
    Parameters:
    n (int): The number of colors to generate.
    
    Returns:
    list: A list of n color codes in hexadecimal format.
    """
    colors = plt.cm.get_cmap('hsv', n)
    return [colors(i) for i in range(n)]

In [None]:
def plot_erp_maxs(fig, ax, data, linewidth=1, point_size=5, labelpad=10, tickpad=10, saturation=0.8, alpha=1, width_viol=0.4, width_box=0.4, bw=0.4,
                  pointplot=False, orient='v', linecolor='darkslategray', move=0, offset=0, legend=False, hide_ticks=True):

    palette = {'Presence': 'dodgerblue', 'Absence': 'crimson'}
    
    significanceComparisons=[('Presence','Absence')]
    configuration = {'test':'Mann-Whitney',  'text_format':'star', 'loc':'outside', 'line_width':linewidth}

    if orient == 'h':
        x, y = 'value', 'condition'
    else:
        x, y = 'condition', 'value'
    
    fig_args = {'x': x, 'y': y, 'data': data, 'dodge': True, 'palette':palette, 'linecolor': linecolor,
                'point_size':point_size, 'linewidth':linewidth, 'box_linewidth': linewidth, 'saturation':saturation, 'alpha': alpha,
                'width_viol':width_viol, 'width_box':width_box, 'bw':bw}
    
    rainclouds = pt.RainCloud(ax=ax, orient=orient, **fig_args, cut=0, pointplot=pointplot, box_fliersize=0, box_whiskerprops=dict(linewidth=linewidth))

    if legend:
        handles, labels = rainclouds.get_legend_handles_labels()
        ax.legend(handles, labels=['Presence', 'Absence'], frameon=False)  # Adjust labels as needed
    
    annotator = Annotator(ax=ax, pairs=significanceComparisons, **fig_args, plot='boxplot', verbose=False)
    annotator.configure(**configuration).apply_test().annotate()

    if orient == 'h':
        ax.set_ylabel('')
        ax.set_xlabel('Peak ERP amplitude', labelpad=labelpad)
        ax.tick_params(axis='y', length=0, pad=tickpad)
        ax.tick_params(axis='x',)
        if hide_ticks:
            ax.set_yticks([])
    else:
        ax.set_xlabel('')
        ax.set_ylabel('Peak ERP amplitude', labelpad=labelpad)
        ax.tick_params(axis='x', length=0, pad=tickpad)
        ax.tick_params(axis='y')
        if hide_ticks:
            ax.set_xticks([])
    
    ax.patch.set_alpha(0.0)


# Data Dictionary

In [None]:
parent_preprocess_dir = ''

lfpDict, eventDict, prosEventDict = genDataDicts('')
neurDict = genDataDicts_Spike('')

In [None]:
Sessions = list(lfpDict.keys())
nSessions = len(Sessions)
sampling_rate = lfpDict['01042014']['sr']

# Events

In [None]:
lenLFPs = [len(v['timestamps']) for k,v in lfpDict.items()]
lenTrials = [len(v['processedEvents']['cond']) for k,v in prosEventDict.items()]
maxTimestamps = max(lenLFPs)
maxTrialLens = max(lenTrials)    

centerEvent = -1

session = Sessions[0]
LFP = lfpDict[session]['LFP']
timestamps = lfpDict[session]['timestamps']        
sessionEventDict = prosEventDict[session]['processedEvents']
timeLocks = sessionEventDict['time'][:,centerEvent]

aggEvtArray = np.empty((maxTrialLens, 3, nSessions), dtype = np.float64)
aggBehArray = np.empty((maxTrialLens, 3, nSessions), dtype = np.float64)
aggCndArray = np.empty((maxTrialLens, nSessions), dtype = np.float64)
    
for sInd, session in enumerate(prosEventDict):

    sessionEventDict = prosEventDict[session]['processedEvents']

    session_time = sessionEventDict['time']
    session_events = sessionEventDict['sar']
    session_conds = sessionEventDict['cond']

    num_trials = session_events.shape[0]
    n_sess_evts = session_events.shape[1]

    if n_sess_evts == 7:
        session_events = session_events[:, [2,4,6]]
        session_events[:,1] -= session_events[:,1].min()
    
    session_time = session_time[:, [2,4,6]]
    
    aggEvtArray[:num_trials, :, sInd] = session_time
    aggEvtArray[num_trials:, :, sInd] = np.nan
    aggBehArray[:num_trials, :, sInd] = session_events
    aggBehArray[num_trials:, :, sInd] = np.nan
    aggCndArray[:num_trials, sInd] = session_conds
    aggCndArray[num_trials:, sInd] = np.nan

In [None]:
eventDims = ["trial", "event", "session"]
eventCoordDict = {"session": Sessions, "event": ['cue', 'movement', 'feedback']}

xr.DataArray(aggEvtArray, dims = eventDims, coords = eventCoordDict).to_netcdf(parent_preprocess_dir + 'Demolliens_aggEvents_All.nc')
xr.DataArray(aggBehArray, dims = eventDims, coords = eventCoordDict).to_netcdf(parent_preprocess_dir + 'Demolliens_aggBehaviors_All.nc')
xr.DataArray(aggCndArray, dims = ["trial", "session"], coords = {"session": Sessions}).to_netcdf(parent_preprocess_dir + 'Demolliens_aggConds_All.nc')

In [None]:
aggEvtArray = xr.load_dataarray(parent_preprocess_dir + 'Demolliens_aggEvents_All.nc')
aggBehArray = xr.load_dataarray(parent_preprocess_dir + 'Demolliens_aggBehaviors_All.nc')
aggCndArray = xr.load_dataarray(parent_preprocess_dir + 'Demolliens_aggConds_All.nc')

# Spikes

In [None]:
maxLead = 0.5
maxLag = 0.5
centerEvent = -1
binsize = 0.005

lenList = []
typeList = []
for name, neuron in neurDict.items():
    ts = neuron['timestamps']
    typeList.append(type(ts))
    if type(ts) == np.ndarray:
        lenList.append(len(neuron['timestamps']))
    else:
        lenList.append(1)
lenList = np.array(lenList)

neuronNames = np.array(list(neurDict.keys()))
goodNeurons = neuronNames[np.where(lenList >= 1000)[0]]
neurDict = {k:v for k,v in neurDict.items()if (k in goodNeurons)
            and (datetime.datetime.strptime(v['date'], '%d%m%y').strftime('%d%m%Y') in Sessions)}
neuronNames = np.array(list(neurDict.keys()))

neuron = neurDict['010414E1N1']
session = datetime.datetime.strptime(neuron['date'], '%d%m%y').strftime('%d%m%Y')
timeLocks = prosEventDict[session]['processedEvents']['time'][:,centerEvent]
timestamps = neuron['timestamps']

firstArray = formatData_Spike(np.ones((len(timestamps))), timestamps,
                                  timeLocks, binsize, maxLead, maxLag, 'count')

neuron_properties = np.zeros((len(neuronNames), 5), dtype=object)
neuron_array = np.zeros((firstArray.shape[0], maxTrialLens, len(list(neurDict.keys()))))

In [None]:
for nInd, (n_name, neuron) in enumerate(neurDict.items()):

    if nInd%10 == 0:
        print(nInd)

    timestamps, n_date, e_ind, n_num, x, y = list(neuron.values())
    
    if e_ind == 'N':
        e_ind = '1'
    
    session = datetime.datetime.strptime(neuron['date'], '%d%m%y').strftime('%d%m%Y')
    timeLocks = prosEventDict[session]['processedEvents']['time'][:,centerEvent]
    nTrials = len(timeLocks)

    neuronTrials = formatData_Spike(np.ones((len(timestamps))), timestamps, timeLocks, binsize, maxLead, maxLag, 'count')

    neuron_array[:, :nTrials, nInd] = neuronTrials
    neuron_array[:, nTrials:, nInd] = np.nan
    neuron_properties[nInd, :] = np.array([session, e_ind, n_num, x, y])

In [None]:
propCoords = {'neuron': neuronNames, 'property': ["date", "electrode", "number", "x", "y"]}
propDims = ['neuron', 'property']
neuronDims = ["time", "trial", "neuron"]
neuronCoords = {"neuron": neuronNames}

neuron_array = xr.DataArray(neuron_array, dims=neuronDims, coords=neuronCoords)
neuron_properties = xr.DataArray(neuron_properties, dims=propDims, coords=propCoords)
neuron_dataset = xr.Dataset({'neuron_array':neuron_array, 'neuron_properties': neuron_properties})
neuron_dataset.to_netcdf(parent_preprocess_dir + 'Demolliens_aggNeurons_All.nc')

In [None]:
neuron_dataset = xr.load_dataset(parent_preprocess_dir + 'Demolliens_aggNeurons_All.nc')
neuron_array = neuron_dataset.neuron_array
neuron_properties = neuron_dataset.neuron_properties

In [None]:
spikes_array = xr.DataArray(np.zeros((neuron_array.time.shape[0], 50, 2, neuron_array.neuron.shape[0])), dims=('time', 'trial', 'condition', 'neuron'),
        coords=dict(condition=['Presence', 'Absence'], neuron=neuron_array.neuron), attrs={'feedback': 'negative', 'maxLead':maxLead, 'maxLag': maxLag})

for n_ind, neuron_name in enumerate(neuron_array.neuron):

    if nInd%10 == 0:
        print(nInd)
    
    neuron = neuron_array.sel(neuron=neuron_name)
    
    neuron_session = neuron_properties.sel(neuron=neuron_name, property='date')
    neuron_timestamps = aggEvtArray.sel(session=neuron_session)
    neuron_behaviours = aggBehArray.sel(session=neuron_session)
    neuron_conditions = aggCndArray.sel(session=neuron_session)
    
    p_cond = neuron_conditions == 3
    a_cond = neuron_conditions == 1
    
    neg_fbs = neuron_behaviours[:,-1] == 0

    p_trials = (p_cond & neg_fbs).to_numpy()
    a_trials = (a_cond & neg_fbs).to_numpy()
    
    spikes_array[:, :p_trials.sum(), 0, n_ind] = neuron[:, p_trials].to_numpy()
    spikes_array[:, p_trials.sum():, 0, n_ind] = np.nan
    
    spikes_array[:, :a_trials.sum(), 1, n_ind] = neuron[:, a_trials].to_numpy()
    spikes_array[:, a_trials.sum():, 1, n_ind] = np.nan
    

In [None]:
p_name = '310314E1N3'
a_name = '150414E1N2'

fig, axes = plt.subplots(2,2, sharex=True)
markersize=10

plot_raster(spikes_array.sel(neuron=p_name, condition='Presence'), axes[0,0], color='b', markersize=markersize)
plot_raster(spikes_array.sel(neuron=p_name, condition='Absence'), axes[1,0], color='r', markersize=markersize)

plot_raster(spikes_array.sel(neuron=a_name, condition='Presence'), axes[0,1], color='b', markersize=markersize)
plot_raster(spikes_array.sel(neuron=a_name, condition='Absence'), axes[1,1], color='r', markersize=markersize)
fig.show()

In [None]:
spikes_array.to_netcdf(parent_preprocess_dir + 'Demolliens_nFB_Spikes.nc')

# Spikes from previous data

In [None]:
save_string = 'SPK/'
spikesParentDir = ''
neuron_dirs = glob.glob(spikesParentDir + '*')

neuron_names = np.array([n_dir.split('/')[-1] for n_dir in neuron_dirs])
neuronStatDict = pd.read_excel('./Social_Asocial_List_2014_NF.xlsx', sheet_name = None, skipfooter = 2)

In [None]:
cond_spikes = {'Presence': 'cnd3_FN', 'Absence': 'cnd1_FN'}

aggSpikes = np.empty((2, 80, 200, len(neuron_names)),dtype = object)
aggSpikes = xr.DataArray(aggSpikes, dims = ['condition', 'trial', 'time', 'neuron'], coords = {'condition': list(cond_spikes.keys()), 'neuron': neuron_names})

for nInd, nDir in enumerate(neuron_dirs):
    neuron = neuron_names[nInd]
    
    spikeData = loadmat(nDir + '/' + neuron, squeeze_me = True)    
        
    for cInd, condition in enumerate(cond_spikes):
        tempSpikes = np.array(spikeData[cond_spikes[condition]])
        aggSpikes[cInd, :tempSpikes.shape[0], :, nInd] = np.array(spikeData[cond_spikes[condition]])
        aggSpikes[cInd, tempSpikes.shape[0]:, :, nInd] = np.nan

In [None]:
time_slice = np.arange(50,150)
spike_array = aggSpikes.sel(time=time_slice)

neuron_names = spike_array.coords['neuron'].to_numpy()

In [None]:
socialNames = neuronStatDict['Social']['Neuron_Name'].to_numpy()
asocialNames = neuronStatDict['Asocial']['Neuron_Name'].to_numpy()

sNeurons = [neuron for neuron in neuron_names if (neuron in socialNames) and (neuron not in asocialNames)]
aNeurons = [neuron for neuron in neuron_names if (neuron in asocialNames) and (neuron not in socialNames)]

sel_neurons = sNeurons + aNeurons

In [None]:
firing_rates_emp = aggSpikes.sel(neuron=sel_neurons)
time_vec = firing_rates_emp['time']

firing_rates_emp = firing_rates_emp.copy().mean('trial')
firing_rates_emp = xr.DataArray(smooth_rates(firing_rates_emp.to_numpy(), remove_tails=False, axis=1),
                                dims=firing_rates_emp.dims, coords=firing_rates_emp.coords)
firing_rates_emp = firing_rates_emp[:, time_slice, :]
firing_rates_emp /= firing_rates_emp.max('time').max('condition')
firing_rates_emp.to_netcdf(parent_preprocess_dir + 'Demolliens_SPK_rates_norm.nc')
firing_rates_emp = aggSpikes.sel(neuron=sel_neurons)

spike_array.to_netcdf(parent_preprocess_dir + 'Demolliens_SPKs.nc')                      
rate_neurons = spike_array.sel(neuron=sel_neurons).copy().mean('trial').max('time').to_dataframe(name = 'value').reset_index()

# LFPs

In [None]:
Sessions = list(lfpDict.keys())
nSessions = len(Sessions)

maxLead = 0.35
maxLag = 0.00
binSize = 1
centerEvent = -1

hp_filt = signal.butter(3, 2, 'hp', fs=sampling_rate, output='sos')
bandKeys = ['high_pass']

session = Sessions[0]
LFP = lfpDict[session]['LFP']
timestamps = lfpDict[session]['timestamps']        
sessionEventDict = prosEventDict[session]['processedEvents']
timeLocks = sessionEventDict['time'][:,centerEvent]

filteredBand = signal.sosfiltfilt(hp_filt, LFP, axis=0)
firstArray = formatData(data=filteredBand, time=timestamps, timeLocks=timeLocks,
                                maxLead=maxLead, maxLag=maxLag, binSize=binSize)
nTrialTimesteps = firstArray.shape[0]

lfp_array = np.empty((nTrialTimesteps, maxTrialLens, 4, nSessions))

bads_list = []

color_list = generate_colors(nSessions)

for sInd, session in enumerate(lfpDict):

    if sInd%10 == 0:
        print((sInd, session))

    LFP = lfpDict[session]['LFP'].astype(np.float32)
    bad_channels = LFP.min(0) > -3000

    bad_inds = np.where(bad_channels)[0]

    bads_list.append(bad_inds)
    if bad_channels.sum() > 0:
        for _, ch_ind in enumerate(bad_inds):    
            LFP[:, ch_ind] = np.nan
        
    timestamps = lfpDict[session]['timestamps']        
    sessionEventDict = prosEventDict[session]['processedEvents']
    timeLocks = sessionEventDict['time'][:,centerEvent]
    nTrials = len(timeLocks)
    
    filteredBand = signal.sosfilt(hp_filt, LFP, axis = 0)

    sessTrialLFPs = formatData(data = filteredBand, time = timestamps, timeLocks = timeLocks,
                                    maxLead = maxLead, maxLag = maxLag, binSize=binSize)
    
    lfp_array[:, :nTrials, :, sInd] = sessTrialLFPs.swapaxes(1,2)
    lfp_array[:, nTrials:, :, sInd] = np.nan

In [None]:
lfpDims = ["time", "trial", "electrode", "session"]
lfpCoordDict = {"session": Sessions}
lfpAttributes = {'MaxLead': maxLead, 'MaxLag': maxLag, 'CenterEvent': centerEvent}
lfp_array = xr.DataArray(lfp_array, coords = lfpCoordDict, dims = lfpDims, attrs = lfpAttributes)
lfp_array.to_netcdf(parent_preprocess_dir + 'Demolliens_aggLFP_All.nc')

In [None]:
lfp_array = xr.load_dataarray(parent_preprocess_dir + 'Demolliens_aggLFP_All.nc')

In [None]:
fig, ax = plt.subplots()
time_vec = np.linspace(-maxLead*1000, maxLag*1000, lfp_array.time.shape[0])

# for sInd in np.arange(3, num_sessions, 5)+1:
for sInd in range(nSessions):
    ax.plot(lfp_array.isel(session=sInd).mean('trial'));
    fig.show();

In [None]:
sessionQuality = pd.read_excel('', converters = {'Session': lambda x: str(x)})
sessionQuality['Electrode'] -= 1
noted_sessions = sessionQuality['Session'].to_numpy()
goodElecs = (sessionQuality['Quality'] == 'excellent') | (sessionQuality['Quality'] == 'good')

In [None]:
elec_props = {}

session_array = []
elec_array = []

for sInd, session in enumerate(Sessions):

    if session in noted_sessions:
        quality_session = sessionQuality[sessionQuality['Session'] == session]
        good_elecs = (quality_session['Quality'] == 'excellent') | (quality_session['Quality'] == 'good')

        good_elecs = quality_session[goodElecs]['Electrode'].to_numpy()
        num_good_elecs = good_elecs.shape[0]        
        elec_props[session] = good_elecs 

        session_array.extend([session]*num_good_elecs)
        elec_array.extend(good_elecs)
    else:
        elec_props[session] = np.empty(0, dtype=np.int64)

erp_props = np.array([session_array, elec_array]).T

# ERPs

In [None]:
num_elec = erp_props.shape[0]
erp_array = np.zeros((lfp_array.shape[0], num_elec, 2))

bad_erps = []

flat_thresh = 1200
peak_window = np.arange(75, 125)

for eInd in range(num_elec):

    session, elec = erp_props[eInd]
    elec = int(elec)
    
    session_elecs = elec_props[session]
    
    lfpSession = lfp_array.sel(session=session)
    condSession = aggCndArray.sel(session=session)
    behSession = aggBehArray.sel(session=session)
    
    pFB = behSession[:,-1] == 1
    nFB = behSession[:,-1] == 0

    pCond = (condSession == 3) | (condSession == 4)
    aCond = (condSession == 1) | (condSession == 2)

    presenceCond = pCond & nFB
    absenceCond = aCond & nFB
    condition_list = [presenceCond, absenceCond]
    
    for cInd in range(2):

        condition_erp = np.nanmean(lfpSession[:, condition_list[cInd], elec],1)
        
        if (condition_erp[peak_window].max() > flat_thresh) & (condition_erp[peak_window].mean() > 100):  
            erp_array[:, eInd, cInd] = condition_erp
        else:
            erp_array[:, eInd, cInd] = np.nan
            bad_erps.append(condition_erp)

baselineRange = np.round(0.02*sampling_rate).astype(int)

erpCoordDict = {"condition":['Presence', 'Absence'], "electrode": [prop[0] + '_' + prop[1] for prop in erp_props]}
erpAttributes = {'MaxLead': maxLead, 'MaxLag': maxLag, 'CenterEvent': centerEvent}

erp_array = xr.DataArray(erp_array, dims = ["time", "electrode", "condition"], coords=erpCoordDict, attrs=erpAttributes)
erp_array = erp_array.dropna('electrode')
num_elec = erp_array.electrode.shape[0]

erp_array.to_netcdf('../task_erps_all_nFB.nc')

In [None]:
parent_preprocess_dir = ''

save_string = 'ERP/'

erpAttrs = erp_array.attrs
max_lead = erpAttrs['MaxLead']*1000
max_lag = erpAttrs['MaxLag']*1000

tend_erp = 100.0
dt_erp=0.1
t0_erp=0.0
ts_erp = np.arange(t0_erp, tend_erp + dt_erp, dt_erp)
nt = ts_erp.shape[0]

eInd = 0
emp_scale = int(lfp_array.mean('trial').max().round(-3)/1000)


srOrig = sampling_rate
srOrig = nt*srOrig/lfp_array.shape[0]
lowpass = signal.butter(3, 20, 'lp', fs=srOrig, output='sos')

tVec = np.linspace(0, ts_erp[-1], erp_array.shape[0])
tVec_Extended = ts_erp

task_ERPs = xr.DataArray(np.zeros((tVec_Extended.shape[0], erp_array.shape[1], 2)), dims=erp_array.dims, coords=erp_array.coords, attrs=erp_array.attrs)

In [None]:
for eInd in range(num_elec):
    
    newData = []

    for cInd in range(2):
        
        data = erp_array[:, eInd, cInd].copy()

        data = signal.sosfiltfilt(lowpass, data, axis = 0)
        interpFunc = interpolate.interp1d(tVec, data, kind='cubic')
        data = interpFunc(tVec_Extended)
        newData.append(data)

    newData = np.array(newData).T

    dataMax = np.max(newData, axis=0)
    dataMin = np.min(newData, axis=0)

    dataDiffs = dataMax - dataMin
    dataDiffs /= dataDiffs.max()

    newData = preprocessing.MinMaxScaler().fit_transform(newData)
    newData = newData * dataDiffs * emp_scale
    newData -= newData[0,:]

    newData = newData.T + np.random.randn(newData.shape[0])*0.0
    newData = newData.T
    
    task_ERPs[:, eInd, :] = newData

In [None]:
fig, ax = plt.subplots()
colors = ['b', 'r']
zords = [1,2]
for eInd in range(num_elec):
    for cInd in range(2):
        ax.plot(task_ERPs.isel(electrode=eInd, condition=cInd), c=colors[cInd], alpha=0.5, zorder=zords[cInd]);
fig.show()    

In [None]:
max_erps_df = task_ERPs.max('time').to_dataframe(name='value').reset_index().dropna()

In [None]:
task_ERPs_MinMax = task_ERPs.copy()

for eInd in range(num_elec):
    
    newData = []

    for cInd in range(2):
        
        data = erp_array[:, eInd, cInd].copy()
    
        data = signal.sosfiltfilt(lowpass, data, axis = 0)
        interpFunc = interpolate.interp1d(tVec, data, kind='cubic')
        data = interpFunc(tVec_Extended)
        newData.append(data)

    newData = np.array(newData).T

    dataMax = np.max(newData, axis=0)
    dataMin = np.min(newData, axis=0)

    dataDiffs = dataMax - dataMin
    dataDiffs /= dataDiffs.max()
    
    newData = preprocessing.MinMaxScaler().fit_transform(newData)
    newData = newData * dataDiffs
    
    newData -= newData[0,:]

    newData = newData.T + np.random.randn(newData.shape[0])*0.0
    newData = newData.T
    
    task_ERPs_MinMax[:, eInd, :] = newData

In [None]:
sData = task_ERPs
sData = task_ERPs_MinMax

sessERP_Peaks = np.max(sData, axis = 0)
sessERP_Peaks = sessERP_Peaks/sessERP_Peaks.max(axis=0)

pDomSessInds = np.where(sessERP_Peaks[:,0]>sessERP_Peaks[:,1])[0]
aDomSessInds = np.where(sessERP_Peaks[:,0]<sessERP_Peaks[:,1])[0]

pDomERPs = sData[:,pDomSessInds,:]
aDomERPs = sData[:,aDomSessInds,:]

task_ERPs.to_netcdf(parent_preprocess_dir + 'task_ERPs_AllRegions.nc')
task_ERPs_MinMax.to_netcdf(parent_preprocess_dir + 'normalized_task_ERPs_AllRegions.nc')