<b>1. Setup and preprocessing</b>

In [28]:
from mne_import_xdf import *
import mne
import numpy as np
import pandas as pd
import os
import cv2
import shutil
import pathlib
import matplotlib.pyplot as plt
import pywt
import ewtpy
from ssqueezepy import cwt as ssq_cwt
from collections import OrderedDict
from sklearn.decomposition import PCA
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import StratifiedShuffleSplit
import itertools
import NeuralNetworks
curr_path = pathlib.Path().absolute()

&nbsp;&nbsp;&nbsp;&nbsp;1.1. Set bad electrodes

In [29]:
def get_subject_bad_electrodes(subject):
    # Define here the subject specific electrodes to make sure are removed from the data: 
    bad_elecs_dict = {
                      'Dekel':{'FT10', 'TP10', 'FT9'},
                      'Gilad':{'FT10', 'TP10', 'FT9', 'TP9'},
                      'Neta':{'TP9'},
                      'Ron-Block':{'PO7'},
                      'sub-Roei': {'TP9'},
                      'Or': {'FT9','T7','FC2','FT7','Iz'},
                      'Roei-MI': {'FT10', 'TP10','P2','AF8','AF7','AF4'},
                      'Fudge':{'Iz','FT10', 'TP10', 'FT9', 'TP9','F1'},
                      'g': {'T7','CP1','TP9','P7','PO7','O1'},
                      'Ron': {'Iz','Cz'}
                    }
    if subject in bad_elecs_dict.keys():
        subject_bad_electrodes=bad_elecs_dict[subject]
    else: 
        subject_bad_electrodes={}
        print('Note that no bad electrodes were defined for the current subject:', subject)
    return subject_bad_electrodes 

&nbsp;&nbsp;&nbsp;&nbsp;1.2. Set parameters for preprocessing (make sure to specify subject name)

In [None]:
subject_name = 'Fudge' # Specify the subject name
recording_path = curr_path / 'Recordings' / subject_name # Path to the directory containing XDF files
# Define the electrode groups: the key can be anything, the values should be a list of electrodes
Electrode_Groups = {
                    'FP': ['Fp1', 'Fp2'],
                    'AF': ['AF7', 'AF3', 'AFz', 'AF4', 'AF8'],
                    'F' : ['F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8'],
                    'FC': ['FC5', 'FC3', 'FC1', 'FC2', 'FC4', 'FC6'],
                    'C' : ['C5', 'C3', 'C1', 'Cz', 'C2', 'C4' ,'C6'],
                    'CP': ['CP5', 'CP3','CP1', 'CPz', 'CP2', 'CP4', 'CP6'],
                    'P' : ['P7','P5','P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8'],
                    'PO': ['PO7','PO3', 'POz', 'PO4', 'PO8'],
                    'O' : ['Oz', 'O2', 'O1', 'Iz']
                  } 

params_dict = {
          'Electrode_Group': [electrode for group in ['AF', 'F', 'FC', 'C', 'CP', 'P', 'PO'] for electrode in Electrode_Groups[group]], # Electrode groups to include in the analysis
          'bad_electrodes': get_subject_bad_electrodes(subject_name),  # Specify bad electrodes for the subject
          'low_freq': 0.5,  # High-pass filter frequency cutoff in Hz (frequency below which to filter out)
          'high_freq': 50,  # Low-pass filter frequency cutoff in Hz (frequency above which to filter out)
          'desired_events': ['ClosePalm','OpenPalm','ActiveRest'],
          'filter_method': 'fir',
          'epoch_tmin': -4, # Start time of epoch in seconds
          'epoch_tmax': 6,  # End time of epoch in seconds
          'classifier_window_s': 0.2,
          'classifier_window_e': 4.2
          }

&nbsp;&nbsp;&nbsp;&nbsp;1.3. Load XDF files

In [None]:
xdf_files = [file for file in recording_path.glob('*.xdf') if subject_name in file.name] # Extracting subject specific XDF files
xdf_files[:]

&nbsp;&nbsp;&nbsp;&nbsp;1.4. Define preprocessing algorithm and event ID (label) standardization.

In [32]:
def EEG_preprocessing(current_path, raw, params_dict):
    """
    Preprocess EEG data by filtering, removing bad electrodes, and extracting epochs.
    
    Parameters:
        current_path (str): Path to the directory containing XDF files.
        raw (mne.io.Raw): MNE Raw object containing EEG data.
        params_dict (dict): Dictionary containing preprocessing parameters.
    
    Returns:
        mne.Epochs: Preprocessed epochs object.
    """

    # Extract parameters from the dictionary
    low_freq, high_freq, filter_method, tmin, tmax = params_dict['low_freq'], params_dict['high_freq'], params_dict['filter_method'], params_dict['epoch_tmin'], params_dict['epoch_tmax']
    if 'ACC_X' in raw.ch_names:
        raw.drop_channels(['ACC_X', 'ACC_Y', 'ACC_Z']) # Drop accelerometer channels if they exist (non-eeg channels)
    montage = mne.channels.read_custom_montage(str(current_path / "Montages" / "CACS-64_REF.bvef"), head_size=0.095, coord_frame=None)
    raw.set_montage(montage, match_case=True, match_alias=False, on_missing='raise', verbose=None) # Set montage to the raw data

    print('###########################################################' \
    '\nremoving subject specific bad electrodes from the raw data' \
    '\n###########################################################' \
    '\nremoving bad channels from epochs:')

    # Remove bad electrodes according to the subject
    bad_electrodes = set(raw.info['ch_names']).intersection(params_dict['bad_electrodes']) # Get the intersection of the raw channel names and the bad electrodes for the subject
    if len(bad_electrodes) > 0: # If there are bad electrodes, drop them
        raw.drop_channels(list(bad_electrodes))
    raw.drop_channels(raw.info['bads'])  # Drop any channels marked as bad in the raw info

    raw.set_eeg_reference(ref_channels='average')  # Set average reference
    mne.set_eeg_reference(raw, copy=False)
    print(f'\n{len(bad_electrodes)} bad electrodes were removed from the raw data: {bad_electrodes}')
    print('##########################################################')
    
    print('\n Filtering data...')
    raw = raw.filter(l_freq=low_freq, h_freq=high_freq, method=filter_method, pad='reflect_limited')
    events_from_annot, event_dict = mne.events_from_annotations(raw)
    
    print('##########################################################' \
        '\nExtracting event info:', event_dict)
    events_trigger_dict = {key: event_dict[key] for key in event_dict.keys() if key in params_dict['desired_events']} # Filter events to keep only the desired ones
    print('##########################################################')

    selected_electrodes = [elec for elec in params_dict['Electrode_Group'] if elec not in bad_electrodes] # Select electrodes that are not in the bad electrodes list
    epochs = mne.Epochs(
        raw,
        events_from_annot,
        event_id=events_trigger_dict,
        tmin=tmin,
        tmax=tmax,
        baseline=None,
        preload=True,
        detrend=0
    )
    epochs.pick(selected_electrodes)  # Ensure only selected electrodes are kept in the epochs

    # Centering the data
    centered_data_list = []
    events_list = []
    mean_across_epochs = epochs.get_data().mean(axis=0)  # Calculate the mean across epochs
    # Loop through each event ID
    for event_id in params_dict['desired_events']:
        print(event_id)
        # Extract epochs for the current event
        event_epochs = epochs[event_id]
        event_data = event_epochs.get_data()

        # Calculate the mean across epochs for the current event
        mean_across_event_epochs = event_data.mean(axis=0)

        # Subtract the mean from each epoch of the current event
        centered_event_data = event_data - mean_across_event_epochs

        # Store the centered data
        centered_data_list.append(centered_event_data)

        # Prepare the events list and event_id_map for the combined EpochsArray
        events_list.append(event_epochs.events)

    # Concatenate all centered data and events
    centered_data = np.concatenate(centered_data_list, axis=0)
    combined_events = np.concatenate(events_list, axis=0)

    # Sort the combined events based on their original occurrence time to preserve the temporal sequence
    sorted_indices = np.argsort(combined_events[:, 0])  # Sort by the first column (time)
    combined_events = combined_events[sorted_indices]
    centered_data = centered_data[sorted_indices]

    # Create a new EpochsArray with the centered data
    epochs = mne.EpochsArray(
        centered_data,
        epochs.info,
        events=combined_events,
        event_id=epochs.event_id,
        tmin=epochs.tmin
    )

    return epochs, mean_across_epochs, events_trigger_dict

