In [None]:
import sys
import os
from os.path import join, dirname, realpath, exists
import json
import glob
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import mne
from mne.preprocessing import (ICA, corrmap, create_ecg_epochs,
                               create_eog_epochs)
from mne.datasets.brainstorm import bst_auditory
from mne.io import read_raw_ctf
from mne.preprocessing import annotate_muscle_zscore

mne.set_log_level('error')  # reduce extraneous MNE output

%matplotlib widget

# Example_dir = dirname(realpath(__file__)) # directory of this file
modules_dir = '/' # directory with all modules
measurements_dir = './data_Sorted/' # directory with all measurements
sys.path.append(modules_dir)

from TMSiFileFormats.file_readers import Poly5Reader
from autoreject import get_rejection_threshold

# Load dataset and montage setting

In [None]:
subjectFolders = glob.glob(measurements_dir + 'pongFac23*')

In [None]:
subjectInd = 8
subjectDir = subjectFolders[subjectInd]
subjectID = subjectDir.split('_')[-1]
ephysDir = os.path.join(subjectDir, 'EEG', subjectID)
behavDir = os.path.join(subjectDir, 'Pong')

sub_Ephys = glob.glob(ephysDir +  '*.Poly5')[0]
sub_Trigs = glob.glob(ephysDir +  '*.csv')[0]
# sub_fifs = glob.glob(subjectDir + 'p01*.fif')
sub_Trials = glob.glob(behavDir + '/test_*kin_*.csv')
sub_Behav = sub_Trials[0].split('_kin_')[0]+'.csv'
sub_Trials = sorted(sub_Trials, key=lambda fname: int(fname.split('_kin_')[-1].split('.')[0]))

In [None]:
dataPath = sub_Ephys
data = Poly5Reader(dataPath)

# When no arguments are given, a pop-up window allows you to select the file you want to read. 
# You can also use data=Poly5Reader(full_path) to load a file. Note that the full file path is required here.

# Extract the samples and channel names from the Poly5Reader object
samples = data.samples
ch_names = data.ch_names

#%% Reordering textile grid channels

isTextileGrid = False

if isTextileGrid:
    channel_conversion_list = np.arange(0,len(ch_names), dtype = int)
    
    # Detect row and column number based on channel name 
    RCch = []
    for i, ch in enumerate(ch_names):
        if ch.find('R') == 0:
            R,C = ch[1:].split('C')
            RCch.append((R,str(C).zfill(2),i))
    
    # Sort data based on row and column
    RCch.sort()
    for ch in range(len(RCch)):
        channel_conversion_list[ch] = RCch[ch][2]
    
    # Change the ordering of the first 32 channels (all channels on the textile grid)
    samples = samples[channel_conversion_list,:]
    ch_names = [ch_names[i] for i in channel_conversion_list]
    
    print(ch_names)
    

In [None]:
# Conversion to MNE raw array
raw = data.read_data_MNE()

In [None]:
channels = np.array(raw.ch_names)
eegChs = channels[:64]
miscChs = channels[64:]

chTypes = {}
for channel in channels:
    if channel not in miscChs:
        chTypes[channel] = 'eeg'
    elif channel == 'STATUS':
        chTypes[channel] = 'stim'
    else:
        chTypes[channel] = 'misc'
raw.set_channel_types(chTypes)
raw.set_montage('standard_1005')

In [None]:
nTrials = 160
behDF = pd.read_csv(sub_Behav, nrows = nTrials)
trigDF = pd.read_csv(sub_Trigs)
conds = behDF['cond']
trialRes = behDF['result']

In [None]:
behTrialData = []
for trialFile in sub_Trials:
    behTrialData.append(pd.read_csv(trialFile))

feedbackInds = np.zeros((nTrials,))
feedbackTimes = np.zeros((nTrials,))

for trialInd, trialDF in enumerate(behTrialData):
    
    trialResult = trialDF['result'].to_numpy()[-1]
    
    if type(trialResult) == str:
        firstFBFrame = np.where(trialDF['result'] == trialResult)[0][0]
        feedbackInds[trialInd] = firstFBFrame
        feedbackTimes[trialInd] = trialDF.iloc[firstFBFrame]['t']
    else:
        firstFBFrame = np.where(trialDF['by'] <= -508)[0][0]
        feedbackInds[trialInd] = firstFBFrame
        feedbackTimes[trialInd] = trialDF.iloc[firstFBFrame]['t']

nanTimes = np.where(feedbackTimes == 0)[0]

In [None]:
events = mne.find_events(raw, output = 'onset')

if events.shape[0] != 507:
    events = events[1:,:]

trigs = events[:,0]

bFbTimes = behDF['feedbackTime']
bThTimes = behDF['threshTime']
tg1Times = behDF['startTrig0']

fb2Thresh = bFbTimes - bThTimes

tsDiffs = np.diff(trigs)
sTrigs = np.where(tsDiffs <= 25)[0]
eTrigs2 = np.where((tsDiffs <= 3000) & (tsDiffs >= 900))[0]
eTrigs = sTrigs + 2

np.all((eTrigs2+1) == eTrigs)

events[sTrigs,2] = 1
events[sTrigs+1,2] = 2
events[eTrigs,2] = 3
sEvents = events[sTrigs]
eEvents = events[eTrigs]
event_dict = {'BallStart': 1, 'BallThresh': 2, 'Feedback': 3}

