In [None]:
# Libraries
import torch
import torch.nn as nn
from transformers import AutoProcessor, VideoMAEForVideoClassification
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from tqdm import tqdm
import json
import os
import random
import numpy as np
import cv2
from torchvision import transforms
from torch.optim import AdamW
from transformers import get_linear_schedule_with_warmup
from torch.cuda.amp import GradScaler, autocast

## VideoMEA: 20 or 30 Epochs with Even Frames

In [None]:
# Step 1: Load Annotations and Determine Number of Classes
annotation_file = 'nslt_100.json'
root_dir = 'WLASL100_videos'
with open(annotation_file, "r") as f:
    annotations = json.load(f)

# Ensure that gloss_to_index contains mappings for all unique labels in the dataset
all_actions = set(data["action"][0] for data in annotations.values())
gloss_to_index = {action: idx for idx, action in enumerate(sorted(all_actions))}
num_classes = len(gloss_to_index)

# Create the inverse mapping from index to gloss
index_to_gloss = {idx: gloss for gloss, idx in gloss_to_index.items()}

# Step 2: Initialize Model and Processor
pretrained_weights = "MCG-NJU/videomae-base-finetuned-kinetics"
processor = AutoProcessor.from_pretrained(pretrained_weights)
model = VideoMAEForVideoClassification.from_pretrained(
    pretrained_weights,
    num_labels=num_classes,  # Ensure model output fits dataset
    ignore_mismatched_sizes=True
)

# Modify the model's classification head
model.classifier = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(model.config.hidden_size, num_classes)  # Change to match num_classes
)

# Detect device to execute model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Step 3: Implement Dataset Class
class VideoDataset(Dataset):
    def __init__(self, annotation_file, root_dir, processor, num_frames, transform, subset):
        self.annotation_file = annotation_file
        self.root_dir = root_dir
        self.processor = processor
        self.num_frames = num_frames
        self.transform = transform
        self.subset = subset

        # Load annotations
        with open(self.annotation_file, "r") as f:
            self.annotations = json.load(f)

        # Filter videos by subset
        self.video_list = [
            vid for vid, data in self.annotations.items() if data.get("subset") == self.subset
        ]
        self.labels = {vid: data["action"][0] for vid, data in self.annotations.items()}

    def __len__(self):
        return len(self.video_list)

    def __getitem__(self, idx):
        video_id = self.video_list[idx]
        label = gloss_to_index[self.labels[video_id]]
        video_path = os.path.join(self.root_dir, f"{video_id}.mp4")
        # Load video frames
        frames = self.load_video_frames(video_path)
        # Fallback for videos with no frames
        if not frames:
            print(f"Warning: No frames loaded for video {video_id}.")
            frames = [np.zeros((224, 224, 3), dtype=np.uint8) for _ in range(self.num_frames)]
        # Apply transformations if any
        if self.transform:
            frames = self.apply_transforms(frames)
        # Process frames with processor
        inputs = self.processor(frames, return_tensors="pt", do_rescale=False)["pixel_values"]
        return {"pixel_values": inputs.squeeze(0), "labels": torch.tensor(label)}

    def load_video_frames(self, video_path):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Error: Cannot open video file {video_path}.")
            return []

        frames = []
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Calculate step size to evenly distribute frames
        if total_frames > 1:
            step = max(1, total_frames // self.num_frames)
        else:
            step = 1

        for i in range(0, step * self.num_frames, step):
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if not ret:
                break
            h, w, _ = frame.shape
            if min(w, h) < 226:
                scale = 226 / min(w, h)
                frame = cv2.resize(frame, None, fx=scale, fy=scale)
            if max(w, h) > 256:
                scale = 256 / max(w, h)
                frame = cv2.resize(frame, None, fx=scale, fy=scale)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)

        cap.release()

        # Handle cases with fewer frames than expected
        if len(frames) < self.num_frames:
            pad_choice = random.random()
            if pad_choice < 0.5 and frames:
                pad_frame = frames[0]
                while len(frames) < self.num_frames:
                    frames.insert(0, pad_frame)
            elif frames:
                pad_frame = frames[-1]
                while len(frames) < self.num_frames:
                    frames.append(pad_frame)

        # In case frames were empty
        while len(frames) < self.num_frames:
            frames.append(np.zeros((224, 224, 3), dtype=np.uint8))

        return frames

    # consistent transformation for a video
    def apply_transforms(self, frames):
        """
        Apply consistent transforms to all frames in a video.
        """
        if self.subset == "train" or self.subset == "val":
            # Generate a single random crop and flip decision for the entire video
            crop_transform = transforms.RandomCrop(224)
            flip_transform = transforms.RandomHorizontalFlip(p=0.5)
            random_flip = random.random() < 0.5  # Determine if flipping is applied

            # Define a manual crop region using RandomCrop's get_params
            first_frame = transforms.ToPILImage()(frames[0])
            i, j, h, w = transforms.RandomCrop.get_params(first_frame, output_size=(224, 224))

            def consistent_transform(frame):
                frame = transforms.ToPILImage()(frame)
                frame = transforms.functional.crop(frame, i, j, h, w)  # Apply the same crop
                if random_flip:
                    frame = transforms.functional.hflip(frame)  # Apply the same flip
                return transforms.ToTensor()(frame)

            # Apply the consistent transforms to all frames
            return [consistent_transform(frame) for frame in frames]
        else:  # For validation and testing
            pil_transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
            return [pil_transforms(frame) for frame in frames]

