In [1]:
class Config:
    DATASET_PATH = "../.local/datasets/PupilCoreV1_all/"
    MODEL_CHKPT_PATH = "../.local/checkpoints/"
    
conf = Config()

In [2]:
from sklearn.preprocessing import StandardScaler
import os
import numpy as np
import pandas as pd
from torch.utils.data import Dataset

class GazeDataset(Dataset):
    CLASS_MAP = {
        'Inspection': 0,
        'Reading': 1,
        'Search': 2,
    }
    def __init__(self, 
                 data_files, 
                 selected_cols, 
                 seq_len, 
                 transform=None,
                 min_confidence=0.8,
                 interpolation_method='linear'):
        self.data_files = data_files
        self.selected_cols = selected_cols
        self.seq_len = seq_len
        self.transform = transform
        self.interpolation_method = interpolation_method
        self.min_confidence = min_confidence
        
        self.data = []
        self.labels = []
        self.participant_ids = []
        
        for file in self.data_files:
            participant_id = file.split('/')[-1].split('_')[0]
            label = file.split('/')[-1].split('_')[1].split('.')[0]
            
            sample = self.load_single_file(file)
            # sample = self.set_low_confidence_data_to_nan(sample)
            # sample = self.remove_low_confidence_data(sample)
            # sample = self.clip_data(sample)
            sample = self.interpolate_data(sample)
            sample = self.drop_nan_rows(sample)
            sequences = self.split_data_into_sequences(sample)
            
            sequences = np.array(sequences).astype(np.float32)
            
            self.data.extend(sequences)  # Append sequences to the dataset
            self.labels.extend([int(self.CLASS_MAP[label])] * len(sequences))
            self.participant_ids.extend([int(participant_id)] * len(sequences))
            
        self.data = np.array(self.data)
        self.labels = np.array(self.labels)
        self.participant_ids = np.array(self.participant_ids)
        
    def load_single_file(self, file: str):
        raw_data = pd.read_csv(file)
        raw_data = raw_data[self.selected_cols]
        return raw_data
    
    def remove_low_confidence_data(self, df: pd.DataFrame):
        low_confidence_data = df[df['confidence'] <= self.min_confidence]
        print(f"Removed {len(low_confidence_data)} low confidence data points")
        return df[df['confidence'] > self.min_confidence]
    
    def set_low_confidence_data_to_nan(self, df: pd.DataFrame):
        df.loc[df['confidence'] <= self.min_confidence, ['norm_pos_x', 'norm_pos_y']] = np.nan
        return df
    
    def clip_data(self, df: pd.DataFrame):
        start = int(len(df) * 0.02)
        end = int(len(df) * 0.98)
        return df[start:end]
    
    def drop_nan_rows(self, df: pd.DataFrame):
        return df.dropna()
    
    def interpolate_data(self, df: pd.DataFrame):
        # Replave inf values with NaN
        df = df.replace([np.inf, -np.inf], np.nan)
        # Interpolate the NaN values
        df = df.interpolate(method=self.interpolation_method)
        return df
    
    def split_data_into_sequences(self, arr: np.ndarray):
        # split the data into sequences of length seq_len.
        sequences = [arr[i:i + self.seq_len] for i in range(0, len(arr), self.seq_len)]
        # remove the last sequence if it is not of length seq_len
        if len(sequences[-1]) != self.seq_len:
            sequences.pop()
        return sequences

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        
        if self.transform:
            sample = self.transform(sample)
        
        return sample, label


In [3]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import torch.optim as optim
from torchmetrics import Accuracy

# Define the PyTorch Lightning Module
class GazeRNN(pl.LightningModule):
    def __init__(self, 
                 input_size, 
                 hidden_size, 
                 num_layers, 
                 num_classes, 
                 rnn_type='LSTM',
                 dropout=0,
                 learning_rate=0.001,
                 opt_step_size=5,
                 opt_gamma=0.1,
                 opt_wd=1e-5):
        super(GazeRNN, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.learning_rate = learning_rate
        self.opt_step_size = opt_step_size
        self.opt_gamma = opt_gamma
        self.opt_wd = opt_wd

        # Choose between LSTM, GRU, or vanilla RNN
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=True)
        else:
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=True)

        # Fully connected layer for classification
        self.fc = nn.Linear(hidden_size, num_classes)

        # Loss function
        self.criterion = nn.CrossEntropyLoss()

        # Accuracy metric
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        if isinstance(self.rnn, nn.LSTM):
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            out, _ = self.rnn(x, (h0, c0))
        else:
            out, _ = self.rnn(x, h0)

        # Take the last time-step's output
        out = out[:, -1, :]
        out = self.fc(out)
        return out

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        acc = self.train_acc(outputs, labels)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        acc = self.val_acc(outputs, labels)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('lr', self.trainer.optimizers[0].param_groups[0]['lr'], on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.learning_rate, weight_decay=self.opt_wd)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=self.opt_step_size, gamma=self.opt_gamma)
        # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
        return [optimizer], [scheduler]