standard_event_id = {'ActiveRest': 1, 'OpenPalm': 22, 'ClosePalm': 33, 'Rating': 4, 'Rest': 55}

def remap_epoch_events_to_standard(epochs, standard_event_id, desired_events):
    """
    Remap event codes to standard event IDs while keeping only desired events
    and strictly preserving the original epochs.event_id order.
    
    Parameters:
        epochs (mne.Epochs): The MNE Epochs object containing the original events.
        standard_event_id (dict): A dictionary mapping original event names to standard event IDs.
        desired_events (list): A list of event names to keep in the epochs.
    
    Returns:
        mne.Epochs: A new MNE Epochs object with remapped events.
    """
    # Copy the original event_id order explicitly
    original_order = [key for key in epochs.event_id if key in desired_events]

    # Create mappings from old numeric codes to labels
    val_to_label = {val: label for label, val in epochs.event_id.items()}

    # Remap numeric event codes to standard_event_id values explicitly
    for i, code in enumerate(epochs.events[:, 2]):
        label = val_to_label[code]
        epochs.events[i, 2] = standard_event_id[label]

    # Reconstruct epochs.event_id preserving the original order explicitly
    epochs.event_id = OrderedDict((label, standard_event_id[label]) for label in original_order)

    return epochs

&nbsp;&nbsp;&nbsp;&nbsp;1.5. Perform preprocessing and event ID (label) standardization.

In [None]:
epochs_list = []
for xdf_file in xdf_files:
    print(f'Processing file: {xdf_file}')
    raw = read_raw_xdf(xdf_file)
    epochs, mean_across_epochs, _ = EEG_preprocessing(curr_path, raw, params_dict)
    epochs = remap_epoch_events_to_standard(epochs, standard_event_id, params_dict['desired_events'])

    # Update events_trigger_dict to match new labels
    events_trigger_dict = {event: standard_event_id[event] for event in params_dict['desired_events']}
    epochs_list.append(epochs)
print('Concatenating all preprocessed epochs...')
epochs = mne.concatenate_epochs(epochs_list, on_mismatch='warn')

In [None]:
# Display metadata
epochs

In [35]:
# Create dictionary for channel names and their corresponding indices
channel_indices = {ch_name: idx for idx, ch_name in enumerate(epochs.ch_names)}
# Create arrays of electrode names
motor_cortex_electrodes = ['C5', 'C3', 'C1', 'C2', 'C4', 'C6', 'CP5', 'CP3', 'CP1', 'CP2', 'CP4', 'CP6', 'P5', 'P3', 'P1', 'P2', 'P4', 'P6']
basic_three_electrodes = ['C3', 'Cz', 'C4']
frontal_electrodes =['F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8']
# Create arrays of channel indices for each of the electrode names arrays
motor_cortex_indices = [channel_indices[ch] for ch in motor_cortex_electrodes if ch in channel_indices]
basic_three_indices = [channel_indices[ch] for ch in basic_three_electrodes if ch in channel_indices]
motor_and_frontal_indices = motor_cortex_indices + [channel_indices[ch] for ch in frontal_electrodes if ch in channel_indices]

&nbsp;&nbsp;&nbsp;&nbsp;1.6. Balance classes by subsampling classes to match size of smallest class.

In [36]:
def balance_epochs_by_subsampling(epochs):
    """
    Subsamples the specified class to match the smallest number of epochs in other classes.

    Parameters:
        epochs (mne.Epochs): MNe Epochs object with labeled events (e.g., 'Rest', 'OpenPalm', 'ClosePalm')
        class_to_subsample (string): class label to downsample.
    
    Returns:
        mne.Epochs: new Epochs object with balanced classes.
    """

    event_id = epochs.event_id
    all_classes = list(event_id.keys())

    # Count epochs in each class
    class_counts = {label: len(epochs[label]) for label in all_classes}
    min_count = min(class_counts.values()) # Find the minimum count among the classes

    # Get indices to keep
    selected_indices = []
    for label in all_classes:
        picks = epochs[label].selection
        if class_counts[label] > min_count:
            picks = np.random.choice(picks, min_count, replace=False)  # Randomly select indices for the class to subsample
        selected_indices.extend(picks)

    # Sort indices and return the new subset
    epochs = epochs[np.sort(selected_indices)]

    return epochs

epochs = balance_epochs_by_subsampling(epochs)

&nbsp;&nbsp;&nbsp;&nbsp;1.7. Crop epoch timeframes

In [37]:
epochs = epochs.crop(tmin=float(params_dict['classifier_window_s']), tmax=float(params_dict['classifier_window_e']))

&nbsp;&nbsp;&nbsp;&nbsp; 1.8. Define dataset and labels

In [38]:
X = epochs.get_data()  # Get the data from the epochs
Y = epochs.events[:, -1]

<b>2. MSPCA: Signal denoising</b>

