In [None]:
# Adapted from mne tutorial by Laurits Dixen, PHD Student @ University of Copenhagen, 2024
import mne

import numpy as np
import pandas as pd

from pathlib import Path
from importlib.resources import files

from matplotlib.pyplot import savefig


# Importing Data


In [None]:
DATADIR = Path.cwd() / "data"
FIGDIR = Path.cwd() / "figs"
print(DATADIR)
RW = True

In [None]:
subject = "laurids"
subject = "chris"
session = "4"

path = files('electroencephalogaming') / 'data' / 'raw' / subject / f's{session}.pq'
df = pd.read_parquet(path).dropna().reset_index(drop=True)

In [None]:
# Only relevant for files made before update to trigger IDs
df['markers'] = df['markers'].where(df['markers'] != 99, -1)
df['direction'] = df['direction'].where(df['direction'] != 99, 4)


df['markers'] = df['markers'].where(df['markers'] <= 4, 0) # Sets markers column to 0 between trials
df['direction'] = df['direction'].where(df['markers'] == 3, 0) # Sets direction column to 0 whenever arrow is _not_ present on the screen

overflow_beg = df[(df['markers'].shift() == 0) & (df['markers'] == 4)].index
overflow_end = df[(df['markers'].shift() == 4) & (df['markers'] == 1)].index
for i,j in zip(overflow_beg, overflow_end):
    df.drop(range(i,j), inplace=True)

df = df.reset_index(drop=True)

In [None]:
# eeg = data['y'][1:-4,:] # removing unnecessary channels
eeg = df.drop(['timestamp', 'markers'], axis='columns').T.to_numpy()
eeg

In [None]:
a = np.where(eeg[8,:] == eeg[8,:].max())[0][-1]
b = np.where(eeg[8,:] > 4)[0][-1]
a, b

In [None]:
SFREQ = 500 + 110 ## Constant ### SFREQ of Enobio + Frame rate

In [None]:
# Set first and last sample based on first and last trigger
first_samp = np.where(eeg[8,:] != -1)[0][0]
last_samp = np.where(eeg[8,:] > 0)[0][-1]

first_conservative = first_samp - 1*SFREQ
last_conservative = last_samp + 4*SFREQ

# Crop data to first and last sample based on triggers
print(f'First sample: {max(0, first_conservative)}, last sample: {last_conservative}')
print('Cropping data based on triggers...')
print(f'data cut from {eeg.shape[1]} to {last_samp-first_samp} samples')
eeg = eeg[:,max(0, first_conservative):last_conservative]

In [None]:
# Channel positions
CH_POSITIONS = {
'CH 1' : 'C1',
'CH 2' : 'C2',
'CH 3' : 'C3',
'CH 4' : 'C4',
'CH 5' : 'FC1',
'CH 6' : 'Cz',
'CH 7' : 'FC2',
'CH 8' : 'Pz',
}

In [None]:
# Set channel names
ch_names = [f'CH {i}' for i in range(1,len(CH_POSITIONS)+1)]

# Exchange channel names with positions
ch_positions = [CH_POSITIONS.get(ch,ch) for ch in ch_names] + ['trigger']
ch_types = ['eeg' for _ in ch_names] + ['stim']

print(ch_positions)
print(ch_types)

In [None]:
# Create info mne object
info = mne.create_info(ch_positions, SFREQ, ch_types)
info['subject_info'] = {'his_id':subject+'_'+session} 

# Make raw object 
raw = mne.io.RawArray(eeg, info)
raw

In [None]:
raw = raw.apply_function(lambda x: x*1e-6) # convert to microvolts

In [None]:
# Set channel positions
montage = mne.channels.make_standard_montage('standard_1020')
raw.set_montage(montage)

In [None]:
raw.plot_sensors(show_names=True);

In [None]:
# raw.plot(n_channels=8, start=0, duration=10, title='Raw EEG data')

In [None]:
if RW:
    outfile = Path(DATADIR / "scratch" / subject / session / f"{session}_raw.fif")
    outfile.parent.mkdir(parents=True, exist_ok=True)
    raw.save(outfile, overwrite=True)

# Preprocessing

In [None]:
if RW:
    raw = mne.io.read_raw_fif(outfile)

In [None]:
raw.load_data()

In [None]:
high_pass = 0.1
low_pass = 40