# Step 4: Initialize Datasets and DataLoaders with Combined Training and Validation
# Create separate datasets
train_dataset = VideoDataset(
    annotation_file=annotation_file,
    root_dir=root_dir,
    processor=processor,
    num_frames=16,
    transform=True, # handled by transfrom function
    subset="train"
)
val_dataset = VideoDataset(
    annotation_file=annotation_file,
    root_dir=root_dir,
    processor=processor,
    num_frames=16,
    transform=True, # handled by transfrom function
    subset="val"
)
test_dataset = VideoDataset(
    annotation_file=annotation_file,
    root_dir=root_dir,
    processor=processor,
    num_frames=16,
    transform=True, # handled by transfrom function
    subset="test"
)

# Combine train and val datasets for training
combined_train_dataset = ConcatDataset([train_dataset, val_dataset])

# Optimized DataLoader settings
batch_size = 6  # Increased from 4 to 16
num_workers = 8  # Reduced from 12 to 8 to balance CPU usage
prefetch_factor = 4  # Added prefetch_factor for data preloading

# Create DataLoaders
train_loader = DataLoader(
    combined_train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=prefetch_factor,
    persistent_workers=True  # Keeps workers alive for faster data loading
)
val_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=prefetch_factor,
    persistent_workers=True
)

# Step 5: Define Optimizer, Scheduler, and Mixed Precision Scaler
learning_rate = 1e-5
num_epochs = 20       # Set epochs: 20 or 30
weight_decay = 1e-4
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),  # 10% of steps for warm-up
    num_training_steps=total_steps
)
scaler = GradScaler()  # Initialized for mixed precision

# Step 6: Training Loop with Validation and Best Model Saving
best_val_accuracy = 0.0  # Initialize best validation accuracy
best_model_path = "best_model_MEA_even_frames_20_epochs.pth"  # Filepath to save the best model, for 30 epochs: "best_model_MEA_even_frames_20_epochs.pth"

for epoch in range(num_epochs):

    # Log learning rate at the start of each epoch
    current_lrs = [param_group['lr'] for param_group in optimizer.param_groups]
    print(f"Epoch {epoch + 1}/{num_epochs} - Current Learning Rates: {current_lrs}")

    # Training Phase
    model.train()
    epoch_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Training")

    for batch in progress_bar:
        pixel_values = batch['pixel_values'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        optimizer.zero_grad()

        with autocast():  # Enables mixed precision
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

        scaler.scale(loss).backward()  # Scales loss for mixed precision
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        epoch_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch + 1} Average Training Loss: {avg_loss:.4f}")

    # Validation Phase
    model.eval()
    total_correct = 0
    total_samples = 0
    epoch_val_loss = 0.0
    with torch.no_grad():
        progress_bar_val = tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Validation")
        for batch in progress_bar_val:
            pixel_values = batch['pixel_values'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)

            with autocast():  # Ensures consistency in mixed precision
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss = outputs.loss
                logits = outputs.logits

            epoch_val_loss += loss.item()
            predictions = torch.argmax(logits, dim=-1)
            total_correct += (predictions == labels).sum().item()
            total_samples += labels.size(0)

    val_accuracy = total_correct / total_samples
    avg_val_loss = epoch_val_loss / len(val_loader)
    print(f"Epoch {epoch + 1} Validation Loss: {avg_val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}")

    # Check if this is the best model so far
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_accuracy': val_accuracy,
        }, best_model_path)
        print(f"✅ Best model improved to {val_accuracy:.4f}. Saved to {best_model_path}")
    else:
        print(f"➖ Validation accuracy did not improve from {best_val_accuracy:.4f}")

