In [None]:
from torcheeg.datasets import SEEDDataset
from torcheeg import transforms
raw_dataset = SEEDDataset(
    root_path='D:/FAST/EEg-based-Emotion-Recognition/Preprocessed_EEG',
    io_path='D:/FAST/EEg-based-Emotion-Recognition/.torcheeg/datasets_1733174610032_5iJyS',
    online_transform=None,
    label_transform=None,
    num_worker=4
)

raw_sample = raw_dataset[0]
print(f"Raw EEG data shape: {raw_sample[0].shape}")

[2024-12-06 18:10:06] INFO (torcheeg/MainThread) 🔍 | Detected cached processing results, reading cache from D:/FAST/EEg-based-Emotion-Recognition/.torcheeg/datasets_1733174610032_5iJyS.


In [None]:
import torch
from torcheeg.utils import plot_2d_tensor

img = plot_2d_tensor(torch.tensor(raw_sample[0]))

In [None]:
from torcheeg.model_selection import LeaveOneSubjectOut
from torch.utils.data import DataLoader

print(f"Dataset contains {len(raw_dataset)} samples.")
print(f"Dataset size: {len(raw_sample)}")
print(f"Sample format: {raw_dataset[0]}")  # Check a single sample

In [None]:
for i, sample in enumerate(raw_dataset):
    assert sample[0].shape == (62, 200), f"Mismatch at sample {i}"

In [None]:
from collections import defaultdict

# Organize data by emotion classes
emotion_classes = defaultdict(list)
for sample, metadata in raw_dataset:
    emotion = metadata['emotion']  # Assume emotion: -1 (negative), 0 (neutral), 1 (positive)
    emotion_classes[emotion].append((sample, metadata))
    
print({emotion: len(samples) for emotion, samples in emotion_classes.items()})


In [None]:
from scipy.signal import butter, lfilter

# Define bandpass filter
def bandpass_filter(data, lowcut=4, highcut=47, fs=200, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return lfilter(b, a, data, axis=-1)

# Apply bandpass filter to dataset
def filter_dataset(dataset, lowcut=4, highcut=47, fs=200):
    filtered_dataset = []
    for sample, metadata in dataset:
        filtered_sample = bandpass_filter(sample, lowcut, highcut, fs)
        filtered_dataset.append((filtered_sample, metadata))
    return filtered_dataset

filtered_dataset = filter_dataset(raw_dataset)
print(len(filtered_dataset))

In [None]:
import torch
from torcheeg.utils import plot_2d_tensor

print(filtered_dataset[0][0].shape)

img = plot_2d_tensor(torch.tensor(filtered_dataset[0][0]))

In [None]:
import mne
import numpy as np
from multiprocessing import Pool

# Function to process a single sample with ICA
def process_sample_with_ica(args):
    sample, metadata, fs, n_components, max_iter = args
    try:
        # Create MNE Info object
        info = mne.create_info(
            ch_names=[f'ch_{i}' for i in range(sample.shape[0])],
            sfreq=fs, 
            ch_types='eeg'
        )
        raw = mne.io.RawArray(sample, info)

        # Fit ICA with reduced components and higher max_iter
        ica = mne.preprocessing.ICA(n_components=n_components, random_state=42, max_iter=max_iter)
        print(f"Fitting ICA for sample with metadata: {metadata}")
        ica.fit(raw)

        # Apply ICA to remove artifacts
        print(f"Applying ICA for sample with metadata: {metadata}")
        raw_cleaned = ica.apply(raw)

        # Return cleaned data and metadata
        return raw_cleaned.get_data(), metadata

    except ValueError as e:
        print(f"ValueError: {e} for sample with metadata: {metadata}. Skipping...")
    except Exception as e:
        print(f"Unexpected error: {e} for sample with metadata: {metadata}. Skipping...")
    return None

In [None]:
# Function for batch processing with parallelization
def parallel_process_ica(dataset, fs=200, n_components=28, max_iter=2000, batch_size=50, num_workers=4):
    cleaned_dataset = []
    print(f"Starting ICA processing with n_components={n_components}, max_iter={max_iter}, batch_size={batch_size}, num_workers={num_workers}...")

    # Process the dataset in batches
    for start in range(0, len(dataset), batch_size):
        end = min(start + batch_size, len(dataset))
        batch = dataset[start:end]
        print(f"Processing batch {start // batch_size + 1} with {len(batch)} samples...")

        # Prepare arguments for parallel processing
        args = [(sample, metadata, fs, n_components, max_iter) for sample, metadata in batch]

        # Use multiprocessing to parallelize ICA across CPU cores
        with Pool(num_workers) as p:
            results = p.map(process_sample_with_ica, args)

        # Collect cleaned data, ignoring failed samples
        cleaned_dataset.extend([result for result in results if result is not None])

    print("ICA processing completed.")
    return cleaned_dataset


In [None]:
# Apply to the dataset with optimizations
cleaned_dataset = parallel_process_ica(
    filtered_dataset,
    fs=200,
    n_components=28,
    max_iter=800,
    batch_size=50,
    num_workers=4
)


In [None]:
import numpy as np

def re_reference_to_average(dataset):
    re_referenced_dataset = []
    for sample, metadata in dataset:
        # Subtract the average of all channels
        mean_channel = np.mean(sample, axis=0, keepdims=True)
        re_referenced_sample = sample - mean_channel
        re_referenced_dataset.append((re_referenced_sample, metadata))
    return re_referenced_dataset

re_referenced_dataset = re_reference_to_average(cleaned_dataset)


In [None]:
def segment_data(dataset, epoch_length=6000, step_size=3000):
    segmented_dataset = []
    for sample, metadata in dataset:
        num_points = sample.shape[1]
        for start in range(0, num_points - epoch_length + 1, step_size):
            segment = sample[:, start:start + epoch_length]
            segmented_dataset.append((segment, metadata))
    return segmented_dataset

segmented_dataset = segment_data(re_referenced_dataset)
