# Sign-Language to Text Modelling

## 1. Install all dependancies

### 1.1 Install Required Packages

In [None]:
%pip install --upgrade pip
%pip install opencv-python
%pip install transformers
%pip install torchvision==0.16.2
%pip install pytorchvideo==0.1.5
%pip install imageio
%pip install accelerate
%pip install --upgrade mlflow

### 1.2 Restart the python kernal

#### Please **restart the kernel** after running the cell above to apply newly installed packages.

### 1.3 Load all the required libraries

In [None]:
import torch
import os
import json
import pickle
import glob
import itertools
import cv2
import mlflow
import accelerate
import shutil
import pathlib
from collections import defaultdict
from pathlib import Path
from datetime import datetime
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification
from transformers import TrainingArguments, Trainer
from random import choice
import torch
from torchvision.transforms import Pad
from torchvision.transforms import functional as F
import pytorchvideo.data
from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    RandomRotation,
    Resize,
    CenterCrop,
    RandomAutocontrast,
    RandomInvert,
    Grayscale,
    ElasticTransform
)
from typing import Any, Callable, Dict, Optional, Type
from pytorchvideo.data.clip_sampling import ClipSampler
import imageio
import numpy as np
from IPython.display import Image
import evaluate
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score, classification_report, matthews_corrcoef, confusion_matrix 
import seaborn as sns
import textwrap
import torch
import torch.nn as nn
import random

## 2. Data Pre-processing

### 2.1 Adjusting the format of `WLASL_v0.3.json`. Also select a subset of 28 glosses.

In [None]:
def generate_flat_dataset(wlasl_json_path, videos_dir, output_path):
    # Only include these 29 selected glosses
    selected_glosses = {
        "cousin", "deaf", "help", "call", "give", "take", "like", "laugh",
        "order", "drop", "pizza", "candy", "shirt", "room", "bar", "language",
        "speech", "cool", "silly", "sweet", "careful", "thin", "last", "soon",
        "what", "california", "convince", "interest"
    }

    with open(wlasl_json_path, "r") as f:
        wlasl_data = json.load(f)

    dataset = []

    for entry in wlasl_data:
        gloss = entry["gloss"]

        # Only include glosses from the selected list
        if gloss not in selected_glosses:
            continue

        split = entry.get("split", "unknown")

        for instance in entry["instances"]:
            video_id = instance.get("video_id")
            video_filename = f"{video_id}.mp4"
            video_path = videos_dir / video_filename

            # Only include if the actual video file exists
            if video_path.exists():
                frame_start = instance.get("frame_start")
                frame_end = instance.get("frame_end")
                instance_split = instance.get("split", split)

                # Final video path in output JSON
                relative_video_path = f"data/{video_filename}"

                dataset.append({
                    "gloss": gloss,
                    "video_path": relative_video_path,
                    "frame_start": frame_start,
                    "frame_end": frame_end,
                    "split": instance_split
                })

    with open(output_path, "w") as f:
        json.dump(dataset, f, indent=2)

In [None]:
base_dir = Path.cwd()
wlasl_json_path = base_dir / "WLASL_v0.3.json"
videos_dir = base_dir / "Videos"
output_path = base_dir / "WLASL_parsed_data_adjustedpath.json"

generate_flat_dataset(wlasl_json_path, videos_dir, output_path)

### 2.2 Restructure the files into `data/test`, `data/train` and `data/val`

In [None]:
# Load the JSON data
with open('WLASL_parsed_data_adjustedpath.json', 'r') as f:
    data = json.load(f)

# Base directory where the new folders will be created
base_dir = 'data'
moved_files = 0
missing_files = 0

for item in data:
    # Use the full relative path from JSON (e.g., "Videos/14894.mp4")
    current_path = os.path.normpath(item['video_path'])

    if os.path.exists(current_path):
        # Extract metadata
        split = item['split']
        gloss = item['gloss']

        # Create destination directory structure
        split_dir = os.path.join(base_dir, split)
        os.makedirs(split_dir, exist_ok=True)

        gloss_dir = os.path.join(split_dir, gloss)
        os.makedirs(gloss_dir, exist_ok=True)

        # Get just the filename
        filename = os.path.basename(current_path)

        # Define the new destination path
        new_path = os.path.join(gloss_dir, filename)

        # Move the file
        shutil.move(current_path, new_path)

        moved_files += 1
        print(f"The video {current_path} is moved to {new_path}")
    else:
        missing_files += 1
        print(f"The video {current_path} does not exist")

