<a href="https://colab.research.google.com/github/antonis00/EKPA/blob/main/GazebaseVRUsersSavedModel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pymovements
import pymovements as pm
dataset = pm.Dataset("GazeBaseVR", path='data/GazeBaseVR')
dataset.download()
dataset.load()

Collecting pymovements
  Downloading pymovements-0.18.0-py3-none-any.whl.metadata (7.9 kB)
Collecting matplotlib<3.9,>=3.8.0 (from pymovements)
  Downloading matplotlib-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.8 kB)
Downloading pymovements-0.18.0-py3-none-any.whl (164 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m164.4/164.4 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading matplotlib-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m53.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: matplotlib, pymovements
  Attempting uninstall: matplotlib
    Found existing installation: matplotlib 3.7.1
    Uninstalling matplotlib-3.7.1:
      Successfully uninstalled matplotlib-3.7.1
Successfully installed matplotlib-3.8.4 pymovements-0.18.0


Downloading https://figshare.com/ndownloader/files/38844024 to data/GazeBaseVR/downloads/gazebasevr.zip


gazebasevr.zip: 0.00B [00:00, ?B/s]

Checking integrity of gazebasevr.zip
Extracting gazebasevr.zip to data/GazeBaseVR/raw


  0%|          | 0/5020 [00:00<?, ?it/s]

<pymovements.dataset.dataset.Dataset at 0x78c5b84ec4f0>

# Eye Movement-Based User Identification

## Introduction

This notebook demonstrates a machine learning approach to identify users based on their eye movement patterns across different tasks. We use data collected from participants performing five distinct eye-tracking (ET) tasks:

1. Vergence task (VRG)
2. Horizontal smooth pursuit task (PUR)
3. Video-viewing task (VID)
4. Self-paced reading task (TEX)
5. Random oblique saccade task (RAN)

## Data Processing and Feature Extraction

We extract a comprehensive set of features from the eye-tracking data for each task. The features are calculated both globally and for multiple time-based portions of each task.

### Velocity Calculation

Before feature extraction, we calculate the velocity of eye movements using the Savitzky-Golay filter:
 dataset.pos2vel(method='savitzky_golay', degree=2, window_length=7)
*italicised text*
 This method applies a Savitzky-Golay filter to smooth the position data and calculate velocities. It uses a polynomial of degree 2 and a window length of 7 samples, which helps to reduce noise while preserving the underlying signal characteristics.

### Extracted Features

For each portion and globally:

1. **Velocity Statistics**:
   - Mean velocity
   - Standard deviation of velocity
   - Maximum velocity
   - Skewness of velocity distribution
   - Kurtosis of velocity distribution

2. **Position Statistics**:
   - Mean X and Y positions
   - Standard deviation of X and Y positions

3. **Target Position** (if available):
   - Mean X and Y target positions
   - Standard deviation of X and Y target positions

4. **Saccade and Fixation Metrics**:
   - Saccade rate (proportion of samples above saccade threshold)
   - Fixation rate (proportion of samples below fixation threshold)

5. **Task Duration**:
   - Total number of samples in the task

## Model Architecture

We use a Multi-Layer Perceptron (MLP) classifier with the following structure:
- Input layer (size depends on the number of features)
- Hidden layers: 64 units, 32 units, 16 units
- Output layer (size equals the number of unique users)
- A skip connection from input to output for improved learning

## Training Process

The model is trained using:
- Adam optimizer with a learning rate of 0.05
- Cross-entropy loss function
- 200 epochs
- Batch size of 256

## Results

The overall test accuracy achieved is 68%, indicating that the model can correctly identify users based on their eye movements in 68% of cases.

### Task-Specific Accuracies

1. Video-viewing task (VID): 93.03%
2. Random oblique saccade task (RAN): 72.11%
3. Horizontal smooth pursuit task (PUR): 68.33%
4. Self-paced reading task (TEX): 57.37%
5. Vergence task (VRG): 43.82%

## Analysis

The results show that different tasks have varying levels of effectiveness in identifying users:

1. The video-viewing task is the most effective, with an impressive 93.03% accuracy. This suggests that how people watch videos is highly individual and consistent.

2. The random oblique saccade task and horizontal smooth pursuit task show moderate effectiveness, with accuracies above the overall average.

3. The self-paced reading task and vergence task are less effective for user identification, with accuracies below the overall average.

These findings indicate that complex, naturalistic tasks like video viewing may be more suitable for eye movement-based user identification compared to simpler, controlled tasks like vergence exercises.

## Conclusion

This study demonstrates the potential of using eye movement patterns for user identification, with varying degrees of success across different tasks. The use of the Savitzky-Golay filter for velocity calculation provides a robust basis for feature extraction. Future work could focus on optimizing feature extraction for the most effective tasks or combining data from multiple tasks to improve overall accuracy.



In [2]:
position_columns = ['x_left', 'y_left', 'x_right', 'y_right']

# Calculate velocity
dataset.pos2vel( method='savitzky_golay', degree=2, window_length=7)
print(dataset.gaze[5].frame)

  0%|          | 0/5020 [00:00<?, ?it/s]

shape: (14_815, 16)
┌────────────┬────────────┬───────────┬────────┬───┬───────────┬───────────┬───────────┬───────────┐
│ time       ┆ x_target_p ┆ y_target_ ┆ zT     ┆ … ┆ session_i ┆ task_name ┆ position  ┆ velocity  │
│ ---        ┆ os         ┆ pos       ┆ ---    ┆   ┆ d         ┆ ---       ┆ ---       ┆ ---       │
│ f32        ┆ ---        ┆ ---       ┆ f32    ┆   ┆ ---       ┆ str       ┆ list[f32] ┆ list[f32] │
│            ┆ f32        ┆ f32       ┆        ┆   ┆ i64       ┆           ┆           ┆           │
╞════════════╪════════════╪═══════════╪════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡
│ 0.0        ┆ 0.0        ┆ 0.0       ┆ 0.4433 ┆ … ┆ 2         ┆ 1_VRG     ┆ [4.9966,  ┆ [5.709814 │
│            ┆            ┆           ┆        ┆   ┆           ┆           ┆ 0.9662, … ┆ , -2.6580 │
│            ┆            ┆           ┆        ┆   ┆           ┆           ┆ 0.7913]   ┆ 34, …     │
│            ┆            ┆           ┆        ┆   ┆           ┆       

In [3]:
import numpy as np
from scipy import stats
import polars as pl
from tqdm import tqdm

def extract_features(gaze_data, n_portions=10, set_type='train'):
    features = []
    y = []
    task_names = []

    for gaze in tqdm(gaze_data):
        frame = gaze.frame
        # Create a temporary index column
        frame = frame.with_row_count("temp_index")

        if set_type == 'train':
            frame = frame.filter(pl.col('temp_index') % 3 != 2)
        elif set_type == 'test':
            frame = frame.filter(pl.col('temp_index') % 3 == 2)

        # Drop the temporary index column
        frame = frame.drop('temp_index')

        total_samples = frame.shape[0]
        portion_size = total_samples // n_portions

        # Initialize lists to store portion-wise statistics
        portion_features = []

        # Calculate statistics for each portion
        for i in range(n_portions):
            start_idx = i * portion_size
            end_idx = (i + 1) * portion_size if i < n_portions - 1 else total_samples

            portion = frame.slice(start_idx, end_idx - start_idx)

            # Velocity statistics
            velocities = portion['velocity'].apply(lambda x: np.linalg.norm(x)).to_numpy()

            # Position statistics
            positions = np.vstack(portion['position'].to_numpy())

            portion_stats = [
                np.mean(velocities),
                np.std(velocities),
                np.max(velocities),
                *np.mean(positions, axis=0),
                *np.std(positions, axis=0),
                stats.skew(velocities),
                stats.kurtosis(velocities)
            ]

            # Add any other relevant statistics from other columns
            if 'x_target_pos' in portion.columns and 'y_target_pos' in portion.columns:
                portion_stats.extend([
                    portion['x_target_pos'].mean(),
                    portion['y_target_pos'].mean(),
                    portion['x_target_pos'].std(),
                    portion['y_target_pos'].std()
                ])

            portion_features.extend(portion_stats)

        # Global features
        velocities = frame['velocity'].apply(lambda x: np.linalg.norm(x)).to_numpy()
        positions = np.vstack(frame['position'].to_numpy())

        global_features = [
            np.mean(velocities),
            np.std(velocities),
            np.max(velocities),
            *np.mean(positions, axis=0),
            *np.std(positions, axis=0),
            stats.skew(velocities),
            stats.kurtosis(velocities)
        ]

        # Saccade and fixation features
        saccade_threshold = np.mean(velocities) + 2 * np.std(velocities)
        fixation_threshold = np.mean(velocities) - np.std(velocities)
        saccade_rate = np.sum(velocities > saccade_threshold) / total_samples
        fixation_rate = np.sum(velocities < fixation_threshold) / total_samples

        # Task-specific features
        task_name = frame['task_name'][0]
        task_duration = total_samples  # Assuming constant sampling rate

        # Combine all features
        feature_vector = portion_features + global_features + [saccade_rate, fixation_rate, task_duration]

        features.append(feature_vector)
        y.append(frame['subject_id'][0])
        task_names.append(task_name)

    return np.array(features), np.array(y), np.array(task_names)

# Extract features for training set (1st and 2nd rows)
X_train, y_train, task_names_train = extract_features(dataset.gaze, n_portions=100, set_type='train')

# Extract features for test set (every 3rd row)
X_test, y_test, task_names_test = extract_features(dataset.gaze, n_portions=100, set_type='test')

# Save the features
np.save('X_train.npy', X_train)
np.save('y_train.npy', y_train)
np.save('task_names_train.npy', task_names_train)

np.save('X_test.npy', X_test)
np.save('y_test.npy', y_test)
np.save('task_names_test.npy', task_names_test)

print(f"Training set shape: {X_train.shape}")
print(f"Testing set shape: {X_test.shape}")

  velocities = portion['velocity'].apply(lambda x: np.linalg.norm(x)).to_numpy()
  velocities = frame['velocity'].apply(lambda x: np.linalg.norm(x)).to_numpy()
100%|██████████| 5020/5020 [1:03:46<00:00,  1.31it/s]
100%|██████████| 5020/5020 [36:34<00:00,  2.29it/s]


Training set shape: (5020, 2120)
Testing set shape: (5020, 2120)


In [4]:
import os
from collections import defaultdict

import numpy as np
import torch
from sklearn.metrics import accuracy_score
from torch import nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm


class EyeMovementDataset(Dataset):
    def __init__(self, X, y, user_to_index):
        self.X = torch.FloatTensor(X)
        self.y = torch.LongTensor([user_to_index[user] for user in y])

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


class MLPClassifier(nn.Module):
    def __init__(self, input_size, num_classes):
        super(MLPClassifier, self).__init__()
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 16)
        self.fc5 = nn.Linear(16, num_classes)
        self.fc6 = nn.Linear(input_size, num_classes)

    def forward(self, x):
        out = torch.relu(self.fc1(x))
        out = torch.relu(self.fc2(out))
        out = self.fc5(out)
        d = self.fc6(x)
        return out + 10 * d


def evaluate_task_accuracy(
    model,
    X_test,
    y_test,
    task_names_test,
    user_to_index,
    index_to_user,
    specific_task=None,
):
    model.eval()
    device = next(model.parameters()).device

    X_test_tensor = torch.FloatTensor(X_test).to(device)
    y_test_tensor = torch.LongTensor([user_to_index[user] for user in y_test]).to(
        device
    )

    with torch.no_grad():
        outputs = model(X_test_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs.data, 1)

    predicted = predicted.cpu().numpy()
    probabilities = probabilities.cpu().numpy()

    task_results = defaultdict(lambda: {"true": [], "pred": [], "prob": []})
    for true, pred, prob, task in zip(
        y_test, predicted, probabilities, task_names_test
    ):
        task_results[task]["true"].append(true)
        task_results[task]["pred"].append(index_to_user[pred])
        task_results[task]["prob"].append(prob[user_to_index[true]])

    task_accuracies = {}
    for task, results in task_results.items():
        if specific_task is None or task == specific_task:
            accuracy = accuracy_score(results["true"], results["pred"])
            avg_probability = np.mean(results["prob"])
            task_accuracies[task] = {
                "accuracy": accuracy,
                "avg_probability": avg_probability,
            }

    if specific_task:
        if specific_task in task_accuracies:
            print(f"Task: {specific_task}")
            print(f"Accuracy: {task_accuracies[specific_task]['accuracy']:.2%}")
            print(
                f"Average Probability: {task_accuracies[specific_task]['avg_probability']:.4f}"
            )
        else:
            print(f"Task {specific_task} not found in the test set.")
    else:
        sorted_tasks = sorted(
            task_accuracies.items(), key=lambda x: x[1]["accuracy"], reverse=True
        )
        print("Task Accuracies and Average Probabilities:")
        for task, metrics in sorted_tasks:
            print(
                f"{task}: Accuracy: {metrics['accuracy']:.2%}, Avg Probability: {metrics['avg_probability']:.4f}"
            )

    return task_accuracies


# Load the data
X_train = np.load("X_train.npy", allow_pickle=True).astype(np.float32)
y_train = np.load("y_train.npy", allow_pickle=True)
X_test = np.load("X_test.npy", allow_pickle=True).astype(np.float32)
y_test = np.load("y_test.npy", allow_pickle=True)
task_names_test = np.load("task_names_test.npy", allow_pickle=True)

# Get unique task names
unique_task_names = np.unique(task_names_test)
# print("Available tasks:", unique_task_names)

# Create user ID to index mapping
unique_users = np.unique(np.concatenate([y_train, y_test]))
user_to_index = {user: idx for idx, user in enumerate(unique_users)}
index_to_user = {idx: user for user, idx in user_to_index.items()}

# Check for NaN or infinite values
X_train = np.nan_to_num(X_train, nan=0, posinf=1e6, neginf=-1e6)
X_test = np.nan_to_num(X_test, nan=0, posinf=1e6, neginf=-1e6)

# Initialize model
input_size = X_train.shape[1]
num_classes = len(unique_users)
model = MLPClassifier(input_size, num_classes)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Check if a trained model exists
model_path = "model.pth"
if os.path.exists(model_path):
    print("Loading existing model...")
    model.load_state_dict(torch.load(model_path))
else:
    print("Training new model...")
    # Create datasets and dataloaders
    train_dataset = EyeMovementDataset(X_train, y_train, user_to_index)
    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.05)

    # Training loop
    num_epochs = 200

    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch_X, batch_y in tqdm(
            train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"
        ):
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(
            f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_dataloader):.4f}"
        )

    # Save the model
    torch.save(model.state_dict(), model_path)