In [None]:
fb2Thresh = np.round((feedbackTimes - bThTimes)* raw.info['sfreq'])
feedbackTimestamps = (fb2Thresh + events[eTrigs,0][-nTrials:]).to_numpy()

fbEvents = eEvents[-nTrials:].copy()
fbEvents[:,0] = feedbackTimestamps

bmEvents = sEvents[-nTrials:].copy()

resDict = {'p': 1, 'n': -1}
condDict = {'Presence':1, 'Absence':0}
resArray = [resDict[trialRes[eInd]] for eInd, _ in enumerate(trialRes)]
fbEvents[-nTrials:,2] = resArray
bmEvents[-nTrials:,2] = resArray

trialCond = behDF['cond']

pBMEvs = bmEvents[trialCond == condDict['Presence']]
aBMEvs = bmEvents[trialCond == condDict['Absence']]

pFBEvs = fbEvents[trialCond == condDict['Presence']]
aFBEvs = fbEvents[trialCond == condDict['Absence']]

In [None]:
raw.compute_psd(fmax = 100).plot();

In [None]:
raw.info['bads'] = []
# raw.info['bads'].extend(['P7'])
# raw.info['bads'].extend(['F8', 'M2', 'PO7'])
# raw.info['bads'].extend(['F8','M1', 'M2', 'PO7', 'Pz', 'O1', 'P6'])
raw.interpolate_bads()
raw.set_eeg_reference(ref_channels='average')

In [None]:
def plotEEG(data, seColor = {1:'tomato', 2:'magenta', 3:'green'}, sdColor = 'b', events = None, butterfly = False, highpass = None, lowpass = None):
    data.plot(events=events, event_color = seColor, theme = 'dark', color = sdColor, butterfly = butterfly, highpass = highpass, lowpass = lowpass);

In [None]:
plotEEG(raw, events = fbEvents, seColor = {1:'tomato', -1:'b'}, highpass = 0.1, lowpass = 45)

In [None]:
low_cut = 0.1
high_cut = 80
nJobs = 8

raw_filt = raw.copy().filter(low_cut, high_cut)

method = 'spectrum_fit'
# method = 'iir'

freqs = None
freqs = [50]

raw_filt.notch_filter(method = method, freqs = freqs, n_jobs = nJobs)
# raw_filt.compute_psd(fmax=100, picks = channels).plot();

# raw_filt.save

low_cut = 1
method = 'spectrum_fit'
freqs = [50]

rawCopy = raw.copy().filter(low_cut, None)
rawCopy.notch_filter(method = method, freqs = freqs, n_jobs=nJobs)

In [None]:
ica = ICA(n_components=15, method = 'picard', max_iter='auto', random_state=97)


# icaEvts = mne.make_fixed_length_events(rawCopy, start = 25, stop = 900)
# icaEpochs = mne.Epochs(rawCopy, events=icaEvts, baseline = None)
# icaEpochs = mne.make_fixed_length_epochs(rawCopy)
# reject = get_rejection_threshold(icaEpochs);
# icaData = icaEpochs

icaData = rawCopy

# ica.fit(icaData, reject = reject)
ica.fit(icaData)

In [None]:
ica.save(subjectDir + '/EEG/' + sub_Ephys.split('/')[-1].split('.')[0] + '-ica.fif', 
        overwrite=True);

# ICA and plot

In [None]:
ica = mne.preprocessing.read_ica(subjectDir + '/EEG/' + sub_Ephys.split('/')[-1].split('.')[0] + '-ica.fif');

In [None]:
ica.plot_sources(raw, show_scrollbars=True)

In [None]:
ica.plot_components()

In [None]:
ica.exclude = [0,1] # SubID 1
# ica.exclude = [0,1,2,3,4]
ica.exclude = [0,1,3,6] # SubID 4
ica.exclude = [0,1,2] # SubID 8

# ica.exclude = [0,1,2,3,4,5,6,9,8,10,11,12, 13, 14]


# ica.apply(raw_filt)
ica.apply(raw)

In [None]:
plotEEG(raw, events = events)
plotEEG(raw_filt, events = events)

## Feedback ERPs

In [None]:
fbID = {'Interception': 1, 'Miss': -1}
fbID = [-1]
baseline = (None, 0)
# baseline = (-0.5, -0.2)
tmin = -0.2
tmax = 0.5

selData = raw.filter(1,40)
# selData = raw_filt.filter(1,40)
# selData = rawCopy.filter(1,15)
# plotChannels = ['FC1', 'FC2', 'FC3', 'FC4', 'FCz']
plotChannels = ['P1', 'P2', 'P3', 'P4', 'P5', 'P6', 'P7', 'Pz']
# plotChannels = eegChs


pEvents = pFBEvs
aEvents = aFBEvs

# pEvents = pBMEvs
# aEvents = aBMEvs


pEpochs = mne.Epochs(selData, pEvents, fbID, tmax = tmax,
                        tmin = tmin, baseline=baseline,
                        preload=True)

aEpochs = mne.Epochs(selData, aEvents, fbID, tmax = tmax,
                        tmin = tmin, baseline=baseline,
                        preload=True)

pERP = pEpochs.average()
aERP = aEpochs.average()

In [None]:
mne.viz.plot_compare_evokeds({'Presence': pERP, 'Absence':aERP}, 
                             picks=plotChannels);

In [None]:
pERP.plot_joint(picks = plotChannels);
aERP.plot_joint(picks = plotChannels);