# LSTM Model Training

For LSTM model training, we begin by loading previously collected EOG data. In the preprocessing phase, although we could use denoised data directly from the dataset, we choose to start with raw data for greater control over the denoising process. This allows us to select the optimal wavelet, 'sym20,' for its efficacy with EOG signal characteristics. We've developed an algorithm to pinpoint saccades, distinct from blinks, enhancing our data's clarity and precision. Utilizing PyTorch, we organize the dataset into training, validation, and test sets for our LSTM model. This model is designed to analyze EOG data, tracking eye movements and detecting blinks.
 

### Necessary Libraries

In [1]:
import glob
import copy
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import numpy as np
from math import sqrt
import pywt
from scipy import stats
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, roc_curve
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import KFold
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split, SubsetRandomSampler, TensorDataset

In [2]:
# Set the device to CUDA if available for GPU acceleration; otherwise, use CPU.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


### Loading Data

In [3]:
# Define the path to the dataset directory and locate all Excel files within it
data_path = "D:\\Projects\\EyeTracking\\Dataset 2"
all_excel_files = glob.glob(data_path + "/*.xlsx")   # Use glob to find all .xlsx files in the directory

list_of_dataframes = []                              # Initialize an empty list to hold dataframes

# Iterate through each Excel file found, read its content into a dataframe, and append it to the list
for excel_file in all_excel_files:
    df = pd.read_excel(excel_file)                   # Read the Excel file into a pandas dataframe
    list_of_dataframes.append(df)                    # Add the dataframe to the list

In [4]:
# Extract a specific subset from each DataFrame in list_of_dataframes and convert it to a NumPy array
data = []                                # Initialize an empty list to store subsets

for df in list_of_dataframes:
    # Select columns 2 to 6, excluding the first row, from each DataFrame
    subset = df.iloc[1:, 1:6]  
    
    # Convert the subset DataFrame to a NumPy array for further processing
    array = subset.to_numpy()
    
    # Append the NumPy array to the data list
    data.append(array)

### Denoising and Shifting Data

In [2]:
def denoise_signal(signal, wavelet='sym20', level=1):
    """
    Denoise a signal using Discrete Wavelet Transform (DWT) and soft thresholding.
    
    Parameters:
    - signal: The input signal (1D numpy array).
    - wavelet: The type of wavelet to use (e.g., 'db4').
    - level: The level of wavelet decomposition.
    
    Returns:
    - The denoised signal as a 1D numpy array.
    """
    # Decompose to get the wavelet coefficients
    coeff = pywt.wavedec(signal, wavelet, mode='symmetric', level=level)
    
    # Calculate the threshold using the universal threshold method
    sigma = np.median(np.abs(coeff[-1])) / 0.6745
    threshold = sigma * np.sqrt(2 * np.log(len(signal)))
    
    # Apply soft thresholding to remove noise
    coeff_thresh = [pywt.threshold(c, threshold, mode='soft') for c in coeff]
    
    # Reconstruct the signal using the thresholded coefficients
    denoised_signal = pywt.waverec(coeff_thresh, wavelet, mode='symmetric')

    # Ensuring the denoised signal has the same length as the input signal.
    if len(denoised_signal) > len(signal):
        denoised_signal = denoised_signal[:len(signal)]
    
    return denoised_signal

In [1]:
# Shifts and denoises the EOG and pointer data, adjusting for a given time delay
def shift_denoise_trial(trial_data, t_delay=20):
    # Shift pointer data backwards to align with the EOG data shifted forwards by t_delay.
    pointer_x_shifted = trial_data[:-t_delay, 2]
    pointer_y_shifted = trial_data[:-t_delay, 3]

    # Apply denoising to the EOG data shifted forwards, removing initial t_delay samples.
    eog_hor_shifted = denoise_signal(trial_data[t_delay:, 0], wavelet='sym20') 
    eog_ver_shifted = denoise_signal(trial_data[t_delay:, 1], wavelet='sym20')

    # Blink data is also shifted forwards by t_delay to match EOG data.
    blink_shifted = trial_data[t_delay:, 4]
    
    # Combine the processed data into a single array for the trial.
    shifted = np.column_stack((eog_hor_shifted, eog_ver_shifted, pointer_x_shifted, pointer_y_shifted, blink_shifted))
    return shifted

In [None]:
t_delay = 20        # Set the time delay for data alignment

# Prepare to store shifted and denoised data from all trials
shifted_all_trials = []

# Process each trial to shift and denoise, then collect the results
for trial in data:
    # Apply shifting and denoising to each trial based on t_delay
    shifted_trial = shift_denoise_trial(trial, t_delay=t_delay)
    # Accumulate the processed trials
    shifted_all_trials.append(shifted_trial)
    
# Convert the list of processed trials into a NumPy array if needed for subsequent operations
shifted_all_trials = np.array(shifted_all_trials, dtype=object)

### Splitting Data to Saccade, Blink and Fixation

