In [12]:
import sys, os
sys.path.append(os.path.abspath(".."))

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [15]:
from models.mobilenetv3 import MobileNetV3Extractor
from models.lstm_attention import BiLSTMWithAttention
from preprocessing.dataset import SignLanguageDataset

In [16]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Image preprocessing pipeline: Resize, Normalize, ToTensor
train_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

In [23]:
# Automatically obtain all category names and map them to digital labels
train_root = "../data/frames/train"
class_names = sorted(os.listdir(train_root))  
# Name
label_map = {name: idx for idx, name in enumerate(class_names)}
print("Discovered category:", label_map)

# Dataset
train_dataset = SignLanguageDataset(
    root_dir=train_root,
    label_map=label_map,
    num_frames=20,
    split="train",
    transform=train_transform
)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2)

Discovered category: {'about': 0, 'accident': 1, 'africa': 2, 'again': 3, 'all': 4, 'always': 5, 'animal': 6, 'apple': 7, 'approve': 8, 'argue': 9, 'arrive': 10, 'baby': 11, 'back': 12, 'backpack': 13, 'bad': 14, 'bake': 15, 'balance': 16, 'ball': 17, 'banana': 18, 'bar': 19, 'basketball': 20, 'bath': 21, 'bathroom': 22, 'beard': 23, 'because': 24, 'bed': 25, 'before': 26, 'behind': 27, 'bird': 28, 'birthday': 29, 'black': 30, 'blanket': 31, 'blue': 32, 'book': 33, 'bowling': 34, 'boy': 35, 'bring': 36, 'brother': 37, 'brown': 38, 'business': 39, 'but': 40, 'buy': 41, 'call': 42, 'can': 43, 'candy': 44, 'careful': 45, 'cat': 46, 'catch': 47, 'center': 48, 'cereal': 49, 'chair': 50, 'champion': 51, 'change': 52, 'chat': 53, 'cheat': 54, 'check': 55, 'cheese': 56, 'children': 57, 'christmas': 58, 'city': 59, 'class': 60, 'clock': 61, 'close': 62, 'clothes': 63, 'coffee': 64, 'cold': 65, 'college': 66, 'color': 67, 'computer': 68, 'convince': 69, 'cook': 70, 'cool': 71, 'copy': 72, 'corn'

In [26]:
class FullSLRModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.feature_extractor = MobileNetV3Extractor()
        self.temporal_model = BiLSTMWithAttention(
            input_dim=960, hidden_dim=256, num_classes=num_classes)

    def forward(self, x):  # x: [B, T, C, H, W]
        features = self.feature_extractor(x)          # [B, T, 960]
        logits, _ = self.temporal_model(features)     # [B, num_classes]
        return logits

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = FullSLRModel(num_classes=len(label_map)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
num_epochs = 1

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for videos, labels in tqdm(train_loader):
        # videos = videos.permute(0, 2, 1, 3, 4) # [B, C, T, H, W] → [B, T, C, H, W]
        videos, labels = videos.to(device), labels.to(device)

        outputs = model(videos)  # [B, num_classes]
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = total_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%")

  0%|          | 0/622 [00:00<?, ?it/s]

100%|██████████| 622/622 [13:11:02<00:00, 76.31s/it]    

Epoch [1/1] | Loss: 5.7179 | Acc: 0.44%





In [9]:
save_path = f"../checkpoints/baseline_epoch{epoch+1}.pth"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(model.state_dict(), save_path)
print(f"model saved to: {save_path}")

model saved to: ../checkpoints/baseline_epoch1.pth