# Evaluation on test set
model.eval()
test_dataset = EyeMovementDataset(X_test, y_test, user_to_index)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

with torch.no_grad():
    correct = 0
    total = 0
    for batch_X, batch_y in test_dataloader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        outputs = model(batch_X)
        _, predicted = torch.max(outputs.data, 1)
        total += batch_y.size(0)
        correct += (predicted == batch_y).sum().item()

    print(f"Overall Test Accuracy: {100 * correct / total:.2f}%")

chosen_task = "3_VID"
task_accuracies = evaluate_task_accuracy(
    model, X_test, y_test, task_names_test, user_to_index, index_to_user, chosen_task
)


# Function to predict on user data
def predict_on_user_data(user_data, model, user_to_index, index_to_user, task_name):
    model.eval()
    device = next(model.parameters()).device

    # Ensure user_data is 2D (add batch dimension if necessary)
    if user_data.ndim == 1:
        user_data = user_data.reshape(1, -1)

    user_data_tensor = torch.FloatTensor(user_data).to(device)

    with torch.no_grad():
        outputs = model(user_data_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs.data, 1)

    predicted_index = 10 * predicted.cpu().numpy()[0]
    predicted_user = index_to_user[predicted_index]

    return predicted_user


# Example usage:
# predict user_to_index[12]
user_data = X_train[user_to_index[250]]
predicted_user = predict_on_user_data(
    user_data, model, user_to_index, index_to_user, chosen_task
)
print(f"Predicted user: {predicted_user}")