In [None]:
def MSPCA(X, n_decomps=5):
    """
    Perform Multiscale Principal Component Analysis on the input data X.
    
    Parameters:
        X (numpy.ndarray): Input data of shape (n_epochs, n_channels, n_samples).
        n_decomps (int): Number of decomposition levels for EWT (Default is 5 as it is supported by literature).
    
    Returns:
        denoised_X (numpy.ndarray): Denoised data (of the same shape as original data) after applying MSPCA.
    """
    
    n_epochs, n_channels, n_samples = X.shape
    X_denoised = np.zeros_like(X)
    for epoch in range(n_epochs):
        # Step 1: Decompose each signal (channel) in the epoch using EWT
        ewt_coeffs = np.zeros((n_channels, n_decomps, n_samples))
        mfb_values = []
        for channel in range(n_channels):
            signal = X[epoch, channel, :]
            ewt, mfb, _ = ewtpy.EWT1D(signal, N=n_decomps)
            ewt_coeffs[channel, :, :] = ewt.T # ewt.shape == (n_samples, n_decomps), ewt_coeffs.shape[channel, :, :] == (n_decomps, n_samples) 
            mfb_values.append(mfb)
        
        # Step 2: Perform PCA on each level/mode/subband across all channels
        denoised_ewt_coeffs = np.zeros_like(ewt_coeffs)
        for level in range(n_decomps):
            level_data = ewt_coeffs[:, level, :].T # Shape (n_samples, n_channels)
            # Perform PCA
            pca = PCA()
            level_data_pca = pca.fit_transform(level_data)
            # Apply Kaiser's rule (keep components of eigenvalues >= mean_eigenvalue)
            mean_eigenvalue = np.mean(pca.explained_variance_)
            n_components_to_keep = np.sum(pca.explained_variance_ >= mean_eigenvalue)
            level_data_pca[:, n_components_to_keep:] = 0  # Set components below the mean eigenvalue to zero, effectively removing unwanted PCs
            denoised_ewt_coeffs[:, level, :] = pca.inverse_transform(level_data_pca).T # Inverse transform to get back to the original space, shape (n_channels, n_samples)
        
        # Step 3: Reconstruct the denoised signals 
        for channel in range(n_channels):
            ewt = denoised_ewt_coeffs[channel, :, :] # This instance of ewt is of shape (n_decomps, n_samples)
            slice_index = int(np.ceil(n_samples / 2)) 
            mfb = mfb_values[channel].T[:, slice_index - 1: -slice_index] # mfb_values[channel].T is of shape (n_samples*2, n_decomps)
            real = all(np.isreal(ewt[0]))
            if real: 
                reconstructed_signal = np.zeros(ewt.shape[1])
                for i in range(0, ewt.shape[0]):
                    reconstructed_signal += np.real(np.fft.ifft(np.fft.fft(ewt[i]) * mfb[i]))
            else:
                reconstructed_signal = np.zeros(ewt.shape[1]) * 0j
                for i in range(0, ewt.shape[0]):
                    reconstructed_signal += np.fft.ifft(np.fft.fft(ewt[i]) * mfb[i])
            X_denoised[epoch, channel, :] = reconstructed_signal
        
        # Step 4: Perform global PCA on the reconstructed signals
        pca = PCA()
        reconstructed_epoch = X_denoised[epoch, :, :].T
        reconstructed_epoch_pca = pca.fit_transform(reconstructed_epoch)
        # Apply Kaiser's rule (keep components of eigenvalues >= mean_eigenvalue)
        mean_eigenvalue = np.mean(pca.explained_variance_)
        n_components_to_keep = np.sum(pca.explained_variance_ >= mean_eigenvalue)
        reconstructed_epoch_pca[:, n_components_to_keep:] = 0  # Set components below the mean eigenvalue to zero, effectively removing unwanted PCs
        X_denoised[epoch, :, :] = pca.inverse_transform(reconstructed_epoch_pca).T
    
    return X_denoised

<div style="margin-left: 30px;">2.1. Perform MSPCA denoising

In [40]:
X_denoised = MSPCA(X)  # Denoise data

Sanity check (No need to run the following code cell for pipeline)

In [None]:
print(Y.shape)
print(X_denoised.shape, X.shape)
print(f"Size: {X.nbytes / 1024**2:.2f} MB")
print(f"Size: {X_denoised.nbytes / 1024**2:.2f} MB")

Back to real code

<b>3. Continuous Wavelet Transform (CWT) for transformation of data into scalograms</b>
<div style="margin-left: 30px;">CWT Turns 1D signal into a time-frequency (2D) signal that can be turned into an image.

In [None]:
os.makedirs('Scalograms', exist_ok=True)  # Create directory for scalograms if it doesn't exist
def cwt_transform(data, wavelet='morse', scales=None, samp_period=1/500, ssq_gamma=3, ssq_beta=60):
    """
    Apply Continuous Wavelet Transform (CWT) to the EEG data using either pywt (morlet) or ssqueezepy (morse).

    Parameters:
        data (numpy.ndarray or mne.Epochs): EEG data of shape (n_epochs, n_channels, n_times).
        wavelet (str): 'morlet' for pywt or 'morse' for ssqueezepy.
        scales (numpy.ndarray): Scales for the CWT. If None, default scales will be used.
        samp_period (float): Sampling period (1/sampling frequency).
        ssq_gamma (float): Morse wavelet gamma parameter (ssqueezepy only).
        ssq_beta (float): Morse wavelet beta parameter (ssqueezepy only).

    Returns:
        numpy.ndarray: Transformed data after applying CWT.
    """
    if scales is None:
        scales = np.arange(8.125, 812.5, 8.125)  # Default scales (0.5Hz to 50Hz when sampling rate == 500Hz)

    coeffs_mtrx = []
    freqs_mtrx = []
    if isinstance(data, mne.Epochs):
        data = data.get_data()  # Convert MNE Epochs to numpy array if needed

    for epoch in data:
        cwt_coeffs_matrices = []
        cwt_freqs_matrices = []
        for channel in epoch:
            if wavelet == 'morlet':
                coeffs, freqs = pywt.cwt(channel, scales, 'morl', sampling_period=samp_period)
            elif wavelet == 'morse':
                Wx, ssq_scales = ssq_cwt(channel, wavelet=('gmw', {'gamma': ssq_gamma, 'beta': ssq_beta}))
                coeffs = Wx
                # Convert scales to pseudo-frequencies (Hz) [not relevant for Morse wavelets]
                freqs = ssq_scales
            cwt_coeffs_matrices.append(coeffs)
            cwt_freqs_matrices.append(freqs)
        coeffs_mtrx.append(np.array(cwt_coeffs_matrices))
        freqs_mtrx.append(np.array(cwt_freqs_matrices))
    coeffs_mtrx = np.array(coeffs_mtrx)  # Convert list to numpy array
    freqs_mtrx = np.array(freqs_mtrx)    # Convert list to numpy array
    return coeffs_mtrx, freqs_mtrx

