'1072': Left Hand MI

'276': Right Hand MI

In [2]:
import mne
import numpy as np
import pywt
import torch
import os
from datetime import datetime

In [1]:
# CHANGE LABELS

In [3]:
def get_epochs_and_labels_bcic_iv_2a(eeg_signal):
    """
    Load the EEG data from a FIF file, extract the epochs and labels, and return them.

    Arguments:
        - FIF_file (str): Path to the FIF file containing the EEG data.

    Returns:
        - Epochs (mne.Epochs): EEG epochs extracted from the FIF file.
        - Labels (numpy.ndarray): Labels corresponding to the epochs.
    """
    # Get the events from the annotations
    events, _ = mne.events_from_annotations(eeg_signal)

    event_id = {'left_hand': 1, 'right_hand': 2}

    # Epochs start 0s before the trigger and end 0.5s after
    epochs = mne.Epochs(eeg_signal, events, event_id, tmin=0, tmax=0.5, baseline=None, preload=True)

    # Get the labels of the epochs
    labels = epochs.events[:, -1]

    # Change the labels to 0 and 1
    labels[labels == 2] = 0
    labels[labels == 3] = 1
    
    return epochs, labels

In [4]:
def z_score(epoch):
    """
    Apply z-score normalization to each channel of the EEG data.
    
    Arguments:
        - Epoch (numpy.ndarray): EEG data to be normalized.
        
    Returns:
        - Z-scored epoch (numpy.ndarray): Normalized EEG data.
    """    
    # Apply z-score normalization to each channel, saved in epoch
    for i in range(epoch.shape[0]):
        channel_epoch = epoch[i, :]
        mean = np.mean(channel_epoch)
        std = np.std(channel_epoch)
        z_scored_epoch = (channel_epoch - mean) / std
        epoch[i, :] = z_scored_epoch
    
    return epoch

In [5]:
def frequency_to_scale(freq, wavelet='morl', sampling_rate=250):
    """
    Convert frequency values to scales for continuous wavelet transform (CWT).

    Arguments:
        = freq (array): Array of frequency values.
        wavelet (str, optional): Type of wavelet to use. Defaults to 'morl'.
        sampling_rate (int): Sampling rate of the EEG data. Defaults to 250 Hz.

    Returns:
        - scales (array): Array of scales corresponding to the input frequencies.
    """
    # For the Morlet wavelet, scales are inversely proportional to frequency
    center_freq = pywt.central_frequency(wavelet)
    return center_freq / (freq / sampling_rate)

In [6]:
def apply_wavelet_transform(data_norm, wavelet='morl', freq_range=(8, 30), sampling_rate=250):
    """
    Apply wavelet transform to EEG data.
    
    Arguments:
        - data_norm (ndarray): 2D array with shape (n_channels, n_time_points)
        - wavelet (str): Wavelet type (default 'morl')
        - freq_range (tuple): Frequency range for the CWT (default (8, 30) Hz)
        - sampling_rate (int): Sampling rate of the EEG data (default 250 Hz)
    
    Returns:
    ndarray: 3D array with shape (n_channels, n_scales, n_time_points)
    """
    n_channels, n_times = data_norm.shape
    # Define scales based on the desired frequency range
    scales = frequency_to_scale(np.arange(freq_range[0], freq_range[1]+1), wavelet=wavelet, sampling_rate=sampling_rate)
    
    coeffs = []
    for i in range(n_channels):
        # Compute the wavelet transform coefficients
        coef, _ = pywt.cwt(data_norm[i], scales=scales, wavelet=wavelet)
        coeffs.append(coef)
    
    # Stack coefficients to form a 3D tensor
    coeffs_done = np.stack(coeffs, axis=0)
    
    return coeffs_done