print("Training Complete!")
print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")

### Model Assessment

In [None]:
import torch
import numpy as np

# Load the full checkpoint
checkpoint = torch.load("best_model_MEA_even_frames_20_epochs.pth", map_location=device)        # or "best_model_MEA_even_frames_30_epochs.pth" for 30 epochs

# Extract the model's state dictionary from the checkpoint
model_state_dict = checkpoint['model_state_dict']

# Load the state dictionary into the model
model.load_state_dict(model_state_dict)

# Put the model in evaluation mode
model.eval()

# Assuming model is already defined and loaded with model.load_state_dict
model.eval()

correct_top1 = 0
correct_top5 = 0
correct_top10 = 0
total = 0

# No gradient needed during evaluation
with torch.no_grad():
    for batch in val_loader:
        # Extract and move data to device
        videos = batch['pixel_values'].squeeze(1).to(device)
        labels = batch['labels'].to(device)

        # Forward pass through the model
        outputs = model(pixel_values=videos)
        logits = outputs.logits

        # Get top-k predictions
        _, top1_preds = logits.topk(1, dim=-1)  # Top-1 predictions
        _, top5_preds = logits.topk(5, dim=-1)  # Top-5 predictions
        _, top10_preds = logits.topk(10, dim=-1)  # Top-10 predictions

        # Update Top-1 correct count
        correct_top1 += (top1_preds.squeeze() == labels).sum().item()

        # Update Top-5 correct count
        correct_top5 += sum([labels[i] in top5_preds[i] for i in range(len(labels))])

        # Update Top-10 correct count
        correct_top10 += sum([labels[i] in top10_preds[i] for i in range(len(labels))])

        # Update total sample count
        total += labels.size(0)

# Calculate accuracies
top1_accuracy = correct_top1 / total
top5_accuracy = correct_top5 / total
top10_accuracy = correct_top10 / total

# Print accuracies
print(f"Top-1 Accuracy: {top1_accuracy:.4f}")
print(f"Top-5 Accuracy: {top5_accuracy:.4f}")
print(f"Top-10 Accuracy: {top10_accuracy:.4f}")



## VideoMEA: 30 Epochs with Censecutive Frames

In [None]:
# Step 1: Load Annotations and Determine Number of Classes
annotation_file = 'nslt_100.json'
root_dir = 'WLASL100_videos'
with open(annotation_file, "r") as f:
    annotations = json.load(f)

# Ensure that gloss_to_index contains mappings for all unique labels in the dataset
all_actions = set(data["action"][0] for data in annotations.values())
gloss_to_index = {action: idx for idx, action in enumerate(sorted(all_actions))}
num_classes = len(gloss_to_index)

# Create the inverse mapping from index to gloss
index_to_gloss = {idx: gloss for gloss, idx in gloss_to_index.items()}

# Step 2: Initialize Model and Processor
pretrained_weights = "MCG-NJU/videomae-base-finetuned-kinetics"
processor = AutoProcessor.from_pretrained(pretrained_weights)
model = VideoMAEForVideoClassification.from_pretrained(
    pretrained_weights,
    num_labels=num_classes,  # Ensure model output fits dataset
    ignore_mismatched_sizes=True
)

# Modify the model's classification head
model.classifier = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(model.config.hidden_size, num_classes)  # Change to match num_classes
)

