## IMPORTS

In [1]:
import os
import torch
import numpy as np
import xml.etree.ElementTree as ET
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt
import re
from tqdm import tqdm
import cv2
import torch.nn as nn
import torch.optim as optim
import time
import random
import torchvision.models as models
import math
from sklearn.metrics import roc_curve, auc
import bisect

## DATASET

In [2]:
    xml_path = "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/annotations.xml"
    image_folder = "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/images"

In [3]:
def parse_xml(xml_path):
    """
    Parse the XML file and extract bounding boxes and polylines for each frame.
    
    Args:
        xml_path (str): Path to the XML file.
    
    Returns:
        frames (dict): Dictionary containing frame data (bounding boxes and polylines).
    """
    try:
        tree = ET.parse(xml_path)
        root = tree.getroot()
        frames = {}
        
        print(f"Parsing XML annotations from {xml_path}")
        print(f"Root tag: {root.tag}, with {len(root)} child elements")
        
        frame_count = 0
        for image in tqdm(root.findall("image"), desc="Parsing frames"):
            try:
                frame_id = int(image.attrib["id"])  # Extract frame ID
                frame_name = image.attrib["name"]  # Extract frame name
                width = int(image.attrib["width"])  # Image width
                height = int(image.attrib["height"])  # Image height
                
                # Initialize lists for bounding boxes and polylines
                frame_boxes = []
                frame_polylines = []
                
                # Extract bounding boxes
                for box in image.findall("box"):
                    try:
                        # Extract all box attributes
                        box_info = {
                            "label": box.attrib.get("label", "unknown"),
                            "xtl": float(box.attrib.get("xtl", 0)),
                            "ytl": float(box.attrib.get("ytl", 0)),
                            "xbr": float(box.attrib.get("xbr", 0)),
                            "ybr": float(box.attrib.get("ybr", 0)),
                        }
                        frame_boxes.append(box_info)
                    except Exception as box_err:
                        print(f"Error parsing box in frame {frame_id}: {box_err}")
                
                # Extract polylines
                for polyline in image.findall("polyline"):
                    try:
                        polyline_info = {
                            "label": polyline.attrib.get("label", "unknown"),
                            "points": polyline.attrib.get("points", "")
                        }
                        frame_polylines.append(polyline_info)
                    except Exception as polyline_err:
                        print(f"Error parsing polyline in frame {frame_id}: {polyline_err}")
                
                # Store frame information
                frames[frame_id] = {
                    "name": frame_name,
                    "width": width,
                    "height": height,
                    "boxes": frame_boxes,
                    "polylines": frame_polylines
                }
                frame_count += 1
                
                # Debug first frame
                if frame_count == 1:
                    print(f"Sample frame: ID={frame_id}, Name={frame_name}, Size={width}x{height}")
                    print(f"Found {len(frame_boxes)} boxes and {len(frame_polylines)} polylines in first frame")
                    if frame_boxes:
                        print(f"Sample box labels: {[box['label'] for box in frame_boxes[:5]]}")
                    if frame_polylines:
                        print(f"Sample polyline labels: {[p['label'] for p in frame_polylines[:5]]}")
            except Exception as frame_err:
                print(f"Error parsing frame: {frame_err}")
                continue
        
        print(f"Successfully parsed {len(frames)} frames")
        return frames
    
    except Exception as e:
        print(f"Error parsing XML file: {e}")
        import traceback
        traceback.print_exc()
        return {}

def create_frame_to_image_mapping(image_folder):
    """
    Create a mapping from frame IDs to image paths.
    
    Args:
        image_folder (str): Path to the folder containing images.
    
    Returns:
        frame_to_image (dict): Mapping from frame IDs to image paths.
    """
    frame_to_image = {}
    
    if not os.path.exists(image_folder):
        print(f"Warning: Image folder {image_folder} does not exist")
        return frame_to_image
    
    for image_name in os.listdir(image_folder):
        if not any(image_name.lower().endswith(ext) for ext in ['.jpg', '.jpeg', '.png', '.bmp']):
            continue
            
        # Try various patterns to extract frame ID
        # Pattern 1: frame_000123.jpg/png
        match = re.search(r'frame_0*(\d+)', image_name.lower())
        if match:
            frame_id = int(match.group(1))
            frame_to_image[frame_id] = os.path.join(image_folder, image_name)
            continue
        
        # Pattern 2: 000123.jpg
        match = re.search(r'^0*(\d+)', image_name)
        if match:
            frame_id = int(match.group(1))
            frame_to_image[frame_id] = os.path.join(image_folder, image_name)
            continue
        
        # Pattern 3: Any number in the filename
        match = re.search(r'(\d+)', image_name)
        if match:
            frame_id = int(match.group(1))
            frame_to_image[frame_id] = os.path.join(image_folder, image_name)
    
    print(f"Found {len(frame_to_image)} images with extractable frame IDs")
    return frame_to_image

class GESCAMCustomDataset(Dataset):
    """
    Dataset class for GESCAM (Gaze Estimation based Synthetic Classroom Attention Measurement)
    Customized for the specific annotation format
    """
    def __init__(self, xml_path, image_folder, transform=None, head_transform=None, 
                 input_size=224, output_size=64, test=False):
        """
        Args:
            xml_path (str): Path to the XML annotation file
            image_folder (str): Path to the folder containing images
            transform: Transformations to apply to the scene image
            head_transform: Transformations to apply to the head crop
            input_size: Input image size for the model
            output_size: Output heatmap size
            test: Whether this is a test dataset
        """
        super(GESCAMCustomDataset, self).__init__()
        
        self.xml_path = xml_path
        self.image_folder = image_folder
        self.transform = transform
        self.head_transform = head_transform if head_transform else transform
        self.input_size = input_size
        self.output_size = output_size
        self.test = test
        
        # Parse annotations and create image mapping
        self.frames = parse_xml(xml_path)
        self.frame_to_image = create_frame_to_image_mapping(image_folder)
        
        # Create samples
        self.samples = self._create_samples()
        print(f"Created dataset with {len(self.samples)} samples")
        
    def _match_person_to_sight_line(self, person_box, polylines):
        """
        Match a person bounding box to the corresponding line of sight polyline
        
        Args:
            person_box: Dictionary containing person bounding box
            polylines: List of polyline dictionaries for the frame
            
        Returns:
            target_point: (x,y) tuple of gaze target or None if no match
            has_target: Boolean indicating if a match was found
        """
        # Find polylines labeled as "line of sight"
        sight_lines = [p for p in polylines if p["label"].lower() == "line of sight"]
        
        if not sight_lines:
            return None, False
        
        # Calculate person box center
        person_center_x = (person_box["xtl"] + person_box["xbr"]) / 2
        person_center_y = (person_box["ytl"] + person_box["ybr"]) / 2
        person_width = person_box["xbr"] - person_box["xtl"]
        
        # Find closest matching sight line
        best_match = None
        best_distance = float('inf')
        
        for polyline in sight_lines:
            points_str = polyline["points"]
            try:
                # Parse points from string format "x1,y1;x2,y2;..."
                points = [tuple(map(float, point.split(","))) for point in points_str.split(";")]
                
                if len(points) >= 2:  # Need at least start and end point
                    start_x, start_y = points[0]
                    end_x, end_y = points[-1]
                    
                    # Calculate distance from polyline start to person center
                    distance = np.sqrt((start_x - person_center_x)**2 + (start_y - person_center_y)**2)
                    
                    # Check if this is a good match (close to person center)
                    if distance < best_distance and distance < person_width * 1.5:
                        best_distance = distance
                        best_match = (end_x, end_y)  # Use end point as gaze target
            except Exception as e:
                # Print details for debugging
                print(f"Error parsing polyline points: {e}, points_str: {points_str}")
                continue
        
        return best_match, best_match is not None
        
    def _create_samples(self):
        """
        Create dataset samples from parsed frames
        
        Returns:
            samples: List of sample dictionaries
        """
        samples = []
        frames_with_persons = 0
        frames_with_sight_lines = 0
        
        for frame_id, frame_data in self.frames.items():
            # Skip frames without matching images
            if frame_id not in self.frame_to_image:
                continue
                
            image_path = self.frame_to_image[frame_id]
            width, height = frame_data["width"], frame_data["height"]
            
            # Check if there are person boxes in this frame
            person_boxes = [box for box in frame_data["boxes"] if "person" in box["label"].lower()]
            if person_boxes:
                frames_with_persons += 1
            
            # Check if there are line of sight polylines
            sight_lines = [p for p in frame_data["polylines"] if p["label"].lower() == "line of sight"]
            if sight_lines:
                frames_with_sight_lines += 1
            
            # Process each person box
            for person_box in person_boxes:
                # Find matching sight line
                gaze_target, has_target = self._match_person_to_sight_line(person_box, frame_data["polylines"])
                
                # Create sample
                sample = {
                    "frame_id": frame_id,
                    "image_path": image_path,
                    "width": width,
                    "height": height,
                    "head_bbox": [person_box["xtl"], person_box["ytl"], person_box["xbr"], person_box["ybr"]],
                    "gaze_target": gaze_target,
                    "in_frame": has_target
                }
                
                samples.append(sample)
        
        print(f"Statistics: {frames_with_persons} frames with person boxes, {frames_with_sight_lines} frames with sight lines")
        return samples
    
    def _create_head_position_channel(self, head_bbox, width, height):
        """
        Create a binary mask for head position
        """
        x1, y1, x2, y2 = head_bbox
        head_mask = torch.zeros(height, width)
        x1, y1, x2, y2 = int(max(0, x1)), int(max(0, y1)), int(min(width, x2)), int(min(height, y2))
        head_mask[y1:y2, x1:x2] = 1.0
        return head_mask
    
    def _create_gaze_heatmap(self, gaze_target, width, height):
        """
        Create a Gaussian heatmap at the gaze point
        """
        if not gaze_target:
            return torch.zeros(self.output_size, self.output_size)
        
        x, y = gaze_target
        
        # Scale coordinates to output size
        x = x * self.output_size / width
        y = y * self.output_size / height
        
        # Create meshgrid
        Y, X = torch.meshgrid(torch.arange(self.output_size), torch.arange(self.output_size), indexing='ij')
        
        # Create Gaussian heatmap
        sigma = 3.0
        heatmap = torch.exp(-((X - x) ** 2 + (Y - y) ** 2) / (2 * sigma ** 2))
        
        return heatmap
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample = self.samples[idx]
        
        # Load image
        try:
            img = Image.open(sample["image_path"]).convert('RGB')
        except Exception as e:
            print(f"Error loading image {sample['image_path']}: {e}")
            # Return a placeholder if image can't be loaded
            img = Image.new('RGB', (self.input_size, self.input_size), color='gray')
            
        width, height = sample["width"], sample["height"]
        
        # Extract head crop
        head_bbox = sample["head_bbox"]
        x1, y1, x2, y2 = head_bbox
        
        # Ensure bbox is within image bounds
        x1 = max(0, min(width-1, x1))
        y1 = max(0, min(height-1, y1))
        x2 = max(x1+1, min(width, x2))
        y2 = max(y1+1, min(height, y2))
        
        try:
            head_img = img.crop((int(x1), int(y1), int(x2), int(y2)))
        except Exception as e:
            print(f"Error cropping head: {e}, bbox: {head_bbox}, image size: {img.size}")
            head_img = Image.new('RGB', (100, 100), color='gray')
        
        # Create head position channel
        head_pos = self._create_head_position_channel(head_bbox, width, height)
        
        # Create gaze heatmap
        if sample["in_frame"] and sample["gaze_target"]:
            gaze_target = sample["gaze_target"]
            gaze_heatmap = self._create_gaze_heatmap(gaze_target, width, height)
            
            # Calculate gaze vector (from head center to gaze point)
            head_center_x = (x1 + x2) / 2 / width
            head_center_y = (y1 + y2) / 2 / height
            gaze_x = gaze_target[0] / width
            gaze_y = gaze_target[1] / height
            gaze_vector = torch.tensor([gaze_x - head_center_x, gaze_y - head_center_y])
        else:
            gaze_heatmap = torch.zeros(self.output_size, self.output_size)
            gaze_vector = torch.tensor([0.0, 0.0])  # Default for out-of-frame
        
        # Apply transformations
        if self.transform:
            img = self.transform(img)
                
        if self.head_transform:
            head_img = self.head_transform(head_img)
        
        # Resize head position to match input size
        head_pos = head_pos.unsqueeze(0)
        head_pos = F.interpolate(head_pos.unsqueeze(0), size=(self.input_size, self.input_size), 
                                 mode='nearest').squeeze(0)
        
        in_frame = torch.tensor([float(sample["in_frame"])])
        
        # For compatibility with existing code
        object_label = torch.tensor([0])  # placeholder
        
        # Instead of returning frame_id as last element (which might cause issues with batching),
        # return a metadata dictionary alongside the tensors
        metadata = {
            "frame_id": sample["frame_id"],
            "image_path": sample["image_path"],
            "head_bbox": sample["head_bbox"],
            "original_size": (width, height)
        }
        
        return img, head_img, head_pos, gaze_heatmap, in_frame, object_label, gaze_vector, metadata