Training new model...


Epoch 1/200: 100%|██████████| 157/157 [00:00<00:00, 184.32it/s]


Epoch [1/200], Loss: 108178.3211


Epoch 2/200: 100%|██████████| 157/157 [00:00<00:00, 225.47it/s]


Epoch [2/200], Loss: 58058.3925


Epoch 3/200: 100%|██████████| 157/157 [00:00<00:00, 191.86it/s]


Epoch [3/200], Loss: 41543.6840


Epoch 4/200: 100%|██████████| 157/157 [00:00<00:00, 192.88it/s]


Epoch [4/200], Loss: 33989.5343


Epoch 5/200: 100%|██████████| 157/157 [00:00<00:00, 242.19it/s]


Epoch [5/200], Loss: 30636.0208


Epoch 6/200: 100%|██████████| 157/157 [00:00<00:00, 243.72it/s]


Epoch [6/200], Loss: 27478.8519


Epoch 7/200: 100%|██████████| 157/157 [00:00<00:00, 235.92it/s]


Epoch [7/200], Loss: 25691.0917


Epoch 8/200: 100%|██████████| 157/157 [00:00<00:00, 226.65it/s]


Epoch [8/200], Loss: 23097.4556


Epoch 9/200: 100%|██████████| 157/157 [00:00<00:00, 238.65it/s]


