# 1. Data Preprocessing

### 1.1 Adjusting the format of `WLASL_v0.3.json`. Also select a subset of 28 glosses.

In [1]:
import json
from pathlib import Path

def generate_flat_dataset(wlasl_json_path, videos_dir, output_path):
    # Only include these 29 selected glosses
    selected_glosses = {
        "cousin", "deaf", "help", "call", "give", "take", "like", "laugh",
        "order", "drop", "pizza", "candy", "shirt", "room", "bar", "language",
        "speech", "cool", "silly", "sweet", "careful", "thin", "last", "soon",
        "what", "california", "convince", "interest"
    }

    with open(wlasl_json_path, "r") as f:
        wlasl_data = json.load(f)

    dataset = []

    for entry in wlasl_data:
        gloss = entry["gloss"]

        # Only include glosses from the selected list
        if gloss not in selected_glosses:
            continue

        split = entry.get("split", "unknown")

        for instance in entry["instances"]:
            video_id = instance.get("video_id")
            video_filename = f"{video_id}.mp4"
            video_path = videos_dir / video_filename

            # Only include if the actual video file exists
            if video_path.exists():
                frame_start = instance.get("frame_start")
                frame_end = instance.get("frame_end")
                instance_split = instance.get("split", split)

                # Final video path in output JSON
                relative_video_path = f"data/{video_filename}"

                dataset.append({
                    "gloss": gloss,
                    "video_path": relative_video_path,
                    "frame_start": frame_start,
                    "frame_end": frame_end,
                    "split": instance_split
                })

    with open(output_path, "w") as f:
        json.dump(dataset, f, indent=2)

In [2]:
base_dir = Path.cwd()
wlasl_json_path = base_dir / "WLASL_v0.3.json"
videos_dir = base_dir / "Videos"
output_path = base_dir / "WLASL_parsed_data_adjustedpath.json"

generate_flat_dataset(wlasl_json_path, videos_dir, output_path)

### 1.2 Restructure the files into `data/test`, `data/train` and `data/val`

In [None]:
import os
import shutil
import json

# Load the JSON data
with open('WLASL_parsed_data_adjustedpath.json', 'r') as f:
    data = json.load(f)

# Base directory where the new folders will be created
base_dir = 'data'
moved_files = 0
missing_files = 0

for item in data:
    # Use the full relative path from JSON (e.g., "Videos/14894.mp4")
    current_path = os.path.normpath(item['video_path'])

    if os.path.exists(current_path):
        # Extract metadata
        split = item['split']
        gloss = item['gloss']

        # Create destination directory structure
        split_dir = os.path.join(base_dir, split)
        os.makedirs(split_dir, exist_ok=True)

        gloss_dir = os.path.join(split_dir, gloss)
        os.makedirs(gloss_dir, exist_ok=True)

        # Get just the filename
        filename = os.path.basename(current_path)

        # Define the new destination path
        new_path = os.path.join(gloss_dir, filename)

        # Move the file
        shutil.move(current_path, new_path)

        moved_files += 1
        print(f"The video {current_path} is moved to {new_path}")
    else:
        missing_files += 1
        print(f"The video {current_path} does not exist")

print(f"Moved {moved_files} files and {missing_files} files are missing")

### 1.3 Display the video distribution of `data`

In [9]:
import os
from collections import defaultdict

def print_gloss_distribution(data_dir, min_per_split=2):
    # Count structure: {gloss: {'train': x, 'val': y, 'test': z, 'total': t}}
    gloss_counts = defaultdict(lambda: {'train': 0, 'val': 0, 'test': 0, 'total': 0})

    # Count videos in each split
    for split in ['train', 'val', 'test']:
        split_path = os.path.join(data_dir, split)
        if not os.path.exists(split_path):
            continue

        for gloss in os.listdir(split_path):
            gloss_path = os.path.join(split_path, gloss)
            if not os.path.isdir(gloss_path):
                continue

            video_files = [f for f in os.listdir(gloss_path) if f.endswith('.mp4')]
            count = len(video_files)

            gloss_counts[gloss][split] += count
            gloss_counts[gloss]['total'] += count

    # Filter glosses with enough videos in each split
    filtered = {
        gloss: counts for gloss, counts in gloss_counts.items()
        if counts['train'] >= min_per_split and counts['val'] >= min_per_split and counts['test'] >= min_per_split
    }

    # Sort by total count descending
    sorted_glosses = sorted(filtered.items(), key=lambda x: x[1]['total'], reverse=True)

    # Print table with ranking
    print(f"{'#':<4} {'Gloss':<20} {'Train':>5} {'Val':>5} {'Test':>5} {'Total':>6}")
    print("-" * 60)
    for idx, (gloss, counts) in enumerate(sorted_glosses, start=1):
        print(f"{idx:<4} {gloss:<20} {counts['train']:>5} {counts['val']:>5} {counts['test']:>5} {counts['total']:>6}")