event_id_name = {1: 'ActiveRest' , 22: 'OpenPalm', 33: 'ClosePalm', 4: 'Rating', 55: 'Rest'}

def save_scalogram(coeffs, freqs, epoch_index=0, epoch_offset=0, channel_index=0, path='Scalograms', close=True, detailed=False):
    """
    Creates scalogram for a specific epoch and channel.
    
    Parameters:
        coeffs (numpy.ndarray): CWT coefficients of shape (n_epochs, n_channels, n_scales, n_times).
        freqs (numpy.ndarray): Frequencies or scales corresponding to the coefficients.
        epoch_index (int): Index of the epoch to plot.
        epoch_offset (int): Offset to adjust the epoch index for saving.
        channel_index (int): Index of the channel to plot.
        path (str): Path to save the scalogram image.
        close (bool): Whether to close the plot after saving.
        detailed (bool): Whether to save a detailed scalogram with axes and labels.

    Returns:
        None: Saves the scalogram image to the specified path.
    """
    
    if detailed:
        # Create path if it doesn't exist
        os.makedirs(path + '/Detailed', exist_ok=True)
        
        # Saving detailed scalogram
        plt.figure()
        plt.imshow(np.abs(coeffs[epoch_index, channel_index]), extent=[0, coeffs.shape[-1], freqs[epoch_index][channel_index][0], freqs[epoch_index][channel_index][-1]], aspect='auto', cmap='jet')
        plt.colorbar(label='Magnitude')
        plt.title(f'Scalogram for Epoch {epoch_index + epoch_offset} ({event_id_name[Y[epoch_index + epoch_offset]]}), Channel {channel_index} ({epochs.ch_names[channel_index]})')
        plt.xlabel('Time (samples)')
        plt.ylabel('Frequency (Hz) / Scales')
        plt.savefig(f'{path}/Detailed/epoch_{epoch_index + epoch_offset}-channel_{channel_index}.png', bbox_inches='tight', pad_inches=0)
        if close:
            plt.close()

    # Saving clean scalogram
    os.makedirs(path + '/Clean', exist_ok=True)
    plt.figure(figsize=(2.90, 2.91))
    plt.imshow(np.abs(coeffs[epoch_index, channel_index]), extent=[0, coeffs.shape[-1], freqs[epoch_index][channel_index][0], freqs[epoch_index][channel_index][-1]], aspect='auto', cmap='jet')
    plt.axis('off')  # Turn off the axis
    plt.savefig(f'{path}/Clean/epoch_{epoch_index + epoch_offset}-channel_{channel_index}.png', dpi=100, bbox_inches='tight', pad_inches=0)
    if close:
        plt.close()