In [4]:
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

class GazeDataModule(pl.LightningDataModule):
    def __init__(self, ds_train, ds_val, batch_size=64, num_workers=4):
        super(GazeDataModule, self).__init__()
        self.ds_train = ds_train
        self.ds_val = ds_val
        self.batch_size = batch_size
        self.num_workers = num_workers

    def train_dataloader(self):
        return DataLoader(self.ds_train, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)

    def val_dataloader(self):
        return DataLoader(self.ds_val, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)


In [5]:
df = pd.read_csv(conf.DATASET_PATH + '01_Inspection.csv')

print(list(df.columns))

times = df['pupil_timestamp'].values

# convert the timestamps to seconds
print(times[:5])
window = 120
window_duration = times[window] - times[0]
print(f"Window duration: {window_duration:.2f} seconds")

['pupil_timestamp', 'norm_pos_x_2d_0', 'norm_pos_y_2d_0', 'diameter_2d_0', 'ellipse_center_x_2d_0', 'ellipse_center_y_2d_0', 'ellipse_axis_a_2d_0', 'ellipse_axis_b_2d_0', 'ellipse_angle_2d_0', 'norm_pos_x_2d_1', 'norm_pos_y_2d_1', 'diameter_2d_1', 'ellipse_center_x_2d_1', 'ellipse_center_y_2d_1', 'ellipse_axis_a_2d_1', 'ellipse_axis_b_2d_1', 'ellipse_angle_2d_1', 'norm_pos_x_3d_0', 'norm_pos_y_3d_0', 'diameter_3d_0', 'ellipse_center_x_3d_0', 'ellipse_center_y_3d_0', 'ellipse_axis_a_3d_0', 'ellipse_axis_b_3d_0', 'ellipse_angle_3d_0', 'diameter_3d_3d_0', 'model_confidence_3d_0', 'sphere_center_x_3d_0', 'sphere_center_y_3d_0', 'sphere_center_z_3d_0', 'sphere_radius_3d_0', 'circle_3d_center_x_3d_0', 'circle_3d_center_y_3d_0', 'circle_3d_center_z_3d_0', 'circle_3d_normal_x_3d_0', 'circle_3d_normal_y_3d_0', 'circle_3d_normal_z_3d_0', 'circle_3d_radius_3d_0', 'theta_3d_0', 'phi_3d_0', 'projected_sphere_center_x_3d_0', 'projected_sphere_center_y_3d_0', 'projected_sphere_axis_a_3d_0', 'projecte

In [35]:
import pickle

from sklearn.preprocessing import RobustScaler

col_of_interest = [
    # 'pupil_timestamp', 
    'norm_pos_x_2d_0', 'norm_pos_y_2d_0', 
    'diameter_2d_0', 
    'ellipse_center_x_2d_0', 'ellipse_center_y_2d_0', 'ellipse_axis_a_2d_0', 'ellipse_axis_b_2d_0', 'ellipse_angle_2d_0', 
    'norm_pos_x_2d_1', 'norm_pos_y_2d_1', 
    'diameter_2d_1', 
    'ellipse_center_x_2d_1', 'ellipse_center_y_2d_1', 'ellipse_axis_a_2d_1', 'ellipse_axis_b_2d_1', 'ellipse_angle_2d_1', 
    'norm_pos_x_3d_0', 'norm_pos_y_3d_0', 
    'diameter_3d_0', 
    'ellipse_center_x_3d_0', 'ellipse_center_y_3d_0', 'ellipse_axis_a_3d_0', 'ellipse_axis_b_3d_0', 'ellipse_angle_3d_0', 
    # 'diameter_3d_3d_0', 
    # 'model_confidence_3d_0', 
    'sphere_center_x_3d_0', 'sphere_center_y_3d_0', 'sphere_center_z_3d_0', 'sphere_radius_3d_0', 
    'circle_3d_center_x_3d_0', 'circle_3d_center_y_3d_0', 'circle_3d_center_z_3d_0', 
    'circle_3d_normal_x_3d_0', 'circle_3d_normal_y_3d_0', 'circle_3d_normal_z_3d_0', 
    'circle_3d_radius_3d_0', 
    'theta_3d_0', 'phi_3d_0', 
    # 'projected_sphere_center_x_3d_0', 'projected_sphere_center_y_3d_0', 'projected_sphere_axis_a_3d_0', 'projected_sphere_axis_b_3d_0', 'projected_sphere_angle_3d_0', 
    'norm_pos_x_3d_1', 'norm_pos_y_3d_1', 
    'diameter_3d_1', 
    'ellipse_center_x_3d_1', 'ellipse_center_y_3d_1', 'ellipse_axis_a_3d_1', 'ellipse_axis_b_3d_1', 'ellipse_angle_3d_1', 
    # 'diameter_3d_3d_1', 
    # 'model_confidence_3d_1', 
    'sphere_center_x_3d_1', 'sphere_center_y_3d_1', 'sphere_center_z_3d_1', 'sphere_radius_3d_1', 
    'circle_3d_center_x_3d_1', 'circle_3d_center_y_3d_1', 'circle_3d_center_z_3d_1', 
    'circle_3d_normal_x_3d_1', 'circle_3d_normal_y_3d_1', 'circle_3d_normal_z_3d_1', 
    'circle_3d_radius_3d_1', 
    'theta_3d_1', 'phi_3d_1', 
    # 'projected_sphere_center_x_3d_1', 'projected_sphere_center_y_3d_1', 'projected_sphere_axis_a_3d_1', 'projected_sphere_axis_b_3d_1', 'projected_sphere_angle_3d_1'
]


other_cols = ['pupil_timestamp']


dir_files = os.listdir(conf.DATASET_PATH)
test_participants = ['10', '02']
train_files = [os.path.join(conf.DATASET_PATH, f) for f in dir_files if f.split('_')[0] not in test_participants and f.endswith('.csv')]
test_files = [os.path.join(conf.DATASET_PATH, f) for f in dir_files if f.split('_')[0] in test_participants and f.endswith('.csv')]

train_ds = GazeDataset(train_files, col_of_interest, 300, min_confidence=0.5)
test_ds = GazeDataset(test_files, col_of_interest, 300, min_confidence=0.5)

scaler = RobustScaler()
all_data = np.concatenate([train_ds.data, test_ds.data])
all_data = all_data.reshape(-1, all_data.shape[-1])
scaler.fit(all_data.reshape(-1, all_data.shape[-1]))

with open('../.local/scaler.pickle', 'wb') as f:
    pickle.dump(scaler, f)

train_ds.data = scaler.transform(train_ds.data.reshape(-1, train_ds.data.shape[-1])).reshape(train_ds.data.shape)
test_ds.data = scaler.transform(test_ds.data.reshape(-1, test_ds.data.shape[-1])).reshape(test_ds.data.shape)

print(f"Training dataset shape: {train_ds.data.shape}")
print(f"Test dataset shape: {test_ds.data.shape}")
print(f"Shape of each sample: {train_ds[0][0].shape}")
print(f"Unique classes: {list(train_ds.CLASS_MAP.keys())}")

# print the number of nan values in the dataset
print(f"Number of NaN values in the training dataset: {np.isnan(train_ds.data).sum()}")

Training dataset shape: (702, 300, 58)
Test dataset shape: (165, 300, 58)
Shape of each sample: (300, 58)
Unique classes: ['Inspection', 'Reading', 'Search']
Number of NaN values in the training dataset: 0


In [36]:
from pytorch_lightning.loggers import TensorBoardLogger
from datetime import datetime

# Hyperparameters
INPUT_SIZE = len(col_of_interest)
HIDDEN_SIZE = 32
NUM_LAYERS = 1
NUM_CLASSES = 3
LEARNING_RATE = 0.0003
MODEL_TYPE = 'LSTM'
DROPOUT = 0.3

BATCH_SIZE = 128
NUM_WORKERS = 4

NUM_EPOCHS = 50
GRAD_CLIP = 0.8
OPT_STEP_SIZE = 10
OPT_GAMMA = 0.9
OPT_WD = 0.000001


# Initialize the DataModule
data_module = GazeDataModule(
    train_ds, test_ds,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS
)

# Initialize the model
model = GazeRNN(
    input_size=len(col_of_interest),
    hidden_size=HIDDEN_SIZE, 
    num_layers=NUM_LAYERS,
    num_classes=NUM_CLASSES, 
    learning_rate=LEARNING_RATE,
    rnn_type=MODEL_TYPE,
    opt_step_size=OPT_STEP_SIZE,
    opt_gamma=OPT_GAMMA,
    dropout=DROPOUT,
    opt_wd=OPT_WD
)

# Checkpoint callback
current_time = datetime.now().strftime('%Y-%m-%d_%H-%M')
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_acc',
    dirpath=conf.MODEL_CHKPT_PATH + f'{MODEL_TYPE}/{current_time}',
    filename='gaze_rnn-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='max',
)

