In [2]:
import numpy as np
import pandas as pd
import mne
import scipy.stats as stats
import pyentrp.entropy as ent
import pywt
import matplotlib.pyplot as plt

In [3]:
def load_eeg_data(participant_id, file_path_template="eegs/{participant_id}.bdf"):
    """
    Load EEG data for a given participant.
    
    Args:
    - participant_id (str): Participant ID.
    - file_path_template (str): Template for file path.
    
    Returns:
    - raw_eeg (mne.io.Raw): The raw EEG data.
    """
    file_path = file_path_template.format(participant_id=participant_id)
    raw = mne.io.read_raw_bdf(file_path, preload=True)
    
    # Keep only EEG channels
    raw_eeg = raw.pick_types(eeg=True)
    return raw_eeg


def reorder_eeg_channels(raw_eeg, channel_names_geneva):
    """
    Reorder EEG channels to Geneva order.
    
    Args:
    - raw_eeg (mne.io.Raw): The raw EEG data.
    - channel_names_geneva (list): List of channel names in Geneva order.
    
    Returns:
    - raw_reordered (mne.io.Raw): Reordered EEG data.
    """
    raw_reordered = raw_eeg.reorder_channels(channel_names_geneva)
    return raw_reordered


def process_participants(participant_ids, channel_names_geneva):
    """
    Process EEG data for all participants, load and reorder channels.
    
    Args:
    - participant_ids (list): List of participant IDs.
    - channel_names_geneva (list): List of channel names in Geneva order.
    
    Returns:
    - eeg_data_reordered (dict): Dictionary with participant ID as keys and reordered EEG data as values.
    """
    eeg_data_reordered = {}
    
    for participant_id in participant_ids:
        raw_eeg = load_eeg_data(participant_id)
        raw_reordered = reorder_eeg_channels(raw_eeg, channel_names_geneva)
        
        # Store reordered EEG data
        eeg_data_reordered[participant_id] = raw_reordered
    
    return eeg_data_reordered


def print_eeg_channel_info(participant_id, raw, raw_reordered):
    """
    Prints channel information before and after reordering for a participant.
    
    Args:
    - participant_id (str): The ID of the participant.
    - raw (mne.io.Raw): The original raw EEG data.
    - raw_reordered (mne.io.Raw): The reordered EEG data.
    """
    print(f"Participant ID: {participant_id}")
    
    # Print channel names before and after reordering
    print("Channel names before reordering:", raw.info['ch_names'])
    print("Channel names after reordering:", raw_reordered.info['ch_names'])
    
    # Print EEG data shape before and after reordering
    print("EEG data shape before reordering:", raw._data.shape)
    print("EEG data shape after reordering:", raw_reordered._data.shape)

def plot_sensor_locations(raw, raw_reordered, show_plots=True):
    """
    Plots sensor locations before and after reordering.
    
    Args:
    - raw (mne.io.Raw): The original raw EEG data.
    - raw_reordered (mne.io.Raw): The reordered EEG data.
    - show_plots (bool): Whether to display the plots.
    """
    if show_plots:
        # Plot channel locations before reordering
        fig_before = raw.plot_sensors(show_names=True, title="Channel Locations Before Reordering")
        plt.show()
        
        # Plot channel locations after reordering
        fig_after = raw_reordered.plot_sensors(show_names=True, title="Channel Locations After Reordering")
        plt.show()


def process_participant_info(eeg_data_reordered, raw_data, show_plots=False):
    """
    Process and display information for each participant, including channel names, 
    data shapes, and sensor plots.
    
    Args:
    - eeg_data_reordered (dict): Dictionary with participant IDs as keys and reordered EEG data as values.
    - raw_data (dict): Dictionary with participant IDs as keys and original raw EEG data as values.
    - show_plots (bool): Whether to display sensor plots.
    """
    for participant_id, raw_reordered in eeg_data_reordered.items():
        raw = raw_data[participant_id]
        
        # Print EEG channel and data shape information
        print_eeg_channel_info(participant_id, raw, raw_reordered)
        
        # Optionally plot sensor locations
        if show_plots:
            plot_sensor_locations(raw, raw_reordered, show_plots)




In [None]:

# Define parameters
channel_names_geneva = ['Fp1', 'AF3', 'F3', 'F7', 'FC5', 'FC1', 'C3', 'T7', 'CP5', 'CP1', 'P3', 'P7', 'PO3', 'O1', 'Oz', 'Pz', 'Fp2', 'AF4', 'Fz', 'F4', 'F8', 'FC6', 'FC2', 'Cz', 'C4', 'T8', 'CP6', 'CP2', 'P4', 'P8', 'PO4', 'O2']
participant_ids = [f"s{i:02d}" for i in range(1, 3)]  # Assuming 5 participants

# Load the raw EEG data for each participant
raw_data = {participant_id: load_eeg_data(participant_id) for participant_id in participant_ids}

# Process (reorder) the participants' data
eeg_data_reordered = process_participants(participant_ids, channel_names_geneva)

# Print information about the loaded and reordered data
process_participant_info(eeg_data_reordered, raw_data, show_plots=False)  # Toggle `show_plots` to True if you want to see plots


In [5]:
def load_participant_ratings(file_path):
    """
    Load participant ratings from a CSV file.
    
    Args:
    - file_path (str): Path to the CSV file containing participant ratings.
    
    Returns:
    - DataFrame: A pandas DataFrame containing the participant ratings.
    """
    df = pd.read_csv(file_path)
    return df

def sort_ratings(df):
    """
    Sort the DataFrame by Participant ID and Experiment ID.
    
    Args:
    - df (DataFrame): The DataFrame containing participant ratings.
    
    Returns:
    - DataFrame: The sorted DataFrame.
    """
    df_sorted = df.sort_values(by=['Participant_id', 'Experiment_id'])
    return df_sorted


In [6]:
def extract_trial_data(participant_id_str, participant_data, raw_data):
    """
    Extracts trial data for a specific participant.
    
    Args:
    - participant_id_str (str): The ID of the participant.
    - participant_data (DataFrame): The DataFrame containing trial information for the participant.
    - raw_data (mne.io.Raw): The raw EEG data for the participant.
    
    Returns:
    - List: A list containing trial data for the participant.
    """
    trials = []
    
    # Iterate over each trial for the current participant
    for index, trial in participant_data.iterrows():
        # Extract the start time of the trial (in seconds)
        start_time = trial['Start_time'] / 1e6  # Convert microseconds to seconds
        
        # Define the start and end time of the trial (assuming 1-minute duration)
        end_time = start_time + 60  # 1-minute duration
        
        # Extract the trial data based on the start and end time
        trial_data = raw_data.copy().crop(tmin=start_time, tmax=end_time)
        
        # Store the trial data in the list for the current participant
        trials.append(trial_data)
        
        # Print participant ID and trial information (optional)
        # print(participant_id_str, trial)
    
    return trials

def process_participant_trials(participant_ids, trial_info, eeg_data_reordered):
    """
    Processes trial data for all participants.
    
    Args:
    - participant_ids (list): List of participant IDs.
    - trial_info (DataFrame): The sorted DataFrame containing trial information.
    - eeg_data_reordered (dict): Dictionary with participant IDs as keys and reordered EEG data as values.
    
    Returns:
    - dict: A dictionary containing trial data for each participant.
    """
    participant_trials = {participant_id: [] for participant_id in participant_ids}

    # Iterate over each participant's data
    for participant_id, participant_data in trial_info.groupby('Participant_id'):
        participant_id_str = f"s{participant_id:02d}"  # Convert participant_id to string format
        if participant_id_str not in participant_ids:
            continue
        
        # Get the raw EEG data for the current participant
        raw_data = eeg_data_reordered[participant_id_str]

        # Extract trial data for the current participant
        trials = extract_trial_data(participant_id_str, participant_data, raw_data)
        
        # Store trial data in the dictionary
        participant_trials[participant_id_str] = trials
    
    return participant_trials


In [None]:

# Load and sort participant ratings
ratings_file_path = "participant_ratings.csv"
df = load_participant_ratings(ratings_file_path)  # Load ratings data
df_sorted = sort_ratings(df)  # Sort the data

# Process trial data for all participants
participant_trials = process_participant_trials(participant_ids, df_sorted, eeg_data_reordered)

# Example: Print the number of trials for each participant
for participant_id, trials in participant_trials.items():
    print(f"Participant ID: {participant_id}, Number of Trials: {len(trials)}")

   


In [None]:
# Plot PSD for the first trial data of the first participant
first_participant_trials = participant_trials['s01']
first_trial_data = first_participant_trials[0]


# Plot the first trial data
first_trial_data.plot()
# Plot the power spectral density (PSD)
first_trial_data.compute_psd().plot()