In [4]:
def saccade_segment(trial, window_length=20):
    """
    Adjust labels for every transition in pointer_x and pointer_y within the trial data,
    ensuring propagation for 'window_length' samples after each edge. For falling edges,
    the amplitude of the propagated signal is the same as the previous pulse but negative.
    """
    # Extract EOG signals, pointer positions, and blink signals from the trial
    pointer_x, pointer_y = trial[:, 2], trial[:, 3]  
    eog_hor, eog_ver = trial[:, 0], trial[:, 1]
    blink = trial[:, 4]

    # Preparation for identifying and adjusting saccade segments.
    saccade_segments = []
    fixation_segments = []
    is_saccade = np.zeros(len(trial), dtype=bool)
    adjusted_x = np.zeros_like(pointer_x)
    adjusted_y = np.zeros_like(pointer_y)

    # Track the amplitude to propagate after falling edges
    amplitude_to_propagate_x = 0
    amplitude_to_propagate_y = 0

    for i in range(1, len(trial)):
        # Detect changes in pointer positions as indicators of saccades.
        if pointer_x[i] != pointer_x[i-1] or pointer_y[i] != pointer_y[i-1]:
            # Calculate amplitudes to propagate based on pointer position changes
            amplitude_to_propagate_x = -(pointer_x[i-1] - pointer_x[i])
            amplitude_to_propagate_y = -(pointer_y[i-1] - pointer_y[i])

            seg = []
            # Propagate saccade amplitude over the window length
            for j in range(1, window_length + 1):
                if i + j < len(trial):
                    is_saccade[i:i+window_length] = True
                    adjusted_x[i + j] = amplitude_to_propagate_x
                    adjusted_y[i + j] = amplitude_to_propagate_y
                    seg.append([eog_hor[i + j], eog_ver[i + j], adjusted_x[i + j], adjusted_y[i + j], blink[i + j]])
            # Collect saccade segment data.
            saccade_segments.append(seg)

    # Compile adjusted trial data.
    adjusted_trial = trial.copy()
    adjusted_trial[:, 2], adjusted_trial[:, 3] = adjusted_x, adjusted_y
    saccade_segments = np.array(saccade_segments)

    return adjusted_trial, saccade_segments, is_saccade

In [5]:
def blink_interval_index(blink_signal):
    """
    Finds the start and end indices of continuous blink intervals in the blink_signal.
    """
    blink_starts_ends = []                                                  # List to store start and end indices of blink intervals
    start = None                                                            # Marker for the start of a blink interval
    
    for i, blink in enumerate(blink_signal):
        if blink == 1 and start is None:
            start = i                                                       # Mark the start of a blink interval      
        elif blink == 0 and start is not None:
            blink_starts_ends.append((start, i - 1))                        # Append interval to list and reset start marker
            start = None
    if start is not None:                                                   # Check if the last interval extends to the signal's end
        blink_starts_ends.append((start, len(blink_signal) - 1))
    return blink_starts_ends

def blink_segment(trial, segment_length=20):
    """
    Creates segments centered around blink events, padding with raw signal data if necessary.
    """
    segments = []                                        # List to store segments
    blink_signal = trial[:, 4]                           # Extract blink signal from the trial
    blink_events = blink_interval_index(blink_signal)    # Identify blink intervals
    
    for start, end in blink_events:
        # Calculate padding to ensure segments have uniform length
        blink_duration = end - start + 1
        padding_needed = segment_length - blink_duration
        pad_left = padding_needed // 2
        pad_right = padding_needed - pad_left

        # Determine segment boundaries, adjusting for trial edges
        effective_start = max(start - pad_left, 0)
        effective_end = min(end + pad_right + 1, trial.shape[0])       
        segment = trial[effective_start:effective_end, :]
        
        # Adjust segment size for edge cases
        if segment.shape[0] < segment_length:
            if effective_start == 0:  # Segment is near the start of the trial
                extra_needed = segment_length - segment.shape[0]
                effective_end = min(effective_end + extra_needed, trial.shape[0])
            elif effective_end == trial.shape[0]:  # Segment is near the end of the trial
                extra_needed = segment_length - segment.shape[0]
                effective_start = max(effective_start - extra_needed, 0)
            
            # Re-extract the segment 
            segment = trial[effective_start:effective_end, :]
        
        segments.append(segment)
        
    segments = np.array(segments)   
    return segments  # Return segments as a NumPy array

In [10]:
def fixation_segment(trial, fixation, window_length=20):
    # List to hold the segments
    segments = []
    
    # Identify continuous stretches of fixation == 1
    in_fixation = False
    start_idx = None
    for i, value in enumerate(fixation):
        if value == 1 and not in_fixation:
            in_fixation = True
            start_idx = i
        elif value == 0 and in_fixation:
            in_fixation = False
            # End of a continuous stretch found, segment it
            end_idx = i
            while start_idx + window_length <= end_idx:
                segment = trial[start_idx:start_idx + window_length]
                segments.append(segment)
                start_idx += window_length

    # Handle case where the last fixation goes to the end of the array
    if in_fixation and start_idx + window_length <= len(trial):
        end_idx = len(trial)
        while start_idx + window_length <= end_idx:
            segment = trial[start_idx:start_idx + window_length]
            segments.append(segment)
            start_idx += window_length

    return np.array(segments)