filtered = raw.copy().filter(l_freq=high_pass, h_freq=low_pass).notch_filter(freqs=[17.8], trans_bandwidth=3)#.notch_filter(freqs=[17.45])

In [None]:
raw.compute_psd(fmax=55).plot();
if RW:
    outfile = Path(FIGDIR / subject / session / f"psd_raw.png")
    outfile.parent.mkdir(parents=True, exist_ok=True)
    savefig(outfile)

In [None]:
filtered.compute_psd(fmax=55).plot();
if RW:
    outfile = Path(FIGDIR / subject / session / f"psd_filtered.png")
    outfile.parent.mkdir(parents=True, exist_ok=True)
    savefig(outfile)

In [None]:
# filtered.plot(n_channels=8, start=0, duration=10, title='Filtered EEG data')

In [None]:
standard_ref = ['Cz']
filtered.set_eeg_reference(ref_channels=standard_ref)

# Marking bad channels

In [None]:
def ransac_bad_channels(raw):
    tstep = configs['bad_tstep']
    events = mne.make_fixed_length_events(raw, duration = tstep)
    epochs = mne.Epochs(raw, events, tmin=0.0, tmax=tstep, baseline=None, preload=True)

    ransac = Ransac(n_jobs=-1)
    ransac = ransac.fit(epochs)

    print('Bad channels detected: ')
    print('\n'.join(ransac.bad_chs_)) # list of bad channels
    return ransac.bad_chs_

In [None]:
def run_ICA(raw):
    
    # filter high frequencies to avoid drift
    raw = raw.filter(l_freq=1, h_freq=None)
    
    # Set parameters from configs
    ica_variances_explained = configs['ica_variances_explained']
    random_state = configs['random_state']
    t_step = configs['segmentation_tstep']
    
    # Get subject info
    s_info = raw.info['subject_info']['his_id']

    # Get rejection threshold
    events = mne.make_fixed_length_events(raw, duration=t_step)
    epochs = mne.Epochs(raw, events, tmin=0, tmax=t_step, baseline=None, preload=True, reject=None, reject_by_annotation=True)
    # threshold = get_rejection_threshold(epochs)
    threshold = configs['reject_blink']
    print(f'Rejection threshold: {threshold}')

    # Drop bad epochs based on threshold
    epochs.drop_bad(reject=threshold)

    # Run ICA
    ica = mne.preprocessing.ICA(n_components=ica_variances_explained,
                                method=configs['ica_method'],
                                random_state=random_state) 
    ica.fit(epochs, tstep=t_step)

    # Plot components and sources
    # Set title for plots
    title = f'ICA - {s_info}'
    components_fig = ica.plot_components(show=True, title=f'{title} - components')[0]
    components_fig.savefig(exp_folder / 'plots' / f'{title} - components.png')

    sources_fig = ica.plot_sources(raw, show=False, title=f'{title} - sources', show_scrollbars=True, start=200, stop=300)
    sources_fig.savefig(exp_folder / 'plots' / f'{title} - sources.png')

    # Manually select components to exclude
    blink_component = input()
    ica.exclude = [int(x) for x in blink_component.split(',')]

    return ica

In [None]:
 # Load bad channels 
# with open(badsfile, 'r') as f:
#     bads = f.read().splitlines()
# filtered.info['bads'] = bads

# if configs['run_ICA']:
#     if icafile.exists() and not configs['overwrite']:
#         print("skipping ICA")
#     else:
#         ica = run_ICA(filtered)
#         ica.save(icafile, overwrite=True)

# ica = mne.preprocessing.read_ica(icafile)
# filtered = ica.apply(filtered)

In [None]:
if RW:
    outfile = DATADIR / "scratch" / subject / session / f"{session}_prep.fif"
    filtered.save(outfile, overwrite=True)

# Segmentation

In [None]:
# Load raw data
if RW:
    infile = DATADIR / "scratch" / subject / session / f"{session}_prep.fif"
    raw = mne.io.read_raw_fif(infile).copy()

In [None]:
events = mne.find_events(raw, stim_channel='trigger', verbose=True)
events;

In [None]:
# Epoch
tmin = 0.5
tmax = 3.5

epochs = mne.Epochs(raw,
                    events,
                    tmin=tmin,
                    tmax=tmax,
                    baseline=None,
                    preload=True)

In [None]:
epochs.events[-1] # Drop tail?

In [None]:
tshift = -0.060
epochs = epochs.shift_time(tshift=tshift)

