<a href="https://colab.research.google.com/github/alimomennasab/ASL-Translator/blob/main/CS4200_3DCNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


Data Preprocessing

In [None]:
# if a label folder has under 5 videos, add to ignore list
import os

dataset_dir = "/content/drive/MyDrive/WLASL/WLASL100_train_augmented_60frames"
ignore_labels = []
for label in os.listdir(dataset_dir):
    label_dir = os.path.join(dataset_dir, label)
    if len(os.listdir(label_dir)) < 5:
        ignore_labels.append(label)

print(f"Ignored labels: {sorted(ignore_labels)}")

# Ignored labels: []

Ignored labels: []


In [None]:
# histogram of video frames

import cv2
import os
import matplotlib.pyplot as plt
import tqdm
import numpy as np

#vid_path = "/content/drive/MyDrive/WLASL/WLASL100_train_augmented/"
vid_path = "/content/drive/MyDrive/WLASL/WLASL100_train_augmented_64frames/"

frames = []
for folder in tqdm.tqdm(os.listdir(vid_path)):
  folder_path = os.path.join(vid_path, folder)
  for vid in os.listdir(folder_path):
    cap = cv2.VideoCapture(os.path.join(folder_path, vid))

    if not cap.isOpened():
        print("Error: Could not open video.")
    else:
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        #print(f"Total number of frames: {total_frames}")
        frames.append((total_frames, vid))

        cap.release()

frame_counts = [f[0] for f in frames]

