In [4]:
import os
import random
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
import matplotlib.pyplot as plt
import timm
import glob

# ==========================================
# 1. UNIVERSAL TEMPORAL DATASET
# ==========================================
class UniversalTemporalDataset(Dataset):
    """
    A robust dataset class that handles image sequences for Gaze, Face, and Drowsiness.
    It expects a list of (image_path, label) tuples that are ALREADY sorted by time.
    """
    def __init__(self, samples, transform=None, sequence_length=32, stride=8):
        self.transform = transform
        self.sequence_length = sequence_length
        self.sequences = []

        # Build sequences with sliding window
        # We assume 'samples' is a list of (path, label) sorted by time
        for i in range(0, len(samples) - sequence_length + 1, stride):
            window = samples[i : i + sequence_length]

            # extract paths and labels
            paths = [item[0] for item in window]
            labels = [item[1] for item in window]

            # Integrity check: Ensure files exist (optimization: check lazily or first few)
            # For speed, we assume the splitter provided valid paths.

            # Determine label for the sequence
            # For Face: Label is constant.
            # For Gaze/Drowsy: Use Majority Vote or Last Frame.
            # We use Majority Vote here for stability.
            target_label = max(set(labels), key=labels.count)

            self.sequences.append((paths, target_label))

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        img_paths, label = self.sequences[idx]
        frames = []

        for p in img_paths:
            try:
                # Load Image
                img = Image.open(p).convert("RGB")
                if self.transform:
                    img = self.transform(img)
                frames.append(img)
            except Exception as e:
                # Black frame fallback for corruption
                # print(f"Error loading {p}: {e}")
                frames.append(torch.zeros(3, 256, 256))

        # Stack to [T, C, H, W]
        try:
            frames = torch.stack(frames, dim=0)
        except:
            # Fallback if transforms failed completely
            frames = torch.zeros(self.sequence_length, 3, 256, 256)

        return frames, label

# Helper to prepare file lists
def get_face_files_sorted(subject_dir):
    """Returns sorted list of (path, label_id) for a subject"""
    # Assuming subject_dir ends in subject ID (e.g., .../01)
    try:
        subject_id = int(os.path.basename(subject_dir)) - 1 # 0-indexed
    except:
        subject_id = 0

    valid_dir = os.path.join(subject_dir, "Frames_RGB_Valid")
    if not os.path.exists(valid_dir):
        return []

    files = sorted(glob.glob(os.path.join(valid_dir, "*.jpg")))
    return [(f, subject_id) for f in files]

def get_gaze_files_sorted(subject_dir):
    """Reads CSV and returns sorted list of (path, gaze_label)"""
    # Adjust paths based on your structure
    try:
        subj_name = os.path.basename(subject_dir)
        valid_dir = os.path.join(subject_dir, "Frames_RGB_Valid")
        csv_path = os.path.join(subject_dir, f"Valid_gaze_label_{subj_name}.csv")

        if not os.path.exists(csv_path) or not os.path.exists(valid_dir):
            return []

        df = pd.read_csv(csv_path)
        # Ensure mapping matches file structure
        # df usually has 'Frame Index' and 'Gaze Zone'
        samples = []
        for _, row in df.iterrows():
            idx = int(row['Frame Index'])
            label = int(row['Gaze Zone'])
            fname = f"frame_{idx:04d}.jpg"
            fpath = os.path.join(valid_dir, fname)
            if os.path.exists(fpath):
                samples.append((fpath, label))
        return samples
    except:
        return []

def get_drowsy_files_sorted(subject_dir):
    """Reads CSV and returns sorted list of (path, label)"""
    frames_dir = os.path.join(subject_dir, "frames")
    csv_path = os.path.join(frames_dir, "labels.csv")

    if not os.path.exists(csv_path):
        return []

    df = pd.read_csv(csv_path).sort_values('filename')
    samples = []
    for _, row in df.iterrows():
        fpath = os.path.join(frames_dir, row['filename'])
        if os.path.exists(fpath):
            samples.append((fpath, int(row['label'])))
    return samples