# Initialize PyTorch Lightning trainer
trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS, 
    accelerator='gpu', 
    enable_progress_bar=True,
    logger=False,
    enable_checkpointing=True,
    callbacks=[checkpoint_callback],
    #gradient_clip_val=GRAD_CLIP
)

# Train the model
trainer.fit(model, data_module)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | rnn       | LSTM               | 31.7 K | train
1 | fc        | Linear             | 195    | train
2 | criterion | CrossEntropyLoss   | 0      | train
3 | train_acc | MulticlassAccuracy | 0      | train
4 | val_acc   | MulticlassAccuracy | 0      | train
---------------------------------------------------------
31.9 K    Trainable params
0         Non-trainable params
31.9 K    Total params
0.128     Total estimated model params size (MB)
5         Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...

KeyboardInterrupt



In [42]:
# Load the best model from the checkpoint
INPUT_SIZE = len(col_of_interest)
HIDDEN_SIZE = 128
NUM_LAYERS = 3
NUM_CLASSES = 3
LEARNING_RATE = 0.0003
MODEL_TYPE = 'LSTM'
DROPOUT = 0.3

best_model = GazeRNN.load_from_checkpoint(
    '../.local/checkpoints/LSTM/2024-10-17_15-12/gaze_rnn-epoch=25-val_loss=0.47.ckpt',#checkpoint_callback.best_model_path,
    input_size=INPUT_SIZE,
    hidden_size=HIDDEN_SIZE,
    num_layers=NUM_LAYERS,
    num_classes=NUM_CLASSES,
    rnn_type=MODEL_TYPE,
    dropout=DROPOUT,
    learning_rate=LEARNING_RATE,
    opt_step_size=OPT_STEP_SIZE,
    opt_gamma=OPT_GAMMA
)