# Detect device to execute model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Step 3: Implement Dataset Class
class VideoDataset(Dataset):
    def __init__(self, annotation_file, root_dir, processor, num_frames, transform, subset):
        self.annotation_file = annotation_file
        self.root_dir = root_dir
        self.processor = processor
        self.num_frames = num_frames
        self.transform = transform
        self.subset = subset

        # Load annotations
        with open(self.annotation_file, "r") as f:
            self.annotations = json.load(f)

        # Filter videos by subset
        self.video_list = [
            vid for vid, data in self.annotations.items() if data.get("subset") == self.subset
        ]
        self.labels = {vid: data["action"][0] for vid, data in self.annotations.items()}

    def __len__(self):
        return len(self.video_list)

    def __getitem__(self, idx):
        video_id = self.video_list[idx]
        label = gloss_to_index[self.labels[video_id]]
        video_path = os.path.join(self.root_dir, f"{video_id}.mp4")
        # Load video frames
        frames = self.load_video_frames(video_path)
        # Fallback for videos with no frames
        if not frames:
            print(f"Warning: No frames loaded for video {video_id}.")
            frames = [np.zeros((224, 224, 3), dtype=np.uint8) for _ in range(self.num_frames)]
        # Apply transformations if any
        if self.transform:
            frames = self.apply_transforms(frames)
        # Process frames with processor
        inputs = self.processor(frames, return_tensors="pt", do_rescale=False)["pixel_values"]
        return {"pixel_values": inputs.squeeze(0), "labels": torch.tensor(label)}

    def load_video_frames(self, video_path):
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Error: Cannot open video file {video_path}.")
            return []
        frames = []
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        start_frame = 0 if total_frames < self.num_frames else random.randint(0, max(total_frames - self.num_frames, 0))
        cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)

        for _ in range(self.num_frames):
            ret, frame = cap.read()
            if not ret:
                break
            h, w, _ = frame.shape
            if min(w, h) < 226:
                scale = 226 / min(w, h)
                frame = cv2.resize(frame, None, fx=scale, fy=scale)
            if max(w, h) > 256:
                scale = 256 / max(w, h)
                frame = cv2.resize(frame, None, fx=scale, fy=scale)
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frames.append(frame)
        cap.release()

        # Pad frames if fewer than num_frames
        if len(frames) < self.num_frames:
            pad_choice = random.random()
            if pad_choice < 0.5 and frames:
                pad_frame = frames[0]
                while len(frames) < self.num_frames:
                    frames.insert(0, pad_frame)
            elif frames:
                pad_frame = frames[-1]
                while len(frames) < self.num_frames:
                    frames.append(pad_frame)

        # In case frames were empty
        while len(frames) < self.num_frames:
            frames.append(np.zeros((224, 224, 3), dtype=np.uint8))
        return frames

    # consistent transformation for a video
    def apply_transforms(self, frames):
        """
        Apply consistent transforms to all frames in a video.
        """
        if self.subset == "train" or self.subset == "val":
            # Generate a single random crop and flip decision for the entire video
            crop_transform = transforms.RandomCrop(224)
            flip_transform = transforms.RandomHorizontalFlip(p=0.5)
            random_flip = random.random() < 0.5  # Determine if flipping is applied

            # Define a manual crop region using RandomCrop's get_params
            first_frame = transforms.ToPILImage()(frames[0])
            i, j, h, w = transforms.RandomCrop.get_params(first_frame, output_size=(224, 224))

            def consistent_transform(frame):
                frame = transforms.ToPILImage()(frame)
                frame = transforms.functional.crop(frame, i, j, h, w)  # Apply the same crop
                if random_flip:
                    frame = transforms.functional.hflip(frame)  # Apply the same flip
                return transforms.ToTensor()(frame)

            # Apply the consistent transforms to all frames
            return [consistent_transform(frame) for frame in frames]
        else:  # For validation and testing
            pil_transforms = transforms.Compose([
                transforms.ToPILImage(),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ])
            return [pil_transforms(frame) for frame in frames]

# Step 4: Initialize Datasets and DataLoaders with Combined Training and Validation
# Create separate datasets
train_dataset = VideoDataset(
    annotation_file=annotation_file,
    root_dir=root_dir,
    processor=processor,
    num_frames=16,
    transform=True, # handled by transfrom function
    subset="train"
)
val_dataset = VideoDataset(
    annotation_file=annotation_file,
    root_dir=root_dir,
    processor=processor,
    num_frames=16,
    transform=True, # handled by transfrom function
    subset="val"
)
test_dataset = VideoDataset(
    annotation_file=annotation_file,
    root_dir=root_dir,
    processor=processor,
    num_frames=16,
    transform=True, # handled by transfrom function
    subset="test"
)

# Combine train and val datasets for training
combined_train_dataset = ConcatDataset([train_dataset, val_dataset])

# Optimized DataLoader settings
batch_size = 6  # Increased from 4 to 16
num_workers = 8  # Reduced from 12 to 8 to balance CPU usage
prefetch_factor = 4  # Added prefetch_factor for data preloading

# Create DataLoaders
train_loader = DataLoader(
    combined_train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=prefetch_factor,
    persistent_workers=True  # Keeps workers alive for faster data loading
)
val_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    prefetch_factor=prefetch_factor,
    persistent_workers=True
)

# Step 5: Define Optimizer, Scheduler, and Mixed Precision Scaler
learning_rate = 1e-5
num_epochs = 20
weight_decay = 1e-4
optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
total_steps = len(train_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),  # 10% of steps for warm-up
    num_training_steps=total_steps
)
scaler = GradScaler()  # Initialized for mixed precision

