In [16]:
import mne
import os

In [17]:
import numpy as np
from scipy.ndimage import zoom

# Define augmentation functions
def time_shift(epoch, shift_max):
    shift = np.random.randint(-shift_max, shift_max)
    return np.roll(epoch, shift, axis=-1)  # roll along the time axis

def time_scaling(epoch, scaling_factor):
    return zoom(epoch, (1, scaling_factor))  # zoom only along the time axis

def amplitude_scaling(epoch, scaling_factor):
    return epoch * scaling_factor

def add_noise(epoch, noise_level):
    noise = np.random.normal(0, noise_level, epoch.shape)
    return epoch + noise

# Function to augment each epoch
def augment_epoch(epoch):
    augmented_epoch = time_shift(epoch, shift_max=10)
    augmented_epoch = time_scaling(augmented_epoch, scaling_factor=1.1)
    augmented_epoch = amplitude_scaling(augmented_epoch, scaling_factor=1.05)
    augmented_epoch = add_noise(augmented_epoch, noise_level=0.01)
    return augmented_epoch


In [24]:
def create_epochs_for_digits(data_dir, output_dir, digit_events=('digit 1 shown', 'digit 2 shown', 'digit 3 shown', 'digit 4 shown', 'digit 5 shown'), to_ica=False):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    files = [f for f in os.listdir(data_dir) if f.endswith('.fif')]

    for file in files:
        file_path = os.path.join(data_dir, file)
        
        # Conversion to microvolts
        raw = mne.io.read_raw_fif(file_path, preload=True)
        raw.apply_function(fun=lambda x: x * 10**(-6))
        raw.filter(l_freq=1, h_freq=40)
        events, event_dict = mne.events_from_annotations(raw)
        raw.pick_types(eeg=True)

        # print events for debugging
        # print(f"Events found in file {file}: {events}")

        event_dict = {
            'digit 1 shown': 2,
            'digit 2 shown': 3,
            'digit 3 shown': 4,
            'digit 4 shown': 5,
            'digit 5 shown': 6
        }
        events, digit_events = mne.events_from_annotations(raw)
        
        selected_event_dict = {key: value for key, value in event_dict.items() if key in digit_events.keys()}

        if to_ica:
            ica = mne.preprocessing.ICA(n_components=16, random_state=97, max_iter=800)
            ica.fit(raw)
            raw_ica = ica.apply(raw)
            epochs = mne.Epochs(raw_ica, events, event_dict, tmin=-0.2, tmax=0.5, baseline=(None, 0), preload=True)
            output_path = os.path.join(output_dir, f'{os.path.splitext(file)[0]}.epo.fif')

            epochs.save(output_path, overwrite=True)
            print(f'Saved epochs for file {file} to {output_path}')
        else:
            try:
                epochs = mne.Epochs(raw, events, event_id=selected_event_dict, tmin=-0.2, tmax=0.5, baseline=(None, 0))
                
                output_path = os.path.join(output_dir, f'{os.path.splitext(file)[0]}_epo.fif')

                epochs.save(output_path, overwrite=True)
                print(f'Saved epochs for file {file} to {output_path}')
            except ValueError as e:
                print(f"Error processing file {file}: {e}")
                
                


In [25]:
data_dir = '/home/grzesiek/documents/programming/projects/eeg_digit_classification/data/cyfry'
output_dir = '/home/grzesiek/documents/programming/projects/eeg_digit_classification/data/epochs'
create_epochs_for_digits(data_dir, output_dir, to_ica=True)

Saved epochs for file GS004.fif to /home/grzesiek/documents/programming/projects/eeg_digit_classification/data/epochs/GS004.epo.fif
Saved epochs for file GS005.fif to /home/grzesiek/documents/programming/projects/eeg_digit_classification/data/epochs/GS005.epo.fif
Saved epochs for file GS006.fif to /home/grzesiek/documents/programming/projects/eeg_digit_classification/data/epochs/GS006.epo.fif
Saved epochs for file GS003.fif to /home/grzesiek/documents/programming/projects/eeg_digit_classification/data/epochs/GS003.epo.fif
Saved epochs for file GS002.fif to /home/grzesiek/documents/programming/projects/eeg_digit_classification/data/epochs/GS002.epo.fif
Saved epochs for file DJ002.fif to /home/grzesiek/documents/programming/projects/eeg_digit_classification/data/epochs/DJ002.epo.fif
Saved epochs for file DJ001.fif to /home/grzesiek/documents/programming/projects/eeg_digit_classification/data/epochs/DJ001.epo.fif
Saved epochs for file GS001.fif to /home/grzesiek/documents/programming/proj