In [11]:
# Process each trial to identify saccades, blinks, and fixations, then adjust trial data accordingly
shifted_all_trials_adjusted = []                      # Holds adjusted trial data with saccade and blink corrections
saccade_all_trial = []                                # Holds segments identified as saccades
blink_all_trial = []                                  # Holds segments identified as blinks
is_saccade_all_trial = []                             # Boolean array marking saccade occurrences in trials
fixation_all_trial = []                               # Holds segments identified as fixations

for trial in shifted_all_trials:
    # Segment and label saccades within each trial
    adjusted_trial, saccade_trial, is_saccade_trial = saccade_segment(trial, window_length=20)
    # Extract blink segments from the trial
    blink_trial = blink_segment(trial, segment_length=20)
    # Identify fixations as intervals not marked as saccades
    fixation = ~ is_saccade_trial.astype(bool)
    # Extract fixation segments based on identified fixations
    fixation_trial = fixation_segment(trial, fixation, window_length=20)

    # Store processed segments and labels for analysis
    shifted_all_trials_adjusted.append(adjusted_trial)
    saccade_all_trial.append(saccade_trial)
    blink_all_trial.append(blink_trial)
    is_saccade_all_trial.append(is_saccade_trial)
    fixation_all_trial.append(fixation_trial)

# Convert lists to NumPy arrays for consistent data handling
shifted_all_trials_adjusted = np.array(shifted_all_trials_adjusted, dtype=object)
blink_all_trial = np.array(blink_all_trial, dtype=object)
saccade_all_trial = np.array(saccade_all_trial, dtype=object)
is_saccade_all_trial = np.array(is_saccade_all_trial, dtype=object)
fixation_all_trial = np.array(fixation_all_trial, dtype = object)

### Calculate distribution of labels, flatten and shuffle recordings

In [12]:
# Calculate the total count of fixation, saccade, and blink events across all processed trials
total_fixation_events = sum(arr.shape[0] for arr in fixation_all_trial)  # Sum fixation event counts
total_saccade_events = sum(arr.shape[0] for arr in saccade_all_trial)    # Sum saccade event counts
total_blink_events = sum(arr.shape[0] for arr in blink_all_trial)        # Sum blink event counts

# Output the total event counts for each category for review
print(f"Total fixation events across all trials: {total_fixation_events}")
print(f"Total saccade events across all trials: {total_saccade_events}")
print(f"Total blink events across all trials: {total_blink_events}")

Total fixation events across all trials: 3682
Total saccade events across all trials: 320
Total blink events across all trials: 98


In [13]:
def flatten_and_shuffle(data):
    """
    Flattens and shuffles multidimensional event data into a 2D array for machine learning input.
    Filters out any empty trials to ensure data integrity. The function reshapes data from a
    complex structure of trials and events into a simpler, unified format and then randomizes
    the order to remove any potential sequence bias.

    Parameters:
    - data: List of numpy arrays with the shape (n_trials, n_events, window_length, n_col).

    Returns:
    - A numpy array of shape (n_total_recording, window_length, n_col) after filtering,
      flattening, and shuffling.
    """
    # Remove any trials without events
    filtered_data = [trial for trial in data if trial.size > 0]
    
    # Return an empty array with the correct dimensions if all trials are filtered out
    if not filtered_data:
        return np.empty((0, data[0].shape[1], data[0].shape[2]))

    # Combine data from all remaining trials into a single array
    flattened_data = np.vstack(filtered_data)
    # Randomize the order of all data points to ensure unbiased training samples
    np.random.shuffle(flattened_data)
    return flattened_data

# Flatten and shuffle each event category, preparing them for analysis or model training
fixation_flattened_shuffled = flatten_and_shuffle(fixation_all_trial)
saccade_flattened_shuffled = flatten_and_shuffle(saccade_all_trial)
blink_flattened_shuffled = flatten_and_shuffle(blink_all_trial)

# Display the new shapes of the event categories after processing
print(f"Fixation flattened and shuffled shape: {fixation_flattened_shuffled.shape}")
print(f"Saccade flattened and shuffled shape: {saccade_flattened_shuffled.shape}")
print(f"Blink flattened and shuffled shape: {blink_flattened_shuffled.shape}")

Fixation flattened and shuffled shape: (3682, 20, 5)
Saccade flattened and shuffled shape: (320, 20, 5)
Blink flattened and shuffled shape: (98, 20, 5)


### select random recordings from specific label to end up with a balanced data

In [14]:
def select_random_records(data, num_records):
    """
    Selects a random subset of records from a dataset, ensuring the selection does not exceed
    the available records.

    Parameters:
    - data: The dataset from which to select records.
    - num_records: The desired number of records to select.

    Returns:
    - A subset of the dataset containing 'num_records' randomly selected records.
    """
    # Limit num_records to the size of the dataset
    num_records = min(num_records, data.shape[0])
    # Randomly select indices without replacement
    indices = np.random.choice(data.shape[0], num_records, replace=False)
    return data[indices]

