In [10]:
# Libraries

In [14]:
import os
import mne
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split

from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

sns.set(style="whitegrid")

device = "cuda" if torch.cuda.is_available() else "cpu"

In [16]:
# Data Loading & Preprocessing

In [None]:
dataset_path = "../APPLESEED_Dataset"  # adjust path if needed

def load_sub_session(sub_path, ses):
    eeg_path = os.path.join(sub_path, f"ses-{ses}", "eeg")
    vhdr_files = [f for f in os.listdir(eeg_path) if f.endswith(".vhdr")]
    raws = []
    for vhdr in vhdr_files:
        raw = mne.io.read_raw_brainvision(os.path.join(eeg_path, vhdr), preload=True)
        raw.filter(1., 40.)  # bandpass filter
        raw.set_eeg_reference('average')
        raws.append(raw)
    return raws

def load_all_subjects(dataset_path):
    subjects = [d for d in os.listdir(dataset_path) if d.startswith("sub")]
    all_data = []
    labels = []
    for sub in subjects:
        sub_path = os.path.join(dataset_path, sub)
        for ses in ["1","2","3","4"]:
            try:
                raws = load_sub_session(sub_path, ses)
                all_data.extend(raws)
                # Example: session number as label (4,8,12,16 weeks)
                labels.extend([int(ses)]*len(raws))
            except:
                continue
    return all_data, np.array(labels)

print("Loading EEG data...")
all_raws, labels = load_all_subjects(dataset_path)
print(f"Total EEG recordings: {len(all_raws)}")

Loading EEG data...
Extracting parameters from ../APPLESEED_Dataset\sub-01\ses-1\eeg\sub-01_ses-1_task-appleseedexample_eeg.vhdr...
Setting channel info structure...
Reading 0 ... 3431099  =      0.000 ...   686.220 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 16501 samples (3.300 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting parameters from ../APPLESEED_Dataset\sub-01\ses-2\eeg\sub-01_ses-2_task-appleseedexample_ee