# Plot
plt.figure(figsize=(15, 6))
interval = 5
start = 15
end = 160
plt.hist(frame_counts, bins= int((end - start)//interval))
plt.xlabel('Number of Frames')
plt.ylabel('Frequency')
plt.title('Histogram of Training Set Video Frames')
plt.xlim(start, end)
plt.xticks(np.arange(start, end+1, interval))
plt.yticks(np.arange(0, 1000, 50))

plt.show()

# Stats
avg_frames = sum(frame_counts) / len(frame_counts)
min_tuple = min(frames, key=lambda x: x[0])
max_tuple = max(frames, key=lambda x: x[0])
print(f"Average number of frames: {avg_frames:.2f}")
print(f"Minimum frame count: {min_tuple[0]}, video={min_tuple[1]}")
print(f"Maximum frame count: {max_tuple[0]}, video={max_tuple[1]}")

# Average number of frames: 63.07
# Minimum frame count: 19, video=06334_AUG1_zoom_fast.mp4
# Maximum frame count: 155, video=63232_AUG2_bright_mirror.mp4

  1%|          | 1/99 [00:10<16:44, 10.25s/it]


KeyboardInterrupt: 

Data Augmentation

In [None]:
# Go through a dataset and transform every video to have a target number of frames

import cv2
import os
import tqdm
import numpy as np

# duplicate frames in a video starting with the middle frames and expanding outwards until we reach a target frame count
def duplicate_frames(frames, target):
    n = len(frames)
    if n == 0:
        print("Warning: Empty frames list")
        return [], "duplicated"

    indices = list(range(n))
    mid = n // 2
    left, right = mid - 1, mid + 1
    side = "left"

    while len(indices) < target:
        if side == "left" and left >= 0:
            indices.append(left)
            left -= 1
        elif side == "right" and right < n:
            indices.append(right)
            right += 1
        else:
            indices.append(mid)  # fallback
        side = "right" if side == "left" else "left"

    indices.sort()
    new_frames = [frames[i] for i in indices]
    return new_frames, "duplicated"


# remove frames alternately from start and end until we reach a target frame count
def remove_frames(frames, target):
    remove_from_start = True
    start = 0
    end = len(frames)

    while (end - start) > target:
        if remove_from_start:
            start += 1
        else:
            end -= 1
        remove_from_start = not remove_from_start

    return frames[start:end], "removed"


target_frames = 64
data_dir = "/content/drive/MyDrive/WLASL/WLASL100_train_augmented/"
output_dir = f"/content/drive/MyDrive/WLASL/WLASL100_train_augmented_{target_frames}frames/"
os.makedirs(output_dir, exist_ok=True)
print(f"Output directory: {output_dir}")

for folder in tqdm.tqdm(os.listdir(data_dir)):
    print(f"\nProcessing folder {folder}")
    folder_path = os.path.join(data_dir, folder)
    output_folder_path = os.path.join(output_dir, folder)
    os.makedirs(output_folder_path, exist_ok=True)

    for vid in (os.listdir(folder_path)):
        frames = []
        suffix = ""

        cap = cv2.VideoCapture(os.path.join(folder_path, vid))
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            frames.append(frame)
        cap.release()

        if len(frames) < target_frames:
            # print(f"Video {vid} has less than 60 frames")
            frames, suffix = duplicate_frames(frames, target_frames)
        elif len(frames) > target_frames:
            # print(f"Video {vid} has more than 60 frames")
            frames, suffix = remove_frames(frames, target_frames)
        # else:
        #     print(f"Video {vid} already has {target_frames} frames")

        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        fps = 30
        if frames:
            height, width, _ = frames[0].shape

            # add suffix to filename
            base, extension = os.path.splitext(vid)
            if suffix:
                output_filename = f"{base}_{suffix}{extension}"
            else:
                output_filename = vid

            output_video_path = os.path.join(output_folder_path, output_filename)
            out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

            for frame in frames:
                out.write(frame)

            out.release()
        else:
            print(f"Warning: No frames were processed for video {vid} in folder {folder}. Skipping video writing.")


Output directory: /content/drive/MyDrive/WLASL/WLASL100_train_augmented_64frames/


  0%|          | 0/100 [00:00<?, ?it/s]


Processing folder tall


  1%|          | 1/100 [00:15<25:42, 15.58s/it]


Processing folder woman


  2%|▏         | 2/100 [00:33<27:26, 16.80s/it]


Processing folder candy


  3%|▎         | 3/100 [01:05<38:54, 24.07s/it]


Processing folder orange


  4%|▍         | 4/100 [01:13<28:18, 17.69s/it]


Processing folder cow


  5%|▌         | 5/100 [01:25<24:48, 15.67s/it]


Processing folder same


  6%|▌         | 6/100 [01:36<21:38, 13.81s/it]


Processing folder help


  7%|▋         | 7/100 [01:50<21:37, 13.95s/it]


Processing folder shirt


  8%|▊         | 8/100 [02:05<22:00, 14.35s/it]


Processing folder africa


  9%|▉         | 9/100 [02:21<22:28, 14.82s/it]


Processing folder drink


 10%|█         | 10/100 [02:36<22:16, 14.85s/it]


Processing folder fish


 11%|█         | 11/100 [02:55<23:50, 16.07s/it]


Processing folder tell


 12%|█▏        | 12/100 [03:06<21:27, 14.63s/it]


Processing folder pizza


 13%|█▎        | 13/100 [03:25<23:06, 15.93s/it]


Processing folder give


 14%|█▍        | 14/100 [03:39<22:05, 15.42s/it]


Processing folder white


 15%|█▌        | 15/100 [03:50<19:57, 14.09s/it]


Processing folder computer


 16%|█▌        | 16/100 [04:35<32:37, 23.30s/it]


Processing folder school


 17%|█▋        | 17/100 [04:54<30:37, 22.13s/it]


Processing folder black


 18%|█▊        | 18/100 [05:16<29:58, 21.94s/it]


Processing folder fine


 19%|█▉        | 19/100 [05:25<24:19, 18.02s/it]


Processing folder play


 20%|██        | 20/100 [05:29<18:38, 13.98s/it]


Processing folder dark


 21%|██        | 21/100 [05:39<16:54, 12.84s/it]


Processing folder wrong


 22%|██▏       | 22/100 [05:50<15:37, 12.01s/it]


Processing folder yes


 23%|██▎       | 23/100 [06:11<18:54, 14.73s/it]


Processing folder forget


 24%|██▍       | 24/100 [06:30<20:17, 16.02s/it]


Processing folder later


 25%|██▌       | 25/100 [06:59<24:57, 19.97s/it]


Processing folder hat


 26%|██▌       | 26/100 [07:04<19:05, 15.48s/it]


Processing folder what


 27%|██▋       | 27/100 [07:06<14:01, 11.53s/it]


Processing folder son


 28%|██▊       | 28/100 [07:34<19:37, 16.36s/it]


Processing folder thursday


 29%|██▉       | 29/100 [07:44<17:16, 14.59s/it]


Processing folder hot


 30%|███       | 30/100 [07:55<15:40, 13.43s/it]


Processing folder book


 31%|███       | 31/100 [07:59<12:15, 10.66s/it]


Processing folder dance


 32%|███▏      | 32/100 [08:15<13:55, 12.29s/it]


Processing folder before


 33%|███▎      | 33/100 [08:46<19:46, 17.71s/it]


Processing folder corn


 34%|███▍      | 34/100 [09:08<20:55, 19.02s/it]


Processing folder thin


 35%|███▌      | 35/100 [10:10<34:49, 32.15s/it]


Processing folder who


 36%|███▌      | 36/100 [11:21<46:38, 43.73s/it]


Processing folder purple


 37%|███▋      | 37/100 [11:48<40:40, 38.74s/it]


Processing folder medicine


 38%|███▊      | 38/100 [12:09<34:27, 33.35s/it]


Processing folder cousin


 39%|███▉      | 39/100 [12:43<34:04, 33.51s/it]


Processing folder year


 40%|████      | 40/100 [13:16<33:19, 33.33s/it]


Processing folder work


 41%|████      | 41/100 [14:03<36:54, 37.54s/it]


Processing folder cook


 42%|████▏     | 42/100 [14:33<34:01, 35.20s/it]


Processing folder finish


 43%|████▎     | 43/100 [15:06<32:50, 34.57s/it]


Processing folder bird


 44%|████▍     | 44/100 [15:55<36:12, 38.79s/it]


Processing folder deaf


 45%|████▌     | 45/100 [16:24<33:01, 36.03s/it]


Processing folder pink


 46%|████▌     | 46/100 [16:51<29:56, 33.26s/it]


Processing folder cool


 47%|████▋     | 47/100 [17:40<33:26, 37.85s/it]


Processing folder wife


 48%|████▊     | 48/100 [18:06<29:56, 34.55s/it]


Processing folder eat


 49%|████▉     | 49/100 [18:25<25:15, 29.71s/it]


Processing folder time


 50%|█████     | 50/100 [18:52<24:02, 28.85s/it]


Processing folder birthday


 51%|█████     | 51/100 [19:00<18:24, 22.54s/it]


Processing folder cheat


 52%|█████▏    | 52/100 [19:44<23:23, 29.24s/it]


Processing folder visit


 53%|█████▎    | 53/100 [20:13<22:48, 29.12s/it]


Processing folder all


 54%|█████▍    | 54/100 [20:45<22:54, 29.89s/it]


Processing folder want


 55%|█████▌    | 55/100 [21:13<21:59, 29.33s/it]


Processing folder basketball


 56%|█████▌    | 56/100 [22:08<27:13, 37.12s/it]


Processing folder chair


 57%|█████▋    | 57/100 [22:31<23:23, 32.65s/it]


Processing folder but


 58%|█████▊    | 58/100 [22:44<18:56, 27.05s/it]


Processing folder letter


 59%|█████▉    | 59/100 [23:21<20:25, 29.88s/it]


Processing folder language


 60%|██████    | 60/100 [24:01<21:58, 32.96s/it]


Processing folder doctor


 61%|██████    | 61/100 [24:20<18:39, 28.71s/it]


Processing folder graduate


 62%|██████▏   | 62/100 [24:54<19:14, 30.37s/it]


Processing folder pull


 63%|██████▎   | 63/100 [25:46<22:38, 36.71s/it]


Processing folder short


 64%|██████▍   | 64/100 [26:26<22:42, 37.84s/it]


Processing folder family


 65%|██████▌   | 65/100 [27:01<21:33, 36.94s/it]


Processing folder color


 66%|██████▌   | 66/100 [27:32<19:56, 35.20s/it]


Processing folder bowling


 67%|██████▋   | 67/100 [28:27<22:41, 41.25s/it]


Processing folder enjoy


 68%|██████▊   | 68/100 [29:08<21:53, 41.06s/it]


Processing folder kiss


 69%|██████▉   | 69/100 [29:37<19:24, 37.56s/it]


Processing folder meet


 70%|███████   | 70/100 [30:06<17:28, 34.93s/it]


Processing folder clothes


 71%|███████   | 71/100 [30:19<13:41, 28.31s/it]


Processing folder water


 72%|███████▏  | 72/100 [30:39<11:58, 25.65s/it]


Processing folder how


 73%|███████▎  | 73/100 [31:14<12:50, 28.52s/it]


Processing folder no


 74%|███████▍  | 74/100 [31:39<11:52, 27.39s/it]


Processing folder dog


 75%|███████▌  | 75/100 [32:10<11:53, 28.54s/it]


Processing folder blue


 76%|███████▌  | 76/100 [32:31<10:32, 26.35s/it]


Processing folder secretary


 77%|███████▋  | 77/100 [33:10<11:35, 30.23s/it]


Processing folder many


 78%|███████▊  | 78/100 [33:49<11:57, 32.63s/it]


Processing folder need


 79%|███████▉  | 79/100 [34:05<09:40, 27.64s/it]


Processing folder go


 80%|████████  | 80/100 [34:44<10:22, 31.15s/it]


Processing folder city


 81%|████████  | 81/100 [35:16<09:57, 31.46s/it]


Processing folder bed


 82%|████████▏ | 82/100 [36:08<11:16, 37.57s/it]


Processing folder decide


 83%|████████▎ | 83/100 [36:33<09:34, 33.81s/it]


Processing folder study


 84%|████████▍ | 84/100 [37:10<09:16, 34.78s/it]


Processing folder walk


 85%|████████▌ | 85/100 [37:34<07:51, 31.44s/it]


Processing folder last


 86%|████████▌ | 86/100 [38:05<07:18, 31.29s/it]


Processing folder can


 87%|████████▋ | 87/100 [38:28<06:15, 28.89s/it]


Processing folder like


 88%|████████▊ | 88/100 [38:59<05:53, 29.45s/it]


Processing folder wait


 89%|████████▉ | 89/100 [39:30<05:31, 30.11s/it]


Processing folder change


 90%|█████████ | 90/100 [40:26<06:17, 37.80s/it]


Processing folder paint


 91%|█████████ | 91/100 [41:12<06:01, 40.19s/it]


Processing folder hearing


 92%|█████████▏| 92/100 [41:35<04:41, 35.23s/it]


Processing folder brown


 93%|█████████▎| 93/100 [41:58<03:39, 31.41s/it]


Processing folder paper


 94%|█████████▍| 94/100 [42:18<02:48, 28.10s/it]


Processing folder jacket


 95%|█████████▌| 95/100 [42:53<02:29, 29.97s/it]


Processing folder full


 96%|█████████▌| 96/100 [43:23<02:00, 30.12s/it]


Processing folder apple


 97%|█████████▋| 97/100 [43:46<01:24, 28.02s/it]


Processing folder man


 98%|█████████▊| 98/100 [44:23<01:01, 30.56s/it]


Processing folder right


 99%|█████████▉| 99/100 [44:52<00:30, 30.11s/it]


Processing folder accident


100%|██████████| 100/100 [45:55<00:00, 27.56s/it]


Data loading

In [None]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import cv2
from pathlib import Path
from PIL import Image
import numpy as np

target_frames = 64
TRAIN_DATASET_DIR = f"/content/drive/MyDrive/WLASL/WLASL100_train_augmented_{target_frames}frames"
VAL_DATASET_DIR = f'/content/drive/MyDrive/WLASL/WLASL100_val_{target_frames}frames'
IMG_SIZE = 112
BATCH_SIZE = 4


if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU")
else:
    device = "cpu"
    print("Using CPU")

# Image transforms
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
    #transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])


class VideoDataset(Dataset):
    def __init__(self, root_dir, transform=None, class_map=None, num_frames=16, mode="Train"):
        self.root_dir = root_dir
        self.transform = transform
        self.num_frames = num_frames
        self.mode = mode
        self.samples = [] # stores (video_path, label_idx) tuples

        if class_map is None:
            self.class_to_idx = {}
            self.idx_to_class = []
            label_folders = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
            for i, label_name in enumerate(label_folders):
                self.class_to_idx[label_name] = i
                self.idx_to_class.append(label_name)
        else:
            self.class_to_idx = class_map['class_to_idx']
            self.idx_to_class = class_map['idx_to_class']

        # Populate samples and create class mappings
        for label_name in os.listdir(root_dir):
            if label_name in self.class_to_idx: # Only include labels present in the class_to_idx map
                label_path = os.path.join(root_dir, label_name)
                if os.path.isdir(label_path):
                    label_idx = self.class_to_idx[label_name]
                    for video_file in os.listdir(label_path):
                        if video_file.lower().endswith('.mp4'):
                            video_path = os.path.join(label_path, video_file)
                            self.samples.append((video_path, label_idx))

        self.class_names = self.idx_to_class
        print(f"Found {len(self.samples)} video samples across {len(set(label for _, label in self.samples))} classes in {root_dir}")

    def __len__(self):
        return len(self.samples)

    def random_sample(self, frames):
      max_start = len(frames) - self.num_frames
      start = np.random.randint(0, max_start + 1)
      return frames[start:start + self.num_frames]

    def uniform_sample(self, frames):
      idxs = np.linspace(0, len(frames) - 1, self.num_frames).astype(int)
      return [frames[i] for i in idxs]

    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        frames = self._load_all_frames(video_path)

        assert self.mode == "Train" or self.mode == "Val", "Mode must be either 'Train' or 'Val'"

        if self.mode == "Train":
          frames = self.random_sample(frames)
        elif self.mode == "Val":
          frames = self.uniform_sample(frames)

        # Apply transformations to each frame
        transformed_frames = []
        for frame in frames:
            pil_image = Image.fromarray(frame)
            if self.transform:
                transformed_frames.append(self.transform(pil_image))

        frames_tensor = torch.stack(transformed_frames) # Shape: (T, C, H, W)
        frames_tensor = frames_tensor.permute(1, 0, 2, 3) # Shape: (C, T, H, W)

        return frames_tensor, label

    def _load_all_frames(self, video_path):
      cap = cv2.VideoCapture(video_path)
      if not cap.isOpened():
          print(f"Error: Could not open video file: {video_path}")
          return []

      extracted_frames_rgb = []
      while True:
          ret, frame = cap.read()
          if not ret:
              break
          rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
          extracted_frames_rgb.append(rgb_frame)

      cap.release()

      if len(extracted_frames_rgb) == 0:
          print(f"Warning: No frames extracted from video {video_path}")

      return extracted_frames_rgb



# Load dataset
print("Creating dataset")
train_ds = VideoDataset(TRAIN_DATASET_DIR, transform=transform, num_frames=16, mode="Train")

# Use the class mapping from the training set for validation set
val_class_map = {'class_to_idx': train_ds.class_to_idx, 'idx_to_class': train_ds.idx_to_class}
val_ds = VideoDataset(VAL_DATASET_DIR, transform=transform, class_map=val_class_map, num_frames=16, mode="Val")

# Dataloaders
print("Creating dataloaders")
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

class_names = train_ds.class_names
num_classes = len(class_names)

print("Classes:", class_names[:10], "...")
print("num_classes =", num_classes)


Using GPU
Creating dataset
Found 2784 video samples across 99 classes in /content/drive/MyDrive/WLASL/WLASL100_train_augmented_64frames
Found 195 video samples across 98 classes in /content/drive/MyDrive/WLASL/WLASL100_val_64frames
Creating dataloaders
Classes: ['accident', 'africa', 'all', 'apple', 'basketball', 'bed', 'before', 'bird', 'birthday', 'black'] ...
num_classes = 99


In [None]:
# ensure both training and val datasets have exactly target_frame number of frames in each vid

import os
import cv2
from collections import Counter
from tqdm import tqdm

def get_frame_count(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Failed to open:", video_path)
        return None
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    cap.release()
    return total

def scan_directory(root):
    frame_counts = []
    bad_files = []

    # find all mp4 files
    video_files = []
    for cls in os.listdir(root):
        cls_dir = os.path.join(root, cls)
        if os.path.isdir(cls_dir):
            for f in os.listdir(cls_dir):
                if f.lower().endswith('.mp4'):
                    video_files.append(os.path.join(cls_dir, f))

    print(f"Found {len(video_files)} videos in {root}")

    for path in tqdm(video_files):
        count = get_frame_count(path)
        if count is None:
            bad_files.append(path)
        else:
            frame_counts.append(count)

    return frame_counts, bad_files


train_root = "/content/drive/MyDrive/WLASL/WLASL100_train_augmented_64frames"
val_root   = "/content/drive/MyDrive/WLASL/WLASL100_val_64frames"

train_counts, train_bad = scan_directory(train_root)
val_counts, val_bad = scan_directory(val_root)

print("TRAIN SET FRAME COUNTS")
print("Unique:", sorted(set(train_counts)))
print("Count histogram:", Counter(train_counts))

print("VAL SET FRAME COUNTS")
print("Unique:", sorted(set(val_counts)))
print("Count histogram:", Counter(val_counts))


Found 2784 videos in /content/drive/MyDrive/WLASL/WLASL100_train_augmented_64frames


  6%|▌         | 160/2784 [00:27<07:38,  5.72it/s]


KeyboardInterrupt: 

Model training

In [None]:
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import json
import os
from torchvision.models.video import r3d_18
from tqdm import tqdm
from google.colab import drive

def plot_losses(train_losses: list, val_losses: list):
    epochs = range(1, len(train_losses) + 1)
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Curves')
    plt.legend()
    plt.show()


# Load pretrained 3D CNN
num_classes = len(class_names)
model = r3d_18(weights="KINETICS400_V1")
model.fc = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(model.fc.in_features, num_classes)
)
model = model.to(device)

# Hyperparameters
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
EPOCHS = 8
start_epoch = 0
run = 4

# Load saved model from checkpoint
drive.mount('/content/drive')
CHECKPOINT_PATH = f"/content/drive/MyDrive/3d_cnn_asl_checkpoint_run{run}.pt"
MODEL_PATH = f"/content/drive/MyDrive/3d_cnn_asl_run{run}.pt"
LABELS_PATH = f"/content/drive/MyDrive/3d_cnn_asl_labels_run{run}.json"

if os.path.exists(CHECKPOINT_PATH):
    print("Loading checkpoint")
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    start_epoch = checkpoint["epoch"] + 1
    train_losses = checkpoint["train_losses"]
    val_losses = checkpoint["val_losses"]
    print(f"Resuming from epoch {start_epoch}")
else:
    train_losses, val_losses = [], []

# Training loop
for epoch in range(start_epoch, EPOCHS):
    model.train()
    epoch_loss, correct, total = 0, 0, 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        _, preds = outputs.max(1)
        total += labels.size(0)
        correct += preds.eq(labels).sum().item()
        pbar.set_postfix(loss=loss.item(), acc=100.*correct/total)

    train_losses.append(epoch_loss / len(train_loader))

    # Validation
    model.eval()
    val_epoch_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_epoch_loss += loss.item()
            _, preds = outputs.max(1)
            val_total += labels.size(0)
            val_correct += preds.eq(labels).sum().item()
    val_losses.append(val_epoch_loss / len(val_loader))
    val_acc = 100. * val_correct / val_total

    print(f"Epoch {epoch+1}: Train Acc={100.*correct/total:.2f}% | Val Acc={val_acc:.2f}%")

    # Save checkpoint every epoch
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "train_losses": train_losses,
        "val_losses": val_losses
    }
    torch.save(checkpoint, CHECKPOINT_PATH)
    print(f"Checkpoint saved at epoch {epoch+1}")



torch.save(model.state_dict(), MODEL_PATH)
with open(LABELS_PATH, "w") as f:
    json.dump(class_names, f)
print("Final model saved to:", MODEL_PATH)

Downloading: "https://download.pytorch.org/models/r3d_18-b3b3357e.pth" to /root/.cache/torch/hub/checkpoints/r3d_18-b3b3357e.pth


100%|██████████| 127M/127M [00:01<00:00, 92.9MB/s]


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Epoch 1/8: 100%|██████████| 696/696 [19:41<00:00,  1.70s/it, acc=7.9, loss=4.71]


Epoch 1: Train Acc=7.90% | Val Acc=6.67%
Checkpoint saved at epoch 1


Epoch 2/8: 100%|██████████| 696/696 [09:39<00:00,  1.20it/s, acc=31.6, loss=2.72]


Epoch 2: Train Acc=31.61% | Val Acc=13.33%
Checkpoint saved at epoch 2


Epoch 3/8: 100%|██████████| 696/696 [09:36<00:00,  1.21it/s, acc=61.3, loss=2.14]


Epoch 3: Train Acc=61.28% | Val Acc=13.85%
Checkpoint saved at epoch 3


Epoch 4/8: 100%|██████████| 696/696 [09:38<00:00,  1.20it/s, acc=78.4, loss=0.311]


Epoch 4: Train Acc=78.45% | Val Acc=15.90%
Checkpoint saved at epoch 4


Epoch 5/8: 100%|██████████| 696/696 [09:44<00:00,  1.19it/s, acc=89.5, loss=0.342]


Epoch 5: Train Acc=89.48% | Val Acc=15.90%
Checkpoint saved at epoch 5


Epoch 6/8: 100%|██████████| 696/696 [09:52<00:00,  1.17it/s, acc=93.8, loss=0.111]


Epoch 6: Train Acc=93.75% | Val Acc=16.92%
Checkpoint saved at epoch 6


Epoch 7/8: 100%|██████████| 696/696 [09:42<00:00,  1.19it/s, acc=95, loss=0.272]


Epoch 7: Train Acc=94.97% | Val Acc=14.87%
Checkpoint saved at epoch 7


Epoch 8/8: 100%|██████████| 696/696 [09:44<00:00,  1.19it/s, acc=95, loss=0.254]


Epoch 8: Train Acc=95.04% | Val Acc=14.87%
Checkpoint saved at epoch 8
Final model saved to: /content/drive/MyDrive/3d_cnn_asl_run4.pt


Testing model

In [None]:
import cv2
import numpy as np
import torch
import json
import matplotlib.pyplot as plt
import tqdm
import os
import torch.nn as nn
from torchvision.models.video import r3d_18
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from pathlib import Path
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay


IMG_SIZE = 112
BATCH_SIZE = 4
target_frames = 64
run = 4

# choose device
if torch.cuda.is_available():
    device = "cuda"
    print("Using GPU")
else:
    device = "cpu"
    print("Using CPU")

# Test dataset augmentations
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.43216, 0.394666, 0.37645], std=[0.22803, 0.22145, 0.216989]),
    #transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
])