def scalogram_transform_and_save(dataset, wavelet='morse', save_path = 'Scalograms', chunks=1):
    '''
    Transform dataset to scalogram using Continuous Wavelet Transform (CWT), either with morlet mother wavelet or generalized morse wavelets.

    Parameters:
        dataset (numpy.ndarray): EEG data of shape (n_epochs, n_channels, n_times).
        data_type (str): 'Noisy' or 'Denoised' to specify the type of data being transformed.
        wavelet (str): Wavelet parameter of the CWT algorithm ('morlet' or 'morse').
        save_path (str): Path to save the scalograms.
        chunks (int): Number of chunks to split the dataset into for processing (to avoid memory allocation issues). Default is 1 (process the entire dataset at once).

    Return:
        None.
    '''
    intervals = np.linspace(0, dataset.shape[0], chunks + 1, dtype=int)
    offset = 0
    for i in range(chunks):
        X = dataset[intervals[i] : intervals[i + 1]] # Portion of the dataset to transfrom using CWT
        offset = intervals[i]
        # CWT on data portion
        if wavelet == 'morlet':
            cwt_coeffs, cwt_freqs = cwt_transform(X, wavelet='morlet', scales=pywt.frequency2scale('morl', (np.arange(8, 30.1, 0.1) / epochs.info['sfreq'])), samp_period=1/epochs.info['sfreq'])
        elif wavelet == 'morse':
            cwt_coeffs, cwt_freqs = cwt_transform(X, wavelet='morse', samp_period=1/epochs.info['sfreq'], ssq_gamma=3, ssq_beta=60)
            scales_size = cwt_coeffs.shape[2]
            # Cropping top and bottom 1/4 of the scales because they don't contain any useful information
            cwt_coeffs = cwt_coeffs[:, :, scales_size // 4: scales_size - (scales_size // 4), :]
            cwt_freqs = cwt_freqs[:, :, scales_size // 4: scales_size - (scales_size // 4)]

        
        # Save the scalograms
        for epoch in range(X.shape[0]):
            for channel in range(X.shape[1]):
                save_scalogram(cwt_coeffs, cwt_freqs, epoch_index=epoch, epoch_offset=offset, channel_index=channel, path=save_path)

<div style="margin-left: 30px;">3.1. <b> Choose from two options</b>: pywt's CWT function (using <u><i>morlet mother wavelet</i></u>) or ssquuezepy's CWT function (using <u><i>generalized morse wavelets</i></u>). <br> 
Empirical research shows a clear preference for morse wavelets when performing CWT, pointing to better time-frequency localization. <br>
Also, due to memory limitations, choose whether to continue running the pipeline with noisy or with denoised data. For the same reason, we split the process into chunks.

In [43]:
data_type = 'Denoised' # 'Denoised' or 'Noisy'
wavelet = 'morse' # 'morse' or 'morlet'
save_path = 'Scalograms/' + data_type + '/All Data' # Path to save scalograms for entire dataset
os.makedirs(save_path, exist_ok=True)
chunks = 9 # Number of chunks to split the data into for processing

<div style="margin-left: 30px;">3.2. Perform CWT on dataset (noisy or denoised) and save as scalograms in storage (may take over an hour).

In [44]:
if data_type == 'Denoised':
    X = X_denoised
elif data_type == 'Noisy':
    if 'X_denoised' in locals() or 'X_denoised' in globals():
        del X_denoised
scalogram_transform_and_save(X, wavelet=wavelet, save_path=save_path, chunks=chunks)

<div style="margin-left: 30px;">3.3. Create a file containing mappings from epochs to labels for future use.

In [15]:
# Create epoch to class mappings for training set
epoch_class_mapping = {}
for epoch_idx in range(len(Y)):
    class_label = event_id_name[Y[epoch_idx]]
    epoch_class_mapping[epoch_idx] = class_label

# Create summary statistics
class_counts = {}
classes = ['ActiveRest', 'OpenPalm', 'ClosePalm']
for class_name in classes:
    class_counts[class_name] = sum(1 for label in epoch_class_mapping.values() if label == class_name)

# Save the mappings and statistics
epoch_mappings = {
    'epoch_to_class': epoch_class_mapping,
    'class_counts': class_counts,
    'classes': classes,
    'total_epochs_size': len(epoch_class_mapping)
}

# Save as numpy file for easy loading
np.save('epoch_class_mappings.npy', epoch_mappings)

print(f"Dataset: {class_counts}")
print(epoch_mappings['epoch_to_class'][2])

Dataset: {'ActiveRest': 30, 'OpenPalm': 30, 'ClosePalm': 30}
ClosePalm


Sanity check

In [16]:
ActiveRest = []
ClosePalm = []
OpenPalm = []
for i in range(len(Y)):
    if event_id_name[Y[i]] == 'ActiveRest':
        ActiveRest.append(i)
    elif event_id_name[Y[i]] == 'ClosePalm':
        ClosePalm.append(i)
    elif event_id_name[Y[i]] == 'OpenPalm':
        OpenPalm.append(i)
    else:
        print(f'Unexpected label {Y[i]} at index {i}')

print("ActiveRest indices:", ActiveRest)
print("ClosePalm indices: ", ClosePalm)
print("OpenPalm indices:  ", OpenPalm)
print(X.shape)

ActiveRest indices: [0, 3, 5, 7, 8, 10, 12, 13, 21, 23, 24, 26, 30, 31, 33, 34, 38, 39, 42, 52, 54, 58, 59, 61, 68, 73, 75, 76, 79, 89]
ClosePalm indices:  [2, 4, 9, 15, 18, 25, 27, 28, 32, 36, 41, 43, 45, 49, 51, 53, 56, 60, 65, 69, 70, 72, 78, 80, 81, 82, 84, 85, 86, 88]
OpenPalm indices:   [1, 6, 11, 14, 16, 17, 19, 20, 22, 29, 35, 37, 40, 44, 46, 47, 48, 50, 55, 57, 62, 63, 64, 66, 67, 71, 74, 77, 83, 87]
(90, 48, 2001)


<b>4. Stacking and resizing of scalograms</b>
<div style="margin-left: 30px;">Here we process and adapt the scalograms to the neural netwrok specifications (for example: Shufflenet requires images of size 224x224)

<div style="margin-left: 30px;">4.1. Import epoch ID to class mappings from storage.

In [17]:
epoch_mappings = np.load('epoch_class_mappings.npy', allow_pickle=True).item()
# Create a DataFrame from the epoch mappings
full_dataset = pd.DataFrame.from_dict(epoch_mappings['epoch_to_class'], orient='index', columns=['label'])
full_dataset = full_dataset.reset_index().rename(columns={'index': 'epoch_id'})
print(full_dataset.loc[full_dataset['epoch_id'] == 2].values[0])

[2 'ClosePalm']


<div style="margin-left: 30px;">4.2. Choose whether to proceed with noisy data or denoised data.

In [None]:
data_type = 'Denoised' # 'Denoised' or 'Noisy'
os.makedirs('Scalograms/' + data_type + '/All Data/Stacked scalograms', exist_ok=True) # Create directory for stacked and resized scalograms

<div style="margin-left: 30px;">4.3. Define function that combines scalograms from different channels into one image for each epoch.  
The function can stack them vertically or horizontally.

In [None]:
def stack_scalograms(image_paths, output_path, plane='vertical', target_size=None):
    """
    Stack multiple scalograms (PNG images) vertically/horizontally.
    
    Parameters:
        image_paths (list): List of paths to the PNG images
        output_path (str): Path where the combined image will be saved
        plane (str): 'vertical' or 'horizontal' to specify the stacking direction
        target_size (tuple): Optional target size (width, height) to resize the combined image before saving
    """
    # Read all images
    images = [cv2.imread(path) for path in image_paths]
    if plane == 'vertical':
        # Get the minimum width to resize all images to the same width
        min_width = min(img.shape[1] for img in images)
        
        # Resize images to have the same width
        resized_images = []
        for img in images:
            if img.shape[1] != min_width:
                aspect_ratio = img.shape[0] / img.shape[1]
                new_height = int(min_width * aspect_ratio)
                resized_img = cv2.resize(img, (min_width, new_height))
            else:
                resized_img = img
            resized_images.append(resized_img)
        
        # Stack vertically
        combined_image = cv2.vconcat(resized_images)
    
    if plane == 'horizontal':
        # Get the minimum height to resize all images to the same height
        min_height = min(img.shape[0] for img in images)
    
        # Resize images to have the same height
        resized_images = []
        for img in images:
            if img.shape[0] != min_height:
                aspect_ratio = img.shape[1] / img.shape[0]
                new_width = int(min_height * aspect_ratio)
                resized_img = cv2.resize(img, (new_width, min_height))
            else:
                resized_img = img
            resized_images.append(resized_img)

        # Stack horizontally
        combined_image = cv2.hconcat(resized_images)
    
    # Save the combined image (resize if necessary)
    if target_size != None:
        combined_image = cv2.resize(combined_image, target_size)
    cv2.imwrite(output_path, combined_image)
    return combined_image

<div style="margin-left: 30px;">4.4. Choose <b>ONE</b> of the following options:

<div style="margin-left: 40px;"> <b><u>Option 1</u></b>: Considering the limited size of the CNN input (224x224) and the size of bandwidth (Mu + Beta = 8Hz-30Hz) and trial duration (4s after crop) <br> we start small and only combine 3 scalograms: C3, Cz, C4.
<br><br>
<u><b>Option 2</b></u>: A more advanced approach that includes more channels with the hope of improving data quality. <br> We are going to combine 18 scalograms from motor cortex channels C5, C3, C1, C2, C4, C6, CP5, CP3, CP1, CP2, CP4, CP6, P5, P3, P1, P2, P4, P6.<br>First, we will vertically stack each half of the 18 images, resulting in 2 stacked images. <br>
Then, we will horizontally stack the 2 images and create one stacked image of size 2 [scalograms] by 9 [scalograms].
<br><br>
<u><b>Option 3</b></u>: Experimental approach which includes the frontal channels on top of the 18 motor cortex channels of the previous option. <br> The channels are: C5, C3, C1, C2, C4, C6, CP5, CP3, CP1, CP2, CP4, CP6, P5, P3, P1, P2, P4, P6, F7, F5, F3, F1, Fz, F2, F4, F6, F8. <br> 
The final image will be a 3 by 9 stacked image.

In [20]:
chosen_option = 3 #1, 2, or 3
resize = (224, 224) # Resize the stacked scalograms to this size

In [None]:
if chosen_option == 1:
    for i in range(full_dataset.shape[0]):
        image_paths = []
        for j in basic_three_indices:
            image_paths.append('Scalograms/' + data_type + f'/All Data/Clean/epoch_{i}-channel_{j}.png')
        stack_scalograms(image_paths, 'Scalograms/' + data_type + f'/All Data/Stacked scalograms/epoch_{i}-channels_{"_".join(map(str, basic_three_indices))}.png', plane='vertical', target_size=resize)
    file_pattern = 'channels_20_22_24.png'
if chosen_option == 2:
    for i in range(full_dataset.shape[0]):
        image_paths1 = []
        image_paths2 = []
        for k, j in enumerate(motor_cortex_indices):
            if k < len(motor_cortex_indices) // 2:
                image_paths1.append('Scalograms/' + data_type + f'/All Data/Clean/epoch_{i}-channel_{j}.png')
            else:
                image_paths2.append('Scalograms/' + data_type + f'/All Data/Clean/epoch_{i}-channel_{j}.png')
        output_path1 = 'Scalograms/' + data_type + f'/All Data/Stacked scalograms/epoch_{i}-channels_{"_".join(map(str, motor_cortex_indices[:len(motor_cortex_indices) // 2]))}.png'
        output_path2 = 'Scalograms/' + data_type + f'/All Data/Stacked scalograms/epoch_{i}-channels_{"_".join(map(str, motor_cortex_indices[len(motor_cortex_indices) // 2:]))}.png'
        # Stack the first half of channels vertically
        stack_scalograms(image_paths1, output_path1, plane='vertical')
        # Stack the second half of channels vertically
        stack_scalograms(image_paths2, output_path2, plane='vertical')
        # Stack the two halves horizontally
        stack_scalograms([output_path1, output_path2], 'Scalograms/' + data_type + f'/All Data/Stacked scalograms/epoch_{i}-18_motor_cortex_channels.png', plane='horizontal', target_size=resize)
        # Delete the intermediate images
        os.remove(output_path1)
        os.remove(output_path2)
    file_pattern = '18_motor_cortex_channels.png'
if chosen_option == 3:
    for i in range(full_dataset.shape[0]):
        image_paths1 = []
        image_paths2 = []
        image_paths3 = []
        for k, j in enumerate(motor_and_frontal_indices):
            if k < len(motor_and_frontal_indices) // 3:
                image_paths1.append('Scalograms/' + data_type + f'/All Data/Clean/epoch_{i}-channel_{j}.png')
            elif k < 2 * len(motor_and_frontal_indices) // 3:
                image_paths2.append('Scalograms/' + data_type + f'/All Data/Clean/epoch_{i}-channel_{j}.png')
            else:
                image_paths3.append('Scalograms/' + data_type + f'/All Data/Clean/epoch_{i}-channel_{j}.png')
        output_path1 = 'Scalograms/' + data_type + f'/All Data/Stacked scalograms/epoch_{i}-channels_{"_".join(map(str, motor_and_frontal_indices[:len(motor_and_frontal_indices) // 3]))}.png'
        output_path2 = 'Scalograms/' + data_type + f'/All Data/Stacked scalograms/epoch_{i}-channels_{"_".join(map(str, motor_and_frontal_indices[len(motor_and_frontal_indices) // 3:2 * len(motor_and_frontal_indices) // 3]))}.png'
        output_path3 = 'Scalograms/' + data_type + f'/All Data/Stacked scalograms/epoch_{i}-channels_{"_".join(map(str, motor_and_frontal_indices[2 * len(motor_and_frontal_indices) // 3:]))}.png'
        # Stack the first third of channels vertically
        stack_scalograms(image_paths1, output_path1, plane='vertical')
        # Stack the second third of channels vertically
        stack_scalograms(image_paths2, output_path2, plane='vertical')
        # Stack the last third of channels vertically
        stack_scalograms(image_paths3, output_path3, plane='vertical')
        # Stack the three thirds horizontally
        stack_scalograms([output_path1, output_path2, output_path3], 'Scalograms/' + data_type + f'/All Data/Stacked scalograms/epoch_{i}-27_motor_and_frontal_channels.png', plane='horizontal', target_size=resize)
        # Delete the intermediate images
        os.remove(output_path1)
        os.remove(output_path2)
        os.remove(output_path3)
    file_pattern = '27_motor_and_frontal_channels.png'

<b>5. Nested cross validation pipeline of CNN transfer learning.</b>

<div style="margin-left: 30px;">5.1. Choose the <u>model</u> and classification <u>classes</u> (<i> 'ActiveRest vs. ClosePalm' / 'ActiveRest vs. OpenPalm' / 'ClosePalm vs. OpenPalm' / 'All classes'</i>).

In [None]:
model_name = 'shufflenet_v2'  # 'alexnet', 'shufflenet_v2', 'convnext_t', 'convnext_s' or 'efficientnet_v2'
chosen_classes = ['ClosePalm', 'OpenPalm'] # ['ActiveRest', 'ClosePalm'] or ['ActiveRest', 'OpenPalm'] or ['ClosePalm', 'OpenPalm'] or ['ActiveRest', 'ClosePalm', 'OpenPalm']

if len(chosen_classes) == 2:
    class_mappings = {chosen_classes[0]: 'A', chosen_classes[1]: 'B'} # Mapping for the classes to be used in the cross-validation
    dataset = full_dataset[full_dataset['label'].isin(chosen_classes)].reset_index(drop=True)
else: 
    dataset = full_dataset.copy()
cv_path = 'Scalograms/' + data_type
os.makedirs(cv_path + '/Test/A', exist_ok=True)
os.makedirs(cv_path + '/Test/B', exist_ok=True)
if len(chosen_classes) == 3:
    os.makedirs(cv_path + '/Test/C', exist_ok=True)
else:
    shutil.rmtree(cv_path + '/Test/C', ignore_errors=True)
os.makedirs(cv_path + '/Outer Train', exist_ok=True)
os.makedirs(cv_path + '/Outer Train/All Train', exist_ok=True)
os.makedirs(cv_path + '/Outer Train/Train/A', exist_ok=True)
os.makedirs(cv_path + '/Outer Train/Train/B', exist_ok=True)
if len(chosen_classes) == 3:
    os.makedirs(cv_path + '/Outer Train/Train/C', exist_ok=True)
else:
    shutil.rmtree(cv_path + '/Outer Train/Train/C', ignore_errors=True)
os.makedirs(cv_path + '/Outer Train/Valid/A', exist_ok=True)
os.makedirs(cv_path + '/Outer Train/Valid/B', exist_ok=True)
if len(chosen_classes) == 3:
    os.makedirs(cv_path + '/Outer Train/Valid/C', exist_ok=True)
else:
    shutil.rmtree(cv_path + '/Outer Train/Valid/C', ignore_errors=True)
os.makedirs('Results', exist_ok=True)

def move_epochs(X, epoch_indices, src_folder, dst_folder, file_pattern, copy=False):
    '''
    Moves or copies files corresponding to the specified epoch indices from the source folder to the destination folder.

    Parameters:
        X (numpy.ndarray): Array of epoch data.
        epoch_indices (list): List of indices of epochs to move or copy.
        src_folder (str): Path to the source folder containing the files.
        dst_folder (str): Path to the destination folder where files will be moved or copied.
        file_pattern (str): Pattern for the filenames, e.g., 'epoch_{}.png'.
        copy (bool): If True, copies files; if False, moves files.
    
    Returns:
        None.
    '''
    os.makedirs(dst_folder, exist_ok=True)
    for idx in epoch_indices:
        src_path = os.path.join(src_folder, file_pattern.format(X[idx]))
        dst_path = os.path.join(dst_folder, file_pattern.format(X[idx]))
        if os.path.exists(src_path):
            if copy:
                shutil.copy(src_path, dst_path)
            else:
                shutil.move(src_path, dst_path)

def delete_all_files_in_folder(folder_path):
    """
    Deletes all files in the specified folder.
    
    Parameters:
        directory (str): Path to the directory from which files will be deleted
    """
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        if os.path.isfile(file_path):
            os.remove(file_path)

<div style="margin-left: 30px;">5.2. Execute 5x4 nested cross-validation and save results. <br>
<i> Note.</i> If the code gets interrupted, make sure all cross validation directories are empty before running again.

In [None]:
X = dataset['epoch_id'].values
y = dataset['label'].values
stack_type = ['3 Channels', '18 Channels', '27 Channels']
if len(chosen_classes) == 2:
    results_path = f'Results/{subject_name}/{stack_type[chosen_option - 1]}/{model_name}/{chosen_classes[0]} vs. {chosen_classes[1]}'
else:
    results_path = f'Results/{subject_name}/{stack_type[chosen_option - 1]}/{model_name}/all classes'
os.makedirs(results_path, exist_ok=True)

outer_cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
inner_cv = StratifiedKFold(n_splits=4, shuffle=True, random_state=42)

# Define hyperparameters to test
optimizers = ['adamw', 'rmsprop', 'sgdm']
learning_rates = [0.001, 0.0001]
batch_sizes = [8, 16, 32]

# Create a DataFrame to store the results of the inner cross-validation
inner_results = pd.DataFrame(list(itertools.product(optimizers, learning_rates, batch_sizes)), columns=['optimizer', 'learning_rate', 'batch_size'])
final_model_mean_accuracy = 0.0

# Perform 5x4 nested cross-validation
for i, (outer_train_indices, test_indices) in enumerate(outer_cv.split(X, y)):
    print(f'Outer fold {i + 1}')
    class_A_test_indices = test_indices[y[test_indices] == chosen_classes[0]]
    class_B_test_indices = test_indices[y[test_indices] == chosen_classes[1]]
    if len(chosen_classes) == 3:
        class_C_test_indices = test_indices[y[test_indices] == chosen_classes[2]]

    move_epochs(X, outer_train_indices, cv_path + '/All Data/Stacked scalograms', cv_path + '/Outer Train/All Train', 'epoch_{}-' + file_pattern, copy=True)
    move_epochs(X, class_A_test_indices, cv_path + '/All Data/Stacked scalograms', cv_path + '/Test/A', 'epoch_{}-' + file_pattern, copy=True)
    move_epochs(X, class_B_test_indices, cv_path + '/All Data/Stacked scalograms', cv_path + '/Test/B', 'epoch_{}-' + file_pattern, copy=True)
    if len(chosen_classes) == 3:
        move_epochs(X, class_C_test_indices, cv_path + '/All Data/Stacked scalograms', cv_path + '/Test/C', 'epoch_{}-' + file_pattern, copy=True)

    inner_results['mean_validation_accuracy'] = 0.0
    y_outer_train = y[outer_train_indices]  # Get labels for outer training set
    for j, (inner_train_indices, valid_indices) in enumerate(inner_cv.split(outer_train_indices, y_outer_train)):
        print(f'  Inner fold {j + 1}')
        class_A_train_indices = outer_train_indices[inner_train_indices[y_outer_train[inner_train_indices] == chosen_classes[0]]]
        class_B_train_indices = outer_train_indices[inner_train_indices[y_outer_train[inner_train_indices] == chosen_classes[1]]]
        class_A_valid_indices = outer_train_indices[valid_indices[y_outer_train[valid_indices] == chosen_classes[0]]]
        class_B_valid_indices = outer_train_indices[valid_indices[y_outer_train[valid_indices] == chosen_classes[1]]]
        if len(chosen_classes) == 3:
            class_C_train_indices = outer_train_indices[inner_train_indices[y_outer_train[inner_train_indices] == chosen_classes[2]]]
            class_C_valid_indices = outer_train_indices[valid_indices[y_outer_train[valid_indices] == chosen_classes[2]]]
        
        move_epochs(X, class_A_train_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train/Train/A', 'epoch_{}-' + file_pattern, copy=True)
        move_epochs(X, class_B_train_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train/Train/B', 'epoch_{}-' + file_pattern, copy=True)
        move_epochs(X, class_A_valid_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train/Valid/A', 'epoch_{}-' + file_pattern, copy=True)
        move_epochs(X, class_B_valid_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train/Valid/B', 'epoch_{}-' + file_pattern, copy=True)
        if len(chosen_classes) == 3:
            move_epochs(X, class_C_train_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train/Train/C', 'epoch_{}-' + file_pattern, copy=True)
            move_epochs(X, class_C_valid_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train/Valid/C', 'epoch_{}-' + file_pattern, copy=True)
        train_path = cv_path + '/Outer Train'
        
        # Perform model training for each of the hyperparameters and track the results in some sort of table
        for optimizer in optimizers:
            for learning_rate in learning_rates:
                for batch_size in batch_sizes:
                    row = (
                        (inner_results['optimizer'] == optimizer) &
                        (inner_results['learning_rate'] == learning_rate) &
                        (inner_results['batch_size'] == batch_size)
                    )
                    inner_results.loc[row, 'mean_validation_accuracy'] += NeuralNetworks.train_model(
                        model_name, train_path, optimizer, learning_rate, batch_size
                    )
        
        # Delete all files in the train and valid folders
        delete_all_files_in_folder(cv_path + '/Outer Train/Train/A')
        delete_all_files_in_folder(cv_path + '/Outer Train/Train/B')
        delete_all_files_in_folder(cv_path + '/Outer Train/Valid/A')
        delete_all_files_in_folder(cv_path + '/Outer Train/Valid/B')
        if len(chosen_classes) == 3:
            delete_all_files_in_folder(cv_path + '/Outer Train/Train/C')
            delete_all_files_in_folder(cv_path + '/Outer Train/Valid/C')

    # Calculate the mean validation accuracy for each hyperparameter combination
    inner_results['mean_validation_accuracy'] /= inner_cv.get_n_splits()
    if len(chosen_classes) == 2:
        inner_results.to_csv(results_path+ f'/{chosen_classes[0]}_vs_{chosen_classes[1]}_inner_results_fold_{i + 1}.csv', index=False)
    else:
        inner_results.to_csv(results_path+ f'/All_Classes_inner_results_fold_{i + 1}.csv', index=False)
    # Find the best hyperparameters based on the mean validation accuracy
    best_hyperparams = inner_results.loc[inner_results['mean_validation_accuracy'].idxmax()]
    print(f'Best hyperparameters for outer fold {i + 1}: {best_hyperparams.to_dict()}')
    
    # Split the outer training set into train and valid sets for retraining of final model
    retrain_split = StratifiedShuffleSplit(n_splits=1, test_size=0.25, random_state=42)
    for train_indices, valid_indices in retrain_split.split(outer_train_indices, y_outer_train):
        class_A_train_indices = outer_train_indices[train_indices[y_outer_train[train_indices] == chosen_classes[0]]]
        class_B_train_indices = outer_train_indices[train_indices[y_outer_train[train_indices] == chosen_classes[1]]]
        class_A_valid_indices = outer_train_indices[valid_indices[y_outer_train[valid_indices] == chosen_classes[0]]]
        class_B_valid_indices = outer_train_indices[valid_indices[y_outer_train[valid_indices] == chosen_classes[1]]]
        if len(chosen_classes) == 3:
            class_C_train_indices = outer_train_indices[train_indices[y_outer_train[train_indices] == chosen_classes[2]]]
            class_C_valid_indices = outer_train_indices[valid_indices[y_outer_train[valid_indices] == chosen_classes[2]]]
        
        move_epochs(X, class_A_train_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train' + '/Train/A', 'epoch_{}-' + file_pattern, copy=True)
        move_epochs(X, class_B_train_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train' + '/Train/B', 'epoch_{}-' + file_pattern, copy=True)
        move_epochs(X, class_A_valid_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train' + '/Valid/A', 'epoch_{}-' + file_pattern, copy=True)
        move_epochs(X, class_B_valid_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train' + '/Valid/B', 'epoch_{}-' + file_pattern, copy=True)
        if len(chosen_classes) == 3:
            move_epochs(X, class_C_train_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train' + '/Train/C', 'epoch_{}-' + file_pattern, copy=True)
            move_epochs(X, class_C_valid_indices, cv_path + '/Outer Train' + '/All Train', cv_path + '/Outer Train' + '/Valid/C', 'epoch_{}-' + file_pattern, copy=True)

    # Retrain the model with the best hyperparameters on the training and validation sets
    NeuralNetworks.train_model(model_name,
                               cv_path + '/Outer Train',
                               optimizer=best_hyperparams['optimizer'],
                               learning_rate=float(best_hyperparams['learning_rate']),
                               bs=int(best_hyperparams['batch_size']),
                               save_model=True)
    
    # Delete all files in the train and valid folders
    delete_all_files_in_folder(cv_path + '/Outer Train' + '/Train/A')
    delete_all_files_in_folder(cv_path + '/Outer Train' + '/Train/B')
    delete_all_files_in_folder(cv_path + '/Outer Train' + '/Valid/A')
    delete_all_files_in_folder(cv_path + '/Outer Train' + '/Valid/B')
    if len(chosen_classes) == 3:
        delete_all_files_in_folder(cv_path + '/Outer Train' + '/Train/C')
        delete_all_files_in_folder(cv_path + '/Outer Train' + '/Valid/C')

    # Evaluate the model on the test set
    current_final_model_accuracy, conf_mat, clssf_rep = \
        NeuralNetworks.evaluate_final_model(model_name=model_name, 
                                            test_path=cv_path + '/Test',
                                            model_path='final_model.pt',
                                            num_classes=len(chosen_classes))
    
    if len(chosen_classes) == 2:
        np.savetxt(results_path + f'/{chosen_classes[0]}_vs_{chosen_classes[1]}_confusion_matrix_{i + 1}.csv', conf_mat, delimiter=',', fmt='%d')
        with open(results_path + f'/{chosen_classes[0]}_vs_{chosen_classes[1]}_classification_report_{i + 1}.txt', 'w') as f:
            f.write(clssf_rep)
    else:
        np.savetxt(results_path + f'/All_Classes_confusion_matrix_{i + 1}.csv', conf_mat, delimiter=',', fmt='%d')
        with open(results_path + f'/All_Classes_classification_report_{i + 1}.txt', 'w') as f:
            f.write(clssf_rep)
    final_model_mean_accuracy += current_final_model_accuracy
    print(f'Final model accuracy for current outer fold: {current_final_model_accuracy:.4f}')
    
    # Delete all files in the outer train and test folders
    delete_all_files_in_folder(cv_path + '/Outer Train/All Train')
    delete_all_files_in_folder(cv_path + '/Test/A')
    delete_all_files_in_folder(cv_path + '/Test/B')
    if len(chosen_classes) == 3:
        delete_all_files_in_folder(cv_path + '/Test/C')


# Calculate the mean accuracy across all outer folds
final_model_mean_accuracy /= outer_cv.get_n_splits()

# Save the final model mean accuracy to a file
if len(chosen_classes) == 2:
    with open(results_path + f'/{chosen_classes[0]}_vs_{chosen_classes[1]}_final_model_mean_accuracy.txt', 'w') as f:
        f.write(f'Mean accuracy of the final model across all outer folds: {final_model_mean_accuracy:.4f}\n')
else:
    with open(results_path + f'/All_Classes_final_model_mean_accuracy.txt', 'w') as f:
        f.write(f'Mean accuracy of the final model across all outer folds: {final_model_mean_accuracy:.4f}\n')
print(f'Mean accuracy of the final model across all outer folds: {final_model_mean_accuracy:.4f}')