# Balance the number of fixation events to match the number of saccade events
fixation_balanced = select_random_records(fixation_flattened_shuffled, saccade_flattened_shuffled.shape[0])

# Output the shape of the balanced fixation dataset
print(f"fixation balanced shape: {fixation_balanced.shape}")

fixation balanced shape: (320, 20, 5)


### Adjust one label for a segment

In [6]:
def label_segments(segments, segment_type):
    """
    Assigns labels to segments based on their type (fixation, saccade, blink), adding
    specific details for saccades based on pointer movements.

    Parameters:
    - segments: A numpy array of segments, each with shape (window_length, n_cols).
    - segment_type: A string indicating the type of segment ('fixation', 'saccade', or 'blink').

    Returns:
    - A numpy array of labels for each segment, tailored to the segment type.
    """
    labels = []
    for segment in segments:
        # Label for blinks
        if segment_type == 'blink':
            labels.append([0, 0, 1])
        # Label for fixations
        elif segment_type == 'fixation':
            labels.append([0, 0, 0])
        # Detailed labeling for saccades based on mode pointer positions
        elif segment_type == 'saccade':
            # Calculate mode of pointer positions for a more nuanced saccade label
            pointer_x_mode_result, _ = stats.mode(segment[:, 2].astype(np.float64), axis=0)
            pointer_y_mode_result, _ = stats.mode(segment[:, 3].astype(np.float64), axis=0)
            
            # Extract the mode values explicitly and convert to scalar
            pointer_x_mode = pointer_x_mode_result.item()  # Convert the first mode value to scalar
            pointer_y_mode = pointer_y_mode_result.item()  # Convert the first mode value to scalar
            # Append mode positions for saccades, ensuring scalar values
            labels.append([pointer_x_mode, pointer_y_mode, 0])
            
    return np.array(labels)

In [16]:
# Extract EOG signals from each event type as inputs for model training
inputs_fixation = fixation_balanced[:, :, :2]                   # EOG signals from fixation events
inputs_saccade = saccade_flattened_shuffled[:, :, :2]           # EOG signals from saccade events
inputs_blink = blink_flattened_shuffled[:, :, :2]               # EOG signals from blink events

# Generate labels for each event type using the label_segments function
labels_fixation = label_segments(fixation_balanced, 'fixation')
labels_saccade = label_segments(saccade_flattened_shuffled, 'saccade')
labels_blink = label_segments(blink_flattened_shuffled, 'blink')

# Display shapes of the generated inputs and labels for each event type
print(f"Fixation inputs shape: {inputs_fixation.shape}, labels shape: {labels_fixation.shape}")
print(f"Saccade inputs shape: {inputs_saccade.shape}, labels shape: {labels_saccade.shape}")
print(f"Blink inputs shape: {inputs_blink .shape}, labels shape: {labels_blink .shape}")

Fixation inputs shape: (320, 20, 2), labels shape: (320, 3)
Saccade inputs shape: (320, 20, 2), labels shape: (320, 3)
Blink inputs shape: (98, 20, 2), labels shape: (98, 3)


### Generate dataset

In [None]:
class EOGDataset(Dataset):
    """
    A custom dataset class for EOG signal data, compatible with PyTorch DataLoader for batching and processing.
    It encapsulates inputs (EOG signals) and their corresponding labels for use in training or evaluating models.
    """
    def __init__(self, inputs, labels):
        """
        Initializes the dataset with inputs and labels.
        
        Parameters:
        - inputs: A numpy array of EOG signals with shape (n_total_recording, window_length, n_cols).
        - labels: A numpy array of corresponding labels for the inputs.
        """
        self.inputs = torch.tensor(inputs, dtype=torch.float32)        # Convert inputs to torch tensors
        self.labels = torch.tensor(labels, dtype=torch.float32)        # Convert labels to torch tensors
    
    def __len__(self):
        # Return the total number of records in the dataset.
        return len(self.inputs)
    
    def __getitem__(self, idx):
        # Retrieve the input-output pair at the specified index
        return self.inputs[idx], self.labels[idx]


In [19]:
# Concatenate inputs and labels separately
all_inputs = np.concatenate([inputs_fixation.astype(np.float32), inputs_saccade.astype(np.float32), inputs_blink.astype(np.float32)], axis=0)
all_labels = np.concatenate([labels_fixation.astype(np.float32), labels_saccade.astype(np.float32), labels_blink.astype(np.float32)], axis=0)
dataset = EOGDataset(all_inputs, all_labels)

### split dataset to train, validation and test

In [20]:
# Splitting the dataset into training, validation, and test sets with specified proportions
dataset_size = len(dataset)                       # Total number of samples in the dataset
train_size = int(dataset_size * 0.7)              # 70% of the dataset for training
val_size = int(dataset_size * 0.2)                # 20% for validation
test_size = dataset_size - train_size - val_size  # Remaining 10% for testing, ensuring complete dataset coverage