# Step 6: Training Loop with Validation and Best Model Saving
best_val_accuracy = 0.0  # Initialize best validation accuracy
best_model_path = "best_model_MEA_consec_frames_20_epochs.pth"  # Filepath to save the best model

for epoch in range(num_epochs):

    # Log learning rate at the start of each epoch
    current_lrs = [param_group['lr'] for param_group in optimizer.param_groups]
    print(f"Epoch {epoch + 1}/{num_epochs} - Current Learning Rates: {current_lrs}")

    # Training Phase
    model.train()
    epoch_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Training")

    for batch in progress_bar:
        pixel_values = batch['pixel_values'].to(device, non_blocking=True)
        labels = batch['labels'].to(device, non_blocking=True)
        optimizer.zero_grad()

        with autocast():  # Enables mixed precision
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

        scaler.scale(loss).backward()  # Scales loss for mixed precision
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        epoch_loss += loss.item()
        progress_bar.set_postfix({'loss': loss.item()})

    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch + 1} Average Training Loss: {avg_loss:.4f}")

    # Validation Phase
    model.eval()
    total_correct = 0
    total_samples = 0
    epoch_val_loss = 0.0
    with torch.no_grad():
        progress_bar_val = tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Validation")
        for batch in progress_bar_val:
            pixel_values = batch['pixel_values'].to(device, non_blocking=True)
            labels = batch['labels'].to(device, non_blocking=True)

            with autocast():  # Ensures consistency in mixed precision
                outputs = model(pixel_values=pixel_values, labels=labels)
                loss = outputs.loss
                logits = outputs.logits

            epoch_val_loss += loss.item()
            predictions = torch.argmax(logits, dim=-1)
            total_correct += (predictions == labels).sum().item()
            total_samples += labels.size(0)

    val_accuracy = total_correct / total_samples
    avg_val_loss = epoch_val_loss / len(val_loader)
    print(f"Epoch {epoch + 1} Validation Loss: {avg_val_loss:.4f} | Validation Accuracy: {val_accuracy:.4f}")

    # Check if this is the best model so far
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_accuracy': val_accuracy,
        }, best_model_path)
        print(f"✅ Best model improved to {val_accuracy:.4f}. Saved to {best_model_path}")
    else:
        print(f"➖ Validation accuracy did not improve from {best_val_accuracy:.4f}")

print("Training Complete!")
print(f"Best Validation Accuracy: {best_val_accuracy:.4f}")

### Model Assessment

In [None]:
import torch
import numpy as np

# Load the full checkpoint
checkpoint = torch.load("best_model_MEA_consec_frames_20_epochs.pth", map_location=device)

# Extract the model's state dictionary from the checkpoint
model_state_dict = checkpoint['model_state_dict']

# Load the state dictionary into the model
model.load_state_dict(model_state_dict)

# Put the model in evaluation mode
model.eval()

# Assuming model is already defined and loaded with model.load_state_dict
model.eval()

correct_top1 = 0
correct_top5 = 0
correct_top10 = 0
total = 0

# No gradient needed during evaluation
with torch.no_grad():
    for batch in val_loader:
        # Extract and move data to device
        videos = batch['pixel_values'].squeeze(1).to(device)
        labels = batch['labels'].to(device)

        # Forward pass through the model
        outputs = model(pixel_values=videos)
        logits = outputs.logits

        # Get top-k predictions
        _, top1_preds = logits.topk(1, dim=-1)  # Top-1 predictions
        _, top5_preds = logits.topk(5, dim=-1)  # Top-5 predictions
        _, top10_preds = logits.topk(10, dim=-1)  # Top-10 predictions

        # Update Top-1 correct count
        correct_top1 += (top1_preds.squeeze() == labels).sum().item()

        # Update Top-5 correct count
        correct_top5 += sum([labels[i] in top5_preds[i] for i in range(len(labels))])

        # Update Top-10 correct count
        correct_top10 += sum([labels[i] in top10_preds[i] for i in range(len(labels))])

        # Update total sample count
        total += labels.size(0)

# Calculate accuracies
top1_accuracy = correct_top1 / total
top5_accuracy = correct_top5 / total
top10_accuracy = correct_top10 / total

# Print accuracies
print(f"Top-1 Accuracy: {top1_accuracy:.4f}")
print(f"Top-5 Accuracy: {top5_accuracy:.4f}")
print(f"Top-10 Accuracy: {top10_accuracy:.4f}")