Epoch [9/200], Loss: 21976.5777


Epoch 10/200: 100%|██████████| 157/157 [00:00<00:00, 242.62it/s]


Epoch [10/200], Loss: 20662.7744


Epoch 11/200: 100%|██████████| 157/157 [00:00<00:00, 255.21it/s]


Epoch [11/200], Loss: 19464.9940


Epoch 12/200: 100%|██████████| 157/157 [00:00<00:00, 228.67it/s]


Epoch [12/200], Loss: 16818.9144


Epoch 13/200: 100%|██████████| 157/157 [00:00<00:00, 236.75it/s]


Epoch [13/200], Loss: 17442.1043


Epoch 14/200: 100%|██████████| 157/157 [00:00<00:00, 242.65it/s]


Epoch [14/200], Loss: 17100.0202


Epoch 15/200: 100%|██████████| 157/157 [00:00<00:00, 222.55it/s]


Epoch [15/200], Loss: 16503.1070


Epoch 16/200: 100%|██████████| 157/157 [00:00<00:00, 221.83it/s]


Epoch [16/200], Loss: 16782.9691


Epoch 17/200: 100%|██████████| 157/157 [00:00<00:00, 234.18it/s]


Epoch [17/200], Loss: 16300.2596


Epoch 18/200: 100%|██████████| 157/157 [00:00<00:00, 225.86it/s]


Epoch [18/200], Loss: 14268.8059


Epoch 19/200: 100%|██████████| 157/157 [00:00<00:00, 217.75it/s]


Epoch [19/200], Loss: 14437.7997


Epoch 20/200: 100%|██████████| 157/157 [00:00<00:00, 197.22it/s]


Epoch [20/200], Loss: 13533.7499


Epoch 21/200: 100%|██████████| 157/157 [00:00<00:00, 226.96it/s]


Epoch [21/200], Loss: 12601.0491


Epoch 22/200: 100%|██████████| 157/157 [00:00<00:00, 224.07it/s]


Epoch [22/200], Loss: 13156.7098


Epoch 23/200: 100%|██████████| 157/157 [00:00<00:00, 227.94it/s]


Epoch [23/200], Loss: 12199.4944


Epoch 24/200: 100%|██████████| 157/157 [00:00<00:00, 245.76it/s]


Epoch [24/200], Loss: 11654.0096


Epoch 25/200: 100%|██████████| 157/157 [00:00<00:00, 232.00it/s]


Epoch [25/200], Loss: 12746.0963


Epoch 26/200: 100%|██████████| 157/157 [00:00<00:00, 239.16it/s]


Epoch [26/200], Loss: 11192.6557


Epoch 27/200: 100%|██████████| 157/157 [00:00<00:00, 223.14it/s]


Epoch [27/200], Loss: 12528.8313


Epoch 28/200: 100%|██████████| 157/157 [00:00<00:00, 239.98it/s]


Epoch [28/200], Loss: 11619.9029


Epoch 29/200: 100%|██████████| 157/157 [00:00<00:00, 232.10it/s]


Epoch [29/200], Loss: 10425.9138


Epoch 30/200: 100%|██████████| 157/157 [00:00<00:00, 230.47it/s]


Epoch [30/200], Loss: 10124.2777


Epoch 31/200: 100%|██████████| 157/157 [00:00<00:00, 230.12it/s]


Epoch [31/200], Loss: 10993.2529


Epoch 32/200: 100%|██████████| 157/157 [00:00<00:00, 223.64it/s]


Epoch [32/200], Loss: 9672.1191


Epoch 33/200: 100%|██████████| 157/157 [00:00<00:00, 223.94it/s]


Epoch [33/200], Loss: 9517.6966


Epoch 34/200: 100%|██████████| 157/157 [00:00<00:00, 230.62it/s]


Epoch [34/200], Loss: 11004.9625


Epoch 35/200: 100%|██████████| 157/157 [00:00<00:00, 218.65it/s]


Epoch [35/200], Loss: 10838.7337


Epoch 36/200: 100%|██████████| 157/157 [00:00<00:00, 185.11it/s]


Epoch [36/200], Loss: 9169.9392


Epoch 37/200: 100%|██████████| 157/157 [00:00<00:00, 235.96it/s]


Epoch [37/200], Loss: 8999.7357


Epoch 38/200: 100%|██████████| 157/157 [00:00<00:00, 228.82it/s]


Epoch [38/200], Loss: 8505.2668


Epoch 39/200: 100%|██████████| 157/157 [00:00<00:00, 229.13it/s]


Epoch [39/200], Loss: 9180.6898


Epoch 40/200: 100%|██████████| 157/157 [00:00<00:00, 232.15it/s]


Epoch [40/200], Loss: 9074.2298


Epoch 41/200: 100%|██████████| 157/157 [00:00<00:00, 231.85it/s]


Epoch [41/200], Loss: 9233.2746


Epoch 42/200: 100%|██████████| 157/157 [00:00<00:00, 220.69it/s]


Epoch [42/200], Loss: 8927.5259


Epoch 43/200: 100%|██████████| 157/157 [00:00<00:00, 231.40it/s]


Epoch [43/200], Loss: 8497.9176


Epoch 44/200: 100%|██████████| 157/157 [00:00<00:00, 241.28it/s]


Epoch [44/200], Loss: 7675.8221


Epoch 45/200: 100%|██████████| 157/157 [00:00<00:00, 233.49it/s]


Epoch [45/200], Loss: 7744.6087


Epoch 46/200: 100%|██████████| 157/157 [00:00<00:00, 219.92it/s]


Epoch [46/200], Loss: 8457.6940


Epoch 47/200: 100%|██████████| 157/157 [00:00<00:00, 236.13it/s]


