## Classification pipeline

In [None]:
# imports
import numpy as np
import pandas as pd
import pickle
import mne
from mne_bids import BIDSPath, read_raw_bids
from typing import List
import xarray as xr
from pathlib import Path
from scipy.signal import resample, butter, sosfiltfilt, iirnotch, tf2sos
from IPython.display import clear_output
from tqdm import tqdm

## Preprocessing OTKA

In [77]:
def process_input(path: str,
                  subject: str,
                  task: str,
                  channels: List[str],
                  resampling_frq=200,
                  notchfilter_frq=50,
                  filter_bounds=[0.3, 75],
                  verbose=False):

    # open data
    print(f'>>>>>>>>preprocessing {subject}-{task}')
    bids_path = BIDSPath(subject=subject,
                         session='01',
                         task=task,
                         root=path
                         )
    raw = read_raw_bids(bids_path, extra_params={'preload': True}, verbose=verbose)
    raw.set_montage('standard_1020')

    # pick eeg channels
    print('picking eeg channels...')
    # ['T3', 'T4', 'T5', 'T6'] channels that are only in the 10-20 system
    # replaced with their equivalent name in the 10-10 system [T7, T8, P7, P8]
    raw.pick(channels, verbose=verbose)

    # interpolate bad channels if there is any
    print('interpolating bad channels...')
    raw.interpolate_bads(verbose=verbose)

    # resampling
    if resample is not None:
        raw.resample(resampling_frq, verbose=verbose)

    raw.filter(l_freq=filter_bounds[0], h_freq=filter_bounds[1], verbose=verbose)
    raw.notch_filter((notchfilter_frq), verbose=verbose)
    eeg_array = raw.get_data().T * 10**6  # volts to microvolts
    points, chs = eeg_array.shape
    a = points % (30 * 200)
    eeg_array = eeg_array[60 * 200:-(a+60 * 200), :]
    eeg_array = eeg_array.reshape(-1, 30, 200, chs)
    eeg_array = eeg_array.transpose(0, 3, 1, 2)

    return eeg_array

In [78]:
eeg_path = '/Volumes/Extreme_SSD/PhD/OTKA_study1/EEG_data/BIDS/'
behvioral_path = '../EEGModalNet/data/OTKA/PLB_HYP_data_MASTER.csv'

behavioral = pd.read_csv(behvioral_path).dropna(subset='bids_id')
behavioral['bids_id'] = behavioral['bids_id'].apply(lambda x: f'{int(x):02}')
bid_id = behavioral['bids_id'].values
behavioral.set_index('bids_id', inplace=True)
gender = behavioral['gender']
hypnotizability = behavioral['hypnotizability_total']

In [79]:
subjects = [f'{i:02}' for i in range(1, 52)]
tasks = ['experience1', 'experience2', 'experience3', 'experience4']
channels = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T7', 'T8', 'P7', 'P8', 'Fz', 'Cz', 'Pz']

def load_trial(path, subj, task):
    eeg_data = process_input(
        path=path,
        subject=subj,
        task=task,
        channels=channels,
        verbose=0
    )
    return eeg_data

def process_subject_trials(path, subj, tasks):
    trial_data = np.vstack(np.array([load_trial(path, subj, task) for task in tasks]))
    clear_output()
    return trial_data

all_eeg_data = [process_subject_trials(eeg_path, subj, tasks) for subj in subjects]

In [80]:
# Two sessions are missing in the Last participants' data so we process and append it separately
subj = '52'
tasks = ['experience1', 'experience4']
sub_52_data = process_subject_trials(eeg_path, subj, tasks)

all_eeg_data.append(sub_52_data)

# update subjects ids accordingly 
subjects = [f'{i:02}' for i in range(1, 53)]

In [81]:
segments_to_be_excluded = {}

for sub in range(len(all_eeg_data)):
    sample_key = []
    for i, sample, in enumerate(all_eeg_data[sub]):
        if np.max(np.abs(sample)) > 100:
            sample_key.append(i)
    segments_to_be_excluded[f'sub_{sub}'] = sample_key

# exclude bad epochs
filtered_x = []