PREPROCESSING


In [9]:
def apply_common_average_reference(trials):
    """
    Apply Common Average Reference (CAR) to the given trials.

    Args:
    - trials (list): A list of mne.io.Raw objects representing the trials.

    Returns:
    - list: A list of mne.io.Raw objects with CAR applied.
    """
    trials_with_car = []
    
    for trial_data in trials:
        # Create a copy of the trial data
        trial_with_car = trial_data.copy()
        # Set the EEG reference to average and apply projection
        trial_with_car.set_eeg_reference('average', projection=False)
        trial_with_car.apply_proj()  # Apply the projection
        trials_with_car.append(trial_with_car)
        
        print(trial_data.info['projs'])

    for trial_data in trials_with_car:
        print("Mean of each channel after CAR:", trial_data.get_data().mean(axis=1))


    
    return trials_with_car

def apply_bandpass_filter(trials, low_freq, high_freq):
    """
    Apply band-pass filter to the given trials.

    Args:
    - trials (list): A list of mne.io.Raw objects representing the trials.
    - low_freq (float): Lower cutoff frequency in Hz.
    - high_freq (float): Upper cutoff frequency in Hz.

    Returns:
    - list: A list of mne.io.Raw objects with the band-pass filter applied.
    """
    filtered_trials = []
    
    for trial_data in trials:
        # Create a copy of the trial data
        filtered_trial = trial_data.copy()
        # Apply band-pass filter
        filtered_trial.filter(low_freq, high_freq)
        filtered_trials.append(filtered_trial)
    
    return filtered_trials


def apply_notch_filter(trials, notch_freqs):
    """
    Apply notch filters to the given trials.

    Args:
    - trials (list): A list of mne.io.Raw objects representing the trials.
    - notch_freqs (list): A list of frequencies to apply notch filters.

    Returns:
    - list: A list of mne.io.Raw objects with the notch filters applied.
    """
    notch_filtered_trials = []
    
    for trial_data in trials:
        # Create a copy of the trial data
        notch_filtered_trial = trial_data.copy()
        # Apply notch filter for each frequency in the list
        for freq in notch_freqs:
            notch_filtered_trial.notch_filter(freqs=freq, verbose=True)
        notch_filtered_trials.append(notch_filtered_trial)
    
    return notch_filtered_trials


def apply_resampling(trials, new_sampling_rate):
    """
    Resample the given trials to the new sampling rate.

    Args:
    - trials (list): A list of mne.io.Raw objects representing the trials.
    - new_sampling_rate (int): The desired sampling rate.

    Returns:
    - list: A list of mne.io.Raw objects resampled to the new sampling rate.
    """
    resampled_trials = []
    
    for trial_data in trials:
        # Resample trial data
        resampled_trial = trial_data.copy().resample(new_sampling_rate, npad="auto")
        resampled_trials.append(resampled_trial)
    
    return resampled_trials





In [None]:
# Apply Common Average Reference (CAR)
car_participant_data = {}
for participant_id, trials in participant_trials.items():
    car_participant_data[participant_id] = apply_common_average_reference(trials)

first_participant_id = 's01'
first_car_trial = car_participant_data[first_participant_id][0]

# Plot the trial data after resampling
first_car_trial.plot(title="EEG Data for Participant s01 - Trial 1 (Common Average Reference)")

In [None]:
# Apply band-pass filter
low_freq = 4  # Lower cutoff frequency in Hz
high_freq = 45  # Upper cutoff frequency in Hz
filtered_participant_data = {}
for participant_id, trials in car_participant_data.items():
    filtered_participant_data[participant_id] = apply_bandpass_filter(trials, low_freq, high_freq)

first_filtered_trial = filtered_participant_data[first_participant_id][0]

# Plot the trial data after resampling
first_filtered_trial.plot(title="EEG Data for Participant s01 - Trial 1 (After band-pass filtering)")

In [None]:
# Apply notch filter
notch_freqs = [50, 60]  # Notch filter frequencies in Hz
notch_filtered_participant_data = {}
for participant_id, trials in filtered_participant_data.items():
    notch_filtered_participant_data[participant_id] = apply_notch_filter(trials, notch_freqs)

first_notch_filtered_trial = notch_filtered_participant_data[first_participant_id][0]

# Plot the trial data after resampling
first_notch_filtered_trial.plot(title="EEG Data for Participant s01 - Trial 1 (After notch filtering)")