Epoch [47/200], Loss: 7370.8102


Epoch 48/200: 100%|██████████| 157/157 [00:00<00:00, 222.32it/s]


Epoch [48/200], Loss: 8132.5674


Epoch 49/200: 100%|██████████| 157/157 [00:00<00:00, 232.60it/s]


Epoch [49/200], Loss: 9057.1178


Epoch 50/200: 100%|██████████| 157/157 [00:00<00:00, 226.12it/s]


Epoch [50/200], Loss: 7927.8240


Epoch 51/200: 100%|██████████| 157/157 [00:00<00:00, 213.25it/s]


Epoch [51/200], Loss: 8130.2959


Epoch 52/200: 100%|██████████| 157/157 [00:00<00:00, 188.32it/s]


Epoch [52/200], Loss: 8178.1667


Epoch 53/200: 100%|██████████| 157/157 [00:00<00:00, 206.20it/s]


Epoch [53/200], Loss: 6960.2738


Epoch 54/200: 100%|██████████| 157/157 [00:00<00:00, 226.60it/s]


Epoch [54/200], Loss: 7176.8831


Epoch 55/200: 100%|██████████| 157/157 [00:00<00:00, 238.96it/s]


Epoch [55/200], Loss: 6439.4627


Epoch 56/200: 100%|██████████| 157/157 [00:00<00:00, 244.80it/s]


Epoch [56/200], Loss: 5795.2830


Epoch 57/200: 100%|██████████| 157/157 [00:00<00:00, 235.33it/s]


Epoch [57/200], Loss: 5831.9830


Epoch 58/200: 100%|██████████| 157/157 [00:00<00:00, 222.57it/s]


Epoch [58/200], Loss: 4878.8599


Epoch 59/200: 100%|██████████| 157/157 [00:00<00:00, 228.26it/s]


Epoch [59/200], Loss: 5607.7709


Epoch 60/200: 100%|██████████| 157/157 [00:00<00:00, 239.51it/s]


Epoch [60/200], Loss: 7338.0243


Epoch 61/200: 100%|██████████| 157/157 [00:00<00:00, 224.95it/s]


Epoch [61/200], Loss: 6674.9374


Epoch 62/200: 100%|██████████| 157/157 [00:00<00:00, 234.43it/s]


Epoch [62/200], Loss: 6288.1555


Epoch 63/200: 100%|██████████| 157/157 [00:00<00:00, 248.48it/s]


Epoch [63/200], Loss: 6547.9989


Epoch 64/200: 100%|██████████| 157/157 [00:00<00:00, 233.68it/s]


Epoch [64/200], Loss: 6476.7065


Epoch 65/200: 100%|██████████| 157/157 [00:00<00:00, 210.96it/s]


Epoch [65/200], Loss: 5825.4681


Epoch 66/200: 100%|██████████| 157/157 [00:00<00:00, 221.77it/s]


Epoch [66/200], Loss: 6294.1891


Epoch 67/200: 100%|██████████| 157/157 [00:00<00:00, 217.15it/s]


Epoch [67/200], Loss: 6069.0106


Epoch 68/200: 100%|██████████| 157/157 [00:00<00:00, 199.59it/s]


Epoch [68/200], Loss: 5912.2416


Epoch 69/200: 100%|██████████| 157/157 [00:00<00:00, 201.39it/s]


Epoch [69/200], Loss: 5161.6276


Epoch 70/200: 100%|██████████| 157/157 [00:00<00:00, 233.39it/s]


Epoch [70/200], Loss: 4969.0530


Epoch 71/200: 100%|██████████| 157/157 [00:00<00:00, 241.40it/s]


Epoch [71/200], Loss: 5069.4465


Epoch 72/200: 100%|██████████| 157/157 [00:00<00:00, 230.82it/s]


Epoch [72/200], Loss: 4991.8315


Epoch 73/200: 100%|██████████| 157/157 [00:00<00:00, 216.69it/s]


Epoch [73/200], Loss: 5468.8363


Epoch 74/200: 100%|██████████| 157/157 [00:00<00:00, 222.60it/s]


Epoch [74/200], Loss: 4568.3631


Epoch 75/200: 100%|██████████| 157/157 [00:00<00:00, 224.65it/s]


Epoch [75/200], Loss: 5055.3313


Epoch 76/200: 100%|██████████| 157/157 [00:00<00:00, 225.83it/s]


Epoch [76/200], Loss: 5651.2236


Epoch 77/200: 100%|██████████| 157/157 [00:00<00:00, 226.46it/s]


Epoch [77/200], Loss: 6613.9860


Epoch 78/200: 100%|██████████| 157/157 [00:00<00:00, 237.45it/s]


Epoch [78/200], Loss: 4156.0266


Epoch 79/200: 100%|██████████| 157/157 [00:00<00:00, 232.81it/s]


Epoch [79/200], Loss: 5409.2181


Epoch 80/200: 100%|██████████| 157/157 [00:00<00:00, 232.12it/s]


Epoch [80/200], Loss: 5652.9633


Epoch 81/200: 100%|██████████| 157/157 [00:00<00:00, 217.20it/s]


Epoch [81/200], Loss: 4864.6932


Epoch 82/200: 100%|██████████| 157/157 [00:00<00:00, 210.85it/s]


Epoch [82/200], Loss: 5779.9051


Epoch 83/200: 100%|██████████| 157/157 [00:00<00:00, 213.02it/s]


Epoch [83/200], Loss: 5120.8055


Epoch 84/200: 100%|██████████| 157/157 [00:00<00:00, 199.33it/s]


Epoch [84/200], Loss: 5081.3904


Epoch 85/200: 100%|██████████| 157/157 [00:00<00:00, 204.10it/s]


Epoch [85/200], Loss: 4492.6011


Epoch 86/200: 100%|██████████| 157/157 [00:00<00:00, 228.85it/s]


