In [1]:
import sys, os
sys.path.append(os.path.abspath(".."))  

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

In [2]:
from models.mobilenetv3 import MobileNetV3Extractor
from models.lstm_attention import BiLSTMWithAttention
from preprocessing.dataset import SignLanguageDataset

In [3]:
import albumentations as A
from albumentations.pytorch import ToTensorV2

train_transform = A.Compose([
    # resize to the standard input size
    A.Resize(224, 224),  
    # A.RandomCrop(200, 200),
    A.HorizontalFlip(p=0.5),
    A.ColorJitter(brightness=0.2, contrast=0.2),
    #A.Cutout(num_holes=2, max_h_size=40, max_w_size=40, fill_value=0),
     # resize it back to 224 to prevent inconsistent dimensions after enhancement
    A.Resize(224, 224), 
    A.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ToTensorV2()
])

In [4]:
train_root = "../data/frames/train"
# class_names = sorted(os.listdir(train_root))
class_names = sorted([
    d for d in os.listdir(train_root)
    if os.path.isdir(os.path.join(train_root, d)) and not d.startswith(".")
])

label_map = {name: idx for idx, name in enumerate(class_names)}

print("[DEBUG] label_map:")
for name, idx in label_map.items():
    print(f"  {name}: {idx}")

train_dataset = SignLanguageDataset(
    root_dir=train_root,
    label_map=label_map,
    num_frames=20,
    split="train",
    transform=train_transform
)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=2)

[DEBUG] label_map:
  about: 0
  accident: 1
  africa: 2
  again: 3
  all: 4
  always: 5
  animal: 6
  apple: 7
  approve: 8
  argue: 9
  arrive: 10
  baby: 11
  back: 12
  backpack: 13
  bad: 14
  bake: 15
  balance: 16
  ball: 17
  banana: 18
  bar: 19
  basketball: 20
  bath: 21
  bathroom: 22
  beard: 23
  because: 24
  bed: 25
  before: 26
  behind: 27
  bird: 28
  birthday: 29
  black: 30
  blanket: 31
  blue: 32
  book: 33
  bowling: 34
  boy: 35
  bring: 36
  brother: 37
  brown: 38
  business: 39
  but: 40
  buy: 41
  call: 42
  can: 43
  candy: 44
  careful: 45
  cat: 46
  catch: 47
  center: 48
  cereal: 49
  chair: 50
  champion: 51
  change: 52
  chat: 53
  cheat: 54
  check: 55
  cheese: 56
  children: 57
  christmas: 58
  city: 59
  class: 60
  clock: 61
  close: 62
  clothes: 63
  coffee: 64
  cold: 65
  college: 66
  color: 67
  computer: 68
  convince: 69
  cook: 70
  cool: 71
  copy: 72
  corn: 73
  cough: 74
  country: 75
  cousin: 76
  cow: 77
  crash: 78
  crazy: 7

In [5]:
class FullSLRModel(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.feature_extractor = MobileNetV3Extractor()
        self.temporal_model = BiLSTMWithAttention(
            input_dim=960, hidden_dim=256, num_classes=num_classes
        )

    def forward(self, x):  # x: [B, T, C, H, W]
        features = self.feature_extractor(x)          # [B, T, 960]
        logits, _ = self.temporal_model(features)     # [B, num_classes]
        return logits

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FullSLRModel(num_classes=len(label_map)).to(device)

# Label smoothing is helpful for generalization and reduces overfitting
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.Adam(model.parameters(), lr=1e-4)



In [7]:
num_epochs = 1

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for videos, labels in tqdm(train_loader):
        
        print(torch.bincount(labels))
        print("labels.shape:", labels.shape)
        print("labels example:", labels[0])
        break
        # videos = videos.permute(0, 2, 1, 3, 4)  # [B, C, T, H, W] → [B, T, C, H, W]
        videos, labels = videos.to(device), labels.to(device)

        outputs = model(videos)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = total_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    print(f"[Epoch {epoch+1}/{num_epochs}] Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%")

  0%|          | 0/622 [00:00<?, ?it/s]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 1])
labels.shape: torch.Size([4])
labels example: tensor(16)


  0%|          | 0/622 [00:20<?, ?it/s]


ZeroDivisionError: float division by zero

In [8]:
os.makedirs("../checkpoints", exist_ok=True)
torch.save(model.state_dict(), "../checkpoints/mobilenetv3_lstm_aug_smooth.pth")