In [None]:
# Resample the data
new_sampling_rate = 128  # Desired sampling rate
resampled_participant_data = {}
for participant_id, trials in notch_filtered_participant_data.items():
    resampled_participant_data[participant_id] = apply_resampling(trials, new_sampling_rate)


first_resampled_trial = resampled_participant_data[first_participant_id][0]

# Plot the trial data after resampling
first_resampled_trial.plot(title="EEG Data for Participant s01 - Trial 1 (After Resampling)")

In [14]:
from mne.preprocessing import ICA
from mne_icalabel import label_components

def perform_ica(trials, montage, variance_proportion=0.999):
    """
    Apply Independent Component Analysis (ICA) to clean EEG trials.

    Args:
    - trials (list): A list of mne.io.Raw objects representing the trials.
    - montage (mne.channels.DigMontage): The montage to set for the EEG data.
    - variance_proportion (float): Proportion of variance to explain (0.999 by default).

    Returns:
    - cleaned_trials (list): A list of cleaned mne.io.Raw objects.
    - ica_models (list): A list of ICA models fitted to each trial.
    """
    cleaned_trials = []
    ica_models = []

    for i, trial_data in enumerate(trials):
        # Set the montage
        trial_data.set_montage(montage)

        # Fit ICA with a proportion of variance
        ica = ICA(n_components=variance_proportion, method='infomax', fit_params=dict(extended=True), random_state=97, max_iter=1000)
        ica.fit(trial_data)
        

        # Store the fitted ICA model
        ica_models.append(ica)

        # Optionally: Label components using mne_icalabel
        labels = label_components(trial_data, ica, method='iclabel')

        # # Print labels for inspection
        # print(f'Trial {i + 1} component labels:')
        # print(labels)

        # Identify indices of components to exclude (not 'brain' or 'other')
        components_to_exclude = [j for j, label in enumerate(labels['labels']) if label not in ['brain']]
        ica.exclude = components_to_exclude

        # Apply ICA to the data, removing the unwanted components
        cleaned_data = ica.apply(trial_data, exclude=ica.exclude)

        # Store cleaned trial data
        cleaned_trials.append(cleaned_data)

        # # Optionally: Show new labels after cleaning
        # new_labels = label_components(cleaned_data, ica, method='iclabel')
        # print(f'Trial {i + 1} cleaned component labels:')
        # print(new_labels)

    return cleaned_trials, ica_models


In [None]:
# Set the montage for the EEG data
montage = mne.channels.make_standard_montage('standard_1020')

# Apply ICA to clean the EEG data for each participant
cleaned_participant_data = {}
for participant_id, trials in resampled_participant_data.items():
    cleaned_trials, ica_models = perform_ica(trials, montage)
    cleaned_participant_data[participant_id] = cleaned_trials

In [None]:
# Example: Plot the EEG data for the first trial of the first participant after resampling
first_participant_id = 's01'
first_cleaned_trial = cleaned_participant_data[first_participant_id][0]

# Plot the trial data after resampling
first_cleaned_trial.plot(title="EEG Data for Participant s01 - Trial 1 (After ICA)")
first_cleaned_trial.compute_psd().plot()

In [20]:
def epoch_trials(data, epoch_duration=1.0, discard_duration=3.0):
    """
    Epoch the EEG data for each trial in the provided dictionary of resampled data.

    Args:
    - resampled_data (dict): A dictionary with participant IDs as keys and lists of mne.io.Raw objects as values.
    - epoch_duration (float): Duration of each epoch in seconds (default: 1.0).
    - discard_duration (float): Duration to discard from the start of each trial in seconds (default: 3.0).

    Returns:
    - dict: A dictionary with participant IDs as keys and lists of epoched trials as values.
    """
    epoched_data = {}

    # Loop through each participant's trials
    for participant_id, trials in data.items():
        epoched_trials = []
        
        # Process each trial for the current participant
        for trial_data in trials:
            # Generate events for fixed-length epochs after discarding initial seconds
            events = mne.make_fixed_length_events(trial_data, start=discard_duration, duration=epoch_duration)

            # Check if events were successfully created
            if len(events) == 0:
                print(f"No events created for participant {participant_id}'s trial. Skipping.")
                continue

            # Create epochs from the events
            epochs = mne.Epochs(trial_data, events, tmin=0.0, tmax=epoch_duration, baseline=None, preload=True, detrend=1)

            # Append epochs if they contain data
            if epochs.get_data().size > 0:
                epoched_trials.append(epochs)
            else:
                print(f"No data after epoching for participant {participant_id}'s trial. Skipping.")

        epoched_data[participant_id] = epoched_trials

    return epoched_data