Epoch [86/200], Loss: 4982.7744


Epoch 87/200: 100%|██████████| 157/157 [00:00<00:00, 229.40it/s]


Epoch [87/200], Loss: 4537.4790


Epoch 88/200: 100%|██████████| 157/157 [00:00<00:00, 248.01it/s]


Epoch [88/200], Loss: 4539.6064


Epoch 89/200: 100%|██████████| 157/157 [00:00<00:00, 229.17it/s]


Epoch [89/200], Loss: 4220.3924


Epoch 90/200: 100%|██████████| 157/157 [00:00<00:00, 246.08it/s]


Epoch [90/200], Loss: 4385.0404


Epoch 91/200: 100%|██████████| 157/157 [00:00<00:00, 233.30it/s]


Epoch [91/200], Loss: 4921.1371


Epoch 92/200: 100%|██████████| 157/157 [00:00<00:00, 227.39it/s]


Epoch [92/200], Loss: 4305.4916


Epoch 93/200: 100%|██████████| 157/157 [00:00<00:00, 234.11it/s]


Epoch [93/200], Loss: 4299.4500


Epoch 94/200: 100%|██████████| 157/157 [00:00<00:00, 239.52it/s]


Epoch [94/200], Loss: 4099.6988


Epoch 95/200: 100%|██████████| 157/157 [00:00<00:00, 246.89it/s]


Epoch [95/200], Loss: 3288.3898


Epoch 96/200: 100%|██████████| 157/157 [00:00<00:00, 247.16it/s]


Epoch [96/200], Loss: 4542.1318


Epoch 97/200: 100%|██████████| 157/157 [00:00<00:00, 235.59it/s]


Epoch [97/200], Loss: 5663.7906


Epoch 98/200: 100%|██████████| 157/157 [00:00<00:00, 246.28it/s]


Epoch [98/200], Loss: 3796.7149


Epoch 99/200: 100%|██████████| 157/157 [00:00<00:00, 236.74it/s]


Epoch [99/200], Loss: 4241.3356


Epoch 100/200: 100%|██████████| 157/157 [00:00<00:00, 216.87it/s]


Epoch [100/200], Loss: 4662.5680


Epoch 101/200: 100%|██████████| 157/157 [00:00<00:00, 199.19it/s]


Epoch [101/200], Loss: 4781.2165


Epoch 102/200: 100%|██████████| 157/157 [00:00<00:00, 218.30it/s]


Epoch [102/200], Loss: 3744.2541


Epoch 103/200: 100%|██████████| 157/157 [00:00<00:00, 253.02it/s]


Epoch [103/200], Loss: 3490.5144


Epoch 104/200: 100%|██████████| 157/157 [00:00<00:00, 231.67it/s]


Epoch [104/200], Loss: 2598.2156


Epoch 105/200: 100%|██████████| 157/157 [00:00<00:00, 243.08it/s]


Epoch [105/200], Loss: 3368.5218


Epoch 106/200: 100%|██████████| 157/157 [00:00<00:00, 247.82it/s]


Epoch [106/200], Loss: 5121.4740


Epoch 107/200: 100%|██████████| 157/157 [00:00<00:00, 218.61it/s]


Epoch [107/200], Loss: 4243.2882


Epoch 108/200: 100%|██████████| 157/157 [00:00<00:00, 238.91it/s]


Epoch [108/200], Loss: 4812.6636


Epoch 109/200: 100%|██████████| 157/157 [00:00<00:00, 227.76it/s]


Epoch [109/200], Loss: 4385.9329


Epoch 110/200: 100%|██████████| 157/157 [00:00<00:00, 224.85it/s]


Epoch [110/200], Loss: 4328.2101


Epoch 111/200: 100%|██████████| 157/157 [00:00<00:00, 234.01it/s]


Epoch [111/200], Loss: 3582.3746


Epoch 112/200: 100%|██████████| 157/157 [00:00<00:00, 230.70it/s]


Epoch [112/200], Loss: 3573.0907


Epoch 113/200: 100%|██████████| 157/157 [00:00<00:00, 233.63it/s]


Epoch [113/200], Loss: 3194.5302


Epoch 114/200: 100%|██████████| 157/157 [00:00<00:00, 219.37it/s]


Epoch [114/200], Loss: 2895.8916


Epoch 115/200: 100%|██████████| 157/157 [00:00<00:00, 239.63it/s]


Epoch [115/200], Loss: 2864.6875


Epoch 116/200: 100%|██████████| 157/157 [00:00<00:00, 228.84it/s]


Epoch [116/200], Loss: 3635.7808


Epoch 117/200: 100%|██████████| 157/157 [00:00<00:00, 216.37it/s]


Epoch [117/200], Loss: 2956.3532


Epoch 118/200: 100%|██████████| 157/157 [00:00<00:00, 188.30it/s]


Epoch [118/200], Loss: 3515.8982


Epoch 119/200: 100%|██████████| 157/157 [00:00<00:00, 228.79it/s]


Epoch [119/200], Loss: 4154.3903


Epoch 120/200: 100%|██████████| 157/157 [00:00<00:00, 228.65it/s]


Epoch [120/200], Loss: 3554.9322


Epoch 121/200: 100%|██████████| 157/157 [00:00<00:00, 224.71it/s]


Epoch [121/200], Loss: 3327.8948


Epoch 122/200: 100%|██████████| 157/157 [00:00<00:00, 221.74it/s]


Epoch [122/200], Loss: 2501.3094


Epoch 123/200: 100%|██████████| 157/157 [00:00<00:00, 259.27it/s]


Epoch [123/200], Loss: 3090.9550


Epoch 124/200: 100%|██████████| 157/157 [00:00<00:00, 224.56it/s]


Epoch [124/200], Loss: 4504.9885


Epoch 125/200: 100%|██████████| 157/157 [00:00<00:00, 223.34it/s]