# Predict on the validation set and set device to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
best_model.to(device)
best_model.eval()

# Initialize the DataLoader for the validation set
val_loader = data_module.val_dataloader()

# Predict on the validation set
y_true = []
y_pred = []
for inputs, labels in val_loader:
    inputs = inputs.to(device)
    labels = labels.to(device)
    outputs = best_model(inputs)
    _, predicted = torch.max(outputs, 1)
    y_true.extend(labels.cpu().numpy())
    y_pred.extend(predicted.cpu().numpy())
    
# Calculate the accuracy
accuracy = np.mean(np.array(y_true) == np.array(y_pred))
print(f"Validation Accuracy: {accuracy}")

RuntimeError: Error(s) in loading state_dict for GazeRNN:
	Missing key(s) in state_dict: "rnn.weight_ih_l2", "rnn.weight_hh_l2", "rnn.bias_ih_l2", "rnn.bias_hh_l2". 
	size mismatch for rnn.weight_ih_l0: copying a param with shape torch.Size([128, 37]) from checkpoint, the shape in current model is torch.Size([512, 58]).
	size mismatch for rnn.weight_hh_l0: copying a param with shape torch.Size([128, 32]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for rnn.bias_ih_l0: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for rnn.bias_hh_l0: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for rnn.weight_ih_l1: copying a param with shape torch.Size([128, 32]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for rnn.weight_hh_l1: copying a param with shape torch.Size([128, 32]) from checkpoint, the shape in current model is torch.Size([512, 128]).
	size mismatch for rnn.bias_ih_l1: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for rnn.bias_hh_l1: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([512]).
	size mismatch for fc.weight: copying a param with shape torch.Size([3, 32]) from checkpoint, the shape in current model is torch.Size([3, 128]).

In [None]:
# Plot the confusion matrix
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score
import seaborn as sns

cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=GazeDataset.CLASS_MAP.keys(), yticklabels=GazeDataset.CLASS_MAP.keys(), cmap='Blues')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

In [None]:
import msgpack
import cv2
import numpy as np
import zmq
from PIL import Image
import socket
import toml


app_config = toml.load("../Dashboard_App/config.toml")

def decode_dict(d):
    result = {}
    for key, value in d.items():
        if isinstance(key, bytes):
            key = key.decode()
        if isinstance(value, bytes):
            value = value.decode()
        elif isinstance(value, dict):
            value = decode_dict(value)
        result.update({key: value})
    return result


class PupilLabsController:
    TOPIC_GAZE = app_config["pupil_topics"]["gaze"]
    TOPIC_FRAME_WORLD = app_config["pupil_topics"]["front_camera"]
    SERVICE_HOST = app_config["pupil_service"]["host"]
    SERVICE_PORT = app_config["pupil_service"]["port"]

    def __init__(self):
        if self.is_service_online():
            self.ctx = zmq.Context()
            self.sub_port, self.pub_port = self.__get_sub_pub_ports()

            self.gaze_socket = self.__create_pupil_sub_socket(self.TOPIC_GAZE)
            self.frame_world_socket = self.__create_pupil_sub_socket(self.TOPIC_FRAME_WORLD)
        else:
            raise ConnectionError("Pupil service is offline or not reachable.")

    def reconnect_sockets(self):
        self.gaze_socket.close()
        self.frame_world_socket.close()
        self.gaze_socket = self.__create_pupil_sub_socket(self.TOPIC_GAZE)
        self.frame_world_socket = self.__create_pupil_sub_socket(self.TOPIC_FRAME_WORLD)

    def receive_gaze_data(self, num_gazes=1):
        gazes = []
        for _ in range(num_gazes):
            topic, payload = self.gaze_socket.recv_multipart()
            message = msgpack.loads(payload)

            # Decode the message
            gaze = decode_dict(message)
            gaze["base_data"] = [decode_dict(data) for data in gaze["base_data"]]

            gazes.append(gaze)
        return gazes

    def receive_cam_frames(self, num_frames=1):
        frames = []
        for _ in range(num_frames):
            _, _, payload = self.frame_world_socket.recv_multipart()

            # Decode the image
            frame = cv2.imdecode(np.frombuffer(payload, dtype=np.uint8), cv2.IMREAD_COLOR)

            # Convert the image to RGB
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            frame = Image.fromarray(frame)
            frames.append(frame)
        return frames

    def close_connection(self):
        self.ctx.term()

    def is_service_online(self):
        """Check if the service is online by attempting to connect to it."""
        try:
            # Create a socket object
            with socket.create_connection((self.SERVICE_HOST, self.SERVICE_PORT), timeout=5):
                return True  # Service is online
        except (socket.timeout, ConnectionRefusedError):
            return False  # Service is offline or not reachable

    def __create_pupil_sub_socket(self, topic=None):
        pupil_socket = self.ctx.socket(zmq.SUB)
        pupil_socket.connect(f'tcp://{self.SERVICE_HOST}:{self.sub_port}')
        pupil_socket.subscribe(topic)
        return pupil_socket

    def __create_pupil_req_socket(self):
        pupil_socket = self.ctx.socket(zmq.REQ)
        pupil_socket.connect(f'tcp://{self.SERVICE_HOST}:{self.SERVICE_PORT}')
        return pupil_socket

    def __get_sub_pub_ports(self):
        pupil_socket = self.__create_pupil_req_socket()

        pupil_socket.send_string('SUB_PORT')
        sub_port = pupil_socket.recv_string()

        pupil_socket.send_string('PUB_PORT')
        pub_port = pupil_socket.recv_string()

        pupil_socket.close()

        return sub_port, pub_port


In [None]:
import pandas as pd
import numpy as np

def dict_to_dataframe(data_dict, col_of_interest):
    
    # Initialize a dictionary for the DataFrame
    df_data = {col: [np.nan] for col in col_of_interest}
    
    # Populate the dictionary with data from data_dict
    for eye_data in data_dict.get('base_data', []):
        eye_id = eye_data['id']
        
        # 2D data
        df_data[f'norm_pos_x_2d_{eye_id}'][0] = eye_data['norm_pos'][0]
        df_data[f'norm_pos_y_2d_{eye_id}'][0] = eye_data['norm_pos'][1]
        df_data[f'diameter_2d_{eye_id}'][0] = eye_data['diameter']
        df_data[f'ellipse_center_x_2d_{eye_id}'][0] = eye_data['ellipse']['center'][0]
        df_data[f'ellipse_center_y_2d_{eye_id}'][0] = eye_data['ellipse']['center'][1]
        df_data[f'ellipse_axis_a_2d_{eye_id}'][0] = eye_data['ellipse']['axes'][0]
        df_data[f'ellipse_axis_b_2d_{eye_id}'][0] = eye_data['ellipse']['axes'][1]
        df_data[f'ellipse_angle_2d_{eye_id}'][0] = eye_data['ellipse']['angle']
        
        # 3D data
        df_data[f'norm_pos_x_3d_{eye_id}'][0] = eye_data['norm_pos'][0]
        df_data[f'norm_pos_y_3d_{eye_id}'][0] = eye_data['norm_pos'][1]
        df_data[f'diameter_3d_{eye_id}'][0] = eye_data['diameter_3d']
        df_data[f'ellipse_center_x_3d_{eye_id}'][0] = eye_data['ellipse']['center'][0]
        df_data[f'ellipse_center_y_3d_{eye_id}'][0] = eye_data['ellipse']['center'][1]
        df_data[f'ellipse_axis_a_3d_{eye_id}'][0] = eye_data['ellipse']['axes'][0]
        df_data[f'ellipse_axis_b_3d_{eye_id}'][0] = eye_data['ellipse']['axes'][1]
        df_data[f'ellipse_angle_3d_{eye_id}'][0] = eye_data['ellipse']['angle']
        
        df_data[f'sphere_center_x_3d_{eye_id}'][0] = eye_data['sphere']['center'][0]
        df_data[f'sphere_center_y_3d_{eye_id}'][0] = eye_data['sphere']['center'][1]
        df_data[f'sphere_center_z_3d_{eye_id}'][0] = eye_data['sphere']['center'][2]
        df_data[f'sphere_radius_3d_{eye_id}'][0] = eye_data['sphere']['radius']
        
        df_data[f'circle_3d_center_x_3d_{eye_id}'][0] = eye_data['circle_3d']['center'][0]
        df_data[f'circle_3d_center_y_3d_{eye_id}'][0] = eye_data['circle_3d']['center'][1]
        df_data[f'circle_3d_center_z_3d_{eye_id}'][0] = eye_data['circle_3d']['center'][2]
        
        df_data[f'circle_3d_normal_x_3d_{eye_id}'][0] = eye_data['circle_3d']['normal'][0]
        df_data[f'circle_3d_normal_y_3d_{eye_id}'][0] = eye_data['circle_3d']['normal'][1]
        df_data[f'circle_3d_normal_z_3d_{eye_id}'][0] = eye_data['circle_3d']['normal'][2]
        df_data[f'circle_3d_radius_3d_{eye_id}'][0] = eye_data['circle_3d']['radius']
        
        df_data[f'theta_3d_{eye_id}'][0] = eye_data['theta']
        df_data[f'phi_3d_{eye_id}'][0] = eye_data['phi']
    
    # Create a DataFrame from the dictionary
    df = pd.DataFrame(df_data)
    return df

In [None]:
import pprint

col_of_interest = [
    'norm_pos_x_2d_0', 'norm_pos_y_2d_0', 'diameter_2d_0', 
    'ellipse_center_x_2d_0', 'ellipse_center_y_2d_0', 'ellipse_axis_a_2d_0', 'ellipse_axis_b_2d_0', 'ellipse_angle_2d_0', 
    'norm_pos_x_2d_1', 'norm_pos_y_2d_1', 'diameter_2d_1', 
    'ellipse_center_x_2d_1', 'ellipse_center_y_2d_1', 'ellipse_axis_a_2d_1', 'ellipse_axis_b_2d_1', 'ellipse_angle_2d_1', 
    'norm_pos_x_3d_0', 'norm_pos_y_3d_0', 'diameter_3d_0', 
    'ellipse_center_x_3d_0', 'ellipse_center_y_3d_0', 'ellipse_axis_a_3d_0', 'ellipse_axis_b_3d_0', 'ellipse_angle_3d_0', 
    'sphere_center_x_3d_0', 'sphere_center_y_3d_0', 'sphere_center_z_3d_0', 'sphere_radius_3d_0', 
    'circle_3d_center_x_3d_0', 'circle_3d_center_y_3d_0', 'circle_3d_center_z_3d_0', 
    'circle_3d_normal_x_3d_0', 'circle_3d_normal_y_3d_0', 'circle_3d_normal_z_3d_0', 
    'circle_3d_radius_3d_0', 'theta_3d_0', 'phi_3d_0', 
    'norm_pos_x_3d_1', 'norm_pos_y_3d_1', 'diameter_3d_1', 
    'ellipse_center_x_3d_1', 'ellipse_center_y_3d_1', 'ellipse_axis_a_3d_1', 'ellipse_axis_b_3d_1', 'ellipse_angle_3d_1', 
    'sphere_center_x_3d_1', 'sphere_center_y_3d_1', 'sphere_center_z_3d_1', 'sphere_radius_3d_1', 
    'circle_3d_center_x_3d_1', 'circle_3d_center_y_3d_1', 'circle_3d_center_z_3d_1', 
    'circle_3d_normal_x_3d_1', 'circle_3d_normal_y_3d_1', 'circle_3d_normal_z_3d_1', 
    'circle_3d_radius_3d_1', 'theta_3d_1', 'phi_3d_1',
]

pupil_cap = PupilLabsController()

gaze_data = pupil_cap.receive_gaze_data(300)
pprint.pprint(gaze_data[0])

# convert the gaze data to a DataFrame
gaze_df = dict_to_dataframe(gaze_data[0], col_of_interest)

for i in range(1, len(gaze_data)):
    gaze_df = pd.concat([gaze_df, dict_to_dataframe(gaze_data[i], col_of_interest)], ignore_index=True)

print(f"Shape of the gaze DataFrame: {gaze_df.shape}")

# Print the number of NaN values in the DataFrame for each column
for col in gaze_df.columns:
    print(f"{gaze_df[col].isna().sum()} : {col}")

# Interpolate the data
gaze_df = gaze_df.replace([np.inf, -np.inf], np.nan)
gaze_df = gaze_df.interpolate(method='linear')
gaze_df = gaze_df.dropna()

# Scale the data
print(gaze_df.values.shape)
gaze_df = scaler.transform(gaze_df.values)

# Convert the DataFrame to a PyTorch tensor
gaze_tensor = torch.tensor(gaze_df, dtype=torch.float32).unsqueeze(0)
gaze_tensor = gaze_tensor.to(device)

# Predict the class
best_model.eval()
with torch.no_grad():
    outputs = best_model(gaze_tensor)
    _, predicted = torch.max(outputs, 1)
    
predicted_class = list(GazeDataset.CLASS_MAP.keys())[predicted.item()]
print(f"Predicted class: {predicted}")
print(f"Predicted class: {predicted_class}")

In [None]:
pupil_cap = PupilLabsController()

for i in range(10):
    pupil_cap.reconnect_sockets()
    gaze_data = pupil_cap.receive_gaze_data(300)
    
    # convert the gaze data to a DataFrame
    gaze_df = dict_to_dataframe(gaze_data[0], col_of_interest)
    
    for i in range(1, len(gaze_data)):
        gaze_df = pd.concat([gaze_df, dict_to_dataframe(gaze_data[i], col_of_interest)], ignore_index=True)

    # Interpolate the data
    gaze_df = gaze_df.replace([np.inf, -np.inf], np.nan)
    gaze_df = gaze_df.interpolate(method='linear')
    gaze_df = gaze_df.dropna()
    
    # Scale the data
    gaze_df = scaler.transform(gaze_df.values)
    
    # Convert the DataFrame to a PyTorch tensor
    gaze_tensor = torch.tensor(gaze_df, dtype=torch.float32).unsqueeze(0)
    gaze_tensor = gaze_tensor.to(device)
    
    # Predict the class
    best_model.eval()
    with torch.no_grad():
        outputs = best_model(gaze_tensor)
        _, predicted = torch.max(outputs, 1)
        
    predicted_class = list(GazeDataset.CLASS_MAP.keys())[predicted.item()]
    print(f"Predicted class: {predicted}")
    print(f"Predicted class: {predicted_class}")
    print()

In [34]:
from pytorch_lightning.callbacks import Callback

class CustomEarlyStopping(Callback):
    def __init__(self, patience, loss_tr):
        super().__init__()
        self.patience = patience
        self.counter = 0
        self.loss_tr = loss_tr

    def on_validation_epoch_end(self, trainer, pl_module):
        # Get the latest validation and training accuracy
        val_loss = trainer.callback_metrics.get('val_loss_epoch')
        train_loss = trainer.callback_metrics.get('train_loss_epoch')

        # Check if val_acc is smaller than train_acc_epoch
        if val_loss is not None and train_loss is not None:
            if val_loss > train_loss + self.loss_tr:
                self.counter += 1
                print(f"Validation accuracy ({val_loss:.4f}) is less than training accuracy ({train_loss:.4f}). Counter: {self.counter}/{self.patience}")
            else:
                self.counter = 0  # Reset counter if condition is not met

            # Stop training if the condition is met for `patience` consecutive epochs
            if self.counter >= self.patience:
                print("Early stopping: Validation accuracy did not improve compared to training accuracy for 3 consecutive epochs.")
                trainer.should_stop = True

In [None]:
from pytorch_lightning.callbacks import EarlyStopping
import wandb
from pytorch_lightning.loggers import WandbLogger
import pytorch_lightning as pl
from datetime import datetime

WANDB_PROJECT = 'gaze_lstm_hp_tuning'

# Initialize WandB
wandb.init(project=WANDB_PROJECT)

# Sweep configuration for hyperparameter tuning
sweep_config = {
    'method': 'random',
    'metric': {
        'name': 'val_acc',
        'goal': 'maximize'
    },
    'parameters': {
        'learning_rate': {
            'values': list(np.linspace(0.00001, 0.0003, 10))
        },
        'batch_size': {
            'values': [128]
        },
        'hidden_size': {
            'values': [8, 16, 32]
        },
        'num_layers': {
            'values': [1, 2, 3]
        },
        'dropout': {
            'values': list(np.linspace(0.2, 0.5, 10))
        },
        'opt_step_size': {
            'values': [20]
        },
        'opt_gamma': {
            'values': [0.8]
        },
        'opt_wd': {
            'values': [1e-3, 1e-4, 1e-5, 1e-6, 1e-7]
        },
        'grad_clip': {
            'values': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
        }
    }
}

HP_NUM_CLASSES = 3
HP_MODEL_TYPE = 'LSTM'
HP_NUM_EPOCHS = 100
HP_NUM_WORKERS = 6
HP_EARLY_STOPPING = 5
HP_LOSS_TR = 0.2

# Sweep ID
sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT)