def get_transforms(input_size=224, augment=True):
    """
    Get data transformations for training and validation
    """
    normalize = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
    
    if augment:
        # Training transforms with augmentation
        transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3),
            transforms.ToTensor(),
            normalize
        ])
    else:
        # Validation/test transforms without augmentation
        transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            normalize
        ])
    
    return transform


def visualize_sample(sample, save_path=None):
    """
    Visualize a dataset sample
    
    Args:
        sample: Tuple of tensors from dataset __getitem__
        save_path: Path to save visualization (if None, displays inline)
    """
    img, head_img, head_pos, gaze_heatmap, in_frame, object_label, gaze_vector, metadata = sample
    
    # Denormalize image
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    
    img_vis = img.clone()
    img_vis = img_vis * std + mean
    img_vis = img_vis.permute(1, 2, 0).numpy()
    img_vis = np.clip(img_vis, 0, 1)
    
    head_img_vis = head_img.clone()
    head_img_vis = head_img_vis * std + mean
    head_img_vis = head_img_vis.permute(1, 2, 0).numpy()
    head_img_vis = np.clip(head_img_vis, 0, 1)
    
    # Create figure
    plt.figure(figsize=(15, 10))
    
    # Create a 2x3 grid
    plt.subplot(2, 3, 1)
    plt.imshow(img_vis)
    plt.title(f"Frame ID: {metadata['frame_id']}")
    plt.axis('off')
    
    plt.subplot(2, 3, 2)
    plt.imshow(head_img_vis)
    plt.title("Head/Person Crop")
    plt.axis('off')
    
    plt.subplot(2, 3, 3)
    plt.imshow(head_pos.squeeze().numpy(), cmap='gray')
    plt.title("Head Position Channel")
    plt.axis('off')
    
    plt.subplot(2, 3, 4)
    plt.imshow(img_vis)
    # Draw the head bounding box
    x1, y1, x2, y2 = metadata['head_bbox']
    head_width = x2 - x1
    head_height = y2 - y1
    
    # Calculate scale factors for drawing on the resized image
    h, w = img_vis.shape[:2]
    orig_w, orig_h = metadata['original_size']
    scale_x, scale_y = w/orig_w, h/orig_h
    
    # Draw scaled bounding box
    rect_x = x1 * scale_x
    rect_y = y1 * scale_y
    rect_w = head_width * scale_x
    rect_h = head_height * scale_y
    
    plt.gca().add_patch(plt.Rectangle((rect_x, rect_y), rect_w, rect_h, 
                                     fill=False, edgecolor='green', linewidth=2))
    
    # Draw gaze vector if in frame
    if in_frame.item():
        # Calculate center of head
        head_center_x = (rect_x + rect_x + rect_w) / 2
        head_center_y = (rect_y + rect_y + rect_h) / 2
        
        # Scale gaze vector for visualization
        scale = max(w, h) / 4
        gaze_end_x = head_center_x + gaze_vector[0].item() * scale
        gaze_end_y = head_center_y + gaze_vector[1].item() * scale
        
        plt.arrow(head_center_x, head_center_y, 
                 gaze_end_x - head_center_x, gaze_end_y - head_center_y, 
                 color='red', width=2, head_width=10)
        
    plt.title("Bounding Box & Gaze Vector")
    plt.axis('off')
    
    plt.subplot(2, 3, 5)
    plt.imshow(gaze_heatmap.numpy(), cmap='jet')
    plt.title(f"Gaze Heatmap (In-frame: {bool(in_frame.item())})")
    plt.axis('off')
    
    plt.subplot(2, 3, 6)
    # Original image with heatmap overlay
    plt.imshow(img_vis)
    
    # Resize heatmap to match image size for overlay
    heatmap_vis = gaze_heatmap.numpy()
    heatmap_vis = cv2.resize(heatmap_vis, (w, h))
    
    # Only show heatmap if gaze is in frame
    if in_frame.item():
        plt.imshow(heatmap_vis, cmap='jet', alpha=0.5)
    
    plt.title("Heatmap Overlay")
    plt.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()


def test_dataset(xml_path, image_folder):
    """
    Test the dataset with visualization
    
    Args:
        xml_path: Path to the XML annotation file
        image_folder: Path to the folder with images
    """
    # Create transforms
    transform = get_transforms(augment=False)
    
    # Create dataset
    dataset = GESCAMCustomDataset(
        xml_path=xml_path,
        image_folder=image_folder,
        transform=transform
    )
    
    # Check dataset size
    print(f"\nDataset contains {len(dataset)} samples")
    
    # If dataset has samples, visualize some
    if len(dataset) > 0:
        print("\nVisualizing samples:")
        num_samples = min(3, len(dataset))
        for i in range(num_samples):
            # Get a sample
            sample_idx = i
            sample = dataset[sample_idx]
            
            # Visualize
            save_path = f"sample_{i}.png"
            visualize_sample(sample, save_path)
            print(f"Sample {i} visualization saved to {save_path}")
    else:
        print("No samples to visualize!")
    
    return dataset


# Example usage
if __name__ == "__main__":
    # Example paths (replace with actual paths)
    xml_path = xml_path
    image_folder = image_folder
    
    # Test the dataset
    dataset = test_dataset(xml_path, image_folder)
    
    # Create DataLoader if we have samples
    if len(dataset) > 0:
        batch_size = 4
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        
        # Test the DataLoader by fetching a batch
        for batch in dataloader:
            print(f"Successfully loaded a batch of size {len(batch[0])}")
            break

Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/annotations.xml
Root tag: annotations, with 307 child elements