# Test dataset that only uses words the model trained on
class VideoDataset(Dataset):
    def __init__(self, root_dir, transform=None, class_map=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = [] # stores (video_path, label_idx) tuples

        if class_map is None:
            self.class_to_idx = {}
            self.idx_to_class = []
            label_folders = sorted([d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))])
            for i, label_name in enumerate(label_folders):
                self.class_to_idx[label_name] = i
                self.idx_to_class.append(label_name)
        else:
            self.class_to_idx = class_map['class_to_idx']
            self.idx_to_class = class_map['idx_to_class']

        # Populate samples and create class mappings
        for label_name in os.listdir(root_dir):
            if label_name in self.class_to_idx: # Only include labels present in the class_to_idx map
                label_path = os.path.join(root_dir, label_name)
                if os.path.isdir(label_path):
                    label_idx = self.class_to_idx[label_name]
                    for video_file in os.listdir(label_path):
                        if video_file.lower().endswith('.mp4'):
                            video_path = os.path.join(label_path, video_file)
                            self.samples.append((video_path, label_idx))

        self.class_names = self.idx_to_class
        print(f"Found {len(self.samples)} video samples across {len(set(label for _, label in self.samples))} classes in {root_dir}")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        video_path, label = self.samples[idx]
        frames = self._load_all_frames(video_path)

        # Apply transformations to each frame
        transformed_frames = []
        for frame in frames:
            pil_image = Image.fromarray(frame)
            if self.transform:
                transformed_frames.append(self.transform(pil_image))

        frames_tensor = torch.stack(transformed_frames) # Shape: (T, C, H, W)
        frames_tensor = frames_tensor.permute(1, 0, 2, 3) # Shape: (C, T, H, W)

        return frames_tensor, label

    def _load_all_frames(self, video_path):
      cap = cv2.VideoCapture(video_path)
      if not cap.isOpened():
          print(f"Error: Could not open video file: {video_path}")
          return []

      extracted_frames_rgb = []
      while True:
          ret, frame = cap.read()
          if not ret:
              break
          rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
          extracted_frames_rgb.append(rgb_frame)

      cap.release()

      if len(extracted_frames_rgb) == 0:
          print(f"Warning: No frames extracted from video {video_path}")

      return extracted_frames_rgb