# ==========================================
# 2. FULLY TEMPORAL MODEL
# ==========================================
class FullyTemporalMultiTaskModel(nn.Module):
    def __init__(self, backbone_name='mobilevit_s', pretrained=True,
                 gaze_classes=9, drowsy_classes=2, face_classes=15,
                 hidden_dim=256):
        super().__init__()

        # 1. Shared Spatial Backbone
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained, num_classes=0)
        feat_dim = self.backbone.num_features

        # 2. Temporal Processing (LSTM)
        # We use separate LSTMs because the temporal dynamics differ:
        # - Gaze: Saccadic, quick changes
        # - Drowsy: Slow/Patterned (blinking sequence)
        # - Face: Constant (Identity doesn't change), but video helps smooth noise

        self.gaze_lstm = nn.LSTM(feat_dim, hidden_dim, batch_first=True, bidirectional=False)
        self.drowsy_lstm = nn.LSTM(feat_dim, hidden_dim, batch_first=True, bidirectional=True) # BiDir for blink context
        self.face_lstm = nn.LSTM(feat_dim, hidden_dim, batch_first=True, bidirectional=False)

        # 3. Heads
        self.gaze_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128, gaze_classes)
        )

        self.drowsy_head = nn.Sequential(
            nn.LayerNorm(hidden_dim * 2), # *2 for Bidirectional
            nn.Linear(hidden_dim * 2, 64),
            nn.ReLU(),
            nn.Linear(64, drowsy_classes)
        )

        self.face_head = nn.Sequential(
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, 128),
            nn.ReLU(),
            nn.Linear(128, face_classes)
        )

    def extract_spatial(self, x):
        # x: [B, T, C, H, W]
        b, t, c, h, w = x.shape
        x = x.view(b * t, c, h, w)
        feats = self.backbone.forward_features(x) # [B*T, Feat, H_grid, W_grid]
        feats = feats.mean(dim=[2, 3]) # Global Average Pooling -> [B*T, Feat]
        feats = feats.view(b, t, -1)   # Reshape back to sequence -> [B, T, Feat]
        return feats

    def forward(self, x, task):
        # 1. Spatial Features
        feats = self.extract_spatial(x) # [B, T, Feat]

        if task == 'gaze':
            # Run LSTM
            out, (hn, cn) = self.gaze_lstm(feats)
            # Use the feature from the LAST frame in sequence (most recent gaze)
            last_seq_feat = out[:, -1, :]
            return self.gaze_head(last_seq_feat)

        elif task == 'drowsy':
            out, (hn, cn) = self.drowsy_lstm(feats)
            # For drowsiness, the whole pattern matters. We can pool the LSTM output.
            # Mean pooling over time captures the "event" (blink) anywhere in window
            seq_feat = out.mean(dim=1)
            return self.drowsy_head(seq_feat)

        elif task == 'face':
            out, (hn, cn) = self.face_lstm(feats)
            # Face ID is constant. Mean pooling reduces noise/occlusion.
            seq_feat = out.mean(dim=1)
            return self.face_head(seq_feat)

# ==========================================
# 3. MAIN SCRIPT
# ==========================================
class MixedDataset(Dataset):
    def __init__(self, datasets):
        self.data = []
        for task, ds in datasets.items():
            if ds: self.data += [(task, i) for i in range(len(ds))]
        self.datasets = datasets
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        task, i = self.data[idx]
        x, y = self.datasets[task][i]
        return x, y, task