Epoch [125/200], Loss: 5033.0149


Epoch 126/200: 100%|██████████| 157/157 [00:00<00:00, 244.35it/s]


Epoch [126/200], Loss: 3670.3917


Epoch 127/200: 100%|██████████| 157/157 [00:00<00:00, 234.97it/s]


Epoch [127/200], Loss: 3068.1920


Epoch 128/200: 100%|██████████| 157/157 [00:00<00:00, 243.45it/s]


Epoch [128/200], Loss: 3253.4989


Epoch 129/200: 100%|██████████| 157/157 [00:00<00:00, 231.94it/s]


Epoch [129/200], Loss: 3454.4133


Epoch 130/200: 100%|██████████| 157/157 [00:00<00:00, 231.16it/s]


Epoch [130/200], Loss: 2670.7382


Epoch 131/200: 100%|██████████| 157/157 [00:00<00:00, 226.98it/s]


Epoch [131/200], Loss: 3705.3216


Epoch 132/200: 100%|██████████| 157/157 [00:00<00:00, 228.12it/s]


Epoch [132/200], Loss: 3425.5504


Epoch 133/200: 100%|██████████| 157/157 [00:00<00:00, 211.02it/s]


Epoch [133/200], Loss: 3174.1169


Epoch 134/200: 100%|██████████| 157/157 [00:00<00:00, 202.93it/s]


Epoch [134/200], Loss: 3135.8117


Epoch 135/200: 100%|██████████| 157/157 [00:00<00:00, 227.33it/s]


Epoch [135/200], Loss: 3397.7297


Epoch 136/200: 100%|██████████| 157/157 [00:00<00:00, 242.79it/s]


Epoch [136/200], Loss: 2721.9402


Epoch 137/200: 100%|██████████| 157/157 [00:00<00:00, 235.32it/s]


Epoch [137/200], Loss: 2660.7845


Epoch 138/200: 100%|██████████| 157/157 [00:00<00:00, 246.23it/s]


Epoch [138/200], Loss: 4145.2141


Epoch 139/200: 100%|██████████| 157/157 [00:00<00:00, 228.56it/s]


Epoch [139/200], Loss: 3770.6061


Epoch 140/200: 100%|██████████| 157/157 [00:00<00:00, 245.93it/s]


Epoch [140/200], Loss: 4044.6942


Epoch 141/200: 100%|██████████| 157/157 [00:00<00:00, 221.17it/s]


Epoch [141/200], Loss: 2770.8790


Epoch 142/200: 100%|██████████| 157/157 [00:00<00:00, 217.58it/s]


Epoch [142/200], Loss: 2488.5632


Epoch 143/200: 100%|██████████| 157/157 [00:00<00:00, 238.30it/s]


Epoch [143/200], Loss: 2183.5918


Epoch 144/200: 100%|██████████| 157/157 [00:00<00:00, 222.66it/s]


Epoch [144/200], Loss: 2753.3299


Epoch 145/200: 100%|██████████| 157/157 [00:00<00:00, 217.34it/s]


Epoch [145/200], Loss: 3329.5770


Epoch 146/200: 100%|██████████| 157/157 [00:00<00:00, 239.01it/s]


Epoch [146/200], Loss: 2335.9276


Epoch 147/200: 100%|██████████| 157/157 [00:00<00:00, 231.13it/s]


Epoch [147/200], Loss: 2408.4758


Epoch 148/200: 100%|██████████| 157/157 [00:00<00:00, 223.34it/s]


Epoch [148/200], Loss: 3763.6500


Epoch 149/200: 100%|██████████| 157/157 [00:00<00:00, 213.36it/s]


Epoch [149/200], Loss: 2857.7753


Epoch 150/200: 100%|██████████| 157/157 [00:00<00:00, 191.59it/s]


Epoch [150/200], Loss: 3078.5365


Epoch 151/200: 100%|██████████| 157/157 [00:00<00:00, 208.27it/s]


Epoch [151/200], Loss: 2730.5585


Epoch 152/200: 100%|██████████| 157/157 [00:00<00:00, 246.76it/s]


Epoch [152/200], Loss: 3815.6195


Epoch 153/200: 100%|██████████| 157/157 [00:00<00:00, 230.41it/s]


Epoch [153/200], Loss: 2503.4452


Epoch 154/200: 100%|██████████| 157/157 [00:00<00:00, 232.09it/s]


Epoch [154/200], Loss: 2927.4659


Epoch 155/200: 100%|██████████| 157/157 [00:00<00:00, 234.42it/s]


Epoch [155/200], Loss: 2735.2518


Epoch 156/200: 100%|██████████| 157/157 [00:00<00:00, 239.03it/s]


Epoch [156/200], Loss: 3720.4552


Epoch 157/200: 100%|██████████| 157/157 [00:00<00:00, 219.80it/s]


Epoch [157/200], Loss: 3436.3807


Epoch 158/200: 100%|██████████| 157/157 [00:00<00:00, 222.37it/s]


Epoch [158/200], Loss: 2549.4517


Epoch 159/200: 100%|██████████| 157/157 [00:00<00:00, 227.73it/s]


Epoch [159/200], Loss: 2735.9785


Epoch 160/200: 100%|██████████| 157/157 [00:00<00:00, 227.08it/s]


Epoch [160/200], Loss: 3032.8981


Epoch 161/200: 100%|██████████| 157/157 [00:00<00:00, 218.95it/s]


Epoch [161/200], Loss: 2635.1311


Epoch 162/200: 100%|██████████| 157/157 [00:00<00:00, 215.74it/s]


Epoch [162/200], Loss: 2179.7599


Epoch 163/200: 100%|██████████| 157/157 [00:00<00:00, 245.23it/s]


Epoch [163/200], Loss: 2885.4702


Epoch 164/200: 100%|██████████| 157/157 [00:00<00:00, 237.91it/s]