In [12]:
def apply_wavelet_transform_freq_bands(data_norm, wavelet='morl', freq_ranges=[(8, 13), (14, 30), (31, 50)], sampling_rate=250):
    """
    Apply wavelet transform to EEG data for specified frequency bands.
    
    Arguments:
        - data_norm (ndarray): 2D array with shape (n_channels, n_time_points)
        - wavelet (str): Wavelet type (default 'morl')
        - freq_ranges (list of tuples): List of frequency ranges for the CWT (default [(8, 13), (14, 30), (31, 50)] Hz)
        - sampling_rate (int): Sampling rate of the EEG data (default 250 Hz)
    
    Returns:
        - ndarray: 3D array with shape (n_channels, total_n_scales, n_time_points)
    """
    n_channels, n_times = data_norm.shape
    all_coeffs = []

    for freq_range in freq_ranges:
        # Define scales based on the desired frequency range
        scales = frequency_to_scale(np.arange(freq_range[0], freq_range[1] + 1), wavelet=wavelet, sampling_rate=sampling_rate)
        
        band_coeffs = []
        for i in range(n_channels):
            # Compute the wavelet transform coefficients for each channel
            coef, _ = pywt.cwt(data_norm[i], scales=scales, wavelet=wavelet)
            band_coeffs.append(coef)
        
        # Stack coefficients for the current frequency band along the channel axis
        band_coeffs_stacked = np.stack(band_coeffs, axis=0)
        all_coeffs.append(band_coeffs_stacked)

    # Concatenate coefficients for all frequency bands along the scale axis
    coeffs_done = np.concatenate(all_coeffs, axis=1)  # Concatenate along the scale axis

    return coeffs_done

In [14]:
def bcic_iv_2a_preprocessing(file_path):    
    # Suppress `mne` library logging
    mne.set_log_level('CRITICAL')

    # Load the EEG data
    raw = mne.io.read_raw_gdf(file_path, preload=True)

    # Filter data between 4 and 35 Hz
    filtered_raw = raw.filter(4., 35., fir_design='firwin', skip_by_annotation='edge')

    epochs, labels = get_epochs_and_labels_bcic_iv_2a(filtered_raw)

    # Empty lists to store the transformed epochs and their labels
    all_transformed_epochs = []
    all_labels = []

    # Process each epoch
    for epoch, label in zip(epochs, labels):
        # Z-score each epoch
        epoch_norm = z_score(epoch)
        
        # Apply wavelet transformation
        epoch_wavelet = apply_wavelet_transform_freq_bands(epoch_norm)
        
        # Append the transformed epoch and its label to the lists
        all_transformed_epochs.append(epoch_wavelet)
        all_labels.append(label)
        
    # Convert the list of all transformed epochs into a single tensor dataset
    tensor_dataset = torch.tensor(all_transformed_epochs, dtype=torch.float)

    # Convert the list of all labels into a tensor
    labels_tensor = torch.tensor(all_labels, dtype=torch.long)
    
    return tensor_dataset, labels_tensor

In [15]:
def bcic_iv_2a_folder_to_tensor(input_folder, output_folder):
    all_data = []
    all_labels = []

    # Loop through each file in the folder
    for filename in os.listdir(input_folder):
        if filename.endswith('.gdf'):  # Assuming all files are gdf format
            file_path = os.path.join(input_folder, filename)
            # Apply preprocessing function to each file
            tensor_dataset, labels_tensor = bcic_iv_2a_preprocessing(file_path)
            all_data.append(tensor_dataset)
            all_labels.append(labels_tensor)

    # Concatenate all tensors along the first dimension to create a single tensor
    all_data = torch.cat(all_data, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    
    # Get the current date and time
    current_datetime = datetime.now()
    
    # Format the date and time as a string for the dataset name
    time_name = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")

    # Make a subfolder for the tensor dataset in the output folder, name it with the current date and time
    subfolder = os.path.join(output_folder, time_name)    
    
    # Save the combined dataset to the subfolder with the specified name
    os.makedirs(subfolder, exist_ok=True)
    tensor_dataset_file_name = os.path.join(subfolder, f"dataset_{time_name}.pt")
    torch.save(all_data, tensor_dataset_file_name)
    
    labels_tensor_file_name = os.path.join(subfolder, f"labels_{time_name}.pt")
    torch.save(all_labels, labels_tensor_file_name)

    return all_data, all_labels

In [16]:
folder_path = r"C:\School\EE_Y3\Q4\BAP\eeg_thesis_cnn_repo\data\bcic_iv_2a\raw"
output_folder = r"C:\School\EE_Y3\Q4\BAP\eeg_thesis_cnn_repo\data\bcic_iv_2a\processed"

all_data, all_labels = bcic_iv_2a_folder_to_tensor(folder_path, output_folder)

  tensor_dataset = torch.tensor(all_transformed_epochs, dtype=torch.float)


In [21]:
all_data.shape, all_labels.shape

(torch.Size([506, 25, 43, 126]), torch.Size([506]))