# Randomly split the dataset into the three subsets using the calculated sizes
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Create DataLoader instances for each subset to facilitate mini-batch processing and data shuffling
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)                   # Shuffle training data for randomness
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)                      # No shuffling for validation and test sets
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

### LSTM model

In [21]:
class LSTMModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)
        # Regression output layers for pointer_x and pointer_y
        self.regressor = nn.Linear(hidden_dim, 2)
        # Classification output layer for blink
        self.classifier = nn.Linear(hidden_dim, 1)
    
    def forward(self, x):
        lstm_out, (hn, cn) = self.lstm(x)
        # Assuming the output for regression is the last time step
        regress_output = self.regressor(lstm_out[:, -1, :])
        # Classification output uses a sigmoid activation to get probabilities
        class_output = torch.sigmoid(self.classifier(lstm_out[:, -1, :]))
        return regress_output, class_output.squeeze()

In [22]:
# Setting up the device for model training and moving the model to the selected device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use CUDA if available, else CPU
model = LSTMModel(input_dim=2, hidden_dim=64, num_layers=1).to(device)  # Initialize the LSTM model and transfer it to the device
print(f"Model is using {device}")  # Display which device (CUDA or CPU) the model will be trained on

Model is using cuda


In [29]:
# Defining loss functions and optimizer for training the LSTM model.
regression_loss_fn = nn.MSELoss()  # Mean Squared Error Loss for regression tasks (e.g., pointer movement)
classification_loss_fn = nn.BCELoss()  # Binary Cross-Entropy Loss for classification tasks (e.g., blink detection)
optimizer = optim.Adam(model.parameters(), lr=0.001)  # Adam optimizer with learning rate set to 0.001 for model parameters optimization

### Training without k-fold cross validation

In [30]:
def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss, all_targets, all_classification_outputs = 0, [], []
    all_regression_outputs, all_regression_targets = [], []

    for inputs, labels in dataloader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()

        # Split labels into regression and classification targets
        regression_targets, classification_targets = labels[:, :2], labels[:, 2]

        # Forward pass
        regression_output, classification_output = model(inputs)

        # Compute losses
        regression_loss = regression_loss_fn(regression_output, regression_targets)
        classification_loss = classification_loss_fn(classification_output, classification_targets)

        # Combined loss
        loss = regression_loss + classification_loss
        total_loss += loss.item()

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Collect data for metrics calculation (detach tensors to avoid tracking history)
        all_targets.extend(classification_targets.cpu().detach().numpy())
        all_classification_outputs.extend(classification_output.cpu().detach().numpy())
        all_regression_outputs.extend(regression_output.cpu().detach().numpy())
        all_regression_targets.extend(regression_targets.cpu().detach().numpy())

    # Calculate metrics after the epoch
    accuracy = accuracy_score(all_targets, np.round(all_classification_outputs))
    precision = precision_score(all_targets, np.round(all_classification_outputs), zero_division=0)
    recall = recall_score(all_targets, np.round(all_classification_outputs), zero_division=0)
    f1 = f1_score(all_targets, np.round(all_classification_outputs), zero_division=0)
    if len(np.unique(all_targets)) > 1:  # Ensure ROC AUC is calculable
        roc_auc = roc_auc_score(all_targets, all_classification_outputs)
    else:
        roc_auc = None

    # Regression Metrics
    mse = mean_squared_error(all_regression_targets, all_regression_outputs)
    mae = mean_absolute_error(all_regression_targets, all_regression_outputs)
    rmse = sqrt(mse)
    r_squared = r2_score(all_regression_targets, all_regression_outputs)

    avg_loss = total_loss / len(dataloader)

    # Return all metrics along with the average loss
    return avg_loss, accuracy, precision, recall, f1, roc_auc, mse, mae, rmse, r_squared