for sub in range(len(all_eeg_data)):
    excluded_segments = segments_to_be_excluded[f'sub_{sub}']
    filtered_sub = np.delete(all_eeg_data[sub], excluded_segments, axis=0)  # Remove excluded segments
    filtered_x.append(filtered_sub)

# Convert the filtered list back to a numpy array
x = np.array(filtered_x, dtype=object)

In [86]:
all_segments = []
subject_ids = []
epoch_ids = []
gender_of_epoch = []
hypnotizablity_of_epoch = []

for subj_idx, subj_array in zip(subjects, x):
    for epoch_idx, segment in enumerate(subj_array):
        all_segments.append(segment)
        subject_ids.append(subj_idx)
        gender_of_epoch.append(gender[subj_idx])
        hypnotizablity_of_epoch.append(hypnotizability[subj_idx])
        epoch_ids.append(epoch_idx)

# Step 2: Convert list to array
all_segments = np.stack(all_segments)  # shape (total_epochs, 30, 128)

subject_epoch = [f'{sub_id}_epoch-{epo_id}' for sub_id, epo_id in zip(subject_ids, epoch_ids)]

data = xr.DataArray(
    all_segments,
    dims=("subject_epoch", "channel", "segment", "time"),
    coords={
        "subject_epoch": subject_epoch,
        "channel": channels,
        "segment": np.arange(all_segments.shape[2]),
        "time": np.arange(all_segments.shape[3]),
    },
    attrs={'gender': gender_of_epoch,
           'hypnotizablity': hypnotizablity_of_epoch},
    name="eeg"
)

# save
# data.to_netcdf('data/OTKA_preprocessed_for_Cbramod.nc5', engine='h5netcdf')

## Feature extraction

In [6]:
ds = xr.open_dataarray('data/OTKA_preprocessed_for_Cbramod.nc5')

In [None]:
from argparse import Namespace
import torch 
from torch import nn
from CBraMod.models.cbramod import CBraMod
from einops.layers.torch import Rearrange

DEFAULT_PARAMS = Namespace(**{
    "foundation_dir": "pretrained_weights/pretrained_weights.pth",
    "features_file_path": "data/CBraMod_features_<DOWNSTREAM_TASK>.pt",
    "num_of_classes": 2,
    "device": 'cpu',

    "data_dir": "data/LEMON/",
    "channels": ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T7', 'T8', 'P7', 'P8', 'Fz', 'Cz', 'Pz'],
    "downstream_task": "gender",
    "segment_size": 30,  # TODO
    "batch_size": 1024,
    "bandpass_filter": 0.3,
    "n_channels": 19,
    "n_segments": 2,  # TODO ?

})

DEFAULT_PARAMS.features_file_path = DEFAULT_PARAMS.features_file_path.replace(
    "<DOWNSTREAM_TASK>", DEFAULT_PARAMS.downstream_task.lower())

params = DEFAULT_PARAMS

x = torch.tensor(ds.to_numpy()).float()

In [None]:
# feature extraction
batch_size = 50  # Adjust batch size based on available memory
num_batches = len(x) // batch_size + (1 if len(x) % batch_size != 0 else 0)

backbone = CBraMod(
    in_dim=200, out_dim=200, d_model=200,
    dim_feedforward=800, seq_len=30,
    n_layer=12, nhead=8
)

backbone.load_state_dict(
    torch.load(params.foundation_dir,
                map_location=torch.device(params.device), weights_only=True))

backbone.eval()
backbone.proj_out = nn.Identity()

for batch_idx in tqdm(range(num_batches), leave=True):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, len(x))
    batch = x[start_idx:end_idx]
    bz, ch_num, seq_len, patch_size = batch.shape
    features = backbone(batch)
    features = Rearrange('b c s p -> b (c s p)')(features).contiguous()
    # store batch features
    torch.save(features.detach(),
               f"data/OTKA_extracted_features/CBraMod_features_{batch_idx:02}.pt")

100%|██████████| 15/15 [01:45<00:00,  7.01s/it]


In [14]:
all_features = torch.zeros((x.shape[0], features.shape[1]))
batch_size = 50
for i, path_dir in enumerate(sorted(Path('data/OTKA_extracted_features/').glob('*.pt'))):
    start_idx = i * batch_size
    end_idx = min((i + 1) * batch_size, len(x))
    feature = torch.load(path_dir, weights_only=True)
    all_features[start_idx:end_idx] = feature

