In [3]:
import sys, os
sys.path.append(os.path.abspath(".."))

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

In [4]:
from models.mobilenetv3 import MobileNetV3Extractor
from models.lstm_attention import BiLSTMWithAttention
from preprocessing.dataset import SignLanguageDataset

In [5]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Image preprocessing pipeline: Resize, Normalize, ToTensor
train_transform = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])

In [6]:
# Automatically obtain all category names and map them to digital labels
train_root = "../data/frames/train"
class_names = sorted(os.listdir(train_root))  
# Name
label_map = {name: idx for idx, name in enumerate(class_names)}
print("Discovered category:", label_map)

# Dataset
train_dataset = SignLanguageDataset(
    root_dir=train_root,
    label_map=label_map,
    num_frames=30,
    split="train",
    transform=train_transform
)

# DataLoader
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)

Discovered category: {'.DS_Store': 0, 'about': 1, 'accident': 2, 'africa': 3, 'again': 4, 'all': 5, 'always': 6, 'animal': 7, 'apple': 8, 'approve': 9, 'argue': 10, 'arrive': 11, 'baby': 12, 'back': 13, 'backpack': 14, 'bad': 15, 'bake': 16, 'balance': 17, 'ball': 18, 'banana': 19, 'bar': 20, 'basketball': 21, 'bath': 22, 'bathroom': 23, 'beard': 24, 'because': 25, 'bed': 26, 'before': 27, 'behind': 28, 'bird': 29, 'birthday': 30, 'black': 31, 'blanket': 32, 'blue': 33, 'book': 34, 'bowling': 35, 'boy': 36, 'bring': 37, 'brother': 38, 'brown': 39, 'business': 40, 'but': 41, 'buy': 42, 'call': 43, 'can': 44, 'candy': 45, 'careful': 46, 'cat': 47, 'catch': 48, 'center': 49, 'cereal': 50, 'chair': 51, 'champion': 52, 'change': 53, 'chat': 54, 'cheat': 55, 'check': 56, 'cheese': 57, 'children': 58, 'christmas': 59, 'city': 60, 'class': 61, 'clock': 62, 'close': 63, 'clothes': 64, 'coffee': 65, 'cold': 66, 'college': 67, 'color': 68, 'computer': 69, 'convince': 70, 'cook': 71, 'cool': 72, '

In [8]:
class FullSLRModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.feature_extractor = MobileNetV3Extractor()
        self.temporal_model = BiLSTMWithAttention(
            input_dim=960, hidden_dim=256, num_classes=num_classes)

    def forward(self, x):  # x: [B, T, C, H, W]
        features = self.feature_extractor(x)          # [B, T, 960]
        logits, _ = self.temporal_model(features)     # [B, num_classes]
        return logits

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = FullSLRModel(num_classes=len(label_map)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)



In [10]:
num_epochs = 2

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for videos, labels in tqdm(train_loader):
        videos, labels = videos.to(device), labels.to(device)

        outputs = model(videos)  # [B, num_classes]
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = total_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%")

100%|██████████| 622/622 [3:23:11<00:00, 19.60s/it]   


Epoch [1/2] | Loss: 5.7166 | Acc: 0.20%


100%|██████████| 622/622 [11:51:21<00:00, 68.62s/it]    

Epoch [2/2] | Loss: 5.5823 | Acc: 0.84%





In [11]:
save_path = f"../checkpoints/baseline_epoch{epoch+1}.pth"
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(model.state_dict(), save_path)
print(f"model saved to: {save_path}")

model saved to: ../checkpoints/baseline_epoch2.pth