# Define training function
def train(config=None):
    with wandb.init(config=config):
        config = wandb.config

        # Hyperparameters from WandB config
        learning_rate = config.learning_rate
        batch_size = config.batch_size
        hidden_size = config.hidden_size
        num_layers = config.num_layers
        dropout = config.dropout
        opt_step_size = config.opt_step_size
        opt_gamma = config.opt_gamma
        opt_wd = config.opt_wd
        grad_clip = config.grad_clip

        # Initialize DataModule with current hyperparameters
        data_module = GazeDataModule(
            train_ds, test_ds,
            batch_size=batch_size,
            num_workers=HP_NUM_WORKERS
        )

        # Initialize the model with current hyperparameters
        model = GazeRNN(
            input_size=len(col_of_interest),
            hidden_size=hidden_size, 
            num_layers=num_layers,
            num_classes=HP_NUM_CLASSES, 
            learning_rate=learning_rate,
            rnn_type=HP_MODEL_TYPE,
            opt_step_size=opt_step_size,
            opt_gamma=opt_gamma,
            dropout=dropout,
            opt_wd=opt_wd
        )

        # Checkpoint callback
        current_time = datetime.now().strftime('%Y-%m-%d_%H-%M')

        # Wandb Logger
        wandb_logger = WandbLogger(project=WANDB_PROJECT)
        
        # Early stopping callback
        early_stopping_1 = CustomEarlyStopping(patience=HP_EARLY_STOPPING, loss_tr=HP_LOSS_TR)
        # early_stopping_2 = EarlyStopping(monitor='val_acc', patience=HP_EARLY_STOPPING, mode='max')
        

        # Trainer with Wandb logger and checkpointing
        trainer = pl.Trainer(
            max_epochs=HP_NUM_EPOCHS, 
            accelerator='gpu', 
            enable_progress_bar=True,
            logger=wandb_logger,
            enable_checkpointing=False,
            callbacks=[early_stopping_1],
            gradient_clip_val=grad_clip,
            log_every_n_steps=10,
        )

        # Train the model
        trainer.fit(model, data_module)

# Run the sweep
wandb.agent(sweep_id, train, count=50)