def create_epoch_dataframe(df, num_participants=2, epochs_per_trial=56):
    """
    Create a DataFrame containing epoch information for each trial.
    
    Args:
    - df_sorted (pd.DataFrame): Sorted DataFrame containing columns 'Participant_id', 'Experiment_id', 'Valence', 'Arousal'.
    - num_participants (int): Number of participants to include in the output (default: 5).
    - epochs_per_trial (int): Number of epochs per trial (default: 57).
    
    Returns:
    - epoch_df (pd.DataFrame): A DataFrame containing 'Participant_ID', 'Experiment_ID', 'Epoch_ID', 'Valence', 'Arousal'.
    """
    epoch_data = []  # List to store epoch data

    # Get unique participant IDs and limit to the specified number
    unique_participants = df['Participant_id'].unique()[:num_participants]

    # Iterate through each row in the sorted DataFrame
    for _, row in df.iterrows():
        participant_id = row['Participant_id']

        # Check if the participant is in the limited list
        if participant_id not in unique_participants:
            continue  # Skip if the participant is not in the selected range

        experiment_id = row['Experiment_id']
        valence = row['Valence']
        arousal = row['Arousal']

        # For each trial, replicate the valence and arousal for all epochs
        for epoch in range(1, epochs_per_trial + 1):
            epoch_data.append({
                'Participant_ID': participant_id,
                'Experiment_ID': experiment_id,
                'Epoch_ID': epoch,
                'Valence': valence,
                'Arousal': arousal
            })

    # Create a new DataFrame from the epoch data
    epoch_df = pd.DataFrame(epoch_data)
    
    return epoch_df


In [None]:
# Epoch the data for each participant's trials
epoched_participant_data = epoch_trials(cleaned_participant_data, epoch_duration=1.0, discard_duration=3.0)



In [None]:
# Get the number of epochs for each participant and the events in each epoch
for participant_id, trials in epoched_participant_data.items():
    print(f"Participant {participant_id} has {len(trials)} trials.")
    
    for trials_index, trials in enumerate(trials):
        
        print(f"  trial {trials_index + 1} has {len(trials.events)} epochs")


In [None]:
example_epoch = epoched_participant_data[first_participant_id][0][0]

# Plot the trial data after resampling
example_epoch.plot(title="EEG Data for Participant s01 - Trial 1 (After Preprocessing and Resampling)")

In [28]:

epoch_df = create_epoch_dataframe(df_sorted, num_participants=2, epochs_per_trial=56)

In [None]:
epoch_df

END OF PREPROCESSING

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

