In [8]:
import os
import glob
import cv2
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import mediapipe as mp

# Dataset class
class PhoenixDataset(Dataset):
    def __init__(self, root, annotation_file, split, max_frames=64):
        self.root = root
        self.split = split  # 'train', 'dev', or 'test'
        self.max_frames = max_frames
        self.samples = []

        with open(annotation_file, "r", encoding="utf-8") as f:
            next(f)  # skip header
            for line in f:
                parts = line.strip().split("|")
                if len(parts) < 4:
                    continue
                sample_id, folder_pattern, signer, annotation = parts
                # Extract main folder name (before '/')
                folder_name = folder_pattern.split('/')[0]
                # Construct full path by adding split subfolder ('train', 'dev', 'test')
                folder_path = os.path.join(self.root, "features", "fullFrame-210x260px", self.split, folder_name, "1")
                frame_paths = sorted(glob.glob(os.path.join(folder_path, "*.png")))
                self.samples.append({
                    "id": sample_id,
                    "frames": frame_paths,
                    "signer": signer,
                    "annotation": annotation.split()
                })
    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        frames = []
        for fp in sample["frames"][:self.max_frames]:
            img = cv2.imread(fp)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            frames.append(img)
        while len(frames) < self.max_frames:
            frames.append(np.zeros_like(frames[0]))
        frames = np.stack(frames)  # T, H, W, C
        return frames, sample["annotation"]

# Mediapipe setup
mp_hands = mp.solutions.hands.Hands(static_image_mode=True, max_num_hands=2)
mp_face = mp.solutions.face_mesh.FaceMesh(static_image_mode=True)

def extract_landmarks(frames):
    landmarks = []
    for frame in frames:
        img = (frame * 255).astype(np.uint8)
        hand_lms = np.zeros(126, dtype=np.float32) # 21 points * 3 coords * 2 hands
        face_lms = np.zeros(1404, dtype=np.float32) # 468 points * 3 coords
        results_hands = mp_hands.process(img)
        if results_hands.multi_hand_landmarks:
            for i, hand_landmarks in enumerate(results_hands.multi_hand_landmarks):
                if i > 1: break
                for j, lm in enumerate(hand_landmarks.landmark):
                    hand_lms[i*63+j*3:i*63+j*3+3] = [lm.x, lm.y, lm.z]
        results_face = mp_face.process(img)
        if results_face.multi_face_landmarks:
            for i, lm in enumerate(results_face.multi_face_landmarks[0].landmark):
                face_lms[i*3:i*3+3] = [lm.x, lm.y, lm.z]
        combined = np.concatenate([hand_lms, face_lms])
        landmarks.append(combined)
    return np.stack(landmarks) # (T, 1530)

# Dataset wrapper
class PhoenixTorchDataset(Dataset):
    def __init__(self, phoenix_ds, vocab):
        self.phoenix_ds = phoenix_ds
        self.vocab = vocab  # dict {gloss: idx}

    def __len__(self):
        return len(self.phoenix_ds)

    def __getitem__(self, idx):
        frames, glosses = self.phoenix_ds[idx]
        frames = frames.astype(np.float32) / 255.0
        frames = np.transpose(frames, (0,3,1,2))
        landmarks = extract_landmarks(frames.transpose(0,2,3,1))
        frames_tensor = torch.tensor(frames)
        landmarks_tensor = torch.tensor(landmarks)
        # Map first gloss to index, 0 if not found
        label = glosses[0] if glosses else ""
        label_idx = self.vocab.get(label, 0)
        return frames_tensor, landmarks_tensor, label_idx

# Model definitions
class CNNEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3,32,3,2,1), nn.ReLU(),
            nn.Conv2d(32,64,3,2,1), nn.ReLU(),
            nn.AdaptiveAvgPool2d((1,1))
        )
    def forward(self,x):
        x = self.conv(x)
        return x.flatten(1)