# Run the function
print_gloss_distribution(data_dir='data', min_per_split=2)

#    Gloss                Train   Val  Test  Total
------------------------------------------------------------
1    cool                    11     3     2     16
2    thin                    11     3     2     16
3    cousin                   9     3     2     14
4    help                    10     2     2     14
5    candy                    8     2     3     13
6    call                     8     2     2     12
7    last                     8     2     2     12
8    pizza                    8     2     2     12
9    shirt                    8     2     2     12
10   what                     6     3     3     12
11   bar                      6     2     3     11
12   deaf                     7     2     2     11
13   laugh                    6     3     2     11
14   room                     7     2     2     11
15   soon                     7     2     2     11
16   take                     7     2     2     11
17   convince                 6     2     2     10
18   give            

# 2. Fine-tune VideoMAE on a subset of 28 signs from the WLASL dataset.

## 2.1 Install dependancies

In [None]:
%pip install -q torch torchvision pytorchvideo transformers albumentations imageio

#### You can then count the number of total videos.

In [None]:
import pathlib

dataset_root_path = pathlib.Path("data")

# Get all video file paths in train, val, test folders
all_video_file_paths = list(dataset_root_path.glob("**/*.mp4"))

video_count_train = len(list(dataset_root_path.glob("train/*/*.mp4")))
video_count_val = len(list(dataset_root_path.glob("val/*/*.mp4")))
video_count_test = len(list(dataset_root_path.glob("test/*/*.mp4")))
video_total = video_count_train + video_count_val + video_count_test
print(f"Total videos: {video_total}")

#### Derive the set of labels present in the dataset.

In [None]:
class_labels = sorted({path.parent.name for path in all_video_file_paths})
label2id = {label: i for i, label in enumerate(class_labels)}
id2label = {i: label for label, i in label2id.items()}

print(f"Unique classes: {list(label2id.keys())}")

## 2.2 Load a model to fine-tune

#### Imports the relevant classes.

- Sets the checkpoint/model name (`"MCG-NJU/videomae-base"`).

- Loads the `VideoMAEImageProcessor` for preprocessing your videos.

- Loads the `VideoMAEForVideoClassification` model pretrained on that checkpoint.

- Passes the `label2id` and `id2label` mappings so the model knows your specific classes.

- Uses `ignore_mismatched_sizes=True` so it can load the pretrained weights even if your classification head shape is different

In [None]:
from transformers import VideoMAEImageProcessor, VideoMAEForVideoClassification

model_ckpt = "MCG-NJU/videomae-base"
hf_token = ""

image_processor = VideoMAEImageProcessor.from_pretrained(model_ckpt, token=hf_token)
model = VideoMAEForVideoClassification.from_pretrained(
    model_ckpt,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
    token=hf_token,
)

## 2.3 Prepare the datasets for training

#### Import Dependancies

In [None]:
import pytorchvideo.data

from pytorchvideo.transforms import (
    ApplyTransformToKey,
    Normalize,
    RandomShortSideScale,
    RemoveKey,
    ShortSideScale,
    UniformTemporalSubsample,
)

from torchvision.transforms import (
    Compose,
    Lambda,
    RandomCrop,
    RandomHorizontalFlip,
    RandomRotation,
    RandomAutocontrast,
    RandomInvert,
    Resize,
)

from albumentations import ElasticTransform

### Setting the input normalization, resolution, frame sampling, and clip duration so the `WLASL` videos are transformed into the exact shape/distribution that VideoMAE expects.

