In [None]:
# %% [1] Download Dataset
!pip install -q kaggle
!mkdir ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d debashishsau/aslamerican-sign-language-aplhabet-dataset
!unzip -q aslamerican-sign-language-aplhabet-dataset.zip

In [None]:
import torch
import torch.nn as nn
from torchvision.models import mobilenet_v3_small
from collections import deque
import numpy as np

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import os
from PIL import Image
import random
import string

In [None]:
# %% [2] Model Architecture (MobileNetV3 + Temporal)
class ASLNetMobile(nn.Module):
    def __init__(self, num_classes=29, temporal_window=5):
        super().__init__()
        self.temporal_window = temporal_window
        
        # Backbone MobileNetV3
        self.backbone = mobilenet_v3_small(pretrained=True)
        self.backbone.classifier = nn.Identity()  # Remove classification head
        
        # Landmarks branch
        self.landmarks_fc = nn.Sequential(
            nn.Linear(63, 64),  # 21 landmarks * 3 (x,y,z)
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Classifier
        self.fc = nn.Linear(576 + 64, num_classes)  # 576 (MobileNet) + 64 (landmarks)
        
    def forward(self, x, landmarks, history=None):
        # Feature extraction
        img_feat = self.backbone(x)
        ldmk_feat = self.landmarks_fc(landmarks)
        combined = torch.cat([img_feat, ldmk_feat], dim=1)
        
        # Temporal fusion
        if history is not None:
            combined = torch.stack(list(history) + [combined], dim=1)
            combined = torch.mean(combined, dim=1)
        
        return self.fc(combined)

In [None]:
# %% [3] Config & Dataset
class CFG:
    TRAIN_PATH = "ASL_Alphabet_Dataset/asl_alphabet_train"
    LABELS = list(string.ascii_uppercase) + ["del", "nothing", "space"]
    NUM_CLASSES = len(LABELS)
    IMG_SIZE = 224
    BATCH_SIZE = 128
    EPOCHS = 30
    LR = 3e-4
    WEIGHT_DECAY = 1e-4
    DROPOUT = 0.3
    TEMPORAL_WINDOW = 5
    
    # Augmentations
    TRAIN_TRANSFORM = transforms.Compose([
        transforms.RandomRotation(15),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.ColorJitter(0.1, 0.1, 0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    VAL_TRANSFORM = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

class ASLDataset(Dataset):
    def __init__(self, split='train', transform=None, val_ratio=0.2):
        self.transform = CFG.TRAIN_TRANSFORM if split == 'train' else CFG.VAL_TRANSFORM
        samples = []
        for label_idx, label in enumerate(CFG.LABELS):
            label_dir = os.path.join(CFG.TRAIN_PATH, label)
            if not os.path.isdir(label_dir):
                continue
            for fname in os.listdir(label_dir):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    samples.append((os.path.join(label_dir, fname), label_idx))
        random.shuffle(samples)
        split_idx = int(len(samples) * (1 - val_ratio))
        self.data = samples[:split_idx] if split == 'train' else samples[split_idx:]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img = Image.open(img_path).convert('RGB')
        return self.transform(img), label

In [None]:
def extract_hand_landmarks_features(results):
    """
    Extrai características dos landmarks de mãos detectados pelo MediaPipe.
    Retorna um tensor com as coordenadas normalizadas.
    """
    # Inicializa um vetor de características vazio
    landmark_features = []

    if results.multi_hand_landmarks:
        # Pega o primeiro conjunto de landmarks (primeira mão detectada)
        hand_landmarks = results.multi_hand_landmarks[0]

        # Extrai coordenadas x, y, z de cada ponto
        for landmark in hand_landmarks.landmark:
            landmark_features.extend([landmark.x, landmark.y, landmark.z])

        # Normaliza as features (opcional mas recomendado)
        if landmark_features:
            landmark_features = np.array(landmark_features)
            min_val = landmark_features.min()
            max_val = landmark_features.max()
            landmark_features = (landmark_features - min_val) / (max_val - min_val + 1e-8)
            landmark_features = landmark_features.tolist()

    # Se não detectou landmarks, retorna vetor de zeros
    if not landmark_features:
        # MediaPipe Hands tem 21 pontos com x,y,z (63 features)
        landmark_features = [0.0] * CFG.LANDMARKS_DIM

    return torch.tensor(landmark_features, dtype=torch.float32)

In [None]:
# %% [4] Training Loop with Metrics
from sklearn.metrics import precision_score, recall_score, f1_score

def train_epoch(model, loader, optimizer, criterion, device):
    model.train()
    total_loss = 0.0
    correct = 0
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs, torch.zeros(imgs.size(0), 63).to(device))  # Dummy landmarks
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()
    return total_loss / len(loader), correct / len(loader.dataset)

@torch.no_grad()
def eval_epoch(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs, torch.zeros(imgs.size(0), 63).to(device))
        loss = criterion(outputs, labels)
        total_loss += loss.item()
        all_preds.extend(outputs.argmax(1).cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    metrics = {
        'loss': total_loss / len(loader),
        'accuracy': (np.array(all_preds) == np.array(all_labels)).mean(),
        'precision': precision_score(all_labels, all_preds, average='macro'),
        'recall': recall_score(all_labels, all_preds, average='macro'),
        'f1': f1_score(all_labels, all_preds, average='macro')
    }
    return metrics

In [None]:
# %% [5] Inference with Temporal Window
import cv2
import mediapipe as mp
from collections import deque

def temporal_inference(model_path, source=0):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ASLNetMobile().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    
    # Initialize MediaPipe Hands
    mp_hands = mp.solutions.hands
    hands = mp_hands.Hands(static_image_mode=False, max_num_hands=2, min_detection_confidence=0.5)
    
    # Temporal buffer
    history = deque(maxlen=CFG.TEMPORAL_WINDOW)
    
    cap = cv2.VideoCapture(source)
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        
        # Process landmarks
        img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        results = hands.process(img_rgb)
        landmarks = extract_hand_landmarks_features(results).unsqueeze(0).to(device)
        
        # Preprocess image
        img_tensor = CFG.VAL_TRANSFORM(Image.fromarray(img_rgb)).unsqueeze(0).to(device)
        
        # Inference
        with torch.no_grad():
            output = model(img_tensor, landmarks, history=history)
            history.append(output)
            pred = output.argmax().item()
        
        # Display
        cv2.putText(frame, f"Pred: {CFG.LABELS[pred]}", (10, 30), 
                   cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
        cv2.imshow('ASL Temporal', frame)
        
        if cv2.waitKey(1) & 0xFF == ord('q'):
            break
    
    cap.release()
    cv2.destroyAllWindows()

In [None]:
# %% [6] Export for Mobile
# Quantização
model = ASLNetMobile()
model = torch.quantization.quantize_dynamic(model, {nn.Linear}, dtype=torch.qint8)

# Exportar para ONNX
dummy_input = torch.randn(1, 3, 224, 224)
dummy_landmarks = torch.randn(1, 63)
torch.onnx.export(
    model, 
    (dummy_input, dummy_landmarks),
    "asl_mobilenet_temporal.onnx",
    input_names=["image", "landmarks"],
    output_names=["output"],
    opset_version=13
)