# Config
test_dir = f"/content/drive/MyDrive/WLASL/WLASL100_test_{target_frames}frames"
LABELS_PATH = f"/content/drive/MyDrive/3d_cnn_asl_labels_run{run}.json"
MODEL_PATH = f"/content/drive/MyDrive/3d_cnn_asl_checkpoint_run{run}.pt"


# Build labels list
with open(LABELS_PATH, "r") as f:
    class_names = json.load(f)
class_to_idx = {c: i for i, c in enumerate(class_names)}
num_classes = len(class_names)

train_labels = set(class_names)
test_labels = {d.name for d in Path(test_dir).iterdir() if d.is_dir()}
in_train_and_test = sorted(list(train_labels & test_labels))
print(f"Testing on {len(in_train_and_test)} shared classes between training & testing")

# Create class_map
test_class_map = {'class_to_idx': class_to_idx, 'idx_to_class': class_names}

# Initialize dataset, dataloader, and model
test_dataset = VideoDataset(test_dir, transform=transform, class_map=test_class_map)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

model = r3d_18(weights="KINETICS400_V1")
model.fc = nn.Sequential(
    nn.Dropout(p=0.5),
    nn.Linear(model.fc.in_features, num_classes)
)

# Load model
checkpoint = torch.load(MODEL_PATH, map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
model = model.to(device)
model.eval()

# Testing
all_preds = []
all_labels = []
with torch.no_grad():
    for images, labels in tqdm.tqdm(test_loader, desc="Testing"):
        images = images.to(device)
        outputs = model(images)
        _, preds = outputs.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Compute & display metrics
accuracy = accuracy_score(all_labels, all_preds)
print(f"Test Accuracy: {accuracy * 100:.2f}%")

# cm = confusion_matrix(all_labels, all_preds)
# print("Confusion Matrix:")
# print(cm)

# disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
# disp.plot(xticks_rotation='vertical', cmap='Blues')
# plt.title("Confusion Matrix")
# plt.show()

Using GPU
Testing on 85 shared classes between training & testing
Found 111 video samples across 85 classes in /content/drive/MyDrive/WLASL/WLASL100_test_64frames


Testing: 100%|██████████| 28/28 [01:19<00:00,  2.83s/it]

Test Accuracy: 17.12%