In [None]:
# Baseline correct
baseline = (0.5, 1)
epochs = epochs.apply_baseline(baseline=baseline)

In [None]:
df.copy()[(df['direction'].shift() == 0) & (df['direction'] != 0)]; # pd.read_csv(behaviour_file)

In [None]:
metadata = df.copy()[(df['direction'].shift() == 0) & (df['direction'] != 0)]#.iloc[:-1] # pd.read_csv(behaviour_file)
metadata['subject'] = subject
metadata['session'] = session
metadata['direction'].astype(str)

metadata.tail()

In [None]:
# Add metadata to epochs
epochs.metadata = metadata

In [None]:
epochs['direction == 90']

In [None]:
import matplotlib.pyplot as plt

In [None]:
f = epochs['direction == 90'].average().plot();
if RW:
    outfile = Path(FIGDIR / subject / session / f"epochs_90.png")
    outfile.parent.mkdir(parents=True, exist_ok=True)
    f.savefig(outfile)

In [None]:
f = epochs['direction == 180'].average().plot();
if RW:
    outfile = Path(FIGDIR / subject / session / f"epochs_180.png")
    outfile.parent.mkdir(parents=True, exist_ok=True)
    f.savefig(outfile)

In [None]:
f = epochs['direction == 270'].average().plot();
if RW:
    outfile = Path(FIGDIR / subject / session / f"epochs_270.png")
    outfile.parent.mkdir(parents=True, exist_ok=True)
    f.savefig(outfile)

In [None]:
# from autoreject import AutoReject

# # Auto reject epochs
# ar = AutoReject(random_state=42, n_jobs=-1)
# epochs = ar.fit_transform(epochs)

# Save Epochs

In [None]:
# Save epochs
if RW:
    outfile = DATADIR / "scratch" / subject / session / f"{session}_epo.fif"
    epochs.save(outfile, overwrite=True)

In [None]:
evokeds = {}
evokeds['right'] = epochs[f'direction == 90'].average()
evokeds['feet'] = epochs[f'direction == 180'].average()
evokeds['left'] = epochs[f'direction == 270'].average()

In [None]:
num_evos = len(evokeds)
times = [0.5, 1, 1.5, 2, 2.5, 3]
avges = [0.025, 0.05, 0.1, 0.1, 0.1, 0.1]
fig, axes = plt.subplots(nrows=num_evos, ncols=len(times), figsize=(4*len(times), 4*num_evos)) 
title = f'{subject}_{session} topos'

for i, cond in enumerate(evokeds.keys()):
    ax = axes[i,:]
    evokeds[cond].plot_topomap(
        ch_type='eeg', times=times, average=avges, colorbar=False, axes=ax, show=False
    )
    ax[0].set_ylabel(cond)


fig.suptitle(title, fontsize=28)
if RW:
    outfile = Path(FIGDIR / subject / session / f"topology.png")
    outfile.parent.mkdir(parents=True, exist_ok=True)
    savefig(outfile)
plt.show()


In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(20, 16)) 
def custom_func(x):
    return x.max(axis=1)
for i, combine in enumerate(['mean', 'median', 'gfp', custom_func]):
    ax = axes[int(i<2), i%2]
    lpp_picks = ['C1','Cz','C2']
    mne.viz.plot_compare_evokeds(evokeds, picks=lpp_picks, combine=combine, axes=ax, show=False)[0]
title = f'{subject}_compare'
fig.suptitle = title
if RW:
    outfile = Path(FIGDIR / subject / session / f"evoked.png")
    outfile.parent.mkdir(parents=True, exist_ok=True)
    savefig(outfile)

In [None]:
assert False

# Unused

In [None]:
before = df[['markers', 'direction']][(df['direction'].shift() != 0) & (df['direction'] == 0) | (df['direction'].shift() == 0) & (df['direction'] != 0)]
# before

In [None]:
after = df[['markers', 'direction']][(df['direction'].shift() != 0) & (df['direction'] == 0) | (df['direction'].shift() == 0) & (df['direction'] != 0)]
# after

In [None]:
# df.iloc[7500:7510]

In [None]:
df[(df['markers'].shift() == 0) & (df['markers'] == 4)].index

In [None]:
df['direction'].unique()

In [None]:
df[df['direction'].shift() == df['direction']]

In [None]:
# df.copy()[(df['direction'].shift() == 0) & (df['direction'] != 0)] # pd.read_csv(behaviour_file)