# Step 7: Colab Train + Export (WLASL Processed)
This notebook trains a tiny baseline and exports TorchScript weights for CPU inference.

In [1]:
!git clone https://github.com/BhumipatSaengduan/wlasl.git
%cd wlasl_demo

Cloning into 'wlasl'...
remote: Enumerating objects: 91, done.[K
remote: Counting objects: 100% (91/91), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 91 (delta 35), reused 91 (delta 35), pack-reused 0 (from 0)[K
Receiving objects: 100% (91/91), 40.59 KiB | 186.00 KiB/s, done.
Resolving deltas: 100% (35/35), done.
/content/wlasl_demo


In [2]:
!pip -q install torch torchvision opencv-python tqdm


## Kaggle dataset download (requires kaggle.json)

In [3]:
!pip -q install kaggle
from google.colab import files
files.upload()  # upload kaggle.json
!mkdir -p ~/.kaggle && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d risangbaskoro/wlasl-processed -p /content/data --unzip
!ls -la /content/data

Saving kaggle.json to kaggle.json
Dataset URL: https://www.kaggle.com/datasets/risangbaskoro/wlasl-processed
License(s): other
Downloading wlasl-processed.zip to /content/data
 99% 4.78G/4.82G [03:34<00:01, 23.0MB/s]
100% 4.82G/4.82G [03:34<00:00, 24.1MB/s]
total 14252
drwxr-xr-x 3 root root     4096 Feb 24 14:16 .
drwxr-xr-x 1 root root     4096 Feb 24 14:07 ..
-rw-r--r-- 1 root root    54617 Feb 24 14:15 missing.txt
-rw-r--r-- 1 root root   704255 Feb 24 14:15 nslt_1000.json
-rw-r--r-- 1 root root   107142 Feb 24 14:15 nslt_100.json
-rw-r--r-- 1 root root  1136283 Feb 24 14:15 nslt_2000.json
-rw-r--r-- 1 root root   272399 Feb 24 14:15 nslt_300.json
drwxr-xr-x 2 root root   339968 Feb 24 14:05 videos
-rw-r--r-- 1 root root    22907 Feb 24 14:16 wlasl_class_list.txt
-rw-r--r-- 1 root root 11932637 Feb 24 14:15 WLASL_v0.3.json


In [4]:
import os, json, math
import numpy as np
import cv2
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from tqdm import tqdm

os.chdir('/content/wlasl')  # สำคัญมาก อย่าลืม

# --- Load dataset จาก nslt_100.json ---
def load_nslt_dataset(data_root, json_path, max_classes=100, max_samples_per_class=50):
    with open(json_path) as f:
        data = json.load(f)
    video_dir = os.path.join(data_root, 'videos')
    class_to_samples = {}
    for video_id, info in data.items():
        class_idx = info['action'][0]
        video_file = os.path.join(video_dir, f"{int(video_id):05d}.mp4")
        if not os.path.exists(video_file):
            continue
        class_to_samples.setdefault(class_idx, []).append({
            'path': video_file, 'subset': info['subset']
        })
    selected_classes = sorted(class_to_samples.keys())[:max_classes]
    label_map = {orig: new for new, orig in enumerate(selected_classes)}
    train_samples, val_samples = [], []
    for orig_class in selected_classes:
        items = class_to_samples[orig_class][:max_samples_per_class]
        new_label = label_map[orig_class]
        for item in items:
            sample = (item['path'], new_label)
            if item['subset'] == 'val':
                val_samples.append(sample)
            else:
                train_samples.append(sample)
    labels = [str(c) for c in selected_classes]
    print(f"dataset: train={len(train_samples)} val={len(val_samples)} classes={len(selected_classes)}")
    return train_samples, val_samples, labels

# --- Model ---
class TinyFrameCNN(nn.Module):
    def __init__(self, in_ch=3, feat_dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_ch, 32, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(inplace=True), nn.MaxPool2d(2),
            nn.Conv2d(64, feat_dim, 3, padding=1), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d(1),
        )
    def forward(self, x):
        return self.net(x).flatten(1)

class TinyVideoClassifier(nn.Module):
    def __init__(self, num_classes, frames=8, size=112, feat_dim=128):
        super().__init__()
        self.backbone = TinyFrameCNN(feat_dim=feat_dim)
        self.classifier = nn.Linear(feat_dim, num_classes)
    def forward(self, x):
        b, t, c, h, w = x.shape
        feat = self.backbone(x.view(b*t, c, h, w)).view(b, t, -1).mean(1)
        return self.classifier(feat)

# --- Dataset ---
FRAMES, SIZE = 8, 112

def sample_frames(path, num_frames=FRAMES, size=SIZE):
    cap = cv2.VideoCapture(path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    frames = []
    if total > 0:
        for idx in np.linspace(0, total-1, num_frames, dtype=int):
            cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
            ok, frame = cap.read()
            if ok and frame is not None:
                frame = cv2.resize(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), (size, size))
                frames.append(frame.astype(np.float32) / 255.0)
    cap.release()
    if not frames:
        frames = [np.zeros((size, size, 3), np.float32)]
    while len(frames) < num_frames:
        frames.append(frames[-1])
    arr = np.stack(frames[:num_frames])
    return torch.from_numpy(arr.transpose(0, 3, 1, 2)).float()

class VideoDataset(Dataset):
    def __init__(self, samples):
        self.samples = samples
    def __len__(self):
        return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        return sample_frames(path), label

# --- Train ---
EPOCHS = 10
BATCH_SIZE = 8
LR = 1e-3
OUT_DIR = '/content/out'
os.makedirs(OUT_DIR, exist_ok=True)

train_samples, val_samples, labels = load_nslt_dataset(
    '/content/data', '/content/data/nslt_100.json',
    max_classes=100, max_samples_per_class=50
)

with open(f'{OUT_DIR}/labels.json', 'w') as f:
    json.dump(labels, f)

train_loader = DataLoader(VideoDataset(train_samples), batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader   = DataLoader(VideoDataset(val_samples),   batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

model = TinyVideoClassifier(num_classes=len(labels)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = nn.CrossEntropyLoss()

best_loss = math.inf
for epoch in range(1, EPOCHS+1):
    model.train()
    train_loss = 0.0
    for x, y in tqdm(train_loader, desc=f"train {epoch}"):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        loss = criterion(model(x), y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * x.size(0)
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for x, y in tqdm(val_loader, desc=f"val {epoch}"):
            x, y = x.to(device), y.to(device)
            val_loss += criterion(model(x), y).item() * x.size(0)
    train_loss /= max(1, len(train_samples))
    val_loss   /= max(1, len(val_samples))
    print(f"epoch={epoch} train_loss={train_loss:.4f} val_loss={val_loss:.4f}")
    if val_loss < best_loss:
        best_loss = val_loss
        torch.save({
            'state_dict': model.state_dict(),
            'meta': {'num_classes': len(labels), 'frames': FRAMES, 'size': SIZE}
        }, f'{OUT_DIR}/best.pt')
        print(f"  saved best checkpoint")

dataset: train=848 val=165 classes=100
device: cuda


train 1: 100%|██████████| 106/106 [05:18<00:00,  3.00s/it]
val 1: 100%|██████████| 21/21 [01:09<00:00,  3.32s/it]


epoch=1 train_loss=4.6122 val_loss=4.6070
  saved best checkpoint


train 2: 100%|██████████| 106/106 [05:13<00:00,  2.96s/it]
val 2: 100%|██████████| 21/21 [01:09<00:00,  3.31s/it]


epoch=2 train_loss=4.6048 val_loss=4.6060
  saved best checkpoint


train 3: 100%|██████████| 106/106 [05:15<00:00,  2.98s/it]
val 3: 100%|██████████| 21/21 [01:09<00:00,  3.32s/it]


epoch=3 train_loss=4.6022 val_loss=4.6040
  saved best checkpoint


train 4: 100%|██████████| 106/106 [05:14<00:00,  2.97s/it]
val 4: 100%|██████████| 21/21 [01:08<00:00,  3.28s/it]


epoch=4 train_loss=4.6003 val_loss=4.6045


train 5: 100%|██████████| 106/106 [05:12<00:00,  2.95s/it]
val 5: 100%|██████████| 21/21 [01:09<00:00,  3.33s/it]


epoch=5 train_loss=4.5928 val_loss=4.6030
  saved best checkpoint


train 6: 100%|██████████| 106/106 [05:12<00:00,  2.94s/it]
val 6: 100%|██████████| 21/21 [01:10<00:00,  3.34s/it]


epoch=6 train_loss=4.5901 val_loss=4.6049


train 7: 100%|██████████| 106/106 [05:13<00:00,  2.96s/it]
val 7: 100%|██████████| 21/21 [01:09<00:00,  3.31s/it]


epoch=7 train_loss=4.5810 val_loss=4.6038


train 8: 100%|██████████| 106/106 [05:13<00:00,  2.96s/it]
val 8: 100%|██████████| 21/21 [01:09<00:00,  3.31s/it]


epoch=8 train_loss=4.5649 val_loss=4.5962
  saved best checkpoint


train 9: 100%|██████████| 106/106 [05:13<00:00,  2.96s/it]
val 9: 100%|██████████| 21/21 [01:09<00:00,  3.32s/it]


epoch=9 train_loss=4.5472 val_loss=4.6145


train 10: 100%|██████████| 106/106 [05:12<00:00,  2.95s/it]
val 10: 100%|██████████| 21/21 [01:09<00:00,  3.32s/it]

epoch=10 train_loss=4.5381 val_loss=4.5975





In [5]:
import json

# map index → ชื่อคำภาษาอังกฤษ
class_list = {}
with open('/content/data/wlasl_class_list.txt') as f:
    for line in f:
        parts = line.strip().split('\t')
        if len(parts) >= 2:
            class_list[int(parts[0])] = parts[1]

# โหลด labels เดิม (ยังเป็นตัวเลข)
with open('/content/out/labels.json') as f:
    old_labels = json.load(f)

# แปลงเป็นชื่อคำ
new_labels = [class_list.get(int(l), f"unknown_{l}") for l in old_labels]
print("ตัวอย่าง:", new_labels[:10])
print(f"จำนวนทั้งหมด: {len(new_labels)} คำ")

# บันทึกทับ
with open('/content/out/labels.json', 'w') as f:
    json.dump(new_labels, f)
print("✅ labels.json อัปเดตแล้ว")

ตัวอย่าง: ['book', 'drink', 'computer', 'before', 'chair', 'go', 'clothes', 'who', 'candy', 'cousin']
จำนวนทั้งหมด: 100 คำ
✅ labels.json อัปเดตแล้ว


In [6]:
import os
os.chdir('/content/wlasl')

!python3 -m src.export_torchscript \
    --ckpt /content/out/best.pt \
    --labels /content/out/labels.json \
    --out_ts /content/out/model.ts

TorchScript output shape: (1, 100)
Saved TorchScript: /content/out/model.ts


In [7]:
from google.colab import files
files.download("/content/out/model.ts")
files.download("/content/out/labels.json")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>