In [88]:
from torch.utils.data import Dataset
import pandas as pd
import os, mne, torch
import numpy as np
from scipy.signal import welch

from torch.utils.data import DataLoader

In [79]:
class SSVEPDataset(Dataset):
    # SSVEP frequencies for each class
    FREQUENCIES = {
        'Left': 10,    # ---> 0
        'Right': 13,   # ---> 1
        'Forward': 7,  # ---> 2
        'Backward': 8  # ---> 3
    }
    
    def __init__(self, csv_metadata_path,
                 task:str='SSVEP', 
                 eeg_reference='average',
                 transform=None,
                 fs=250 ,
                 tmin=1,
                 tmax=6,
                 noise_margin=1, 
                 bandpass_band=(5,30),
                 notch_freq=50,
                 do_normalization=True,
                 get_psd_plus_snr=True,
                 noise_n_neighbor_freqs=3, 
                 noise_skip_neighbor_freqs=0
                ):
        
        self.base_path = '.' 
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.metadata = pd.read_csv(csv_metadata_path)
        self.metadata = self.metadata[self.metadata['task'] == task] if task else self.metadata
        self.eeg_reference = eeg_reference
        self.transform = transform
        self.sfreq = fs
        self.tmin = tmin
        self.tmax = tmax
        self.bandpass_band = bandpass_band
        self.notch_freq = notch_freq
        self.do_normalization = do_normalization
        self.get_psd_plus_snr = get_psd_plus_snr
        self.noise_n_neighbor_freqs = noise_n_neighbor_freqs
        self.noise_skip_neighbor_freqs = noise_skip_neighbor_freqs

        self.label2idx = {label:i for i, (label,freq) in enumerate(self.FREQUENCIES.items())}
        self.idx2label = {i:label for label,i in self.label2idx.items()}
        
        self.freq_bands = {label: (freq-noise_margin , freq+noise_margin) for label, freq in self.FREQUENCIES.items() }
    
    def __getitem__(self, index):
        row = self.metadata.iloc[index]
        id_num = row['id']
        split = 'train' if id_num <=4800 else 'validation' if id_num <=4900 else 'test'

        eeg_path = os.path.join(
              self.base_path, row['task'], split, row['subject_id'], str(row['trial_session']), 'EEGdata.csv'
        )

        ten_trials_df = pd.read_csv(eeg_path)
        ten_trials_df['Time'] -= ten_trials_df['Time'].iloc[0]

        trial_num = int(row['trial'])
        samples_per_trial = 1750 if row['task'].lower() == 'ssvep' else 2250
        one_trial_df = ten_trials_df.iloc[(trial_num-1)*samples_per_trial : trial_num*samples_per_trial] # shape(1750, 8ch+othercolumns)

        ssvep_channels = ['PO7', 'OZ', 'PO8']
        if not all(ch in one_trial_df.columns for ch in ssvep_channels):
            raise ValueError(f"Missing required EEG channels in file: {eeg_path}")
        
        one_trial_array_ssvep_channels = one_trial_df[ssvep_channels].values.T # shape (3, 1750)

        # normalizing 
        normalized_one_trial_array = self._normalize(one_trial_array_ssvep_channels) if self.do_normalization else one_trial_array_ssvep_channels
        
        info = mne.create_info(ch_names=ssvep_channels, sfreq=self.sfreq, ch_types='eeg')
        raw = mne.io.RawArray(normalized_one_trial_array, info, verbose=False)
        
        # processing
        raw = self._apply_reference(raw)
        raw.notch_filter(freqs=self.notch_freq, verbose=False)
        raw.filter(l_freq=self.bandpass_band[0], h_freq=self.bandpass_band[1], verbose=False)  # Wider band to capture all frequencies
        
        features_as_array = self._get_psds_pls_snrs_as_array(raw) if self.get_psd_plus_snr else raw.get_data()

        X_trial = torch.from_numpy(features_as_array)
        Y_trial = torch.tensor(self.label2idx[row['label']])
        
        return X_trial.to(self.device), Y_trial.to(self.device)
    
    def __len__(self):
        return len(self.metadata)
    

    def _apply_reference(self, raw):
        try:
            if self.eeg_reference == 'average':
                raw.set_eeg_reference('average', verbose=False)
            elif self.eeg_reference in raw.ch_names:
                raw.set_eeg_reference([self.eeg_reference], verbose=False)
            elif self.eeg_reference is not None:
                raise ValueError(f"Invalid EEG reference: {self.eeg_reference}")
        except Exception as e:
            raise RuntimeError(f"EEG referencing failed: {e}")
        return raw
    
    def _normalize(self, data_array): # shape(3, 1750)
        data_array-= np.mean(data_array, axis=1, keepdims=True)
        return data_array/(np.std(data_array,axis=1, keepdims=True)+1e-9)
    
    def _snr_spectrum(self, psd):
        """Compute SNR spectrum from PSD spectrum using convolution"""
        # Construct kernel for noise calculation
        kernel = np.concatenate((
            np.ones(self.noise_n_neighbor_freqs),
            np.zeros(2 * self.noise_skip_neighbor_freqs + 1),
            np.ones(self.noise_n_neighbor_freqs)
        ))
        kernel /= kernel.sum()
        
        # Calculate mean noise through convolution
        mean_noise = np.convolve(psd, kernel, mode='valid')
        
        # Pad edges with NaNs
        edge_width = self.noise_n_neighbor_freqs + self.noise_skip_neighbor_freqs
        pad_width = [(edge_width, edge_width)]
        mean_noise = np.pad(mean_noise, pad_width, constant_values=np.nan)
        
        return psd / mean_noise
    
    def _get_psds_pls_snrs_as_array(self, raw):
        
        data_array = raw.get_data()
        
        start_idx = int(self.tmin * self.sfreq)
        end_idx = int(self.tmax * self.sfreq)
        data_array = data_array[:, start_idx:end_idx]  # Shape: (n_channels, n_samples)
        
        # Compute features for each channel
        features = []
        for ch_data in data_array:
            # Compute PSD using Welch's method
            freqs, psd = welch(ch_data,
                               fs=self.sfreq, 
                               window="boxcar", 
                               nperseg=min(256, len(ch_data)),
                               average='mean')
            # Compute SNR spectrum
            snr = self._snr_spectrum(psd)
            # Extract features at target frequencies
            freq_features = []
            for freq in self.FREQUENCIES.values():
                # Find closest frequency bin
                idx = np.argmin(np.abs(freqs - freq))
                freq_features.extend([
                    psd[idx],      # PSD at target frequency
                    snr[idx]       # SNR at target frequency
                ])
            features.extend(freq_features)
        
        # normalize features
        features = np.array(features)
        features-= np.mean(features,axis=-1, keepdims=True)
        features/= (np.std(features,axis=-1, keepdims=True)+1e-9)
        
        return features

In [89]:
datasets = {}
dataloaders = {}

for split in ['train', 'validation', 'test']:
    datasets[split] = SSVEPDataset(f'{split}.csv')
    dataloaders[split] = DataLoader(datasets[split],
                                    batch_size=64,
                                    shuffle=(split=='train'))