Epoch [164/200], Loss: 2490.0669


Epoch 165/200: 100%|██████████| 157/157 [00:00<00:00, 198.86it/s]


Epoch [165/200], Loss: 2228.0301


Epoch 166/200: 100%|██████████| 157/157 [00:00<00:00, 184.10it/s]


Epoch [166/200], Loss: 2095.3387


Epoch 167/200: 100%|██████████| 157/157 [00:00<00:00, 205.71it/s]


Epoch [167/200], Loss: 2150.3531


Epoch 168/200: 100%|██████████| 157/157 [00:00<00:00, 228.77it/s]


Epoch [168/200], Loss: 2158.4532


Epoch 169/200: 100%|██████████| 157/157 [00:00<00:00, 230.17it/s]


Epoch [169/200], Loss: 2941.3359


Epoch 170/200: 100%|██████████| 157/157 [00:00<00:00, 232.74it/s]


Epoch [170/200], Loss: 2177.0774


Epoch 171/200: 100%|██████████| 157/157 [00:00<00:00, 229.81it/s]


Epoch [171/200], Loss: 3295.3835


Epoch 172/200: 100%|██████████| 157/157 [00:00<00:00, 221.84it/s]


Epoch [172/200], Loss: 2587.4684


Epoch 173/200: 100%|██████████| 157/157 [00:00<00:00, 210.92it/s]


Epoch [173/200], Loss: 2782.5339


Epoch 174/200: 100%|██████████| 157/157 [00:00<00:00, 228.64it/s]


Epoch [174/200], Loss: 3548.8146


Epoch 175/200: 100%|██████████| 157/157 [00:00<00:00, 225.82it/s]


Epoch [175/200], Loss: 3767.2132


Epoch 176/200: 100%|██████████| 157/157 [00:00<00:00, 237.70it/s]


Epoch [176/200], Loss: 2003.3750


Epoch 177/200: 100%|██████████| 157/157 [00:00<00:00, 226.16it/s]


Epoch [177/200], Loss: 1334.4421


Epoch 178/200: 100%|██████████| 157/157 [00:00<00:00, 231.41it/s]


Epoch [178/200], Loss: 1893.1106


Epoch 179/200: 100%|██████████| 157/157 [00:00<00:00, 234.51it/s]


Epoch [179/200], Loss: 2467.4197


Epoch 180/200: 100%|██████████| 157/157 [00:00<00:00, 218.26it/s]


Epoch [180/200], Loss: 2336.8272


Epoch 181/200: 100%|██████████| 157/157 [00:00<00:00, 216.87it/s]


Epoch [181/200], Loss: 3018.5070


Epoch 182/200: 100%|██████████| 157/157 [00:00<00:00, 200.43it/s]


Epoch [182/200], Loss: 4794.7971


Epoch 183/200: 100%|██████████| 157/157 [00:00<00:00, 204.81it/s]


Epoch [183/200], Loss: 3072.8381


Epoch 184/200: 100%|██████████| 157/157 [00:00<00:00, 229.19it/s]


Epoch [184/200], Loss: 2204.4761


Epoch 185/200: 100%|██████████| 157/157 [00:00<00:00, 246.78it/s]


Epoch [185/200], Loss: 1884.2286


Epoch 186/200: 100%|██████████| 157/157 [00:00<00:00, 220.47it/s]


Epoch [186/200], Loss: 1807.2131


Epoch 187/200: 100%|██████████| 157/157 [00:00<00:00, 228.16it/s]


Epoch [187/200], Loss: 2901.1346


Epoch 188/200: 100%|██████████| 157/157 [00:00<00:00, 220.50it/s]


Epoch [188/200], Loss: 2306.0424


Epoch 189/200: 100%|██████████| 157/157 [00:00<00:00, 227.81it/s]


Epoch [189/200], Loss: 2197.9799


Epoch 190/200: 100%|██████████| 157/157 [00:00<00:00, 244.40it/s]


Epoch [190/200], Loss: 2292.9911


Epoch 191/200: 100%|██████████| 157/157 [00:00<00:00, 234.36it/s]


Epoch [191/200], Loss: 1959.3314


Epoch 192/200: 100%|██████████| 157/157 [00:00<00:00, 244.31it/s]


Epoch [192/200], Loss: 3288.5949


Epoch 193/200: 100%|██████████| 157/157 [00:00<00:00, 226.50it/s]


Epoch [193/200], Loss: 2525.2412


Epoch 194/200: 100%|██████████| 157/157 [00:00<00:00, 231.05it/s]


Epoch [194/200], Loss: 1689.7862


Epoch 195/200: 100%|██████████| 157/157 [00:00<00:00, 238.42it/s]


Epoch [195/200], Loss: 1950.9284


Epoch 196/200: 100%|██████████| 157/157 [00:00<00:00, 241.60it/s]


Epoch [196/200], Loss: 2390.1245


Epoch 197/200: 100%|██████████| 157/157 [00:00<00:00, 233.93it/s]


Epoch [197/200], Loss: 1794.8757


Epoch 198/200: 100%|██████████| 157/157 [00:00<00:00, 195.44it/s]


Epoch [198/200], Loss: 1613.6174


Epoch 199/200: 100%|██████████| 157/157 [00:00<00:00, 188.38it/s]


Epoch [199/200], Loss: 1695.2161


Epoch 200/200: 100%|██████████| 157/157 [00:00<00:00, 218.88it/s]


Epoch [200/200], Loss: 3158.2833
Overall Test Accuracy: 67.45%
Task: 3_VID
Accuracy: 94.12%
Average Probability: 0.9412
Predicted user: 250


In [38]:
# Example usage:
user_data = X_train[user_to_index[11]]
predicted_user = predict_on_user_data(
    user_data, model, user_to_index, index_to_user, chosen_task
)
print(f"Predicted user: {predicted_user}")

Predicted user: 2