class SignLanguageRecognizer(nn.Module):
    def __init__(self, cnn_feat_dim=64, landmark_dim=1530, hidden_dim=128, num_classes=100):
        super().__init__()
        self.cnn = CNNEncoder()
        self.lstm = nn.LSTM(cnn_feat_dim+landmark_dim, hidden_dim, batch_first=True)
        self.classifier = nn.Linear(hidden_dim,num_classes)
    def forward(self,frames,landmarks):
        B,T,C,H,W = frames.shape
        x_cnn = self.cnn(frames.view(B*T,C,H,W)).view(B,T,-1)
        x = torch.cat([x_cnn, landmarks], dim=2)
        out, _ = self.lstm(x)
        out = self.classifier(out[:, -1, :])
        return out

def compute_accuracy(output, target):
    preds = output.argmax(dim=1)
    correct = (preds == target).sum().item()
    return correct, target.size(0)

def main():
    root = 'phoenix-2014.v3.tar/phoenix2014-release/phoenix-2014-multisigner'
    train_ann = os.path.join(root,"annotations","manual","train.corpus.csv")
    val_ann = os.path.join(root,"annotations","manual","dev.corpus.csv")
    test_ann = os.path.join(root,"annotations","manual","test.corpus.csv")

    # Load vocab from train set glosses
    vocab = {}
    idx = 0
    with open(train_ann,"r", encoding="utf-8") as f:
        next(f)
        for line in f:
            parts = line.strip().split("|")
            if len(parts) < 4: continue
            glosses = parts[3].split()
            for g in glosses:
                if g not in vocab:
                    vocab[g] = idx
                    idx +=1
    n_classes = len(vocab)
    print(f"Vocab size: {n_classes}")

    # Prepare datasets
    train_ds = PhoenixTorchDataset(PhoenixDataset(root, train_ann, split='train'), vocab)
    val_ds = PhoenixTorchDataset(PhoenixDataset(root, val_ann, split='dev'), vocab)
    test_ds = PhoenixTorchDataset(PhoenixDataset(root, test_ann, split='test'), vocab)

    train_loader = DataLoader(train_ds,batch_size=4,shuffle=True)
    val_loader = DataLoader(val_ds,batch_size=4)
    test_loader = DataLoader(test_ds,batch_size=4)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SignLanguageRecognizer(num_classes=n_classes).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(5):
        model.train()
        train_loss=0
        train_correct=0
        train_total=0
        for frames, landmarks, labels in train_loader:
            frames, landmarks, labels = frames.to(device), landmarks.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(frames, landmarks)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            correct, total = compute_accuracy(outputs, labels)
            print(f"Correct: {correct}, Total: {total}")
            train_correct += correct
            train_total += total
        print(f"Epoch {epoch}: Train Loss={train_loss/len(train_loader):.4f} Accuracy={train_correct/train_total:.4f}")

        # Validation accuracy
        model.eval()
        val_correct=0
        val_total=0
        with torch.no_grad():
            for frames, landmarks, labels in val_loader:
                frames,landmarks,labels = frames.to(device), landmarks.to(device), labels.to(device)
                outputs = model(frames, landmarks)
                correct, total = compute_accuracy(outputs, labels)
                val_correct += correct
                val_total += total
        print(f"Epoch {epoch}: Validation Accuracy={val_correct/val_total:.4f}")

    # Test evaluation after training
    model.eval()
    test_correct=0
    test_total=0
    with torch.no_grad():
        for frames, landmarks, labels in test_loader:
            frames,landmarks,labels = frames.to(device), landmarks.to(device), labels.to(device)
            outputs = model(frames, landmarks)
            correct, total = compute_accuracy(outputs, labels)
            test_correct += correct
            test_total += total
    print(f"Test Accuracy: {test_correct/test_total:.4f}")

if __name__=="__main__":
    main()


Vocab size: 1231


KeyboardInterrupt: 