In [31]:
def validate_epoch(model, dataloader, device):
    """
    Evaluates the model's performance on a validation dataset, computing both regression and classification metrics.
    """
    model.eval()  ## Initialize lists to collect outputs and targets for metric calculations
    total_loss, all_targets, all_classification_outputs = 0, [], []
    all_regression_predictions, all_regression_targets = [], []  # Correctly defined lists for regression
    
    with torch.no_grad():                       # No gradient calculation to save memory and computation
        for inputs, labels in dataloader:
            # Move data to the configured device
            inputs, labels = inputs.to(device), labels.to(device)

            # Separate labels for regression and classification tasks
            regression_targets, classification_targets = labels[:, :2], labels[:, 2]
            regression_output, classification_output = model(inputs)

            # Compute losses for both tasks
            regression_loss = regression_loss_fn(regression_output, regression_targets)
            classification_loss = classification_loss_fn(classification_output, classification_targets)
            loss = regression_loss + classification_loss
            total_loss += loss.item()
            
            # Collect data for metrics calculation
            all_targets.append(classification_targets.cpu().numpy())
            all_classification_outputs.append(classification_output.cpu().numpy())
            all_regression_predictions.append(regression_output.cpu().numpy())  # Collect regression predictions
            all_regression_targets.append(regression_targets.cpu().numpy())  # Collect true regression targets
    
    # Concatenate all batch outputs for metrics calculation
    all_targets = np.concatenate(all_targets)
    all_classification_outputs = np.concatenate(all_classification_outputs)
    all_regression_predictions = np.concatenate(all_regression_predictions, axis=0)  # Correct concatenation
    all_regression_targets = np.concatenate(all_regression_targets, axis=0)  # Correct concatenation
    
    # Classification Metrics
    accuracy = accuracy_score(all_targets, all_classification_outputs.round())
    precision = precision_score(all_targets, all_classification_outputs.round(), zero_division=0)
    recall = recall_score(all_targets, all_classification_outputs.round(), zero_division=0)
    f1 = f1_score(all_targets, all_classification_outputs.round(), zero_division=0)
    roc_auc = roc_auc_score(all_targets, all_classification_outputs)
    
    # Regression Metrics
    mse = mean_squared_error(all_regression_targets, all_regression_predictions)
    mae = mean_absolute_error(all_regression_targets, all_regression_predictions)
    rmse = sqrt(mse)
    r_squared = r2_score(all_regression_targets, all_regression_predictions)

    # Include ROC Curve plotting here if needed

    return total_loss / len(dataloader), accuracy, precision, recall, f1, roc_auc, mse, mae, rmse, r_squared


In [32]:
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

In [34]:
# Example training and validation
num_epochs = 50
for epoch in range(num_epochs):
    #train_loss = train_epoch(model, train_dataloader, optimizer, device)
    #val_loss = validate_epoch(model, val_dataloader, device)
    #print(f"Epoch {epoch+1}, Train Loss: {train_loss}, Val Loss: {val_loss}")

    train_loss, train_accuracy, train_precision, train_recall, train_f1, train_roc_auc, train_mse, train_mae, train_rmse, train_r_squared = train_epoch(model, train_dataloader, optimizer, device)
    val_loss, val_accuracy, val_precision, val_recall, val_f1, val_roc_auc, val_mse, val_mae, val_rmse, val_r_squared = validate_epoch(model, val_dataloader, device)
    
    print(color.PURPLE + color.BOLD + 'Training Results:' + 'Epoch'+ f'{epoch+1}' + color.END)
    print(f"Loss: {train_loss:.2f}, Accuracy: {train_accuracy:.2f}, Precision: {train_precision:.2f}, Recall: {train_recall:.2f}, F1: {train_f1:.2f}, ROC-AUC: {train_roc_auc:.2f}")
    print(f"MSE: {train_mse:.2f}, MAE: {train_mae:.2f}, RMSE: {train_rmse:.2f}, R-2: {train_r_squared:.2f}")
    print('...................................................................................................................')
    
    print(color.GREEN + color.BOLD + 'Validation Results:' + color.END)
    print(f"Validation Loss: {val_loss:.2f}, Accuracy: {val_accuracy:.2f}, Precision: {val_precision:.2f}, Recall: {val_recall:.2f}, F1 Score: {val_f1:.2f}, ROC-AUC: {val_roc_auc:.2f}")
    print(f"MSE: {val_mse:.2f}, MAE: {val_mae:.2f}, RMSE: {val_rmse:.2f}, R-2: {val_r_squared:.2f}")
    print('...................................................................................................................')


