In [None]:
import os, random, torch, glob
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import timm
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix

# ====== Temporal Dataset Classes ======
class TemporalGazeDataset(Dataset):
    """Gaze dataset with temporal context (sliding window)"""
    def __init__(self, image_dir, label_csv, transform=None, sequence_length=5):
        self.image_dir = image_dir
        self.labels_df = pd.read_csv(label_csv).sort_values('Frame Index')
        self.transform = transform
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.labels_df) - self.sequence_length + 1

    def __getitem__(self, idx):
        frames = []
        labels = []

        for i in range(idx, idx + self.sequence_length):
            frame_index = self.labels_df.iloc[i]['Frame Index']
            gaze_label = int(self.labels_df.iloc[i]['Gaze Zone'])
            image_path = os.path.join(self.image_dir, f"frame_{int(frame_index):04d}.jpg")

            image = Image.open(image_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            frames.append(image)
            labels.append(gaze_label)

        # Stack frames along channel dimension for temporal processing
        frames = torch.stack(frames, dim=0)  # [T, C, H, W]
        # Return middle frame's label as target (or last frame for prediction)
        return frames, labels[self.sequence_length // 2]

class TemporalDrowsinessDataset(Dataset):
    """Drowsiness dataset with temporal context for blinking detection"""
    def __init__(self, subject_dirs, transform=None, sequence_length=5):
        self.transform = transform
        self.sequence_length = sequence_length
        self.samples = []

        for subject_dir in subject_dirs:
            frames_dir = os.path.join(subject_dir, "frames")
            csv_path = os.path.join(frames_dir, "labels.csv")

            if os.path.exists(csv_path):
                df = pd.read_csv(csv_path)
                # Sort by filename to maintain temporal order
                df = df.sort_values('filename')

                for i in range(len(df) - sequence_length + 1):
                    seq_files = []
                    seq_labels = []
                    valid = True

                    for j in range(i, i + sequence_length):
                        img_path = os.path.join(frames_dir, df.iloc[j]['filename'])
                        if os.path.exists(img_path):
                            seq_files.append(img_path)
                            seq_labels.append(df.iloc[j]['label'])
                        else:
                            valid = False
                            break

                    if valid:
                        # Use middle frame's label or majority vote
                        target_label = seq_labels[sequence_length // 2]
                        self.samples.append((seq_files, target_label))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        seq_files, label = self.samples[idx]
        frames = []

        for img_path in seq_files:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            frames.append(image)

        frames = torch.stack(frames, dim=0)  # [T, C, H, W]
        return frames, label

class TemporalFaceDataset(Dataset):
    """Face recognition with temporal context"""
    def __init__(self, samples, transform, sequence_length=5):
        self.transform = transform
        self.sequence_length = sequence_length
        self.temporal_samples = []

        # Group samples by person
        person_samples = {}
        for img_path, label in samples:
            if label not in person_samples:
                person_samples[label] = []
            person_samples[label].append(img_path)

        # Create temporal sequences
        for label, paths in person_samples.items():
            paths.sort()  # Ensure temporal ordering
            for i in range(len(paths) - sequence_length + 1):
                seq = paths[i:i + sequence_length]
                self.temporal_samples.append((seq, label))

    def __len__(self):
        return len(self.temporal_samples)

    def __getitem__(self, idx):
        seq_paths, label = self.temporal_samples[idx]
        frames = []

        for img_path in seq_paths:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            frames.append(image)

        frames = torch.stack(frames, dim=0)  # [T, C, H, W]
        return frames, label

# ====== Temporal Multi-Task Model ======
class TemporalMultiTaskModel(nn.Module):
    def __init__(self, backbone_name='mobilevit_s', pretrained=True,
                 gaze_classes=9, drowsy_classes=2, face_classes=15,
                 sequence_length=5, hidden_dim=256):
        super().__init__()

        # Spatial feature extractor (shared backbone)
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, num_classes=0)
        feat_dim = self.backbone.num_features

        # Temporal processing modules (task-specific)
        self.sequence_length = sequence_length

        # LSTM for temporal modeling (task-specific)
        self.gaze_lstm = nn.LSTM(feat_dim, hidden_dim, num_layers=2,
                                  batch_first=True, dropout=0.3, bidirectional=True)
        self.drowsy_lstm = nn.LSTM(feat_dim, hidden_dim, num_layers=2,
                                    batch_first=True, dropout=0.3, bidirectional=True)
        self.face_lstm = nn.LSTM(feat_dim, hidden_dim, num_layers=2,
                                  batch_first=True, dropout=0.3, bidirectional=True)

        # Temporal attention modules
        self.gaze_attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=4, dropout=0.2)
        self.drowsy_attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=4, dropout=0.2)
        self.face_attention = nn.MultiheadAttention(hidden_dim * 2, num_heads=4, dropout=0.2)

        # Task-specific heads
        self.gaze_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, gaze_classes)
        )

        self.drowsy_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, drowsy_classes)
        )

        self.face_head = nn.Sequential(
            nn.Linear(hidden_dim * 2, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, face_classes)
        )

    def extract_spatial_features(self, x):
        """Extract spatial features for each frame in the sequence"""
        batch_size, seq_len, c, h, w = x.size()
        x = x.view(batch_size * seq_len, c, h, w)

        # Extract features using backbone
        features = self.backbone.forward_features(x)
        features = features.mean(dim=[2, 3])  # Global average pooling

        # Reshape back to sequence
        features = features.view(batch_size, seq_len, -1)
        return features

    def forward(self, x, task):
        # x shape: [batch_size, sequence_length, channels, height, width]

        # Extract spatial features for each frame
        spatial_features = self.extract_spatial_features(x)

        if task == 'gaze':
            # Temporal modeling with LSTM
            lstm_out, _ = self.gaze_lstm(spatial_features)

            # Apply temporal attention
            lstm_out = lstm_out.transpose(0, 1)  # [seq_len, batch, features]
            attn_out, _ = self.gaze_attention(lstm_out, lstm_out, lstm_out)
            attn_out = attn_out.transpose(0, 1)  # [batch, seq_len, features]

            # Aggregate temporal features (use last timestep or mean)
            temporal_features = attn_out.mean(dim=1)  # Mean pooling over time

            return self.gaze_head(temporal_features)

        elif task == 'drowsy':
            # Temporal modeling with LSTM
            lstm_out, _ = self.drowsy_lstm(spatial_features)

            # Apply temporal attention
            lstm_out = lstm_out.transpose(0, 1)
            attn_out, _ = self.drowsy_attention(lstm_out, lstm_out, lstm_out)
            attn_out = attn_out.transpose(0, 1)

            # For blinking detection, we might want to use the full sequence
            temporal_features = attn_out.mean(dim=1)

            return self.drowsy_head(temporal_features)

        else:  # face
            # Temporal modeling with LSTM
            lstm_out, _ = self.face_lstm(spatial_features)

            # Apply temporal attention
            lstm_out = lstm_out.transpose(0, 1)
            attn_out, _ = self.face_attention(lstm_out, lstm_out, lstm_out)
            attn_out = attn_out.transpose(0, 1)

            # Aggregate temporal features
            temporal_features = attn_out.mean(dim=1)

            return self.face_head(temporal_features)

