In [None]:
import torch
from torch.utils.data import DataLoader,Dataset
import lightning as L
import pickle
import numpy as np
from tqdm import tqdm


In [None]:


class DGazeDataset(Dataset):
    def __init__(self, driver_data, drivers, sequences, transform=False):
        """
        Hybrid Dataset:
        - Lazy loads frames (low memory usage).
        - Computes global normalization stats for eye images & facial features.
        - Normalizes gaze points to [0,1].
        """

        self.driver_data = driver_data
        self.drivers = drivers
        self.sequences = sequences
        self.transform = transform

        self.index = []  # list of (driver, seq_key, frame_idx)

        # Pre-build index list
        for driver in drivers:
            data = driver_data[driver]
            for seq in tqdm(sequences, desc=f"Indexing driver {driver}"):
                seq_key = f"seq{seq}"
                if seq_key in data:
                    num_frames = len(data[seq_key]["left_eye"])
                    for frame_idx in range(num_frames):
                        self.index.append((driver, seq_key, frame_idx))

        print(f"Lazy dataset ready! Total samples: {len(self.index)}")

        # === Compute global normalization stats ===
        print("Computing global stats for normalization...")
        all_features = []
        all_pixels = []

        for driver, seq_key, frame_idx in tqdm(self.index, desc="Scanning data"):
            data_seq = driver_data[driver][seq_key]

            # collect eye pixels
            eye_img = data_seq["left_eye"][frame_idx].astype(np.float32)
            all_pixels.append(eye_img.reshape(-1, 3))  # flatten to (N,3)

            # collect facial features
            headpose = data_seq["headpose_pupil"][frame_idx, 1:].astype(np.float32)
            face_loc = data_seq["face_location"][frame_idx].astype(np.float32)
            all_features.append(np.concatenate((headpose, face_loc)))

        all_pixels = np.vstack(all_pixels)
        all_features = np.vstack(all_features)

        # Eye image stats (per channel)
        self.eye_mean = all_pixels.mean(axis=0)
        self.eye_std = all_pixels.std(axis=0)

        # Facial features stats
        self.feat_mean = all_features.mean(axis=0)
        self.feat_std = all_features.std(axis=0)

        print("Normalization stats ready!")

    def __len__(self):
        return len(self.index)

    def __getitem__(self, idx):
        driver, seq_key, frame_idx = self.index[idx]
        data_seq = self.driver_data[driver][seq_key]

        # --- Load & normalize left eye ---
        left_eye = data_seq["left_eye"][frame_idx].astype(np.float32)
        left_eye = (left_eye - self.eye_mean) / (self.eye_std + 1e-6)  # normalize
        left_eye = np.transpose(left_eye, (2, 0, 1))  # (C,H,W)
        left_eye = torch.tensor(left_eye, dtype=torch.float32)

        # --- Load & normalize facial features ---
        headpose = data_seq["headpose_pupil"][frame_idx, 1:].astype(np.float32)
        face_loc = data_seq["face_location"][frame_idx].astype(np.float32)
        facial = np.concatenate((headpose, face_loc))
        facial = (facial - self.feat_mean) / (self.feat_std + 1e-6)
        facial_features = torch.tensor(facial, dtype=torch.float32)

        # --- Load gaze point & normalize to [0,1] ---
        gaze_point = data_seq["gaze_point"][frame_idx, :2].astype(np.float32).copy()
        gaze_point[0] = np.clip(gaze_point[0], 0, 1919) / 1920.0
        gaze_point[1] = np.clip(gaze_point[1], 0, 1079) / 1080.0
        gaze_point = torch.tensor(gaze_point, dtype=torch.float32)

        # --- Optional transforms ---
        if self.transform:
            # Add torchvision transforms here if needed
            pass

        return left_eye, facial_features, gaze_point


In [None]:

class DGazeDataModule(L.LightningDataModule):
    def __init__(self, data_path, split_path, batch_size=64, num_workers=4, transform=False):
        super().__init__()
        self.data_path = data_path
        self.split_path = split_path
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform

        self.driver_data = None
        self.data_split = None
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self):
        # Load once (not distributed)
        with open(self.data_path, "rb") as f:
            self.driver_data = pickle.load(f)
        with open(self.split_path, "rb") as f:
            self.data_split = pickle.load(f)

    def setup(self, stage=None):
        # Build datasets only once
        if self.train_dataset is None:
            self.train_dataset = DGazeDataset(
                self.driver_data,
                self.data_split["drivers_train"],
                self.data_split["sequence_train"],
                transform=self.transform
            )

        if self.val_dataset is None:
            self.val_dataset = DGazeDataset(
                self.driver_data,
                self.data_split["drivers_val"],
                self.data_split["sequence_val"],
                transform=self.transform
            )

        if self.test_dataset is None:
            self.test_dataset = DGazeDataset(
                self.driver_data,
                self.data_split["drivers_test"],
                self.data_split["sequence_test"],
                transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size,
                          shuffle=True, num_workers=self.num_workers, pin_memory=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          shuffle=False, num_workers=self.num_workers, pin_memory=True)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size,
                          shuffle=False, num_workers=self.num_workers, pin_memory=True)