print(f"Moved {moved_files} files and {missing_files} files are missing")

### 2.3 Display the video distribution of `data`

In [None]:
def print_gloss_distribution(data_dir, min_per_split=2):
    # Count structure: {gloss: {'train': x, 'val': y, 'test': z, 'total': t}}
    gloss_counts = defaultdict(lambda: {'train': 0, 'val': 0, 'test': 0, 'total': 0})

    # Count videos in each split
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(data_dir, split)
        if not os.path.exists(split_path):
            continue

        for gloss in os.listdir(split_path):
            gloss_path = os.path.join(split_path, gloss)
            if not os.path.isdir(gloss_path):
                continue

            video_files = [f for f in os.listdir(gloss_path) if f.endswith('.mp4')]
            count = len(video_files)

            gloss_counts[gloss][split] += count
            gloss_counts[gloss]['total'] += count

    # Filter glosses with enough videos in each split
    filtered = {
        gloss: counts for gloss, counts in gloss_counts.items()
        if counts['train'] >= min_per_split and counts['val'] >= min_per_split and counts['test'] >= min_per_split
    }

    # Sort by total count descending
    sorted_glosses = sorted(filtered.items(), key=lambda x: x[1]['total'], reverse=True)

    # Print table with ranking
    print(f"{'#':<4} {'Gloss':<20} {'Train':>5} {'Val':>5} {'Test':>5} {'Total':>6}")
    print("-" * 60)
    for idx, (gloss, counts) in enumerate(sorted_glosses, start=1):
        print(f"{idx:<4} {gloss:<20} {counts['train']:>5} {counts['val']:>5} {counts['test']:>5} {counts['total']:>6}")

# Run the function
print_gloss_distribution(data_dir='data', min_per_split=2)

## 3. Fine-tune `VideoMAE` on a subset of 28 signs from the `WLASL` dataset.

### 3.1 Counting the total number of videos.

In [None]:
dataset_root_path = pathlib.Path("data")

# Get all video file paths in train, val, test folders
all_video_file_paths = list(dataset_root_path.glob("**/*.mp4"))

video_count_train = len(list(dataset_root_path.glob("train/*/*.mp4")))
video_count_val = len(list(dataset_root_path.glob("val/*/*.mp4")))
video_count_test = len(list(dataset_root_path.glob("test/*/*.mp4")))
video_total = video_count_train + video_count_val + video_count_test
print(f"Total videos: {video_total}")

### 3.2 Derive the set of labels present in the dataset.

In [None]:
# Extract parent folder names (i.e., class labels) from video file paths
class_labels = sorted({path.parent.name for path in all_video_file_paths})

# Create label ↔ ID mappings
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

# Print class info
print(f"{len(class_labels)} unique classes:")
labels_str = ', '.join(label2id.keys())
print('\n'.join(textwrap.wrap(labels_str, width=150)))

## Print the label2id and id2label mappings
print("\nLabel to ID mapping:")
print(label2id)

print("\nID to Label mapping:")
print(id2label)

### 3.3 Define functions to use for fine-tuning

#### 3.3.1 AddDistortion Class

A PyTorch module that adds random Gaussian noise (distortion) to a video tensor.

- **Purpose:** Introduces noise to video data for augmentation, improving model robustness.
- **Input:** A video tensor of shape `(C, T, H, W)` — Channels, Time (frames), Height, Width.
- **Output:** The input tensor with added distortion noise.

> The noise is sampled from a normal distribution with mean 0 and a configurable standard deviation (`distortion`).