# ====== Temporal Dataset Wrapper for Validation/Test ======
class TemporalMultiTaskDataset(Dataset):
    def __init__(self, datasets):
        self.data = []
        for task, ds in datasets.items():
            self.data += [(task, i) for i in range(len(ds))]
        self.datasets = datasets

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        task, i = self.data[idx]
        x, y = self.datasets[task][i]
        return x, y, task

# ====== Modified PCGrad for Temporal Model ======
class TemporalPCGrad(torch.optim.Optimizer):
    def __init__(self, optimizer):
        self._optim = optimizer

    def zero_grad(self):
        return self._optim.zero_grad()

    def step(self):
        return self._optim.step()

    def pc_backward(self, objectives):
        self._optim.zero_grad()

        # Get parameter groups
        backbone_params = self._optim.param_groups[0]['params']
        lstm_params = self._optim.param_groups[1]['params']
        head_params = self._optim.param_groups[2]['params']

        grads = []
        for i, loss in enumerate(objectives):
            # Clear backbone and LSTM gradients
            for p in backbone_params + lstm_params:
                if p.grad is not None:
                    p.grad = None

            loss.backward(retain_graph=(i < len(objectives) - 1))

            # Collect gradients
            single = []
            for p in backbone_params + lstm_params:
                if p.grad is None:
                    single.append(torch.zeros_like(p))
                else:
                    single.append(p.grad.detach().clone())
            grads.append(single)

        # Apply PCGrad projection
        proj = grads[0]
        for other in grads[1:]:
            for j, (g, g2) in enumerate(zip(proj, other)):
                dot = (g * g2).sum()
                if dot < 0:
                    denom = g2.pow(2).sum() + 1e-12
                    proj[j] = g - (dot / denom) * g2

        # Write back projected gradients
        for p, g in zip(backbone_params + lstm_params, proj):
            p.grad = g