In [None]:
mean = image_processor.image_mean
std = image_processor.image_std
if "shortest_edge" in image_processor.size:
    height = width = image_processor.size["shortest_edge"]
else:
    height = image_processor.size["height"]
    width = image_processor.size["width"]
resize_to = (height, width)

num_frames_to_sample = model.config.num_frames
sample_rate = 4
fps = 30
clip_duration = num_frames_to_sample * sample_rate / fps

### Define the dataset-specific transformations and the datasets respectively. Starting with the training set

In [None]:
import torch
import numpy as np

class AddDistortion(torch.nn.Module):
    """
    Adds Gaussian noise to a video tensor (C, T, H, W).
    """
    def __init__(self, distortion=0.5):
        super().__init__()
        self.distortion = distortion

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        assert len(x.shape) == 4, "video must have shape (C, T, H, W)"
        # Generate per-pixel Gaussian noise
        noise = torch.randn_like(x) * self.distortion
        return x + noise

In [None]:
import random
from torchvision.transforms import RandomAutocontrast, RandomInvert

def apply_with_prob(transform, p=0.3):
    def wrapper(x):
        if random.random() < p:
            return transform(x)
        return x
    return wrapper

In [None]:
train_transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
            UniformTemporalSubsample(num_frames_to_sample),
            Lambda(lambda x: x / 255.0),
            Normalize(mean, std),
            Resize(resize_to, antialias=True),

            # Augmentations to reduce overfitting
            RandomHorizontalFlip(p=0.4),
            RandomRotation(degrees=10),
            Lambda(lambda x: torch.tensor(ElasticTransform(alpha=30.0, sigma=4.0, alpha_affine=4.0)(x.permute(1, 2, 3, 0).numpy())).permute(3, 0, 1, 2)),
            AddDistortion(0.1),

            # RandomInvert & RandomAutocontrast with probability
            Lambda(lambda x: apply_with_prob(RandomAutocontrast(p=1.0), p=0.2)(x)),
            Lambda(lambda x: apply_with_prob(RandomInvert(p=1.0), p=0.3)(x)),
        ])
    )
])

train_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(dataset_root_path, "train"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("random", clip_duration),
    decode_audio=False,
    transform=train_transform,
)

### Now apply the same to the validation and evaluation sets

In [None]:
val_transform = Compose([
    ApplyTransformToKey(
        key="video",
        transform=Compose([
            UniformTemporalSubsample(num_frames_to_sample),
            Lambda(lambda x: x / 255.0),
            Normalize(mean, std),
            Resize(resize_to, antialias=True),
        ])
    )
])

# Datasets
val_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(dataset_root_path, "val"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=val_transform,
)

test_dataset = pytorchvideo.data.Ucf101(
    data_path=os.path.join(dataset_root_path, "test"),
    clip_sampler=pytorchvideo.data.make_clip_sampler("uniform", clip_duration),
    decode_audio=False,
    transform=val_transform,
)

### Access the `num_videos` argument to know the number of videos in the dataset.

In [None]:
print(train_dataset.num_videos, val_dataset.num_videos, test_dataset.num_videos)

## 2.4 Visualize the preprocessed video for better debugging

In [None]:
import imageio
import numpy as np
from IPython.display import Image

def unnormalize_img(img):
    """Un-normalizes the image pixels."""
    img = (img * std) + mean
    img = (img * 255).astype("uint8")
    return img.clip(0, 255)

def create_gif(video_tensor, filename="sample.gif"):
    """Prepares a GIF from a video tensor.
    The video tensor is expected to have the following shape:
    (num_frames, num_channels, height, width).
    """
    frames = []
    for video_frame in video_tensor:
        frame_unnormalized = unnormalize_img(video_frame.permute(1, 2, 0).numpy())
        frames.append(frame_unnormalized)
    kargs = {"duration": 0.25}
    imageio.mimsave(filename, frames, "GIF", **kargs)
    return filename

def display_gif(video_tensor, gif_name="sample.gif"):
    """Prepares and displays a GIF from a video tensor."""
    video_tensor = video_tensor.permute(1, 0, 2, 3)
    gif_filename = create_gif(video_tensor, gif_name)
    return Image(filename=gif_filename)

sample_video = next(iter(train_dataset))
video_tensor = sample_video["video"]
display_gif(video_tensor)