if __name__ == "__main__":
    # Settings
    WINDOW_SIZE = 16
    STRIDE = 8
    BATCH_SIZE = 1 # Lower batch size because [B, 5, 3, 256, 256] is heavy
    EPOCHS = 20
    LR = 1e-4
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

    # Transform
    transform = transforms.Compose([
        transforms.Resize((224, 224)), # Slightly smaller for sequence efficiency
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    print(f"Initializing Sequence Training (Window={WINDOW_SIZE})...")

    # --- DATA PATHS ---
    base_gaze_face = r"C:\Users\shikh\OneDrive - The University of Western Ontario\Desktop\Thesis\MultiTask\Gaze\S6_face_RGB"
    base_drowsy = r"C:\Users\shikh\OneDrive - The University of Western Ontario\Desktop\Thesis\MultiTask\Drowsiness\S5"

    # --- 1. PREPARE DROWSINESS (Subject-wise Split) ---
    # Drowsiness MUST stay subject-wise to avoid learning personal blink patterns
    print("Preparing Drowsiness Data...")
    drowsy_subjs = sorted([d for d in os.listdir(base_drowsy) if d.isdigit()], key=int)
    random.seed(42); random.shuffle(drowsy_subjs)

    split1 = int(0.7 * len(drowsy_subjs))
    split2 = int(0.85 * len(drowsy_subjs))

    drowsy_train_files, drowsy_val_files, drowsy_test_files = [], [], []

    for s in drowsy_subjs[:split1]: drowsy_train_files.extend(get_drowsy_files_sorted(os.path.join(base_drowsy, s)))
    for s in drowsy_subjs[split1:split2]: drowsy_val_files.extend(get_drowsy_files_sorted(os.path.join(base_drowsy, s)))
    for s in drowsy_subjs[split2:]: drowsy_test_files.extend(get_drowsy_files_sorted(os.path.join(base_drowsy, s)))

    # --- 2. PREPARE GAZE (Subject-wise Split) ---
    # Gaze generalizes better with subject-wise split
    print("Preparing Gaze Data...")
    gaze_subjs = sorted([d for d in os.listdir(base_gaze_face) if d.isdigit()], key=int)
    random.seed(42); random.shuffle(gaze_subjs)

    gaze_train_files, gaze_val_files, gaze_test_files = [], [], []

    for s in gaze_subjs[:split1]: gaze_train_files.extend(get_gaze_files_sorted(os.path.join(base_gaze_face, s)))
    for s in gaze_subjs[split1:split2]: gaze_val_files.extend(get_gaze_files_sorted(os.path.join(base_gaze_face, s)))
    for s in gaze_subjs[split2:]: gaze_test_files.extend(get_gaze_files_sorted(os.path.join(base_gaze_face, s)))

    # --- 3. PREPARE FACE (Within-Subject Split) ---
    # Face recognition needs to see the person in Train to recognize in Test
    print("Preparing Face Data (Within-Subject Split)...")
    face_train_files, face_val_files, face_test_files = [], [], []

    # Iterate ALL subjects available in the folder
    all_face_subjs = sorted([d for d in os.listdir(base_gaze_face) if d.isdigit()], key=int)

    for s in all_face_subjs:
        s_path = os.path.join(base_gaze_face, s)
        # Get all files for this person, sorted by time/filename
        person_files = get_face_files_sorted(s_path)

        if len(person_files) < 10: continue # Skip if too few frames

        # Split THIS person's data 70/15/15
        n = len(person_files)
        n_tr = int(n * 0.7)
        n_va = int(n * 0.15)

        face_train_files.extend(person_files[:n_tr])
        face_val_files.extend(person_files[n_tr : n_tr + n_va])
        face_test_files.extend(person_files[n_tr + n_va:])

    # --- BUILD DATASETS ---
    datasets = {
        'train': {
            'gaze': UniversalTemporalDataset(gaze_train_files, transform, WINDOW_SIZE, STRIDE),
            'drowsy': UniversalTemporalDataset(drowsy_train_files, transform, WINDOW_SIZE, STRIDE),
            'face': UniversalTemporalDataset(face_train_files, transform, WINDOW_SIZE, STRIDE)
        },
        'val': {
            'gaze': UniversalTemporalDataset(gaze_val_files, transform, WINDOW_SIZE, STRIDE),
            'drowsy': UniversalTemporalDataset(drowsy_val_files, transform, WINDOW_SIZE, STRIDE),
            'face': UniversalTemporalDataset(face_val_files, transform, WINDOW_SIZE, STRIDE)
        },
        'test': {
            'gaze': UniversalTemporalDataset(gaze_test_files, transform, WINDOW_SIZE, STRIDE),
            'drowsy': UniversalTemporalDataset(drowsy_test_files, transform, WINDOW_SIZE, STRIDE),
            'face': UniversalTemporalDataset(face_test_files, transform, WINDOW_SIZE, STRIDE)
        }
    }

    print(f"\nDataset Sizes (Sequences):")
    print(f"  Train: Gaze={len(datasets['train']['gaze'])}, Drowsy={len(datasets['train']['drowsy'])}, Face={len(datasets['train']['face'])}")
    print(f"  Val:   Gaze={len(datasets['val']['gaze'])}, Drowsy={len(datasets['val']['drowsy'])}, Face={len(datasets['val']['face'])}")

    # --- DATALOADERS ---
    train_loaders = {
        k: DataLoader(v, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
        for k, v in datasets['train'].items() if len(v) > 0
    }

    val_loader = DataLoader(MixedDataset(datasets['val']), batch_size=1, shuffle=False)
    test_loader = DataLoader(MixedDataset(datasets['test']), batch_size=1, shuffle=False)

    # --- MODEL & OPTIMIZER ---
    model = FullyTemporalMultiTaskModel(face_classes=16).to(DEVICE) # Ensure face_classes matches dataset (1-15 -> 16 slots)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
    criterion = nn.CrossEntropyLoss()

    # --- LOGGING SETUP ---
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_acc_gaze': [],
        'val_acc_drowsy': [],
        'val_acc_face': []
    }

    # --- TRAINING LOOP ---
    print("\nStarting Fully Temporal Training...")

    for epoch in range(EPOCHS):
        model.train()
        total_train_loss = 0
        batches = 0

        # Iterate tasks in round-robin or random mix
        active_tasks = list(train_loaders.keys())
        iters = {t: iter(train_loaders[t]) for t in active_tasks}

        # Train for N steps (based on largest dataset)
        steps = max([len(l) for l in train_loaders.values()])

        pbar = tqdm(range(steps), desc=f"Epoch {epoch+1}/{EPOCHS}")

        for _ in pbar:
            step_loss = 0
            optimizer.zero_grad()

            for task in active_tasks:
                try:
                    x, y = next(iters[task])
                except StopIteration:
                    iters[task] = iter(train_loaders[task])
                    x, y = next(iters[task])

                x, y = x.to(DEVICE), y.to(DEVICE)

                out = model(x, task)
                loss = criterion(out, y)
                loss.backward()
                step_loss += loss.item()

            optimizer.step()
            total_train_loss += step_loss / len(active_tasks)
            batches += 1
            pbar.set_postfix({'loss': total_train_loss/batches})

        avg_train_loss = total_train_loss / batches
        history['train_loss'].append(avg_train_loss)

        # --- VALIDATION ---
        model.eval()
        val_loss_sum = 0
        val_samples = 0
        correct = {'gaze':0, 'drowsy':0, 'face':0}
        total = {'gaze':0, 'drowsy':0, 'face':0}

        with torch.no_grad():
            for x, y, task in val_loader:
                # x is [1, T, C, H, W] via MixedDataset batch_size=1
                # But DataLoader adds a batch dim, so if batch_size=1 it matches.
                # If MixedDataset returns tuple, collate might stack them.
                # Here, batch_size=1 in Loader means x is [1, T, C, H, W] directly if collated

                # Handling MixedDataset tuple return:
                # x tuple of tensors if batch > 1. Here batch=1.
                x = x[0].unsqueeze(0).to(DEVICE) # Ensure [1, T, C, H, W]
                y = y[0].unsqueeze(0).to(DEVICE)
                t_name = task[0]

                out = model(x, t_name)
                loss = criterion(out, y)
                val_loss_sum += loss.item()
                val_samples += 1

                pred = out.argmax(1)
                correct[t_name] += (pred == y).sum().item()
                total[t_name] += 1

        avg_val_loss = val_loss_sum / max(1, val_samples)
        history['val_loss'].append(avg_val_loss)

        acc_g = 100 * correct['gaze'] / max(1, total['gaze'])
        acc_d = 100 * correct['drowsy'] / max(1, total['drowsy'])
        acc_f = 100 * correct['face'] / max(1, total['face'])

        history['val_acc_gaze'].append(acc_g)
        history['val_acc_drowsy'].append(acc_d)
        history['val_acc_face'].append(acc_f)

        print(f"Epoch {epoch+1} Results:")
        print(f"  Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")
        print(f"  Val Accuracies: Gaze={acc_g:.2f}% | Drowsy={acc_d:.2f}% | Face={acc_f:.2f}%")

        # --- SAVE PLOTS EVERY EPOCH ---
        plt.figure(figsize=(12, 5))

        # Loss Plot
        plt.subplot(1, 2, 1)
        plt.plot(history['train_loss'], label='Train Loss')
        plt.plot(history['val_loss'], label='Val Loss')
        plt.title('Loss History')
        plt.legend()

        # Accuracy Plot
        plt.subplot(1, 2, 2)
        plt.plot(history['val_acc_gaze'], label='Gaze Acc')
        plt.plot(history['val_acc_drowsy'], label='Drowsy Acc')
        plt.plot(history['val_acc_face'], label='Face Acc')
        plt.title('Validation Accuracy')
        plt.legend()

        plt.tight_layout()
        plt.savefig(f'training_curves_epoch_{epoch+1}.png')
        plt.close()

        # Save Checkpoint
        torch.save(model.state_dict(), f"hybrid_temporal_model_ep{epoch+1}.pth")

    # --- FINAL TEST ---
    print("\nRunning Final Test...")
    model.eval()
    t_correct = {'gaze':0, 'drowsy':0, 'face':0}
    t_total = {'gaze':0, 'drowsy':0, 'face':0}

    with torch.no_grad():
        for x, y, task in tqdm(test_loader, desc="Testing"):
            x = x[0].unsqueeze(0).to(DEVICE)
            y = y[0].unsqueeze(0).to(DEVICE)
            t_name = task[0]
            out = model(x, t_name)
            pred = out.argmax(1)
            t_correct[t_name] += (pred == y).sum().item()
            t_total[t_name] += 1

    print("FINAL TEST RESULTS:")
    for t in ['gaze', 'drowsy', 'face']:
        acc = 100 * t_correct[t] / max(1, t_total[t])
        print(f"  {t.upper()}: {acc:.2f}%")

Initializing Sequence Training (Window=16)...
Preparing Drowsiness Data...
Preparing Gaze Data...
Preparing Face Data (Within-Subject Split)...

Dataset Sizes (Sequences):
  Train: Gaze=2962, Drowsy=6808, Face=3121
  Val:   Gaze=618, Drowsy=1347, Face=667


OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB. GPU 0 has a total capacity of 6.00 GiB of which 0 bytes is free. Of the allocated memory 5.29 GiB is allocated by PyTorch, and 48.43 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)