# ====== Training Function for Temporal Model ======
def train_temporal(model, optim, device, criterions, task_probs, loaders, steps_per_epoch, loss_scale=None):
    if loss_scale is None:
        loss_scale = {}

    model.train()
    iters = {t: iter(loaders[t]) for t in loaders}
    total_loss = 0.0
    task_losses = {t: 0.0 for t in loaders}
    counts = {t: 0 for t in loaders}
    corrects = {t: 0 for t in loaders}
    totals = {t: 0 for t in loaders}

    for _ in tqdm(range(steps_per_epoch), desc="Train"):
        objectives, batch_info = [], []
        active_tasks = [t for t, p in task_probs.items() if p > 0]
        random.shuffle(active_tasks)

        for t in active_tasks:
            try:
                x, y = next(iters[t])
            except StopIteration:
                iters[t] = iter(loaders[t])
                x, y = next(iters[t])

            x, y = x.to(device), y.to(device)
            out = model(x, t)
            loss = criterions[t](out, y)

            if t in loss_scale:
                loss = loss * float(loss_scale[t])

            objectives.append(loss)
            batch_info.append((t, out, y))

        if not objectives:
            continue

        optim.zero_grad()
        optim.pc_backward(objectives)
        optim.step()

        with torch.no_grad():
            step_loss = sum(l.item() for l in objectives) / len(objectives)
            total_loss += step_loss

            for (t, out, y), L in zip(batch_info, objectives):
                preds = out.argmax(1)
                corrects[t] += (preds == y).sum().item()
                totals[t] += y.numel()
                task_losses[t] += L.item() * y.size(0)
                counts[t] += y.size(0)

    avg_task_losses = {t: task_losses[t] / max(counts[t], 1) for t in task_losses}
    accs = {t: 100 * corrects[t] / max(totals[t], 1) for t in totals}

    return total_loss / steps_per_epoch, accs, avg_task_losses

# ====== Create Optimizer for Temporal Model ======
def create_temporal_optimizer(model, lr_backbone=1e-4, lr_lstm=1e-3, lr_heads=1e-3, weight_decay=1e-4):
    # Group parameters
    backbone_params = list(model.backbone.parameters())

    lstm_params = []
    for module in [model.gaze_lstm, model.drowsy_lstm, model.face_lstm,
                   model.gaze_attention, model.drowsy_attention, model.face_attention]:
        lstm_params.extend(list(module.parameters()))

    head_params = []
    for head in [model.gaze_head, model.drowsy_head, model.face_head]:
        head_params.extend(list(head.parameters()))

    base_optim = torch.optim.Adam([
        {'params': backbone_params, 'lr': lr_backbone, 'weight_decay': weight_decay},
        {'params': lstm_params, 'lr': lr_lstm, 'weight_decay': weight_decay},
        {'params': head_params, 'lr': lr_heads, 'weight_decay': weight_decay},
    ])

    return TemporalPCGrad(base_optim)