In [None]:
class AddDistortion(torch.nn.Module):
    """
    Adds distortion to a video.
    """

    def __init__(self, distortion=0.5):
        super().__init__()
        self.distortion = distortion

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): video tensor with shape (C, T, H, W).
        """
        assert len(x.shape) == 4, "video must have shape (C, T, H, W)"
        
        # Create a new tensor with the same shape as x, filled with random values between -0.05 and 0.05
        random_values = torch.rand_like(x) * 0 + np.random.normal(0, self.distortion)

        # Add the random values to x
        x = x + random_values

        return x

#### 3.3.2 Count Files Per Folder

This function counts the number of files in each subfolder of a given directory and groups folder names by their file counts.

- **Input:** `dir_path` — path to the root directory containing label subfolders.
- **Output:** Dictionary where keys are file counts and values are lists of folder names with that count.
- **Purpose:** Helps analyze dataset distribution by showing how many files each class folder contains.


In [None]:
def count_files_per_folder(dir_path):
    file_count = {}
    file_count_grouped = {}
    for dirpath, _, filenames in os.walk(dir_path):
        if len(filenames) > 0:
            label = os.path.basename(dirpath)
            file_count[label] = len(filenames)
    
    for label, count in file_count.items():
        if count not in file_count_grouped:
            file_count_grouped[count] = []
        file_count_grouped[count].append(label)
    
    file_count_grouped = dict(sorted(file_count_grouped.items()))
    
    return file_count_grouped

#### 3.3.3 GIF Utilities for Video Tensors

These functions convert a normalized video tensor into a displayable GIF:

- **`unnormalize_img(img, mean, std)`**  
  Reverses normalization by applying mean and std, then scales pixel values back to the [0, 255] range.

- **`create_gif(video_tensor, mean, std, filename)`**  
  Converts a video tensor of shape `(num_frames, channels, height, width)` into a GIF by unnormalizing each frame and saving it.

- **`display_gif(video_tensor, mean, std, gif_name)`**  
  Prepares and displays the GIF directly in a notebook by calling `create_gif` and rendering the saved file.

> Useful for visualizing video clips during preprocessing or inference.

In [None]:
def unnormalize_img(img, mean, std):
    """Un-normalizes the image pixels."""
    img = (img * std) + mean
    img = (img * 255).astype("uint8")
    return img.clip(0, 255)

def create_gif(video_tensor, mean, std, filename="sample.gif"):
    """Prepares a GIF from a video tensor.

    The video tensor is expected to have the following shape:
    (num_frames, num_channels, height, width).
    """
    frames = []
    for video_frame in video_tensor:
        frame_unnormalized = unnormalize_img(video_frame.permute(1, 2, 0).numpy(), mean, std)
        frames.append(frame_unnormalized)
    kargs = {"duration": 0.25}
    imageio.mimsave(filename, frames, "GIF", **kargs)
    return filename

def display_gif(video_tensor, mean, std, gif_name="sample.gif"):
    """Prepares and displays a GIF from a video tensor."""
    video_tensor = video_tensor.permute(1, 0, 2, 3)
    gif_filename = create_gif(video_tensor, mean, std, gif_name)
    return Image(filename=gif_filename)

#### 3.3.4 Explanation of Functions

##### `compute_metrics(eval_pred)`
- Calculates multiple evaluation metrics for model predictions:
  - **Accuracy, F1-score, Precision, Recall** (weighted averages).
  - Computes **confusion matrix** and per-class precision and recall using `classification_report`.
- Returns a dictionary containing all metrics, including precision and recall for each class.

##### `collate_fn(examples)`
- Prepares a batch of video samples for the model.
- Converts a list of examples into:
  - `pixel_values`: stacked video tensors permuted to shape `(batch, num_frames, num_channels, height, width)`.
  - `labels`: tensor of corresponding labels.

##### `run_inference(model, video)`
- Performs inference on a single video tensor.
- Expects video tensor shape `(num_frames, num_channels, height, width)`.
- Permutes and batches the input, sends it to the device (CPU/GPU).
- Returns raw logits output by the model for further processing (e.g., applying softmax or argmax).


In [None]:
def compute_metrics(eval_pred):
    # Load the metrics
    accuracy = evaluate.load("accuracy")
    f1 = evaluate.load("f1")
    precision = evaluate.load("precision")
    recall = evaluate.load("recall")
    
    # Compute predictions
    predictions = np.argmax(eval_pred.predictions, axis=1)

    # Compute the metrics
    accuracy_result = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    f1_result = f1.compute(predictions=predictions, references=eval_pred.label_ids, average='weighted')
    precision_result = precision.compute(predictions=predictions, references=eval_pred.label_ids, average='weighted', zero_division=0)
    recall_result = recall.compute(predictions=predictions, references=eval_pred.label_ids, average='weighted', zero_division=0)
    
    # Compute confusion matrix
    confmat = confusion_matrix(eval_pred.label_ids, predictions)
    
    # Compute precision and recall per class
    report = classification_report(eval_pred.label_ids, predictions, output_dict=True, zero_division=0)
    
    label_values = [int(key) for key in report.keys() if key.isdigit()]
    precision_per_class_dict = {f'precision_{id2label[label]}': report[str(label)]['precision'] for label in label_values}
    recall_per_class_dict = {f'recall_{id2label[label]}': report[str(label)]['recall'] for label in label_values}

    
    # Return the metrics as a dictionary
    return {
        "accuracy": accuracy_result['accuracy'],
        "f1": f1_result['f1'],
        "precision": precision_result['precision'],
        "recall": recall_result['recall'],
        # "confusionmatrix": confmat,
        **precision_per_class_dict,
        **recall_per_class_dict
    }

def collate_fn(examples):
    # permute to (num_frames, num_channels, height, width)
    pixel_values = torch.stack(
        [example["video"].permute(1, 0, 2, 3) for example in examples]
    )
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

def run_inference(model, video):
    # (num_frames, num_channels, height, width)
    permuted_sample_test_video = video.permute(1, 0, 2, 3)
    inputs = {
        "pixel_values": permuted_sample_test_video.unsqueeze(0)
    }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    model = model.to(device)

    # forward pass
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    return logits

### 3.4 Load a model to fine-tune

##### Imports the relevant classes.

- Sets the checkpoint/model name (`"MCG-NJU/videomae-base"`).

- Loads the `VideoMAEImageProcessor` for preprocessing your videos.

- Loads the `VideoMAEForVideoClassification` model pretrained on that checkpoint.

- Passes the `label2id` and `id2label` mappings so the model knows your specific classes.

- Uses `ignore_mismatched_sizes=True` so it can load the pretrained weights even if your classification head shape is different

In [None]:
 # Load the default pretrained model
model_ckpt = "MCG-NJU/videomae-huge-finetuned-kinetics"

image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt)
model = VideoMAEForVideoClassification.from_pretrained(
    model_ckpt,
    ignore_mismatched_sizes=True,  
)

### 3.5  Prepare the datasets for training

#### 3.5.1 Video Preprocessing Parameters Setup

- **Mean and Standard Deviation:**  
  Retrieved from the `image_processor` to normalize video frames during preprocessing.

- **Resize Dimensions:**  
  Determines target height and width for resizing videos.  
  - If `shortest_edge` is specified in the processor's size, it is used for both height and width (square resize).  
  - Otherwise, height and width are set individually.

- **Frame Sampling Parameters:**  
  - `num_frames_to_sample`: Number of frames the model expects per video clip.  
  - `sample_rate`: Frame sampling rate (e.g., every 3rd frame).  
  - `fps`: Frames per second of the original videos.

- **Clip Duration Calculation:**  
  Computes the total duration of the video clip to sample based on the number of frames, sample rate, and fps.  
  This ensures consistent clip length matching model input requirements.

- **Output:**  
  Prints out the final height, width, number of frames, and clip duration for reference.

In [None]:
mean = image_processor.image_mean
std = image_processor.image_std

if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = model.config.num_frames
sample_rate = 3
fps = 28
clip_duration = num_frames_to_sample * sample_rate / fps

print(f"""the height of the video is {height}, the width is {width}, 
the number of frames is {num_frames_to_sample}, the clip duration is {clip_duration}.""")

#### 3.5.2 Extracting Video Properties from MP4 Files

- **Finding Video Files:**  
  Uses `glob.iglob` with recursive search to find all `.mp4` files under `folder_path`.

- **Selecting Files:**  
  `itertools.islice` limits the search to the first 10 `.mp4` files found.

- **Reading Video Metadata:**  
  For each video:  
  - Opens the video with OpenCV's `VideoCapture`.  
  - Retrieves the video's width, height, and frames per second (FPS) using OpenCV properties.

- **Output:**  
  Prints out the resolution and frame rate for each sampled video file.

In [None]:
# use itertools.islice() to get the first 10 .mp4 files
video_paths = list(itertools.islice(all_video_file_paths, 10))

for video_path in video_paths:
    cap = cv2.VideoCapture(video_path)

    # Get the width, height and frame rate of the video
    width = cap.get(cv2.CAP_PROP_FRAME_WIDTH)
    height = cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
    fps = cap.get(cv2.CAP_PROP_FPS)

    print(f"For video {video_path}, the width is {width}, the height is {height}, the frame rate is {fps}.")

### 3.6  Define the dataset-specific transformations and the datasets respectively

#### 3.6.1 Starting with the training set:

- **Custom frame-wise augmentation:**  
  Using `RandomTransformCustom` to apply transforms like autocontrast and invert randomly to individual video frames.

- **Training transform pipeline:**  
  Combines temporal subsampling, normalization, resizing, standard augmentations (flip, rotation, elastic transform), distortion, and the custom frame-wise transforms.

- **Training dataset:**  
  Created using `pytorchvideo.data.Ucf101` with the training transform pipeline, random clip sampling, and no audio decoding.

In [None]:
# Define a generalized custom transform for applying any transform with a given probability
class RandomTransformCustom(torch.nn.Module):
    def __init__(self, transform, p=0.3):
        super().__init__()
        self.p = p
        self.transform = transform  # Pass any transform object

    def forward(self, x):
        # Assuming input shape is (C, T, H, W) where:
        # C = channels (1 or 3 typically)
        # T = number of frames
        # H = height of frames
        # W = width of frames
        c, t, h, w = x.shape
        for i in range(t):  # Loop over frames (temporal dimension)
            if random.random() < self.p:
                x[:, i, :, :] = self.transform(x[:, i, :, :])  # Apply transform to each frame
        return x

train_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    # same arguments as test set
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    Resize(resize_to, antialias=True),
                    
                    # additional noise to avoid overfitting
                    RandomHorizontalFlip(p=0.4),
                    RandomRotation(degrees=10),
                    ElasticTransform(alpha=30.0),
                    AddDistortion(0.1),

                    # Use generalized RandomTransformCustom for both RandomInvert and RandomAutocontrast
                    RandomTransformCustom(RandomAutocontrast(p=1.0), p=0.2),  
                    RandomTransformCustom(RandomInvert(p=1.0), p=0.3),        
                ]
            ),
        ),
    ]
)

# Create train dataset
train_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(videos_dir, "train"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
    decode_audio=False,
    transform=train_transform,
)

#### 3.6.2 The same sequence of workflow can be applied to the test dataset:

- **Validation/test transform pipeline:**  
  A simpler transform pipeline that includes temporal subsampling, normalization, and resizing—no augmentations, to evaluate model performance reliably.

- **Test dataset:**  
  Created with `pytorchvideo.data.Ucf101` using the validation transform pipeline, random clip sampling, and no audio decoding.

In [None]:
val_transform = Compose(
    [
        ApplyTransformToKey(
            key="video",
            transform=Compose(
                [
                    UniformTemporalSubsample(num_frames_to_sample),
                    Lambda(lambda x: x / 255.0),
                    Normalize(mean, std),
                    Resize(resize_to, antialias=True),
                ]
            ),
        ),
    ]
)

test_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(videos_dir, "test"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
    decode_audio=False,
    transform=val_transform,
)

#### 3.6.3 Count the number of video files per class (folder) in both the training and test directories using the `count_files_per_folder` function:

- `train_count_labels`: Dictionary grouping classes by the number of files in the training set.
- `test_count_labels`: Dictionary grouping classes by the number of files in the test set.

In [None]:
train_count_labels = count_files_per_folder(os.path.join(videos_dir, "train"))
test_count_labels = count_files_per_folder(os.path.join(videos_dir, "test"))

print("In the train folder, the number of videos per label are:")
for k, v in train_count_labels.items():
    print(f"{len(v)} labels with {k} video(s): {v}")
    
print("In the test folder, the number of videos per label are:")
for k, v in test_count_labels.items():
    print(f"{len(v)} labels with {k} video(s): {v}")

### 3.7 Visualize the pre-processed data

In [None]:
sample_video = next(iter(train_dataset))
video_tensor = sample_video["video"]
display_gif(video_tensor, mean, std, "vid_example.gif")

In [None]:
sample_video = next(iter(test_dataset))
video_tensor = sample_video["video"]
display_gif(video_tensor, mean, std, "val_example.gif")

## 3.8 Train the model

#### 3.8.1 Check NVIDIA GPU information

In [None]:
!nvidia-smi

#### 3.8.2 Check the number of videos in the training dataset
- `train_dataset.num_videos`:  
  Returns the total number of video samples in the training dataset.  
  This is useful for verifying dataset size or calculating steps per epoch.

In [None]:
train_dataset.num_videos 

#### 3.8.3 Most of the training arguments below are self-explanatory, but one particularly important one is:

- `remove_unused_columns=False`:  
  This prevents the `Trainer` from automatically dropping input features not explicitly used by the model's `forward` method.  
  In this case, we **need the `'video'` key** to generate `pixel_values` — a required input for the `VideoMAE` model.  
  If `remove_unused_columns` were left as `True`, those unused fields (like `'video'`) would be removed before reaching the model.

Other key parameters:
- `output_dir`: Where to save the model.
- `learning_rate`: Fine-tuned for small updates.
- `warmup_ratio`: Helps with training stability.
- `max_steps`: Total number of training steps (based on dataset size, batch size, and epochs).
- `logging_steps`, `save_steps`, etc.: Control how often logs and checkpoints are saved.

In [None]:
model_name = model_ckpt.split("/")[-1]
training_start_moment = datetime.now().isoformat(timespec='hours')
new_model_name = f"{model_name}-sign_finetuned-{training_start_moment}"

output_dir = os.path.join("Models", new_model_name)
num_epochs = 100 
batch_size = 16

max_steps = (train_dataset.num_videos // batch_size) * num_epochs
logging_steps = 10 if max_steps >= 10 else 1

args = TrainingArguments(
    output_dir=output_dir,
    remove_unused_columns=False,
    evaluation_strategy="steps",
    eval_steps=25,
    save_strategy="steps",
    save_steps=50,
    learning_rate=0.5e-5,
    auto_find_batch_size=True,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=batch_size,
    warmup_ratio=0.2,
    logging_steps=logging_steps,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    max_steps=max_steps,
    overwrite_output_dir=True,
    weight_decay=0.1,
)

#### 3.8.4 `max_steps` Calculation

This line prints the total number of training steps (`max_steps`) that will be executed during the fine-tuning process. It is computed based on:

- the size of the training dataset,
- the batch size,
- and the number of training epochs.

In [None]:
print(f"max_steps: {args.max_steps}")

#### 3.8.5 Check Available CUDA Devices

This line prints the number of CUDA-enabled GPUs (e.g., NVIDIA GPUs) available on your machine

In [None]:
print(torch.cuda.device_count())

#### 3.8.6 Initialize the Trainer

This creates a `Trainer` object from the Hugging Face `transformers` library, which handles the training and evaluation loop.

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

#### 3.8.7 Check if torch can use the GPU

In [None]:
torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### 3.8.8 If there are a lot of labels in the dataset, params need to be flattened

In [None]:
os.environ["MLFLOW_FLATTEN_PARAMS"] = "true"

with mlflow.start_run(run_name=new_model_name):
    train_results = trainer.train()
    mlflow.log_param("label2id", json.dumps(label2id))

#### 3.8.9 Save the best model

In [None]:
model_save_path = os.path.join("Models", new_model_name, "best_model")

os.makedirs(model_save_path, exist_ok=True)  # create folder if needed

# Save the model and image processor locally
model.save_pretrained(model_save_path)
image_processor.save_pretrained(model_save_path)

print(f"Model saved to {model_save_path}")

### 3.9 Evaluate the model

In [None]:
eval_results = trainer.evaluate()
print(eval_results)

## 4. Inference

### 4.1 Run Inference on Multiple Sample Test Videos

This code snippet runs inference on 5 sample videos randomly selected from the test dataset. For each sample, it prints the predicted and actual labels and displays the video as a GIF.

- The `run_inference` function processes each video tensor and returns the model's raw output logits.
- The predicted class is determined by taking the index of the highest logit.
- The actual class label is retrieved from the `id2label` dictionary.
- Each video is saved and displayed as a GIF for visual inspection.

In [None]:
for i in range(1, 6):
    sample_test_video = choice(test_dataset)
    
    logits = run_inference(model, sample_test_video["video"])
    predicted_class_idx = logits.argmax(-1).item()
    
    print(f"Sample {i}:")
    print(f"  Predicted class: {model.config.id2label[predicted_class_idx]}")
    print(f"  Real class: {id2label[sample_test_video['label']]}")
    
    video_tensor = sample_test_video["video"]
    display_gif(video_tensor, mean, std, f"example_{i}.gif")

## 5. Cleanup

### 5.1 Remove all .gif files in the current working directory

In [None]:
for file in os.listdir(os.getcwd()):
    if file.endswith(".gif"):
        os.remove(file)
        print(f"The file {file} has been removed")