In [9]:
import sys, os
sys.path.append(os.path.abspath(".."))  

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

In [10]:
from models.mobilenetv3 import MobileNetV3Extractor
from models.lstm_attention import BiLSTMWithAttention
from preprocessing.dataset import SignLanguageDataset

In [11]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    # resize to the standard input size
    A.Resize(224, 224),  
    # A.RandomCrop(200, 200),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.2, contrast=0.2),
    #A.Cutout(num_holes=2, max_h_size=40, max_w_size=40, fill_value=0),
     # resize it back to 224 to prevent inconsistent dimensions after enhancement
    A.Resize(224, 224), 
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])

In [12]:
train_root = "../data/frames/train"
#class_names = sorted(os.listdir(train_root))
class_names = sorted([
    d for d in os.listdir(train_root)
    if os.path.isdir(os.path.join(train_root, d)) and not d.startswith(".")
]) 

label_map = {name: idx for idx, name in enumerate(class_names)}

train_dataset = SignLanguageDataset(
    root_dir=train_root,
    label_map=label_map,
    num_frames=20,
    split="train",
    transform=train_transform
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)


[DEBUG] First 20 samples loaded:
Sample 0: class=write, label=292
    Frame count: 20 | First frame: ../data/frames/train/write/66815/frame_000.jpg
Sample 1: class=write, label=292
    Frame count: 20 | First frame: ../data/frames/train/write/70261/frame_000.jpg
Sample 2: class=write, label=292
    Frame count: 20 | First frame: ../data/frames/train/write/64065/frame_000.jpg
Sample 3: class=write, label=292
    Frame count: 20 | First frame: ../data/frames/train/write/64067/frame_000.jpg
Sample 4: class=write, label=292
    Frame count: 20 | First frame: ../data/frames/train/write/64058/frame_000.jpg
Sample 5: class=write, label=292
    Frame count: 20 | First frame: ../data/frames/train/write/64060/frame_000.jpg
Sample 6: class=write, label=292
    Frame count: 20 | First frame: ../data/frames/train/write/64056/frame_000.jpg
Sample 7: class=write, label=292
    Frame count: 20 | First frame: ../data/frames/train/write/64057/frame_000.jpg
Sample 8: class=write, label=292
    Frame cou

In [13]:
class FullSLRModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.feature_extractor = MobileNetV3Extractor()
        self.temporal_model = BiLSTMWithAttention(
            input_dim=960, hidden_dim=256, num_classes=num_classes
        )

    def forward(self, x):  # x: [B, T, C, H, W]
        features = self.feature_extractor(x)          # [B, T, 960]
        logits, _ = self.temporal_model(features)     # [B, num_classes]
        return logits

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FullSLRModel(num_classes=len(label_map)).to(device)

# Label smoothing is helpful for generalization and reduces overfitting
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for videos, labels in tqdm(train_loader):
        # videos = videos.permute(0, 2, 1, 3, 4)  # [B, C, T, H, W] → [B, T, C, H, W]
        videos, labels = videos.to(device), labels.to(device)

        outputs = model(videos)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = total_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%")

 27%|██▋       | 167/622 [11:56:06<5:04:03, 40.10s/it]   

In [8]:
os.makedirs("../checkpoints", exist_ok=True)
torch.save(model.state_dict(), "../checkpoints/mobilenetv3_lstm_aug_smooth.pth")