[95m[1mTraining Results:Epoch1[0m
Loss: 0.54, Accuracy: 0.88, Precision: 0.00, Recall: 0.00, F1: 0.00, ROC-AUC: 0.55
MSE: 0.16, MAE: 0.25, RMSE: 0.40, R-2: -0.02
...................................................................................................................
[92m[1mValidation Results:[0m
Validation Loss: 0.51, Accuracy: 0.85, Precision: 0.00, Recall: 0.00, F1 Score: 0.00, ROC-AUC: 0.90
MSE: 0.12, MAE: 0.21, RMSE: 0.35, R-2: 0.01
...................................................................................................................
[95m[1mTraining Results:Epoch2[0m
Loss: 0.53, Accuracy: 0.88, Precision: 0.00, Recall: 0.00, F1: 0.00, ROC-AUC: 0.59
MSE: 0.16, MAE: 0.25, RMSE: 0.39, R-2: 0.02
...................................................................................................................
[92m[1mValidation Results:[0m
Validation Loss: 0.51, Accuracy: 0.85, Precision: 0.00, Recall: 0.00, F1 Score: 0.00, ROC-AUC: 0.92
MSE: 0.12, MA

### Testing the model on unseen data

In [35]:
def evaluate_model(model, dataloader, device):
    """
    Evaluates the performance of the given model using the provided dataloader and computes both
    classification and regression metrics. Outputs the calculated metrics to provide insights into
    model accuracy, error rates, and predictive capabilities.

    Parameters:
    - model: The trained model to be evaluated.
    - dataloader: DataLoader containing the dataset for evaluation.
    - device: The device (CPU or GPU) on which to perform the evaluation.
    """
    model.eval()
    all_targets, all_classification_outputs = [], []
    all_regression_outputs, all_regression_targets = [], []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            
            regression_targets, classification_targets = labels[:, :2], labels[:, 2]
            regression_output, classification_output = model(inputs)
            
            all_targets.extend(classification_targets.cpu().numpy())
            all_classification_outputs.extend(classification_output.cpu().numpy())
            all_regression_outputs.extend(regression_output.cpu().numpy())
            all_regression_targets.extend(regression_targets.cpu().numpy())

    # Calculate classification metrics
    accuracy = accuracy_score(all_targets, np.round(all_classification_outputs))
    precision = precision_score(all_targets, np.round(all_classification_outputs), zero_division=0)
    recall = recall_score(all_targets, np.round(all_classification_outputs), zero_division=0)
    f1 = f1_score(all_targets, np.round(all_classification_outputs), zero_division=0)
    roc_auc = roc_auc_score(all_targets, all_classification_outputs)
    
    # Calculate regression metrics
    all_regression_outputs = np.array(all_regression_outputs)
    all_regression_targets = np.array(all_regression_targets)
    mse = mean_squared_error(all_regression_targets, all_regression_outputs)
    mae = mean_absolute_error(all_regression_targets, all_regression_outputs)
    rmse = sqrt(mse)
    r_squared = r2_score(all_regression_targets, all_regression_outputs)
    
    # Display metrics
    print(f"Test Accuracy: {accuracy:.4f}")
    print(f"Test Precision: {precision:.4f}")
    print(f"Test Recall: {recall:.4f}")
    print(f"Test F1 Score: {f1:.4f}")
    print(f"Test ROC-AUC: {roc_auc:.4f}")
    print(f"Test MSE: {mse:.4f}")
    print(f"Test MAE: {mae:.4f}")
    print(f"Test RMSE: {rmse:.4f}")
    print(f"Test R-2: {r_squared:.4f}")


In [36]:
evaluate_model(model, test_dataloader, device)

Test Accuracy: 1.0000
Test Precision: 1.0000
Test Recall: 1.0000
Test F1 Score: 1.0000
Test ROC-AUC: 1.0000
Test MSE: 0.0196
Test MAE: 0.0873
Test RMSE: 0.1399
Test R-2: 0.8595


### K-Fold cross validation

In [38]:
# Split the dataset into training+validation and test sets
data_train_val, data_test, labels_train_val, labels_test = train_test_split(all_inputs, all_labels, test_size=0.1, random_state=42)

# Convert to PyTorch datasets
dataset_train_val = TensorDataset(torch.tensor(data_train_val, dtype=torch.float32), torch.tensor(labels_train_val, dtype=torch.float32))
dataset_test = TensorDataset(torch.tensor(data_test, dtype=torch.float32), torch.tensor(labels_test, dtype=torch.float32))

# Parameters for cross-validation
k_folds = 5
num_epochs = 50

# KFold for splitting train+validation data
kf = KFold(n_splits=k_folds, shuffle=True, random_state=42)

best_model_wts = None
best_metric = float('inf')  # Example for loss. Use float('-inf') for accuracy.
best_fold = 0
best_epoch = 0

for fold, (train_idx, val_idx) in enumerate(kf.split(dataset_train_val)):
    print(f"FOLD {fold}")
    train_subsampler = SubsetRandomSampler(train_idx)
    val_subsampler = SubsetRandomSampler(val_idx)
    
    train_loader = DataLoader(dataset_train_val, batch_size=10, sampler=train_subsampler)
    val_loader = DataLoader(dataset_train_val, batch_size=10, sampler=val_subsampler)

    # Initialize your model, optimizer here
    model = LSTMModel(input_dim=2, hidden_dim=64, num_layers=1).to(device)
    
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Train and validate your model (implement these functions)
    for epoch in range(num_epochs):

        train_loss, train_accuracy, train_precision, train_recall, train_f1, train_roc_auc, train_mse, train_mae, train_rmse, train_r_squared = train_epoch(model, train_dataloader, optimizer, device)
        val_loss, val_accuracy, val_precision, val_recall, val_f1, val_roc_auc, val_mse, val_mae, val_rmse, val_r_squared = validate_epoch(model, val_dataloader, device)
    
        print(color.PURPLE + color.BOLD + 'Training Results:' + 'Epoch'+ f'{epoch+1}' + color.END)
        print(f"Loss: {train_loss:.2f}, Accuracy: {train_accuracy:.2f}, Precision: {train_precision:.2f}, Recall: {train_recall:.2f}, F1: {train_f1:.2f}, ROC-AUC: {train_roc_auc:.2f}")
        print(f"MSE: {train_mse:.2f}, MAE: {train_mae:.2f}, RMSE: {train_rmse:.2f}, R-2: {train_r_squared:.2f}")
        print('...................................................................................................................')
    
        print(color.GREEN + color.BOLD + 'Validation Results:' + color.END)
        print(f"Validation Loss: {val_loss:.2f}, Accuracy: {val_accuracy:.2f}, Precision: {val_precision:.2f}, Recall: {val_recall:.2f}, F1 Score: {val_f1:.2f}, ROC-AUC: {val_roc_auc:.2f}")
        print(f"MSE: {val_mse:.2f}, MAE: {val_mae:.2f}, RMSE: {val_rmse:.2f}, R-2: {val_r_squared:.2f}")
        print('...................................................................................................................')
        
        if val_loss < best_metric:
            best_metric = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            best_fold = fold
            best_epoch = epoch

    
    print("---------------Best Model----------------")
    print(f'Fold {best_fold}, Epoch: {best_epoch}, Best validation loss: {best_metric:.4f}')
    print('--------------------------------')

# After cross-validation, evaluate the best model on the test set
# Load best model weights
best_model = LSTMModel(input_dim=2, hidden_dim=64, num_layers=1).to(device)
best_model.load_state_dict(best_model_wts)

test_loader = DataLoader(dataset_test, batch_size=10)
evaluate_model(best_model, test_loader, device)
# To use the last trained model which is not necessary the best model, use the following code:
#evaluate_model(model, test_loader, device) 


FOLD 0
[95m[1mTraining Results:Epoch1[0m
Loss: 0.72, Accuracy: 0.88, Precision: 0.00, Recall: 0.00, F1: 0.00, ROC-AUC: 0.56
MSE: 0.18, MAE: 0.29, RMSE: 0.42, R-2: -0.13
...................................................................................................................
[92m[1mValidation Results:[0m
Validation Loss: 0.52, Accuracy: 0.85, Precision: 0.00, Recall: 0.00, F1 Score: 0.00, ROC-AUC: 0.83
MSE: 0.13, MAE: 0.22, RMSE: 0.35, R-2: -0.02
...................................................................................................................
[95m[1mTraining Results:Epoch2[0m
Loss: 0.54, Accuracy: 0.88, Precision: 0.00, Recall: 0.00, F1: 0.00, ROC-AUC: 0.49
MSE: 0.16, MAE: 0.26, RMSE: 0.40, R-2: -0.03
...................................................................................................................
[92m[1mValidation Results:[0m
Validation Loss: 0.51, Accuracy: 0.85, Precision: 0.00, Recall: 0.00, F1 Score: 0.00, ROC-AUC: 0.91
MSE:

In [39]:
torch.save(best_model_wts, 'best_lstm_model.pth')

In [None]:
def grid_labels(pointer_x, pointer_y, n_row_quarter, n_col_quarter, orientation='portrait'):
    """
    Maps normalized pixel values to a grid on the monitor, with the center of the screen being (0,0),
    considering the monitor's orientation and the specified number of rows and columns for a quarter of the monitor.

    Args:
        pointer_x: Normalized x pixel value, ranging from -1 to 1.
        pointer_y: Normalized y pixel value, ranging from -1 to 1.
        n_row_quarter: Number of rows in a quarter of the monitor.
        n_col_quarter: Number of columns in a quarter of the monitor.
        orientation: Monitor orientation ('portrait' or 'landscape').
    
    Returns:
        grid_index: Index of the grid block where the point lies.
        grid_center_x: X coordinate of the grid block's center.
        grid_center_y: Y coordinate of the grid block's center.
    """
    scaled_x = pointer_x * 960
    scaled_y = pointer_y * 960

    # Determine full number of rows and columns based on orientation
    n_rows, n_cols = n_row_quarter * 2, n_col_quarter * 2
    
    # Calculate grid dimensions
    if orientation == 'portrait':
        height, width = 1920, 1080
    else:  # landscape
        width, height = 1920, 1080
    
    # Calculate the size of each grid block
    block_width, block_height = width / n_cols, height / n_rows

    # Determine the indices of the grid block
    col_index = (scaled_x + width / 2) // block_width
    row_index = (scaled_y + height / 2) // block_height

    # Ensure indices are within the grid bounds
    col_index = int(max(0, min(col_index, n_cols - 1)))
    row_index = int(max(0, min(row_index, n_rows - 1)))

    # Calculate grid index
    grid_index = row_index * n_cols + col_index

    # Calculate the center of the grid block
    grid_center_x = (col_index * block_width + block_width / 2) - width / 2
    grid_center_y = (row_index * block_height + block_height / 2) - height / 2

    return grid_index, grid_center_x, grid_center_y, block_width, block_height

# Example usage
pointer_x, pointer_y = -0.5, 0.5  # Example normalized positions
n_row_quarter, n_col_quarter = 2, 3  # Quarter grid specification
orientation = 'portrait'  # Monitor orientation
grid_index, grid_center_x, grid_center_y, block_width, block_height = grid_labels(pointer_x, pointer_y, n_row_quarter, n_col_quarter, orientation)

print(f"Point ({pointer_x}, {pointer_y}) maps to grid {grid_index}, with center at ({grid_center_x}, {grid_center_y}) with width of {block_width}, and height of {block_height}.")