# ====== Main Training Script ======
if __name__ == "__main__":
    import multiprocessing
    multiprocessing.set_start_method('spawn', force=True)

    # Configuration
    SEQUENCE_LENGTH = 5  # Number of frames in temporal window
    output_dir = "temporal_training_results"
    plots_dir = os.path.join(output_dir, "plots")
    os.makedirs(plots_dir, exist_ok=True)

    # Data augmentation
    transform = transforms.Compose([
        transforms.RandomResizedCrop(256, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(p=0.25, scale=(0.02, 0.15))
    ])

    # === Gaze Dataset (Temporal) ===
    gaze_base = "/Users/shikharsrivastava/Desktop/Thesis/Thesis/MultiTask/Gaze/S6_face_RGB"
    subjects = list(range(1, 16))
    random.seed(42)
    random.shuffle(subjects)
    train_subjects, val_subjects, test_subjects = subjects[:11], subjects[11:13], subjects[13:]

    def build_temporal_gaze(subjects):
        datasets = []
        for s in subjects:
            image_dir = os.path.join(gaze_base, str(s), "Frames_RGB_Valid")
            csv_path = os.path.join(gaze_base, str(s), f"Valid_gaze_label_{s}.csv")
            if os.path.exists(image_dir) and os.path.exists(csv_path):
                datasets.append(TemporalGazeDataset(image_dir, csv_path, transform, SEQUENCE_LENGTH))
        return ConcatDataset(datasets) if datasets else None

    train_gaze = build_temporal_gaze(train_subjects)
    val_gaze = build_temporal_gaze(val_subjects)
    test_gaze = build_temporal_gaze(test_subjects)
    print("Temporal Gaze data built")

    # === Drowsiness Dataset (Temporal - New Blinking Detection) ===
    drowsy_base = "/Users/shikharsrivastava/Desktop/Thesis/Thesis/MultiTask/Drowsiness/S5"

    # Get all subject directories
    drowsy_subjects = [os.path.join(drowsy_base, d) for d in os.listdir(drowsy_base)
                      if os.path.isdir(os.path.join(drowsy_base, d)) and d.isdigit()]
    drowsy_subjects.sort()

    # Split subjects
    random.seed(42)
    random.shuffle(drowsy_subjects)
    n_subjects = len(drowsy_subjects)
    train_end = int(n_subjects * 0.7)
    val_end = int(n_subjects * 0.85)

    train_drowsy_subjects = drowsy_subjects[:train_end]
    val_drowsy_subjects = drowsy_subjects[train_end:val_end]
    test_drowsy_subjects = drowsy_subjects[val_end:]

    train_drowsy = TemporalDrowsinessDataset(train_drowsy_subjects, transform, SEQUENCE_LENGTH)
    val_drowsy = TemporalDrowsinessDataset(val_drowsy_subjects, transform, SEQUENCE_LENGTH)
    test_drowsy = TemporalDrowsinessDataset(test_drowsy_subjects, transform, SEQUENCE_LENGTH)
    print(f"Temporal Drowsiness data built: Train={len(train_drowsy)}, Val={len(val_drowsy)}, Test={len(test_drowsy)}")

    # === Face Recognition Dataset (Temporal) ===
    face_root = "/Users/shikharsrivastava/Desktop/Thesis/Thesis/MultiTask/Gaze/S6_face_RGB"

    def build_temporal_face_datasets(face_root, transform, sequence_length):
        samples_train, samples_val, samples_test = [], [], []
        pid_list = [f"{i:02d}" for i in range(1, 16)]
        id_to_label = {pid: i for i, pid in enumerate(pid_list)}

        for pid in pid_list:
            label = id_to_label[pid]
            img_dir = os.path.join(face_root, pid, "Frames_RGB_Valid")
            if not os.path.exists(img_dir):
                continue

            img_paths = [os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith(".jpg")]
            img_paths.sort()

            if len(img_paths) < sequence_length:
                continue

            train_imgs, temp_imgs = train_test_split(img_paths, test_size=0.3, random_state=42)
            val_imgs, test_imgs = train_test_split(temp_imgs, test_size=0.5, random_state=42)

            samples_train.extend([(p, label) for p in train_imgs])
            samples_val.extend([(p, label) for p in val_imgs])
            samples_test.extend([(p, label) for p in test_imgs])

        return (
            TemporalFaceDataset(samples_train, transform, sequence_length),
            TemporalFaceDataset(samples_val, transform, sequence_length),
            TemporalFaceDataset(samples_test, transform, sequence_length)
        )

    train_face, val_face, test_face = build_temporal_face_datasets(face_root, transform, SEQUENCE_LENGTH)
    print("Temporal Face data built")

    # === Create DataLoaders ===
    batch_size = 16  # Reduced due to temporal sequences

    loaders = {
        'gaze': DataLoader(train_gaze, batch_size=batch_size, shuffle=True, num_workers=0),
        'drowsy': DataLoader(train_drowsy, batch_size=batch_size, shuffle=True, num_workers=0),
        'face': DataLoader(train_face, batch_size=batch_size, shuffle=True, num_workers=0),
    }

    val_loader = DataLoader(
        TemporalMultiTaskDataset({'gaze': val_gaze, 'drowsy': val_drowsy, 'face': val_face}),
        batch_size=batch_size
    )

    test_loader = DataLoader(
        TemporalMultiTaskDataset({'gaze': test_gaze, 'drowsy': test_drowsy, 'face': test_face}),
        batch_size=batch_size
    )

    # === Model Setup ===
    device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
    print(f"Using device: {device}")

    model = TemporalMultiTaskModel(
        backbone_name='mobilevit_s',
        pretrained=True,
        gaze_classes=9,
        drowsy_classes=2,  # Binary: blinking or not
        face_classes=15,
        sequence_length=SEQUENCE_LENGTH
    ).to(device)

    criterions = {
        'gaze': nn.CrossEntropyLoss(label_smoothing=0.1),
        'drowsy': nn.CrossEntropyLoss(label_smoothing=0.1),
        'face': nn.CrossEntropyLoss(label_smoothing=0.1),
    }

    # Create optimizer
    optimizer = create_temporal_optimizer(
        model,
        lr_backbone=5e-5,
        lr_lstm=1e-3,
        lr_heads=1e-3,
        weight_decay=1e-4
    )

    scheduler = torch.optim.lr_scheduler.StepLR(optimizer._optim, step_size=10, gamma=0.5)

    # === Training Loop ===
    steps_per_epoch = max(len(loaders[t]) for t in loaders)
    train_losses, val_losses = [], []
    gaze_train_accs, gaze_val_accs = [], []
    drowsy_train_accs, drowsy_val_accs = [], []
    face_train_accs, face_val_accs = [], []
    num_epochs = 30
    gaze_task_losses_list, drowsy_task_losses_list, face_task_losses_list = [], [], []
    best_val_loss = float('inf')
    loss_scale = {'face': 0.25}

    for epoch in range(num_epochs):
        # Staged training strategy
        if epoch < 10:
            # Stage 1: Focus on gaze with temporal features
            task_probs = {'gaze': 1.0, 'drowsy': 0.0, 'face': 0.0}
        elif epoch < 20:
            # Stage 2: Add drowsiness (blinking detection)
            task_probs = {'gaze': 0.5, 'drowsy': 0.5, 'face': 0.0}
        else:
            # Stage 3: All tasks
            task_probs = {'gaze': 0.4, 'drowsy': 0.4, 'face': 0.2}

        print(f"\nEpoch {epoch+1}/{num_epochs} | task_probs={task_probs}")

        # Train
        loss, accs, task_losses = train_temporal(
            model, optimizer, device, criterions,
            task_probs, loaders, steps_per_epoch, loss_scale
        )

        gaze_task_losses_list.append(task_losses['gaze'])
        drowsy_task_losses_list.append(task_losses['drowsy'])
        face_task_losses_list.append(task_losses['face'])

        # Evaluate
        def evaluate_temporal(model, loader, device, criterions):
            model.eval()
            preds = {t: [] for t in ['gaze', 'drowsy', 'face']}
            targets = {t: [] for t in ['gaze', 'drowsy', 'face']}

            with torch.no_grad():
                for x, y, task in loader:
                    x, y = x.to(device), y.to(device)
                    for i, t in enumerate(task):
                        out = model(x[i].unsqueeze(0), t)
                        pred = out.argmax(dim=1).item()
                        preds[t].append(pred)
                        targets[t].append(y[i].item())

            return {t: accuracy_score(targets[t], preds[t]) * 100 for t in preds}

        def compute_val_loss_temporal(model, loader, device, criterions):
            model.eval()
            total_loss, total_samples = 0, 0
            with torch.no_grad():
                for x, y, task in loader:
                    x, y = x.to(device), y.to(device)
                    for i, t in enumerate(task):
                        out = model(x[i].unsqueeze(0), t)
                        loss = criterions[t](out, y[i].unsqueeze(0))
                        total_loss += loss.item()
                        total_samples += 1
            return total_loss / max(total_samples, 1)

        val_accs = evaluate_temporal(model, val_loader, device, criterions)
        val_loss = compute_val_loss_temporal(model, val_loader, device, criterions)

        scheduler.step()

        print(f"Epoch {epoch+1} TrainLoss {loss:.4f} "
              f"(Gaze {task_losses['gaze']:.4f} Drowsy {task_losses['drowsy']:.4f} Face {task_losses['face']:.4f}) "
              f"ValLoss {val_loss:.4f} "
              f"ValAcc Gaze {val_accs['gaze']:.2f}% Drowsy {val_accs['drowsy']:.2f}% Face {val_accs['face']:.2f}% "
              f"LR {optimizer._optim.param_groups[0]['lr']:.2e}")

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "best_temporal_model.pth")
            print(f"*** Best model updated at Epoch {epoch+1} with Val Loss {val_loss:.4f} ***")

        # Store metrics
        train_losses.append(loss)
        val_losses.append(val_loss)
        gaze_train_accs.append(accs['gaze'])
        gaze_val_accs.append(val_accs['gaze'])
        drowsy_train_accs.append(accs['drowsy'])
        drowsy_val_accs.append(val_accs['drowsy'])
        face_train_accs.append(accs['face'])
        face_val_accs.append(val_accs['face'])

    # === Test with best model ===
    print("\nTesting with best temporal model...")
    model.load_state_dict(torch.load("best_temporal_model.pth"))
    test_accs = evaluate_temporal(model, test_loader, device, criterions)
    print("Test Accuracies:", test_accs)

    # === Save Results ===
    print(f"\n{'=' * 60}")
    print("ðŸŽ‰ TEMPORAL TRAINING COMPLETE! ðŸŽ‰")
    print(f"{'=' * 60}")
    print(f"Test Results:")
    print(f"  Gaze Detection:      {test_accs['gaze']:.2f}%")
    print(f"  Blinking Detection:  {test_accs['drowsy']:.2f}%")
    print(f"  Face Recognition:    {test_accs['face']:.2f}%")
    print(f"Best model saved: best_temporal_model.pth")
    print(f"{'=' * 60}")

  from .autonotebook import tqdm as notebook_tqdm


Temporal Gaze data built
Temporal Drowsiness data built: Train=54438, Val=10782, Test=16176
Temporal Face data built
Using device: mps

Epoch 1/30 | task_probs={'gaze': 1.0, 'drowsy': 0.0, 'face': 0.0}


Train:   6%|â–Œ         | 194/3403 [03:39<1:00:46,  1.14s/it]