In [16]:
# open all the batch data, and concat them
subject_ids = [i.split('_')[0] for i in ds.subject_epoch.values]
gender = ds.gender
gender = [0 if i=='Female' else 1 for i in gender]

In [18]:
# store all features
torch.save({'features': all_features,
            'gender': gender,
            'subject_ids': subject_ids}, params.features_file_path)
print("features saved to", params.features_file_path)

features saved to data/LEMON/CBraMod_features_gender.pt


## preprocessing Lemon

In [None]:
# open lemon data

# channels: ['T3', 'T4', 'T5', 'T6'] channels that are only in the 10-20 system replaced with their equivalent name in the 10-10 system [T7, T8, P7, P8].
channels = ['Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8', 'T7', 'C3',
            'Cz', 'C4', 'T8', 'P7', 'P3', 'Pz', 'P4', 'P8', 'O1', 'O2']  # FIXME: odering might be different in the data they inputed to their model

n_subjects = 202
lemon = xr.open_dataset('data/LEMON/eeg_eo_ec.nc5')

x = lemon['eye_closed'].sel(subject=lemon.subject[:n_subjects], channel=channels).to_numpy() * 10**6  # Volts to microvolts

# FIXME resampling to 200Hz
# n_samples = int((x.shape[-1] / sampling_rate) * downsample_frq)
# x = resample(x, num=n_samples, axis=-1)

# bandpass filter
sos = butter(4, 0.3, btype='highpass', fs=128, output='sos')
x = sosfiltfilt(sos, x, axis=-1)

# notchfilter
fs = 128
f0 = 50
Q = 30
w0 = f0 / (fs / 2)  # Normalized Frequency
b, a = iirnotch(w0, Q)
sos = tf2sos(b, a)
x = sosfiltfilt(sos, x, axis=-1)

subs, chs, points = x.shape

# epoching
x = x.reshape(subs, -1, chs, 30, 128)  # subjects, segments, channels, 30 seconds, sampling rate

In [None]:
segments_to_be_excluded = {}

for sub in range(x.shape[0]):
    sample_key = []
    for i, sample, in enumerate(x[sub]):
        if np.max(np.abs(sample)) > 100:
            sample_key.append(i)
    segments_to_be_excluded[f'sub_{sub}'] = sample_key

# exclude bad epochs
filtered_x = []

for sub in range(x.shape[0]):
    excluded_segments = segments_to_be_excluded[f'sub_{sub}']
    filtered_sub = np.delete(x[sub], excluded_segments, axis=0)  # Remove excluded segments
    filtered_x.append(filtered_sub)

# Convert the filtered list back to a numpy array
x = np.array(filtered_x, dtype=object)

In [None]:
demog = pd.read_csv('data/LEMON/Demographics.csv')
gender = demog['Gender_ 1=female_2=male'].values
age = demog['Age']

In [None]:
all_segments = []
subject_ids = []
epoch_ids = []
gender_of_epoch = []
age_of_epoch = []

for subj_idx, subj_array in enumerate(x):
    for epoch_idx, segment in enumerate(subj_array):
        all_segments.append(segment)
        subject_ids.append(lemon.subject.values.tolist()[subj_idx])
        gender_of_epoch.append(int(gender[subj_idx]))
        age_of_epoch.append(age[subj_idx])
        epoch_ids.append(epoch_idx)

# Step 2: Convert list to array
all_segments = np.stack(all_segments)  # shape (total_epochs, 30, 128)

subject_epoch = [f'{sub_id}_epoch-{epo_id}' for sub_id, epo_id in zip(subject_ids, epoch_ids)]

data = xr.DataArray(
    all_segments,
    dims=("subject_epoch", "channel", "segment", "time"),
    coords={
        "subject_epoch": subject_epoch,
        "channel": channels,
        "segment": np.arange(all_segments.shape[2]),
        "time": np.arange(all_segments.shape[3]),
    },
    attrs={'gender': gender_of_epoch,
           'age': age_of_epoch},
    name="eeg"
)

In [None]:
# data.to_netcdf('data/LEMON/lemon_preprocessed_for_cbramod.nc5', engine='h5netcdf')