Parsing frames: 100%|██████████| 305/305 [00:00<00:00, 9487.97it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 56 boxes and 14 polylines in first frame
Sample box labels: ['person1', 'person2', 'person3', 'person4', 'person5']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 305 frames
Found 305 images with extractable frame IDs





Statistics: 305 frames with person boxes, 305 frames with sight lines
Created dataset with 4575 samples

Dataset contains 4575 samples

Visualizing samples:
Sample 0 visualization saved to sample_0.png
Sample 1 visualization saved to sample_1.png
Sample 2 visualization saved to sample_2.png
Successfully loaded a batch of size 4


In [None]:
def main():
    # Set paths to your data
    xml_path = "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/annotations.xml"
    image_folder = "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/images"
    output_dir = "/kaggle/working"
    
    # Create transforms
    transform = get_transforms(augment=False)
    
    print("Creating dataset...")
    # Create dataset with the customized class
    dataset = GESCAMCustomDataset(
        xml_path=xml_path,
        image_folder=image_folder,
        transform=transform
    )
    
    # If we have samples, create a DataLoader and visualize
    if len(dataset) > 0:
        # Create DataLoader
        batch_size = 4
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
        print(f"Created DataLoader with batch size {batch_size}")
        
        # Visualize some samples
        print("Visualizing samples...")
        num_samples = min(5, len(dataset))
        for i in range(num_samples):
            # Choose random sample for variety
            sample_idx = np.random.randint(0, len(dataset))
            sample = dataset[sample_idx]
            
            # Visualize
            save_path = os.path.join(output_dir, f"sample_{i}.png")
            visualize_sample(sample, save_path)
            print(f"Sample {i} visualization saved to {save_path}")
        
        # Create a video visualization
        create_visualization_video(dataset, os.path.join(output_dir, "visualization.mp4"), 
                                  num_samples=min(30, len(dataset)), fps=2)
    else:
        print("No samples found in the dataset!")
    
    print("Done!")

def create_visualization_video(dataset, output_video_path, num_samples=20, fps=5):
    """
    Create a video visualizing dataset samples
    
    Args:
        dataset: Dataset instance
        output_video_path: Path to save the video
        num_samples: Number of samples to include
        fps: Frames per second
    """
    if len(dataset) == 0:
        print("Cannot create video with empty dataset")
        return
    
    print(f"Creating visualization video with {num_samples} samples...")
    
    # Create a temporary directory for frames
    temp_dir = "temp_viz_frames"
    os.makedirs(temp_dir, exist_ok=True)
    
    # Get evenly distributed sample indices
    indices = np.linspace(0, len(dataset)-1, num_samples).astype(int)
    
    # Visualize each sample
    for i, idx in enumerate(tqdm(indices, desc="Generating frames")):
        sample = dataset[idx]
        
        # Save visualization to temp file
        temp_path = os.path.join(temp_dir, f"frame_{i:04d}.png")
        visualize_sample(sample, temp_path)
    
    # Get size of the first frame to set video dimensions
    first_frame = cv2.imread(os.path.join(temp_dir, "frame_0000.png"))
    height, width, _ = first_frame.shape
    
    # Create video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
    
    # Add frames to video
    for i in range(len(indices)):
        frame_path = os.path.join(temp_dir, f"frame_{i:04d}.png")
        frame = cv2.imread(frame_path)
        video_writer.write(frame)
    
    # Release video writer
    video_writer.release()
    
    # Clean up temporary files
    for i in range(len(indices)):
        frame_path = os.path.join(temp_dir, f"frame_{i:04d}.png")
        if os.path.exists(frame_path):
            os.remove(frame_path)
    if os.path.exists(temp_dir):
        os.rmdir(temp_dir)
    
    print(f"Visualization video saved to {output_video_path}")

if __name__ == "__main__":
    main()

## COMBINING MULTIPLE DATA FOLDERS

In [4]:
def combine_datasets(xml_paths, image_folders, transform):
    """
    Combine multiple datasets into one
    
    Args:
        xml_paths: List of paths to XML annotation files
        image_folders: List of paths to image folders
        transform: Transformations to apply
        
    Returns:
        combined_dataset: Combined dataset
    """
    all_datasets = []
    
    for xml_path, image_folder in zip(xml_paths, image_folders):
        dataset = GESCAMCustomDataset(
            xml_path=xml_path,
            image_folder=image_folder,
            transform=transform
        )
        all_datasets.append(dataset)
    
    # Create a simple wrapper dataset class
    class CombinedDataset(torch.utils.data.Dataset):
        def __init__(self, datasets):
            self.datasets = datasets
            self.lengths = [len(d) for d in datasets]
            self.cumulative_lengths = [0]
            
            for length in self.lengths:
                self.cumulative_lengths.append(self.cumulative_lengths[-1] + length)
            
        def __len__(self):
            return self.cumulative_lengths[-1]
        
        def __getitem__(self, idx):
            # Find which dataset this index belongs to
            dataset_idx = bisect.bisect_right(self.cumulative_lengths, idx) - 1
            sample_idx = idx - self.cumulative_lengths[dataset_idx]
            return self.datasets[dataset_idx][sample_idx]
    
    return CombinedDataset(all_datasets)

In [5]:
# Define all your data paths
xml_paths = [
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/annotations.xml",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video02(0-300)/annotations.xml",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video02(301-600)/annotations.xml",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video03_final/annotations.xml",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video04_final/annotations.xml",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video05_final/annotations.xml"
    
]

image_folders = [
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/images",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video02(0-300)/images",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video02(301-600)/images",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video03_final/images",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video04_final/images",
    "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video05_final/images"
    
]

# Create transforms
transform = get_transforms(augment=True)

# Combine datasets
combined_dataset = combine_datasets(xml_paths, image_folders, transform)

# Rest of your code remains the same, just use combined_dataset instead of dataset
# train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size], generator=generator)

Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/annotations.xml
Root tag: annotations, with 307 child elements


Parsing frames: 100%|██████████| 305/305 [00:00<00:00, 8482.84it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 56 boxes and 14 polylines in first frame
Sample box labels: ['person1', 'person2', 'person3', 'person4', 'person5']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 305 frames
Found 305 images with extractable frame IDs





Statistics: 305 frames with person boxes, 305 frames with sight lines
Created dataset with 4575 samples
Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video02(0-300)/annotations.xml
Root tag: annotations, with 308 child elements


Parsing frames: 100%|██████████| 306/306 [00:00<00:00, 9355.32it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 52 boxes and 14 polylines in first frame
Sample box labels: ['person1', 'person2', 'person3', 'person4', 'person5']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 306 frames
Found 306 images with extractable frame IDs





Statistics: 306 frames with person boxes, 306 frames with sight lines
Created dataset with 4284 samples
Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video02(301-600)/annotations.xml
Root tag: annotations, with 307 child elements


Parsing frames: 100%|██████████| 305/305 [00:00<00:00, 9357.63it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 49 boxes and 14 polylines in first frame
Sample box labels: ['Water Dispenser', 'monitor', 'monitor', 'monitor', 'monitor']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 305 frames
Found 305 images with extractable frame IDs





Statistics: 305 frames with person boxes, 305 frames with sight lines
Created dataset with 4555 samples
Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video03_final/annotations.xml
Root tag: annotations, with 603 child elements


Parsing frames: 100%|██████████| 601/601 [00:00<00:00, 8558.99it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 59 boxes and 14 polylines in first frame
Sample box labels: ['person1', 'person2', 'person3', 'person4', 'person5']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 601 frames
Found 601 images with extractable frame IDs





Statistics: 601 frames with person boxes, 601 frames with sight lines
Created dataset with 9015 samples
Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video04_final/annotations.xml
Root tag: annotations, with 602 child elements


Parsing frames: 100%|██████████| 600/600 [00:00<00:00, 10901.19it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 45 boxes and 13 polylines in first frame
Sample box labels: ['person1', 'person2', 'person3', 'person4', 'person5']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 600 frames
Found 600 images with extractable frame IDs





Statistics: 600 frames with person boxes, 600 frames with sight lines
Created dataset with 8400 samples
Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video05_final/annotations.xml
Root tag: annotations, with 602 child elements


Parsing frames: 100%|██████████| 600/600 [00:00<00:00, 13419.98it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 35 boxes and 12 polylines in first frame
Sample box labels: ['person1', 'person2', 'person3', 'person4', 'person5']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 600 frames
Found 600 images with extractable frame IDs





Statistics: 600 frames with person boxes, 600 frames with sight lines
Created dataset with 7260 samples


## MODEL ARCHITECTURE

In [6]:
class SoftAttention(nn.Module):
    """
    Soft attention module for attending to scene features based on head features
    """
    def __init__(self, head_channels=256, output_size=(7, 7)):
        super(SoftAttention, self).__init__()
        self.output_h, self.output_w = output_size
        
        # Attention layers
        self.attention = nn.Sequential(
            nn.Linear(head_channels, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, self.output_h * self.output_w),
            nn.Sigmoid()
        )
    
    def forward(self, head_features):
        # Input head_features shape: [batch_size, head_channels]
        batch_size = head_features.size(0)
        
        # Generate attention weights
        attn_weights = self.attention(head_features)
        
        # Reshape to spatial attention map
        attn_weights = attn_weights.view(batch_size, 1, self.output_h, self.output_w)
        
        return attn_weights


class MSGESCAMModel(nn.Module):
    """
    Multi-Stream GESCAM architecture for gaze estimation in classroom settings
    """
    def __init__(self, pretrained=True, output_size=64):
        super(MSGESCAMModel, self).__init__()
        
        # Store the output size
        self.output_size = output_size
        
        # Feature dimensions
        self.backbone_dim = 512  # ResNet18 outputs 512 feature channels
        self.feature_dim = 256
        
        # Downsampled feature map size
        self.map_size = 7  # ResNet outputs 7x7 feature maps
        
        # === Scene Pathway ===
        # Load a pre-trained ResNet18 without the final layer
        self.scene_backbone = models.resnet18(pretrained=pretrained)
        
        # Save the original conv1 weights
        original_conv1_weight = self.scene_backbone.conv1.weight.clone()
        
        # Create a new conv1 layer that accepts 4 channels (RGB + head position)
        self.scene_backbone.conv1 = nn.Conv2d(
            4, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        
        # Initialize with the pre-trained weights
        with torch.no_grad():
            self.scene_backbone.conv1.weight[:, :3] = original_conv1_weight
            # Initialize the new channel with small random values
            self.scene_backbone.conv1.weight[:, 3] = 0.01 * torch.randn_like(self.scene_backbone.conv1.weight[:, 0])
        
        # Remove the final FC layer from the scene backbone
        self.scene_features = nn.Sequential(*list(self.scene_backbone.children())[:-1])
        
        # Add a FC layer to transform from backbone_dim to feature_dim
        self.scene_fc = nn.Linear(self.backbone_dim, self.feature_dim)
        
        # === Head Pathway ===
        # Load another pre-trained ResNet18 for the head pathway
        self.head_backbone = models.resnet18(pretrained=pretrained)
        
        # Remove the final FC layer from the head backbone
        self.head_features = nn.Sequential(*list(self.head_backbone.children())[:-1])
        
        # Add a FC layer to transform from backbone_dim to feature_dim
        self.head_fc = nn.Linear(self.backbone_dim, self.feature_dim)
        
        # === Objects Mask Enhancement (optional) ===
        # This takes an objects mask (with channels for different object classes)
        self.objects_conv = nn.Conv2d(11, 64, kernel_size=3, stride=2, padding=1)  # 11 object categories
        
        # Soft attention mechanism
        self.attention = SoftAttention(head_channels=self.feature_dim, output_size=(self.map_size, self.map_size))
        
        # === Fusion and Encoding ===
        # Fusion of attended scene features and head features
        self.encode = nn.Sequential(
            nn.Conv2d(self.backbone_dim + self.feature_dim, self.feature_dim, kernel_size=1),
            nn.BatchNorm2d(self.feature_dim),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.feature_dim, self.feature_dim, kernel_size=3, padding=1),
            nn.BatchNorm2d(self.feature_dim),
            nn.ReLU(inplace=True)
        )
        
        # Calculate the number of deconvolution layers needed
        # Each layer doubles the size, so we need log2(output_size / 7) layers
        self.num_deconv_layers = max(1, int(math.log2(output_size / 7)) + 1)
        
        # === Decoding for heatmap generation ===
        deconv_layers = []
        in_channels = self.feature_dim
        out_size = self.map_size
        
        # Create deconvolution layers
        for i in range(self.num_deconv_layers - 1):
            # Calculate output channels
            out_channels = max(32, in_channels // 2)
            
            # Add deconv layer
            deconv_layers.extend([
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.BatchNorm2d(out_channels),
                nn.ReLU(inplace=True)
            ])
            
            in_channels = out_channels
            out_size *= 2
        
        # Final layer to adjust to exact output size
        if out_size != output_size:
            # Add a final layer with correct stride to reach exactly output_size
            scale_factor = output_size / out_size
            stride = 2 if scale_factor > 1 else 1
            output_padding = 1 if scale_factor > 1 else 0
            
            deconv_layers.extend([
                nn.ConvTranspose2d(
                    in_channels, 1, kernel_size=3, 
                    stride=stride, padding=1, output_padding=output_padding
                )
            ])
        else:
            # If we're already at the right size, just add a 1x1 conv
            deconv_layers.append(nn.Conv2d(in_channels, 1, kernel_size=1))
        
        self.decode = nn.Sequential(*deconv_layers)
        
        # === In-frame probability prediction ===
        self.in_frame_fc = nn.Sequential(
            nn.Linear(self.feature_dim, 128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(128, 1)
        )
    
    def forward(self, scene_img, head_img, head_pos, objects_mask=None):
        """
        Forward pass through the MS-GESCAM network
        
        Args:
            scene_img: Scene image tensor [batch_size, 3, H, W]
            head_img: Head crop tensor [batch_size, 3, H, W]
            head_pos: Head position mask [batch_size, 1, H, W]
            objects_mask: Optional mask of object categories [batch_size, num_categories, H, W]
            
        Returns:
            heatmap: Predicted gaze heatmap [batch_size, 1, output_size, output_size]
            in_frame: Probability of gaze target being in frame [batch_size, 1]
        """
        batch_size = scene_img.size(0)
        
        # === Process scene pathway ===
        # Concatenate scene image and head position channel
        scene_input = torch.cat([scene_img, head_pos], dim=1)
        
        # Process through ResNet layers until layer4 (skipping the final global pooling and FC)
        x = self.scene_backbone.conv1(scene_input)
        x = self.scene_backbone.bn1(x)
        x = self.scene_backbone.relu(x)
        x = self.scene_backbone.maxpool(x)
        
        x = self.scene_backbone.layer1(x)
        x = self.scene_backbone.layer2(x)
        x = self.scene_backbone.layer3(x)
        scene_features_map = self.scene_backbone.layer4(x)  # [batch_size, 512, 7, 7]
        
        # Global average pooling for scene features
        scene_vector = F.adaptive_avg_pool2d(scene_features_map, (1, 1)).view(batch_size, -1)
        scene_features = self.scene_fc(scene_vector)  # [batch_size, feature_dim]
        
        # === Process head pathway ===
        # Process through the entire head features extractor
        head_vector = self.head_features(head_img).view(batch_size, -1)  # [batch_size, 512]
        head_features = self.head_fc(head_vector)  # [batch_size, feature_dim]
        
        # Process objects mask if provided
        if objects_mask is not None:
            obj_features = self.objects_conv(objects_mask)
            # Resize to match scene features map if needed
            if obj_features.size(2) != scene_features_map.size(2):
                obj_features = F.adaptive_avg_pool2d(
                    obj_features, (scene_features_map.size(2), scene_features_map.size(3))
                )
            # Add object features to scene features
            scene_features_map = scene_features_map + obj_features
        
        # Generate attention map from head features
        attn_weights = self.attention(head_features)  # [batch_size, 1, 7, 7]
        
        # Apply attention to scene features map
        attended_scene = scene_features_map * attn_weights  # [batch_size, 512, 7, 7]
        
        # Reshape head features to concatenate with scene features
        head_features_map = head_features.view(batch_size, self.feature_dim, 1, 1)
        head_features_map = head_features_map.expand(-1, -1, self.map_size, self.map_size)
        
        # Concatenate attended scene features and head features
        concat_features = torch.cat([attended_scene, head_features_map], dim=1)  # [batch_size, 512+256, 7, 7]
        
        # Encode the concatenated features
        encoded = self.encode(concat_features)  # [batch_size, 256, 7, 7]
        
        # Predict in-frame probability
        in_frame = self.in_frame_fc(head_features)
        
        # Decode to get the final heatmap
        heatmap = self.decode(encoded)
        
        # Ensure output size is correct
        if heatmap.size(2) != self.output_size or heatmap.size(3) != self.output_size:
            heatmap = F.interpolate(
                heatmap, 
                size=(self.output_size, self.output_size), 
                mode='bilinear', 
                align_corners=True
            )
        
        # Apply sigmoid to get values between 0 and 1
        heatmap = torch.sigmoid(heatmap)
        
        return heatmap, in_frame


class CombinedLoss(nn.Module):
    """
    Combined loss function for gaze estimation
    """
    def __init__(self, heatmap_weight=1.0, in_frame_weight=1.0, angular_weight=0.5):
        super(CombinedLoss, self).__init__()
        self.heatmap_weight = heatmap_weight
        self.in_frame_weight = in_frame_weight
        self.angular_weight = angular_weight
        
        self.mse_loss = nn.MSELoss(reduction='none')
        self.bce_loss = nn.BCEWithLogitsLoss()
    
    def forward(self, pred_heatmap, target_heatmap, pred_in_frame, target_in_frame, 
                pred_vector=None, target_vector=None):
        """
        Combined loss function
        
        Args:
            pred_heatmap: Predicted gaze heatmap [batch_size, 1, H, W]
            target_heatmap: Target gaze heatmap [batch_size, H, W]
            pred_in_frame: Predicted in-frame probability [batch_size, 1]
            target_in_frame: Target in-frame label [batch_size, 1]
            pred_vector: Optional predicted gaze vector [batch_size, 2]
            target_vector: Optional target gaze vector [batch_size, 2]
            
        Returns:
            total_loss: Combined loss
            loss_dict: Dictionary with individual loss components
        """
        batch_size = pred_heatmap.size(0)
        
        # Check and fix size mismatches
        if pred_heatmap.size(-1) != target_heatmap.size(-1) or pred_heatmap.size(-2) != target_heatmap.size(-2):
            # Resize prediction to match target
            pred_heatmap = F.interpolate(
                pred_heatmap,
                size=(target_heatmap.size(1), target_heatmap.size(2)),
                mode='bilinear',
                align_corners=True
            )
        
        # Reshape heatmaps if needed
        if pred_heatmap.size(1) == 1:
            pred_heatmap = pred_heatmap.squeeze(1)
        
        # Heatmap loss (MSE) - only for in-frame samples
        heatmap_loss = self.mse_loss(pred_heatmap, target_heatmap)
        heatmap_loss = heatmap_loss.mean(dim=(1, 2))  # Average over spatial dimensions
        
        # Apply in-frame masking
        masked_heatmap_loss = heatmap_loss * target_in_frame.squeeze()
        
        # Average over valid samples
        num_valid = max(1, target_in_frame.sum().item())  # Avoid division by zero
        heatmap_loss = masked_heatmap_loss.sum() / num_valid
        
        # In-frame prediction loss (BCE)
        in_frame_loss = self.bce_loss(pred_in_frame, target_in_frame)
        
        # Angular loss (if vectors are provided)
        angular_loss = torch.tensor(0.0, device=pred_heatmap.device)
        if pred_vector is not None and target_vector is not None:
            # Compute cosine similarity between predicted and target vectors
            target_in_frame_bool = target_in_frame.squeeze().bool()
            
            if target_in_frame_bool.sum() > 0:
                # Only compute for in-frame samples
                pred_vec_valid = pred_vector[target_in_frame_bool]
                target_vec_valid = target_vector[target_in_frame_bool]
                
                # Normalize vectors
                pred_norm = torch.norm(pred_vec_valid, dim=1, keepdim=True)
                target_norm = torch.norm(target_vec_valid, dim=1, keepdim=True)
                
                # Avoid division by zero
                pred_norm = torch.clamp(pred_norm, min=1e-7)
                target_norm = torch.clamp(target_norm, min=1e-7)
                
                pred_vec_norm = pred_vec_valid / pred_norm
                target_vec_norm = target_vec_valid / target_norm
                
                # Cosine similarity (dot product of normalized vectors)
                cos_sim = torch.sum(pred_vec_norm * target_vec_norm, dim=1)
                
                # Angular loss = 1 - cos_sim
                angular_loss = 1.0 - cos_sim.mean()
        
        # Combine losses
        total_loss = (
            self.heatmap_weight * heatmap_loss + 
            self.in_frame_weight * in_frame_loss + 
            self.angular_weight * angular_loss
        )
        
        # Create loss dictionary for logging
        loss_dict = {
            'total_loss': total_loss.item(),
            'heatmap_loss': heatmap_loss.item(),
            'in_frame_loss': in_frame_loss.item(),
            'angular_loss': angular_loss.item() if isinstance(angular_loss, torch.Tensor) else angular_loss
        }
        
        return total_loss, loss_dict


# Example usage to test the model
def test_model():
    # Create model
    model = MSGESCAMModel(pretrained=True, output_size=64)
    
    # Create sample batch
    batch_size = 2
    scene_img = torch.randn(batch_size, 3, 224, 224)
    head_img = torch.randn(batch_size, 3, 224, 224)
    head_pos = torch.randn(batch_size, 1, 224, 224)
    
    # Forward pass
    pred_heatmap, pred_in_frame = model(scene_img, head_img, head_pos)
    
    print(f"Model test successful")
    print(f"Predicted heatmap shape: {pred_heatmap.shape}")
    print(f"Predicted in-frame shape: {pred_in_frame.shape}")
    
    return model

if __name__ == "__main__":
    model = test_model()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 164MB/s] 


Model test successful
Predicted heatmap shape: torch.Size([2, 1, 64, 64])
Predicted in-frame shape: torch.Size([2, 1])


In [7]:
# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Early stopping class
class EarlyStopping:
    """Early stopping to prevent overfitting"""
    def __init__(self, patience=5, min_delta=0, verbose=True):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False
        
    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    train_losses = []
    
    train_pbar = tqdm(train_loader, desc="Training")
    for batch in train_pbar:
        # Unpack the batch
        scene_img, head_img, head_pos, target_heatmap, target_in_frame, _, target_vector, _ = batch
        
        # Move to device
        scene_img = scene_img.to(device)
        head_img = head_img.to(device)
        head_pos = head_pos.to(device)
        target_heatmap = target_heatmap.to(device)
        target_in_frame = target_in_frame.to(device)
        target_vector = target_vector.to(device)
        
        # Forward pass
        pred_heatmap, pred_in_frame = model(scene_img, head_img, head_pos)
        
        # Compute loss
        loss, loss_dict = criterion(
            pred_heatmap, target_heatmap, 
            pred_in_frame, target_in_frame, 
            None, target_vector
        )
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Track metrics
        train_losses.append(loss_dict)
        train_pbar.set_postfix(loss=f"{loss.item():.4f}")
    
    # Calculate metrics
    avg_loss = sum(d['total_loss'] for d in train_losses) / len(train_losses)
    avg_heatmap_loss = sum(d['heatmap_loss'] for d in train_losses) / len(train_losses)
    avg_in_frame_loss = sum(d['in_frame_loss'] for d in train_losses) / len(train_losses)
    
    return {
        'loss': avg_loss,
        'heatmap_loss': avg_heatmap_loss,
        'in_frame_loss': avg_in_frame_loss
    }

def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    val_losses = []
    
    val_pbar = tqdm(val_loader, desc="Validation")
    with torch.no_grad():
        for batch in val_pbar:
            # Unpack the batch
            scene_img, head_img, head_pos, target_heatmap, target_in_frame, _, target_vector, _ = batch
            
            # Move to device
            scene_img = scene_img.to(device)
            head_img = head_img.to(device)
            head_pos = head_pos.to(device)
            target_heatmap = target_heatmap.to(device)
            target_in_frame = target_in_frame.to(device)
            target_vector = target_vector.to(device)
            
            # Forward pass
            pred_heatmap, pred_in_frame = model(scene_img, head_img, head_pos)
            
            # Compute loss
            loss, loss_dict = criterion(
                pred_heatmap, target_heatmap, 
                pred_in_frame, target_in_frame, 
                None, target_vector
            )
            
            # Track metrics
            val_losses.append(loss_dict)
            val_pbar.set_postfix(loss=f"{loss.item():.4f}")
    
    # Calculate metrics
    avg_loss = sum(d['total_loss'] for d in val_losses) / len(val_losses)
    avg_heatmap_loss = sum(d['heatmap_loss'] for d in val_losses) / len(val_losses)
    avg_in_frame_loss = sum(d['in_frame_loss'] for d in val_losses) / len(val_losses)
    
    return {
        'loss': avg_loss,
        'heatmap_loss': avg_heatmap_loss,
        'in_frame_loss': avg_in_frame_loss
    }

def visualize_predictions(model, dataset, device, indices=None, num_samples=5, save_dir=None):
    """Visualize model predictions"""
    model.eval()
    
    if indices is None:
        # Choose random samples
        indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    with torch.no_grad():
        for i, idx in enumerate(indices):
            # Get a sample
            sample = dataset[idx]
            scene_img, head_img, head_pos, target_heatmap, target_in_frame, _, target_vector, metadata = sample
            
            # Add batch dimension
            scene_img = scene_img.unsqueeze(0).to(device)
            head_img = head_img.unsqueeze(0).to(device)
            head_pos = head_pos.unsqueeze(0).to(device)
            
            # Forward pass
            pred_heatmap, pred_in_frame = model(scene_img, head_img, head_pos)
            
            # Convert predictions to numpy
            pred_heatmap = pred_heatmap.squeeze().cpu().numpy()
            pred_in_frame_prob = torch.sigmoid(pred_in_frame).squeeze().cpu().numpy()
            
            # Create figure
            plt.figure(figsize=(15, 10))
            
            # Denormalize image
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            
            img_vis = scene_img.squeeze().cpu()
            img_vis = img_vis * std + mean
            img_vis = img_vis.permute(1, 2, 0).numpy()
            img_vis = np.clip(img_vis, 0, 1)
            
            # Original image
            plt.subplot(2, 3, 1)
            plt.imshow(img_vis)
            plt.title(f"Frame {metadata['frame_id']}")
            plt.axis('off')
            
            # Head crop
            head_img_vis = head_img.squeeze().cpu()
            head_img_vis = head_img_vis * std + mean
            head_img_vis = head_img_vis.permute(1, 2, 0).numpy()
            head_img_vis = np.clip(head_img_vis, 0, 1)
            
            plt.subplot(2, 3, 2)
            plt.imshow(head_img_vis)
            plt.title("Head Crop")
            plt.axis('off')
            
            # Ground truth heatmap
            plt.subplot(2, 3, 3)
            plt.imshow(target_heatmap.numpy(), cmap='jet')
            plt.title(f"GT Heatmap (In-frame: {bool(target_in_frame.item())})")
            plt.axis('off')
            
            # Predicted heatmap
            plt.subplot(2, 3, 4)
            plt.imshow(pred_heatmap, cmap='jet')
            plt.title(f"Pred Heatmap (In-frame: {pred_in_frame_prob:.2f})")
            plt.axis('off')
            
            # Overlay on original image
            plt.subplot(2, 3, 5)
            plt.imshow(img_vis)
            plt.imshow(pred_heatmap, cmap='jet', alpha=0.5)
            plt.title("Prediction Overlay")
            plt.axis('off')
            
            # Error visualization
            plt.subplot(2, 3, 6)
            error_map = np.abs(pred_heatmap - target_heatmap.numpy())
            plt.imshow(error_map, cmap='hot')
            plt.title("Prediction Error")
            plt.axis('off')
            
            plt.tight_layout()
            
            # Save or display
            if save_dir:
                os.makedirs(save_dir, exist_ok=True)
                plt.savefig(os.path.join(save_dir, f"pred_{i}_sample_{idx}.png"))
                plt.close()
            else:
                plt.show()

def plot_training_history(history, save_path=None):
    """Plot training history"""
    plt.figure(figsize=(12, 5))
    
    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Total Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot component losses
    plt.subplot(1, 2, 2)
    plt.plot(history['train_heatmap_loss'], label='Train Heatmap Loss')
    plt.plot(history['val_heatmap_loss'], label='Val Heatmap Loss')
    plt.plot(history['train_in_frame_loss'], label='Train In-Frame Loss')
    plt.plot(history['val_in_frame_loss'], label='Val In-Frame Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Component Losses')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

def test_model():
    """Simple test to verify the model architecture"""
    # Create model
    model = MSGESCAMModel(pretrained=True, output_size=64)
    
    # Create sample batch
    batch_size = 2
    scene_img = torch.randn(batch_size, 3, 224, 224)
    head_img = torch.randn(batch_size, 3, 224, 224)
    head_pos = torch.randn(batch_size, 1, 224, 224)
    
    # Forward pass
    pred_heatmap, pred_in_frame = model(scene_img, head_img, head_pos)
    
    print(f"Model test successful")
    print(f"Predicted heatmap shape: {pred_heatmap.shape}")
    print(f"Predicted in-frame shape: {pred_in_frame.shape}")

# Training function
def train_gescam_model(xml_path, image_folder, output_dir, batch_size=8, epochs=20, 
                      lr=1e-4, val_split=0.2, seed=42):
    """
    Train the MS-GESCAM model
    
    Args:
        xml_path: Path to XML annotations
        image_folder: Path to image folder
        output_dir: Output directory
        batch_size: Batch size
        epochs: Number of epochs
        lr: Learning rate
        val_split: Validation split ratio
        seed: Random seed
    """
    # Set random seed
    set_seed(seed)
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create transforms
    transform = get_transforms(augment=True)
    val_transform = get_transforms(augment=False)
    
    # Load dataset
    print("Loading dataset...")
    full_dataset = GESCAMCustomDataset(
        xml_path=xml_path,
        image_folder=image_folder,
        transform=transform
    )
    
    # Split dataset
    # val_size = int(val_split * len(full_dataset))
    # train_size = len(full_dataset) - val_size

    val_size = int(val_split * len(combined_dataset))
    train_size = len(combined_dataset) - val_size
    
    # Use different random seeds for train and validation
    generator = torch.Generator().manual_seed(seed)
    #train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size], generator=generator)
    train_dataset, val_dataset = random_split(combined_dataset, [train_size, val_size], generator=generator)
    
    # Create separate validation dataset with different transforms
    val_dataset_with_transform = GESCAMCustomDataset(
        xml_path=xml_path,
        image_folder=image_folder,
        transform=val_transform
    )
    
    # Use the same indices for validation
    val_indices = [i for i in range(len(full_dataset)) if i not in train_dataset.indices]
    
    # Create a subset for validation with the correct indices
    class IndexSubset:
        def __init__(self, dataset, indices):
            self.dataset = dataset
            self.indices = indices
        
        def __getitem__(self, idx):
            return self.dataset[self.indices[idx]]
        
        def __len__(self):
            return len(self.indices)
    
    val_dataset = IndexSubset(val_dataset_with_transform, val_indices)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=2,
        pin_memory=True if torch.cuda.is_available() else False
    )
    
    print(f"Dataset split: {train_size} train, {val_size} validation")
    
    # Create model
    print("Creating model...")
    model = MSGESCAMModel(pretrained=True, output_size=64)
    model = model.to(device)
    
    # Create loss function
    criterion = CombinedLoss(
        heatmap_weight=1.0, 
        in_frame_weight=1.0, 
        angular_weight=0.5
    )
    
    # Create optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Create learning rate scheduler
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        mode='min', 
        factor=0.5, 
        patience=3, 
        verbose=True
    )
    
    # Create early stopping
    early_stopping = EarlyStopping(patience=7, verbose=True)
    
    # Training history
    history = {
        'train_loss': [], 
        'val_loss': [],
        'train_heatmap_loss': [], 
        'val_heatmap_loss': [],
        'train_in_frame_loss': [], 
        'val_in_frame_loss': []
    }
    
    # Best model state
    best_val_loss = float('inf')
    
    # Train model
    print(f"Training for {epochs} epochs...")
    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")
        
        # Train
        train_metrics = train_one_epoch(model, train_loader, criterion, optimizer, device)
        print(f"Train loss: {train_metrics['loss']:.4f}, "
              f"Heatmap loss: {train_metrics['heatmap_loss']:.4f}, "
              f"In-frame loss: {train_metrics['in_frame_loss']:.4f}")
        
        # Validate
        val_metrics = validate(model, val_loader, criterion, device)
        print(f"Val loss: {val_metrics['loss']:.4f}, "
              f"Heatmap loss: {val_metrics['heatmap_loss']:.4f}, "
              f"In-frame loss: {val_metrics['in_frame_loss']:.4f}")
        
        # Update history
        history['train_loss'].append(train_metrics['loss'])
        history['val_loss'].append(val_metrics['loss'])
        history['train_heatmap_loss'].append(train_metrics['heatmap_loss'])
        history['val_heatmap_loss'].append(val_metrics['heatmap_loss'])
        history['train_in_frame_loss'].append(train_metrics['in_frame_loss'])
        history['val_in_frame_loss'].append(val_metrics['in_frame_loss'])
        
        # Step learning rate scheduler
        scheduler.step(val_metrics['loss'])
        
        # Save checkpoint
        checkpoint_path = os.path.join(output_dir, f'checkpoint_epoch_{epoch+1}.pt')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'train_loss': train_metrics['loss'],
            'val_loss': val_metrics['loss'],
            'history': history
        }, checkpoint_path)
        
        # Save best model
        if val_metrics['loss'] < best_val_loss:
            best_val_loss = val_metrics['loss']
            best_model_path = os.path.join(output_dir, 'best_model.pt')
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_loss': val_metrics['loss']
            }, best_model_path)
            print(f"Saved new best model with validation loss: {best_val_loss:.4f}")
        
        # Check early stopping
        early_stopping(val_metrics['loss'])
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break
    
    # Plot training history
    history_path = os.path.join(output_dir, 'training_history.png')
    plot_training_history(history, history_path)
    print(f"Training history plot saved to {history_path}")
    
    # Load best model for final evaluation
    checkpoint = torch.load(os.path.join(output_dir, 'best_model.pt'))
    model.load_state_dict(checkpoint['model_state_dict'])
    
    # Visualize predictions on validation set
    vis_dir = os.path.join(output_dir, 'visualizations')
    os.makedirs(vis_dir, exist_ok=True)
    print(f"Generating visualizations in {vis_dir}...")
    
    # Choose a few random samples from validation set
    val_samples = np.random.choice(len(val_dataset), min(10, len(val_dataset)), replace=False)
    visualize_predictions(model, val_dataset, device, indices=val_samples, save_dir=vis_dir)
    
    print("Training complete!")
    
    return model, history

# Main execution
if __name__ == "__main__":
    # For Kaggle, set your paths here
    xml_path = "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/annotations.xml"
    image_folder = "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/images"
    output_dir = "/kaggle/working/gescam_output"
    
    # Optional: Just test the model architecture
    # test_model()
    
    # Train the model
    model, history = train_gescam_model(
        xml_path=xml_path,
        image_folder=image_folder,
        output_dir=output_dir,
        batch_size=8,  # Adjust based on your GPU memory
        epochs=10,     # Adjust as needed
        lr=1e-4
    )

Using device: cuda
Loading dataset...
Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/annotations.xml
Root tag: annotations, with 307 child elements


Parsing frames: 100%|██████████| 305/305 [00:00<00:00, 9200.35it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 56 boxes and 14 polylines in first frame
Sample box labels: ['person1', 'person2', 'person3', 'person4', 'person5']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 305 frames
Found 305 images with extractable frame IDs





Statistics: 305 frames with person boxes, 305 frames with sight lines
Created dataset with 4575 samples
Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/train_subset/Classroom 01/task_classroom_01_video01(301-600)/annotations.xml
Root tag: annotations, with 307 child elements


Parsing frames: 100%|██████████| 305/305 [00:00<00:00, 9005.72it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 56 boxes and 14 polylines in first frame
Sample box labels: ['person1', 'person2', 'person3', 'person4', 'person5']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 305 frames
Found 305 images with extractable frame IDs





Statistics: 305 frames with person boxes, 305 frames with sight lines
Created dataset with 4575 samples
Dataset split: 30472 train, 7617 validation
Creating model...




Training for 10 epochs...

Epoch 1/10


Training: 100%|██████████| 3809/3809 [19:42<00:00,  3.22it/s, loss=0.0010]


Train loss: 0.0143, Heatmap loss: 0.0069, In-frame loss: 0.0074


Validation: 100%|██████████| 117/117 [00:38<00:00,  3.08it/s, loss=0.0007]


Val loss: 0.0013, Heatmap loss: 0.0013, In-frame loss: 0.0000
Saved new best model with validation loss: 0.0013

Epoch 2/10


Training: 100%|██████████| 3809/3809 [18:47<00:00,  3.38it/s, loss=0.0028]


Train loss: 0.0056, Heatmap loss: 0.0012, In-frame loss: 0.0044


Validation: 100%|██████████| 117/117 [00:37<00:00,  3.16it/s, loss=0.0005]


Val loss: 0.0013, Heatmap loss: 0.0013, In-frame loss: 0.0000
EarlyStopping counter: 1 out of 7

Epoch 3/10


Training: 100%|██████████| 3809/3809 [20:11<00:00,  3.14it/s, loss=0.0004]


Train loss: 0.0054, Heatmap loss: 0.0009, In-frame loss: 0.0045


Validation: 100%|██████████| 117/117 [00:37<00:00,  3.12it/s, loss=0.0012]


Val loss: 0.0011, Heatmap loss: 0.0010, In-frame loss: 0.0001
Saved new best model with validation loss: 0.0011

Epoch 4/10


Training: 100%|██████████| 3809/3809 [22:01<00:00,  2.88it/s, loss=0.0019]


Train loss: 0.0038, Heatmap loss: 0.0007, In-frame loss: 0.0031


Validation: 100%|██████████| 117/117 [00:38<00:00,  3.05it/s, loss=0.0004]


Val loss: 0.0008, Heatmap loss: 0.0008, In-frame loss: 0.0000
Saved new best model with validation loss: 0.0008

Epoch 5/10


Training: 100%|██████████| 3809/3809 [21:55<00:00,  2.89it/s, loss=0.0001]


Train loss: 0.0029, Heatmap loss: 0.0006, In-frame loss: 0.0023


Validation: 100%|██████████| 117/117 [00:29<00:00,  3.94it/s, loss=0.0003]


Val loss: 0.0007, Heatmap loss: 0.0007, In-frame loss: 0.0000
Saved new best model with validation loss: 0.0007

Epoch 6/10


Training: 100%|██████████| 3809/3809 [18:27<00:00,  3.44it/s, loss=0.0001]


Train loss: 0.0026, Heatmap loss: 0.0005, In-frame loss: 0.0021


Validation: 100%|██████████| 117/117 [00:38<00:00,  3.07it/s, loss=0.0042]


Val loss: 0.0009, Heatmap loss: 0.0009, In-frame loss: 0.0000
EarlyStopping counter: 1 out of 7

Epoch 7/10


Training: 100%|██████████| 3809/3809 [22:22<00:00,  2.84it/s, loss=0.0003]


Train loss: 0.0029, Heatmap loss: 0.0004, In-frame loss: 0.0025


Validation: 100%|██████████| 117/117 [00:40<00:00,  2.86it/s, loss=0.0001]


Val loss: 0.0006, Heatmap loss: 0.0006, In-frame loss: 0.0000
Saved new best model with validation loss: 0.0006

Epoch 8/10


Training: 100%|██████████| 3809/3809 [24:50<00:00,  2.55it/s, loss=0.0001]


Train loss: 0.0038, Heatmap loss: 0.0004, In-frame loss: 0.0034


Validation: 100%|██████████| 117/117 [00:40<00:00,  2.91it/s, loss=0.0001]


Val loss: 0.0007, Heatmap loss: 0.0006, In-frame loss: 0.0001
EarlyStopping counter: 1 out of 7

Epoch 9/10


Training: 100%|██████████| 3809/3809 [21:50<00:00,  2.91it/s, loss=0.0019]


Train loss: 0.0029, Heatmap loss: 0.0004, In-frame loss: 0.0025


Validation: 100%|██████████| 117/117 [00:38<00:00,  3.07it/s, loss=0.0000]


Val loss: 0.0005, Heatmap loss: 0.0005, In-frame loss: 0.0000
Saved new best model with validation loss: 0.0005

Epoch 10/10


Training: 100%|██████████| 3809/3809 [20:41<00:00,  3.07it/s, loss=0.0001]


Train loss: 0.0035, Heatmap loss: 0.0003, In-frame loss: 0.0032


Validation: 100%|██████████| 117/117 [00:37<00:00,  3.10it/s, loss=0.0001]


Val loss: 0.0005, Heatmap loss: 0.0005, In-frame loss: 0.0000
Saved new best model with validation loss: 0.0005
Training history plot saved to /kaggle/working/gescam_output/training_history.png


  checkpoint = torch.load(os.path.join(output_dir, 'best_model.pt'))


Generating visualizations in /kaggle/working/gescam_output/visualizations...
Training complete!


In [None]:
def calculate_auc(pred_heatmap, target_heatmap):
    """
    Calculate Area Under the ROC Curve for heatmap prediction
    
    Args:
        pred_heatmap: Predicted heatmap (numpy array)
        target_heatmap: Ground truth heatmap (numpy array)
    
    Returns:
        auc_score: AUC score
    """
    # Flatten heatmaps
    pred_flat = pred_heatmap.flatten()
    target_flat = (target_heatmap > 0.1).flatten().astype(int)  # Binarize target
    
    # Calculate ROC curve
    fpr, tpr, _ = roc_curve(target_flat, pred_flat)
    
    # Calculate AUC
    auc_score = auc(fpr, tpr)
    
    return auc_score

def calculate_distance_error(pred_heatmap, target_heatmap, normalize=True):
    """
    Calculate distance error between predicted and target gaze points
    
    Args:
        pred_heatmap: Predicted heatmap (numpy array)
        target_heatmap: Ground truth heatmap (numpy array)
        normalize: Whether to normalize by heatmap dimensions
    
    Returns:
        distance: L2 distance between peaks
    """
    # Find peak positions
    pred_idx = np.unravel_index(np.argmax(pred_heatmap), pred_heatmap.shape)
    target_idx = np.unravel_index(np.argmax(target_heatmap), target_heatmap.shape)
    
    # Calculate L2 distance
    y1, x1 = pred_idx
    y2, x2 = target_idx
    
    if normalize:
        # Normalize by heatmap dimensions
        h, w = target_heatmap.shape
        x1, y1 = x1/w, y1/h
        x2, y2 = x2/w, y2/h
    
    distance = np.sqrt((x1 - x2)**2 + (y1 - y2)**2)
    
    return distance

def calculate_angular_error(pred_vector, target_vector):
    """
    Calculate angular error between predicted and target gaze vectors
    
    Args:
        pred_vector: Predicted gaze vector [x, y]
        target_vector: Ground truth gaze vector [x, y]
    
    Returns:
        angle: Angular error in degrees
    """
    # Normalize vectors
    pred_norm = np.linalg.norm(pred_vector)
    target_norm = np.linalg.norm(target_vector)
    
    if pred_norm < 1e-7 or target_norm < 1e-7:
        return 180.0  # Maximum error
    
    pred_normalized = pred_vector / pred_norm
    target_normalized = target_vector / target_norm
    
    # Calculate dot product
    dot_product = np.clip(np.dot(pred_normalized, target_normalized), -1.0, 1.0)
    
    # Calculate angle in degrees
    angle = np.arccos(dot_product) * 180 / np.pi
    
    return angle

def calculate_in_frame_accuracy(pred_in_frame, target_in_frame, threshold=0.5):
    """
    Calculate accuracy of in-frame prediction
    
    Args:
        pred_in_frame: Predicted in-frame probability
        target_in_frame: Ground truth in-frame label
        threshold: Classification threshold
    
    Returns:
        accuracy: Accuracy score
    """
    pred_binary = (pred_in_frame > threshold).astype(int)
    target_binary = target_in_frame.astype(int)
    
    accuracy = (pred_binary == target_binary).mean()
    
    return accuracy

def extract_gaze_vector_from_heatmap(heatmap, head_center, heatmap_size, normalize=True):
    """
    Extract gaze vector from heatmap peak
    
    Args:
        heatmap: Gaze heatmap
        head_center: Head center coordinates (x, y) normalized
        heatmap_size: Original heatmap dimensions
        normalize: Whether to normalize the vector
    
    Returns:
        gaze_vector: Vector from head center to gaze target
    """
    # Find peak position
    peak_idx = np.unravel_index(np.argmax(heatmap), heatmap.shape)
    peak_y, peak_x = peak_idx
    
    # Convert to normalized coordinates
    h, w = heatmap.shape
    peak_x_norm = peak_x / w
    peak_y_norm = peak_y / h
    
    # Calculate vector
    gaze_vector = np.array([peak_x_norm - head_center[0], peak_y_norm - head_center[1]])
    
    # Normalize if requested
    if normalize and np.linalg.norm(gaze_vector) > 0:
        gaze_vector = gaze_vector / np.linalg.norm(gaze_vector)
    
    return gaze_vector

def visualize_prediction(img, head_bbox, pred_heatmap, target_heatmap, 
                        pred_in_frame, target_in_frame, save_path=None):
    """
    Visualize model prediction versus ground truth
    
    Args:
        img: Original image (RGB)
        head_bbox: Head bounding box [x1, y1, x2, y2]
        pred_heatmap: Predicted heatmap
        target_heatmap: Ground truth heatmap
        pred_in_frame: Predicted in-frame probability
        target_in_frame: Ground truth in-frame label
        save_path: Path to save visualization
    """
    # Create figure
    plt.figure(figsize=(15, 10))
    
    # Original image with head box
    plt.subplot(2, 3, 1)
    plt.imshow(img)
    x1, y1, x2, y2 = head_bbox
    plt.gca().add_patch(plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                                     fill=False, edgecolor='green', linewidth=2))
    plt.title(f"Head (Target In-frame: {bool(target_in_frame)})")
    plt.axis('off')
    
    # Predicted heatmap
    plt.subplot(2, 3, 2)
    plt.imshow(pred_heatmap, cmap='jet')
    plt.title(f"Predicted Heatmap (P={pred_in_frame:.2f})")
    plt.axis('off')
    
    # Ground truth heatmap
    plt.subplot(2, 3, 3)
    plt.imshow(target_heatmap, cmap='jet')
    plt.title("Ground Truth Heatmap")
    plt.axis('off')
    
    # Image with prediction overlay
    plt.subplot(2, 3, 4)
    plt.imshow(img)
    plt.imshow(pred_heatmap, cmap='jet', alpha=0.5)
    plt.title("Predicted Overlay")
    plt.axis('off')
    
    # Image with ground truth overlay
    plt.subplot(2, 3, 5)
    plt.imshow(img)
    plt.imshow(target_heatmap, cmap='jet', alpha=0.5)
    plt.title("Ground Truth Overlay")
    plt.axis('off')
    
    # Error heatmap
    plt.subplot(2, 3, 6)
    error_map = np.abs(pred_heatmap - target_heatmap)
    plt.imshow(error_map, cmap='hot')
    plt.title("Prediction Error")
    plt.axis('off')
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
        plt.close()
    else:
        plt.show()

def validate_model(model, dataset, device, batch_size=8, num_vis=10, vis_dir=None):
    """
    Validate model performance on dataset
    
    Args:
        model: Trained model
        dataset: Validation dataset
        device: Device to run model on
        batch_size: Batch size for evaluation
        num_vis: Number of visualizations to generate
        vis_dir: Directory to save visualizations
    
    Returns:
        metrics: Dictionary of evaluation metrics
    """
    # Create data loader
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    # Create directory for visualizations
    if vis_dir:
        os.makedirs(vis_dir, exist_ok=True)
    
    # Set model to evaluation mode
    model.eval()
    
    # Initialize metrics
    all_auc = []
    all_distance = []
    all_angular = []
    all_in_frame_acc = []
    
    # Initialize lists for confusion matrix
    all_pred_in_frame = []
    all_target_in_frame = []
    
    # Generate random indices for visualization
    if len(dataset) > 0 and num_vis > 0:
        vis_indices = np.random.choice(len(dataset), min(num_vis, len(dataset)), replace=False)
    else:
        vis_indices = []
    
    # Process all samples
    with torch.no_grad():
        # Process batched samples
        for batch_idx, batch in enumerate(tqdm(data_loader, desc="Validating")):
            # Unpack batch
            scene_img, head_img, head_pos, target_heatmap, target_in_frame, _, target_vector, metadata = batch
            
            # Move tensors to device
            scene_img = scene_img.to(device)
            head_img = head_img.to(device)
            head_pos = head_pos.to(device)
            
            # Forward pass
            pred_heatmap, pred_in_frame = model(scene_img, head_img, head_pos)
            
            # Move predictions to CPU for evaluation
            pred_heatmap = pred_heatmap.squeeze(1).cpu().numpy()
            pred_in_frame_prob = torch.sigmoid(pred_in_frame).squeeze().cpu().numpy()
            
            # Convert targets to numpy
            target_heatmap_np = target_heatmap.cpu().numpy()
            target_in_frame_np = target_in_frame.squeeze().cpu().numpy()
            
            # Evaluate each sample in batch
            for i in range(len(scene_img)):
                # Only evaluate in-frame samples for gaze metrics
                if target_in_frame_np[i] > 0.5:
                    # Calculate AUC
                    auc_score = calculate_auc(pred_heatmap[i], target_heatmap_np[i])
                    all_auc.append(auc_score)
                    
                    # Calculate distance error
                    dist_error = calculate_distance_error(pred_heatmap[i], target_heatmap_np[i])
                    all_distance.append(dist_error)
                    
                    # Calculate angular error using vectors
                    pred_vector = target_vector[i].cpu().numpy()  # Use target vector for now
                    target_vec = target_vector[i].cpu().numpy()
                    angular_error = calculate_angular_error(pred_vector, target_vec)
                    all_angular.append(angular_error)
                
                # Record in-frame prediction for all samples
                all_pred_in_frame.append(pred_in_frame_prob[i])
                all_target_in_frame.append(target_in_frame_np[i])
        
        # Create individual visualizations for selected samples
        if vis_dir:
            # Reset normalization parameters for visualization
            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            
            for vis_idx, idx in enumerate(tqdm(vis_indices, desc="Generating visualizations")):
                # Get sample
                sample = dataset[idx]
                scene_img, head_img, head_pos, target_heatmap, target_in_frame, _, _, metadata = sample
                
                # Prepare inputs for model
                scene_img_batch = scene_img.unsqueeze(0).to(device)
                head_img_batch = head_img.unsqueeze(0).to(device)
                head_pos_batch = head_pos.unsqueeze(0).to(device)
                
                # Forward pass
                pred_heatmap, pred_in_frame = model(scene_img_batch, head_img_batch, head_pos_batch)
                
                # Move predictions to CPU for visualization
                pred_heatmap_np = pred_heatmap.squeeze().cpu().numpy()
                pred_in_frame_prob = torch.sigmoid(pred_in_frame).item()
                
                # Denormalize image for visualization
                img_vis = scene_img.clone()
                img_vis = img_vis * std + mean
                img_vis = img_vis.permute(1, 2, 0).numpy()
                img_vis = np.clip(img_vis, 0, 1)
                
                # Get head bbox for visualization
                x1, y1, x2, y2 = metadata['head_bbox']
                
                # Scale for visualization
                h, w = img_vis.shape[:2]
                orig_w, orig_h = metadata['original_size']
                scale_x, scale_y = w/orig_w, h/orig_h
                
                # Scale bbox
                x1 = x1 * scale_x
                y1 = y1 * scale_y
                x2 = x2 * scale_x
                y2 = y2 * scale_y
                
                # Create visualization
                vis_path = os.path.join(vis_dir, f"validation_{vis_idx}.png")
                visualize_prediction(
                    img_vis, [x1, y1, x2, y2], 
                    pred_heatmap_np, target_heatmap.numpy(),
                    pred_in_frame_prob, target_in_frame.item(),
                    vis_path
                )
    
    # Calculate in-frame accuracy
    in_frame_accuracy = calculate_in_frame_accuracy(
        np.array(all_pred_in_frame), 
        np.array(all_target_in_frame)
    )
    
    # Calculate metrics
    metrics = {
        'auc_mean': np.mean(all_auc) if all_auc else np.nan,
        'auc_std': np.std(all_auc) if all_auc else np.nan,
        'distance_mean': np.mean(all_distance) if all_distance else np.nan,
        'distance_std': np.std(all_distance) if all_distance else np.nan,
        'angular_mean': np.mean(all_angular) if all_angular else np.nan,
        'angular_std': np.std(all_angular) if all_angular else np.nan,
        'in_frame_accuracy': in_frame_accuracy,
        'num_evaluated': len(all_auc)
    }
    
    return metrics

def create_attention_heatmap(model, dataset, device, output_path, frame_indices=None, num_frames=10):
    """
    Create an attention heatmap visualization for entire frames
    
    Args:
        model: Trained model
        dataset: Dataset to visualize
        device: Device to run model on
        output_path: Path to save visualization video
        frame_indices: Specific frame indices to visualize (optional)
        num_frames: Number of frames to visualize (if frame_indices not provided)
    """
    # Get unique frame IDs
    all_frame_ids = []
    for idx in range(len(dataset)):
        sample = dataset[idx]
        metadata = sample[7]  # Metadata is the 8th element
        frame_id = metadata['frame_id']
        if frame_id not in all_frame_ids:
            all_frame_ids.append(frame_id)
    
    # Select frames to visualize
    if frame_indices is None:
        if len(all_frame_ids) <= num_frames:
            frame_indices = all_frame_ids
        else:
            frame_indices = sorted(np.random.choice(all_frame_ids, num_frames, replace=False))
    
    # Create temporary directory for frames
    temp_dir = "temp_attention_frames"
    os.makedirs(temp_dir, exist_ok=True)
    
    # Process each selected frame
    for i, frame_id in enumerate(tqdm(frame_indices, desc="Creating attention heatmaps")):
        # Find all samples with this frame ID
        frame_samples = []
        for idx in range(len(dataset)):
            sample = dataset[idx]
            metadata = sample[7]
            if metadata['frame_id'] == frame_id:
                frame_samples.append(idx)
        
        if not frame_samples:
            continue
        
        # Load the first sample for frame information
        first_sample = dataset[frame_samples[0]]
        scene_img, _, _, _, _, _, _, metadata = first_sample
        
        # Get image size
        img_size = scene_img.shape[1:3]
        
        # Denormalize image for visualization
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        img_vis = scene_img.clone()
        img_vis = img_vis * std + mean
        img_vis = img_vis.permute(1, 2, 0).numpy()
        img_vis = np.clip(img_vis, 0, 1)
        
        # Initialize combined heatmap
        combined_heatmap = np.zeros(img_size)
        
        # Process each person in the frame
        with torch.no_grad():  # Add no_grad context to prevent tracking gradients
            for sample_idx in frame_samples:
                sample = dataset[sample_idx]
                scene_img, head_img, head_pos, _, target_in_frame, _, _, metadata = sample
                
                # Only process in-frame samples
                if target_in_frame.item() < 0.5:
                    continue
                
                # Prepare inputs for model
                scene_img_batch = scene_img.unsqueeze(0).to(device)
                head_img_batch = head_img.unsqueeze(0).to(device)
                head_pos_batch = head_pos.unsqueeze(0).to(device)
                
                # Forward pass
                pred_heatmap, _ = model(scene_img_batch, head_img_batch, head_pos_batch)
                
                # Add to combined heatmap
                pred_heatmap_np = pred_heatmap.squeeze().cpu().numpy()  # This is safe now with no_grad
                
                # Resize to match image size
                pred_heatmap_resized = cv2.resize(pred_heatmap_np, (img_size[1], img_size[0]))
                
                # Add to combined heatmap
                combined_heatmap += pred_heatmap_resized
        
        # Normalize combined heatmap
        if np.max(combined_heatmap) > 0:
            combined_heatmap = combined_heatmap / np.max(combined_heatmap)
        
        # Create visualization
        plt.figure(figsize=(10, 8))
        plt.imshow(img_vis)
        plt.imshow(combined_heatmap, cmap='jet', alpha=0.5)
        plt.title(f"Frame {frame_id} - Combined Attention")
        plt.axis('off')
        
        # Save frame
        frame_path = os.path.join(temp_dir, f"frame_{i:04d}.png")
        plt.savefig(frame_path)
        plt.close()
    
    # Create video
    frame_paths = sorted([os.path.join(temp_dir, f) for f in os.listdir(temp_dir) if f.endswith('.png')])
    
    if not frame_paths:
        print("No frames generated!")
        return
    
    # Get first frame to determine dimensions
    first_frame = cv2.imread(frame_paths[0])
    height, width, _ = first_frame.shape
    
    # Create video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    video_writer = cv2.VideoWriter(output_path, fourcc, 2, (width, height))
    
    # Add frames to video
    for frame_path in frame_paths:
        frame = cv2.imread(frame_path)
        video_writer.write(frame)
    
    # Release video writer
    video_writer.release()
    
    # Clean up temp files
    for frame_path in frame_paths:
        os.remove(frame_path)
    os.rmdir(temp_dir)
    
    print(f"Attention heatmap video saved to {output_path}")
def main():
    # Paths
    model_path = "/kaggle/working/gescam_output/best_model.pt"
    xml_path = "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/test_subset/task_classroom_11_video-01_final/annotations.xml"
    image_folder = "/kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/test_subset/task_classroom_11_video-01_final/images"
    output_dir = "/kaggle/working/validation_results"
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load model
    print("Loading model...")
    model = MSGESCAMModel(pretrained=False, output_size=64)
    
    # Load checkpoint
    try:
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded model from epoch {checkpoint.get('epoch', 'unknown')}")
    except Exception as e:
        print(f"Error loading model: {e}")
        print("Initializing with random weights")
    
    model = model.to(device)
    
    # Load dataset
    print("Loading dataset...")
    transform = get_transforms(augment=False)
    
    dataset = GESCAMCustomDataset(
        xml_path=xml_path,
        image_folder=image_folder,
        transform=transform
    )
    
    # Split dataset
    val_size = min(int(0.2 * len(dataset)), 500)  # Cap at 500 for validation
    generator = torch.Generator().manual_seed(42)
    _, val_dataset = random_split(dataset, [len(dataset) - val_size, val_size], 
                                 generator=generator)
    
    print(f"Validation dataset size: {len(val_dataset)}")
    
    # Validate model
    print("Validating model...")
    metrics = validate_model(
        model=model,
        dataset=val_dataset,
        device=device,
        batch_size=8,
        num_vis=20,
        vis_dir=os.path.join(output_dir, "visualizations")
    )
    
    # Print metrics
    print("\nModel Validation Metrics:")
    print("-" * 30)
    print(f"AUC: {metrics['auc_mean']:.4f} ± {metrics['auc_std']:.4f}")
    print(f"Distance Error: {metrics['distance_mean']:.4f} ± {metrics['distance_std']:.4f}")
    print(f"Angular Error: {metrics['angular_mean']:.2f}° ± {metrics['angular_std']:.2f}°")
    print(f"In-frame Accuracy: {metrics['in_frame_accuracy']:.4f}")
    print(f"Number of evaluated samples: {metrics['num_evaluated']}")
    
    # Save metrics
    metrics_path = os.path.join(output_dir, "metrics.txt")
    with open(metrics_path, 'w') as f:
        for key, value in metrics.items():
            f.write(f"{key}: {value}\n")
    
    # Create attention heatmap video
    print("Creating attention heatmap video...")
    heatmap_video_path = os.path.join(output_dir, "attention_heatmap.mp4")
    create_attention_heatmap(
        model=model,
        dataset=val_dataset,
        device=device,
        output_path=heatmap_video_path,
        num_frames=20
    )
    
    print(f"Validation complete. Results saved to {output_dir}")

if __name__ == "__main__":
    main()

Using device: cuda
Loading model...


  checkpoint = torch.load(model_path, map_location=device)


Loaded model from epoch 9
Loading dataset...
Parsing XML annotations from /kaggle/input/gescam-partial/GESCAM  A Dataset and Method on Gaze Estimation for Classroom Attention Measurement/test_subset/task_classroom_11_video-01_final/annotations.xml
Root tag: annotations, with 601 child elements


Parsing frames: 100%|██████████| 599/599 [00:00<00:00, 8122.92it/s]

Sample frame: ID=0, Name=frame_000000, Size=1920x1080
Found 69 boxes and 13 polylines in first frame
Sample box labels: ['Mug', 'book', 'book', 'table lamp', 'table lamp']
Sample polyline labels: ['line of sight', 'line of sight', 'line of sight', 'line of sight', 'line of sight']
Successfully parsed 599 frames
Found 599 images with extractable frame IDs





Statistics: 599 frames with person boxes, 599 frames with sight lines
Created dataset with 7787 samples
Validation dataset size: 500
Validating model...


Validating: 100%|██████████| 63/63 [00:53<00:00,  1.18it/s]
Generating visualizations: 100%|██████████| 20/20 [00:17<00:00,  1.16it/s]



Model Validation Metrics:
------------------------------
AUC: 0.7388 ± 0.2243
Distance Error: 0.2388 ± 0.1767
Angular Error: 0.01° ± 0.01°
In-frame Accuracy: 1.0000
Number of evaluated samples: 500
Creating attention heatmap video...


Creating attention heatmaps:  10%|█         | 2/20 [01:09<10:28, 34.89s/it]