class EEGFeatureExtractor(nn.Module):
    def __init__(self, input_channels=32, input_time_samples=128, output_dim=64):
        super(EEGFeatureExtractor, self).__init__()

        # First Conv Block
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d((2, 2))
        self.drop1 = nn.Dropout(0.3)
        
        # Second Conv Block
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 3), padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d((2, 2))
        self.drop2 = nn.Dropout(0.3)
        
        # Third Conv Block
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3), padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d((2, 2))
        self.drop3 = nn.Dropout(0.4)

        # Calculate flattened dimensions after convolutions and pooling
        # After 3 MaxPool2d layers, height and width are reduced by a factor of 8
        self.flatten_dim = 128 * (input_channels // 8) * (input_time_samples // 8)
        
        # GRU Layer for Temporal Features
        self.gru = nn.GRU(input_size=self.flatten_dim, hidden_size=256, batch_first=True)
        
        # Fully Connected Layers for Feature Compression
        self.fc1 = nn.Linear(256, 128)
        self.drop_fc1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(128, output_dim)
        
    def forward(self, x):
        # Forward pass through conv layers
        x = self.drop1(self.pool1(torch.relu(self.bn1(self.conv1(x)))))
        x = self.drop2(self.pool2(torch.relu(self.bn2(self.conv2(x)))))
        x = self.drop3(self.pool3(torch.relu(self.bn3(self.conv3(x)))))
        
        # Reshape for GRU: (batch_size, sequence_length, features)
        x = x.view(x.size(0), -1, self.flatten_dim)  # Ensure this shape is correct
        _, h_n = self.gru(x)
        
        # Fully connected layers
        x = torch.relu(self.fc1(h_n[-1]))  # Take the last hidden state of GRU
        x = self.drop_fc1(x)
        x = torch.relu(self.fc2(x))
        
        return x

# Initialize model, optimizer, and loss
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EEGFeatureExtractor(input_channels=32, input_time_samples=128, output_dim=64).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()  # Modify if you have labels

# Example input tensor (batch_size, channels, height, width)
# For instance, a batch of 16 samples, each with 1 channel, 32 height, and 128 width
input_tensor = torch.randn(16, 1, 32, 128).to(device)
output = model(input_tensor)  # Forward pass
print(output.shape)  # Should print the shape of the output


In [None]:
class EEGDataset(Dataset):
    def __init__(self, epoched_data, epoch_df):
        self.epoched_data = epoched_data
        self.epoch_df = epoch_df  # Use epoch_df for labels

    def __len__(self):
        return len(self.epoch_df)

    def __getitem__(self, idx):
        # Get the participant, trial, and epoch IDs from the current epoch row
        row = self.epoch_df.iloc[idx]

        # Format the participant_id correctly
        participant_id = f"s{int(row['Participant_ID']):02d}"
        trial_id = int(row['Experiment_ID'])
        epoch_id = int(row['Epoch_ID'])

        # Check if the participant_id exists in the epoched data
        if participant_id not in self.epoched_data:
            raise KeyError(f"Participant ID '{participant_id}' not found in epoched_data.")

        # Retrieve the EEG data from epoched_data
        try:
            eeg_data = self.epoched_data[participant_id][trial_id - 1][epoch_id - 1]
        except IndexError as e:
            raise IndexError(f"Error accessing data for Participant '{participant_id}', Trial '{trial_id}', Epoch '{epoch_id}': {str(e)}")

        # Debugging: Print the type of the retrieved EEG data
        print(f"Type of eeg_data for {participant_id}, Trial {trial_id}, Epoch {epoch_id}: {type(eeg_data)}")

        # If eeg_data is an instance of a class (like Epochs), you need to extract the underlying data
        if hasattr(eeg_data, 'get_data'):
            eeg_data = eeg_data.get_data()
        elif hasattr(eeg_data, 'data'):
            eeg_data = eeg_data.data

        # Check if we successfully retrieved a NumPy array or something similar
        if not isinstance(eeg_data, (np.ndarray, list)):
            raise ValueError(f"Unexpected data format for eeg_data: {eeg_data}")

        # Print the shape of the retrieved EEG data
        print(f"Shape of eeg_data for {participant_id}, Trial {trial_id}, Epoch {epoch_id}: {np.array(eeg_data).shape}")

        # Convert eeg_data to tensor
        eeg_data = torch.tensor(eeg_data, dtype=torch.float32)

        # Ensure the shape is correct (add channel dimension if necessary)
        eeg_data = eeg_data.unsqueeze(0)  # Add a channel dimension if required by your model

        # Retrieve valence and arousal
        valence = float(row['Valence'])
        arousal = float(row['Arousal'])

        return eeg_data, (valence, arousal)

# Dataset and DataLoader setup
dataset = EEGDataset(epoched_participant_data, epoch_df)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)  # Adjust num_workers as needed

# # Run feature extraction
# df_extracted_features = extract_features(model, dataloader)
# df_extracted_features.to_csv('eeg_extracted_features.csv', index=False)


In [None]:
# Model evaluation function to extract features
def extract_features(model, dataloader):
    model.eval()
    all_features = []
    all_labels = []
    
    with torch.no_grad():
        for batch, labels in dataloader:
            batch = batch.to(device)
            labels = labels.to(device)

            # Forward pass through the model
            features = model(batch)
            
            # Flatten features and move to CPU
            features = features.view(features.size(0), -1).cpu().numpy()
            all_features.append(features)
            all_labels.append(labels.cpu().numpy())

    # Convert to DataFrame
    all_features = np.vstack(all_features)
    all_labels = np.vstack(all_labels)
    df_features = pd.DataFrame(all_features, columns=[f'feature_{i}' for i in range(all_features.shape[1])])
    df_labels = pd.DataFrame(all_labels, columns=['Valence', 'Arousal'])
    df_combined = pd.concat([df_features, df_labels], axis=1)
    
    return df_combined

In [None]:
# Run feature extraction
df_extracted_features = extract_features(model, dataloader)
df_extracted_features.to_csv('eeg_extracted_features.csv', index=False)