In [1]:
from tqdm import tqdm
import pandas as pd
from collections import Counter
import os
import json

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import cv2
import mediapipe as mp

from datetime import date

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm

In [3]:
class WLASLGlossPoseDataset(Dataset):
    def __init__(self, json_path, video_dir, gloss_vocab=None, max_samples=100, cache_dir="pose_cache", force_reprocess=False):
        """
        WLASL Gloss to Pose Dataset
        Args:
            json_path: Path to WLASL_v0.3.json
            video_dir: Directory containing video files
            gloss_vocab: Existing gloss vocabulary (or create new)
            max_samples: Maximum samples to process
            cache_dir: Directory to store extracted poses
            force_reprocess: Reprocess even if cached exists
        """
        self.video_dir = video_dir
        self.cache_dir = cache_dir
        self.force_reprocess = force_reprocess
        os.makedirs(cache_dir, exist_ok=True)
        
        with open(json_path, 'r') as f:
            data = json.load(f)
        
        if gloss_vocab is None:
            self.gloss_vocab = {}
            for i, entry in enumerate(data):
                self.gloss_vocab[entry['gloss']] = i
        else:
            self.gloss_vocab = gloss_vocab
            
        self.samples = []
        for entry in data[:max_samples]:
            gloss = entry['gloss']
            for instance in entry['instances']:
                video_id = instance['video_id']
                video_path = os.path.join(video_dir, f"{video_id}.mp4")
                if os.path.exists(video_path):
                    self.samples.append((gloss, video_path, video_id))
        
        # Initialize MediaPipe Pose
        self.mp_pose = mp.solutions.pose.Pose(
            static_image_mode=False,
            model_complexity=2,
            enable_segmentation=False,
            min_detection_confidence=0.5
        )
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        gloss, video_path, video_id = self.samples[idx]
        cache_path = os.path.join(self.cache_dir, f"{video_id}.pt")
        
        # Load from cache or process video
        if os.path.exists(cache_path) and not self.force_reprocess:
            pose_seq = torch.load(cache_path)
        else:
            pose_seq = self.process_video(video_path)
            torch.save(pose_seq, cache_path)
        
        return {
            'gloss': self.gloss_vocab[gloss],
            'pose': pose_seq,
            'video_id': video_id
        }
    
    def process_video(self, video_path):
        # extract pose sequence from video using
        cap = cv2.VideoCapture(video_path)
        pose_sequence = []
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
                
            # Process frame with MediaPipe
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            results = self.mp_pose.process(frame_rgb)
            
            if results.pose_landmarks:
                # Extract 33 pose landmarks (x, y, visibility)
                landmarks = []
                for landmark in results.pose_landmarks.landmark:
                    landmarks.extend([landmark.x, landmark.y, landmark.visibility])
                pose_sequence.append(landmarks)
            else:
                # Pad with zeros if no detection
                pose_sequence.append([0.0]*99)
        
        cap.release()
        return torch.tensor(pose_sequence, dtype=torch.float32)


In [4]:
class Gloss2Pose(nn.Module):
    def __init__(self, gloss_vocab_size, pose_dim=99):  # 33 keypoints * 3 (x, y, conf)
        super().__init__()
        self.embed = nn.Embedding(gloss_vocab_size, 128)
        self.conv = nn.Sequential(
            nn.Conv1d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Conv1d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(512, pose_dim, 3, padding=1)
        )

    def forward(self, gloss_seq):
        # Input shape: (batch_size, seq_len)
        x = self.embed(gloss_seq)  # (batch_size, seq_len, embed_dim=128)
        x = x.permute(0, 2, 1)  # (batch_size, embed_dim, seq_len)
        x = self.conv(x)  # (batch_size, pose_dim, seq_len)
        return x.permute(0, 2, 1)  # (batch_size, seq_len, pose_dim)

In [None]:

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 1e-4
SAVE_PATH = "g2p_model_trained.pth"

losses = []
accuracies = []

def collate_fn(batch):
    """Custom collate function to pad pose sequences"""
    glosses = torch.tensor([item['gloss'] for item in batch])
    poses = [item['pose'] for item in batch]
    video_ids = [item['video_id'] for item in batch]
    
    poses_padded = torch.nn.utils.rnn.pad_sequence(
        poses, batch_first=True, padding_value=0.0
    )
    
    return {
        'gloss': glosses,
        'pose': poses_padded,
        'pose_lengths': torch.tensor([len(p) for p in poses]),
        'video_id': video_ids
    }

def train():
    # Initialize dataset and loader
    dataset = WLASLGlossPoseDataset(
        json_path="archive/WLASL_v0.3.json",
        video_dir="archive/videos",
        max_samples=1000
    )
    
    # Create train/test split
    train_size = int(0.8 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn
    )
    
    # Initialize model
    model = Gloss2Pose(
        gloss_vocab_size=len(dataset.gloss_vocab),
        pose_dim=99
    ).to(DEVICE)
    
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)
    
    best_loss = float('inf')
    
    for epoch in range(EPOCHS):
        model.train()
        train_loss = 0.0
        
        # Training phase
        for batch in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
            gloss = batch['gloss'].to(DEVICE)  # shape: (batch_size,)
            pose = batch['pose'].to(DEVICE)    # shape: (batch_size, max_seq_len, 99)
            lengths = batch['pose_lengths']
            
            optimizer.zero_grad()
            
            # Forward pass - need to unsqueeze gloss if it's 1D
            if gloss.dim() == 1:
                gloss = gloss.unsqueeze(1)  # shape: (batch_size, 1)
            
            outputs = model(gloss)  # shape: (batch_size, seq_len, 99)
            
            # Calculate loss only on valid frames
            loss = 0.0
            for i in range(outputs.size(0)):
                # Ensure we don't exceed sequence length
                valid_len = min(lengths[i], outputs.size(1))
                loss += criterion(outputs[i, :valid_len], pose[i, :valid_len])
            
            loss /= outputs.size(0)
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
            train_loss += loss.item()
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for batch in tqdm(test_loader, desc=f"Val Epoch {epoch+1}"):
                gloss = batch['gloss'].to(DEVICE)
                pose = batch['pose'].to(DEVICE)
                lengths = batch['pose_lengths']
                
                if gloss.dim() == 1:
                    gloss = gloss.unsqueeze(1)
                
                outputs = model(gloss)
                
                batch_loss = 0.0
                for i in range(outputs.size(0)):
                    valid_len = min(lengths[i], outputs.size(1))
                    batch_loss += criterion(outputs[i, :valid_len], pose[i, :valid_len])
                
                val_loss += batch_loss.item() / outputs.size(0)
        
        train_loss /= len(train_loader)
        val_loss /= len(test_loader)
        scheduler.step(val_loss)
        
        print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")
        
        # Save best model
        if val_loss < best_loss:
            best_loss = val_loss
            torch.save(model.state_dict(), "best_g2p_model.pth")
            print("Saved new best model")
    
    print("Training complete")

In [6]:
train()

I0000 00:00:1749793641.327379  141639 gl_context_egl.cc:85] Successfully initialized EGL. Major : 1 Minor: 5
I0000 00:00:1749793641.378231  141943 gl_context.cc:369] GL version: 3.2 (OpenGL ES 3.2 NVIDIA 535.230.02), renderer: Quadro P2000/PCIe/SSE2
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
W0000 00:00:1749793641.456598  141927 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1749793641.572985  141938 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
Train Epoch 1:   0%|          | 0/181 [00:00<?, ?it/s]W0000 00:00:1749793642.430326  141942 landmark_projection_calculator.cc:186] Using NORM_RECT without IMAGE_DIMENSIONS is only supported for the square ROI. Provide IMAGE_DIMENSIONS or use PROJECTION_MATRIX.
Train Epoch 1:  56%|█████▋    | 102/181 [3:25:31<2:37:18, 119.48s/it]

Epoch 1: Train Loss = 0.2468, Val Loss = 0.0390
Saved new best model


Train Epoch 2: 100%|██████████| 181/181 [00:03<00:00, 55.20it/s]
Val Epoch 2: 100%|██████████| 46/46 [00:00<00:00, 76.06it/s]


Epoch 2: Train Loss = 0.0274, Val Loss = 0.0239
Saved new best model


Train Epoch 3: 100%|██████████| 181/181 [00:03<00:00, 56.80it/s]
Val Epoch 3: 100%|██████████| 46/46 [00:00<00:00, 75.27it/s]


Epoch 3: Train Loss = 0.0228, Val Loss = 0.0225
Saved new best model


Train Epoch 4: 100%|██████████| 181/181 [00:03<00:00, 56.47it/s]
Val Epoch 4: 100%|██████████| 46/46 [00:00<00:00, 74.37it/s]


Epoch 4: Train Loss = 0.0218, Val Loss = 0.0220
Saved new best model


Train Epoch 5: 100%|██████████| 181/181 [00:03<00:00, 55.78it/s]
Val Epoch 5: 100%|██████████| 46/46 [00:00<00:00, 75.20it/s]


Epoch 5: Train Loss = 0.0212, Val Loss = 0.0217
Saved new best model


Train Epoch 6: 100%|██████████| 181/181 [00:03<00:00, 57.05it/s]
Val Epoch 6: 100%|██████████| 46/46 [00:00<00:00, 76.67it/s]


Epoch 6: Train Loss = 0.0208, Val Loss = 0.0217


Train Epoch 7: 100%|██████████| 181/181 [00:03<00:00, 56.39it/s]
Val Epoch 7: 100%|██████████| 46/46 [00:00<00:00, 74.54it/s]


Epoch 7: Train Loss = 0.0205, Val Loss = 0.0216
Saved new best model


Train Epoch 8: 100%|██████████| 181/181 [00:03<00:00, 56.89it/s]
Val Epoch 8: 100%|██████████| 46/46 [00:00<00:00, 76.73it/s]


Epoch 8: Train Loss = 0.0201, Val Loss = 0.0214
Saved new best model


Train Epoch 9: 100%|██████████| 181/181 [00:03<00:00, 56.70it/s]
Val Epoch 9: 100%|██████████| 46/46 [00:00<00:00, 74.89it/s]


Epoch 9: Train Loss = 0.0198, Val Loss = 0.0215


Train Epoch 10: 100%|██████████| 181/181 [00:03<00:00, 56.83it/s]
Val Epoch 10: 100%|██████████| 46/46 [00:00<00:00, 78.28it/s]


Epoch 10: Train Loss = 0.0197, Val Loss = 0.0216


Train Epoch 11: 100%|██████████| 181/181 [00:03<00:00, 56.18it/s]
Val Epoch 11: 100%|██████████| 46/46 [00:00<00:00, 77.61it/s]


Epoch 11: Train Loss = 0.0195, Val Loss = 0.0214


Train Epoch 12: 100%|██████████| 181/181 [00:03<00:00, 57.53it/s]
Val Epoch 12: 100%|██████████| 46/46 [00:00<00:00, 61.33it/s]


Epoch 12: Train Loss = 0.0193, Val Loss = 0.0216


Train Epoch 13: 100%|██████████| 181/181 [00:03<00:00, 57.63it/s]
Val Epoch 13: 100%|██████████| 46/46 [00:00<00:00, 76.63it/s]


Epoch 13: Train Loss = 0.0185, Val Loss = 0.0213
Saved new best model


Train Epoch 14: 100%|██████████| 181/181 [00:03<00:00, 57.24it/s]
Val Epoch 14: 100%|██████████| 46/46 [00:00<00:00, 73.43it/s]


Epoch 14: Train Loss = 0.0183, Val Loss = 0.0213


Train Epoch 15: 100%|██████████| 181/181 [00:03<00:00, 58.39it/s]
Val Epoch 15: 100%|██████████| 46/46 [00:00<00:00, 71.91it/s]


Epoch 15: Train Loss = 0.0183, Val Loss = 0.0213


Train Epoch 16: 100%|██████████| 181/181 [00:03<00:00, 57.45it/s]
Val Epoch 16: 100%|██████████| 46/46 [00:00<00:00, 75.67it/s]


Epoch 16: Train Loss = 0.0182, Val Loss = 0.0213
Saved new best model


Train Epoch 17: 100%|██████████| 181/181 [00:03<00:00, 56.77it/s]
Val Epoch 17: 100%|██████████| 46/46 [00:00<00:00, 78.74it/s]


Epoch 17: Train Loss = 0.0182, Val Loss = 0.0213
Saved new best model


Train Epoch 18: 100%|██████████| 181/181 [00:03<00:00, 57.87it/s]
Val Epoch 18: 100%|██████████| 46/46 [00:00<00:00, 81.41it/s]


Epoch 18: Train Loss = 0.0182, Val Loss = 0.0213


Train Epoch 19: 100%|██████████| 181/181 [00:03<00:00, 57.10it/s]
Val Epoch 19: 100%|██████████| 46/46 [00:00<00:00, 76.31it/s]


Epoch 19: Train Loss = 0.0182, Val Loss = 0.0213
Saved new best model


Train Epoch 20: 100%|██████████| 181/181 [00:03<00:00, 57.80it/s]
Val Epoch 20: 100%|██████████| 46/46 [00:00<00:00, 78.79it/s]


Epoch 20: Train Loss = 0.0181, Val Loss = 0.0213
Saved new best model


Train Epoch 21: 100%|██████████| 181/181 [00:03<00:00, 56.90it/s]
Val Epoch 21: 100%|██████████| 46/46 [00:00<00:00, 73.00it/s]


Epoch 21: Train Loss = 0.0180, Val Loss = 0.0213


Train Epoch 22: 100%|██████████| 181/181 [00:03<00:00, 56.83it/s]
Val Epoch 22: 100%|██████████| 46/46 [00:00<00:00, 78.29it/s]


Epoch 22: Train Loss = 0.0181, Val Loss = 0.0213


Train Epoch 23: 100%|██████████| 181/181 [00:03<00:00, 57.42it/s]
Val Epoch 23: 100%|██████████| 46/46 [00:00<00:00, 74.60it/s]


Epoch 23: Train Loss = 0.0181, Val Loss = 0.0214


Train Epoch 24: 100%|██████████| 181/181 [00:03<00:00, 57.24it/s]
Val Epoch 24: 100%|██████████| 46/46 [00:00<00:00, 76.83it/s]


Epoch 24: Train Loss = 0.0181, Val Loss = 0.0213


Train Epoch 25: 100%|██████████| 181/181 [00:03<00:00, 56.92it/s]
Val Epoch 25: 100%|██████████| 46/46 [00:00<00:00, 78.82it/s]


Epoch 25: Train Loss = 0.0180, Val Loss = 0.0213
Saved new best model


Train Epoch 26: 100%|██████████| 181/181 [00:03<00:00, 56.15it/s]
Val Epoch 26: 100%|██████████| 46/46 [00:00<00:00, 75.57it/s]


Epoch 26: Train Loss = 0.0179, Val Loss = 0.0214


Train Epoch 27: 100%|██████████| 181/181 [00:03<00:00, 58.42it/s]
Val Epoch 27: 100%|██████████| 46/46 [00:00<00:00, 78.37it/s]


Epoch 27: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 28: 100%|██████████| 181/181 [00:03<00:00, 57.77it/s]
Val Epoch 28: 100%|██████████| 46/46 [00:00<00:00, 73.32it/s]


Epoch 28: Train Loss = 0.0179, Val Loss = 0.0214


Train Epoch 29: 100%|██████████| 181/181 [00:03<00:00, 56.71it/s]
Val Epoch 29: 100%|██████████| 46/46 [00:00<00:00, 75.81it/s]


Epoch 29: Train Loss = 0.0179, Val Loss = 0.0214


Train Epoch 30: 100%|██████████| 181/181 [00:03<00:00, 58.11it/s]
Val Epoch 30: 100%|██████████| 46/46 [00:00<00:00, 75.64it/s]


Epoch 30: Train Loss = 0.0179, Val Loss = 0.0214


Train Epoch 31: 100%|██████████| 181/181 [00:03<00:00, 57.70it/s]
Val Epoch 31: 100%|██████████| 46/46 [00:00<00:00, 78.40it/s]


Epoch 31: Train Loss = 0.0179, Val Loss = 0.0214


Train Epoch 32: 100%|██████████| 181/181 [00:03<00:00, 57.92it/s]
Val Epoch 32: 100%|██████████| 46/46 [00:00<00:00, 75.82it/s]


Epoch 32: Train Loss = 0.0179, Val Loss = 0.0214


Train Epoch 33: 100%|██████████| 181/181 [00:03<00:00, 57.36it/s]
Val Epoch 33: 100%|██████████| 46/46 [00:00<00:00, 75.91it/s]


Epoch 33: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 34: 100%|██████████| 181/181 [00:03<00:00, 57.39it/s]
Val Epoch 34: 100%|██████████| 46/46 [00:00<00:00, 75.83it/s]


Epoch 34: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 35: 100%|██████████| 181/181 [00:03<00:00, 57.71it/s]
Val Epoch 35: 100%|██████████| 46/46 [00:00<00:00, 76.05it/s]


Epoch 35: Train Loss = 0.0179, Val Loss = 0.0214


Train Epoch 36: 100%|██████████| 181/181 [00:03<00:00, 58.08it/s]
Val Epoch 36: 100%|██████████| 46/46 [00:00<00:00, 76.05it/s]


Epoch 36: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 37: 100%|██████████| 181/181 [00:03<00:00, 57.32it/s]
Val Epoch 37: 100%|██████████| 46/46 [00:00<00:00, 78.96it/s]


Epoch 37: Train Loss = 0.0179, Val Loss = 0.0214


Train Epoch 38: 100%|██████████| 181/181 [00:03<00:00, 57.47it/s]
Val Epoch 38: 100%|██████████| 46/46 [00:00<00:00, 77.64it/s]


Epoch 38: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 39: 100%|██████████| 181/181 [00:03<00:00, 55.94it/s]
Val Epoch 39: 100%|██████████| 46/46 [00:00<00:00, 76.26it/s]


Epoch 39: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 40: 100%|██████████| 181/181 [00:03<00:00, 57.62it/s]
Val Epoch 40: 100%|██████████| 46/46 [00:00<00:00, 79.05it/s]


Epoch 40: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 41: 100%|██████████| 181/181 [00:03<00:00, 58.10it/s]
Val Epoch 41: 100%|██████████| 46/46 [00:00<00:00, 75.72it/s]


Epoch 41: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 42: 100%|██████████| 181/181 [00:03<00:00, 57.82it/s]
Val Epoch 42: 100%|██████████| 46/46 [00:00<00:00, 78.97it/s]


Epoch 42: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 43: 100%|██████████| 181/181 [00:03<00:00, 57.24it/s]
Val Epoch 43: 100%|██████████| 46/46 [00:00<00:00, 79.91it/s]


Epoch 43: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 44: 100%|██████████| 181/181 [00:03<00:00, 57.41it/s]
Val Epoch 44: 100%|██████████| 46/46 [00:00<00:00, 79.39it/s]


Epoch 44: Train Loss = 0.0179, Val Loss = 0.0214


Train Epoch 45: 100%|██████████| 181/181 [00:03<00:00, 57.73it/s]
Val Epoch 45: 100%|██████████| 46/46 [00:00<00:00, 76.69it/s]


Epoch 45: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 46: 100%|██████████| 181/181 [00:03<00:00, 57.84it/s]
Val Epoch 46: 100%|██████████| 46/46 [00:00<00:00, 77.08it/s]


Epoch 46: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 47: 100%|██████████| 181/181 [00:03<00:00, 57.22it/s]
Val Epoch 47: 100%|██████████| 46/46 [00:00<00:00, 75.77it/s]


Epoch 47: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 48: 100%|██████████| 181/181 [00:03<00:00, 57.49it/s]
Val Epoch 48: 100%|██████████| 46/46 [00:00<00:00, 76.17it/s]


Epoch 48: Train Loss = 0.0179, Val Loss = 0.0213


Train Epoch 49: 100%|██████████| 181/181 [00:03<00:00, 58.17it/s]
Val Epoch 49: 100%|██████████| 46/46 [00:00<00:00, 75.65it/s]


Epoch 49: Train Loss = 0.0179, Val Loss = 0.0212
Saved new best model


Train Epoch 50: 100%|██████████| 181/181 [00:03<00:00, 57.51it/s]
Val Epoch 50: 100%|██████████| 46/46 [00:00<00:00, 78.00it/s]

Epoch 50: Train Loss = 0.0179, Val Loss = 0.0213
Training complete





In [7]:
class TextGlossDataset3(Dataset): # takes from .pth
    def __init__(self, processed_path):
        
        data = torch.load(processed_path, map_location=torch.device("cpu"))
        self.text_vocab  = data["text_vocab"]
        self.gloss_vocab = data["gloss_vocab"]
        self.inv_gloss   = data["inv_gloss"]

        # Pre‐tokenized (N, max_seq_len)
        self.text_matrix  = data["text_matrix"]
        self.gloss_matrix = data["gloss_matrix"]
        
        self.pose_matrix = data['pose_matrix']

        assert self.text_matrix.size(0) == self.gloss_matrix.size(0), "Mismatch in example count"

    def __len__(self):
        return self.text_matrix.size(0)

    def __getitem__(self, idx):
        text_indices  = self.text_matrix[idx]
        gloss_indices = self.gloss_matrix[idx]
        pose_indices = self.pose_matrix[idx]
        return text_indices, gloss_indices, pose_indices

    def decode_gloss(self, indices):
        return " ".join(
            [self.inv_gloss.get(int(idx), "<unk>") for idx in indices if idx not in {0, 1, 2}]
        )

In [8]:
class Text2GlossTransformer(nn.Module):
    def __init__(self, text_vocab_size, gloss_vocab_size):
        super().__init__()
        self.text_embed = nn.Embedding(text_vocab_size, 256)
        self.gloss_embed = nn.Embedding(gloss_vocab_size, 256)
        self.transformer = nn.Transformer(
            d_model=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3
        ).to(device)
        self.fc = nn.Linear(256, gloss_vocab_size)

    def forward(self, src, tgt):
        src = self.text_embed(src).permute(1,0,2) # S, B, E
        tgt = self.gloss_embed(tgt).permute(1,0,2)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(0)).to(device)
        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        return self.fc(output).permute(1,0,2) # B, S, V

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [11]:
WLASL_DATASET = 'wlasl_dataset.pt'

In [14]:
t2g_losses = []

def train2():
    dataset = TextGlossDataset3(WLASL_DATASET)
    
    t2g_model = Text2GlossTransformer(
        len(dataset.text_vocab),
        len(dataset.gloss_vocab)
    ).to(device)
    
    t2g_optim = torch.optim.Adam(t2g_model.parameters(), lr=1e-4)
    loader = DataLoader(dataset, batch_size=2, shuffle=True)
    
    for epoch in range(30):
        for batch in loader:
            text, gloss, pose = batch
            text, gloss = text.to(device), gloss.to(device)
            
            t2g_optim.zero_grad()
            decoder_input = torch.cat([
                torch.ones_like(gloss[:, :1]) * dataset.gloss_vocab["<sos>"],
                gloss[:, :-1]
            ], dim=1)
            
            gloss_logits = t2g_model(text, decoder_input)
            
            # FIXED: Make tensors contiguous before view
            t2g_loss = F.cross_entropy(
                gloss_logits.contiguous().view(-1, gloss_logits.size(-1)),
                gloss.contiguous().view(-1),
                ignore_index=0  # Optional: ignore padding index
            )
            t2g_loss.backward()
            t2g_optim.step()
            
            t2g_losses.append(t2g_loss.item())
        print(f"Epoch {epoch+1}: T2G Loss={t2g_loss.item():.4f}")
        
    return t2g_model

In [15]:
t2g_model = train2()



Epoch 1: T2G Loss=2.4554
Epoch 2: T2G Loss=2.5578
Epoch 3: T2G Loss=2.6521
Epoch 4: T2G Loss=2.6818
Epoch 5: T2G Loss=2.6016
Epoch 6: T2G Loss=2.5468
Epoch 7: T2G Loss=2.5108
Epoch 8: T2G Loss=2.5678
Epoch 9: T2G Loss=2.5760
Epoch 10: T2G Loss=2.9889
Epoch 11: T2G Loss=2.6295
Epoch 12: T2G Loss=2.5566
Epoch 13: T2G Loss=2.8760
Epoch 14: T2G Loss=2.5514
Epoch 15: T2G Loss=2.5751
Epoch 16: T2G Loss=2.5277
Epoch 17: T2G Loss=2.5481
Epoch 18: T2G Loss=2.5730
Epoch 19: T2G Loss=2.4859
Epoch 20: T2G Loss=2.5169
Epoch 21: T2G Loss=2.5904
Epoch 22: T2G Loss=2.5398
Epoch 23: T2G Loss=2.4936
Epoch 24: T2G Loss=2.8740
Epoch 25: T2G Loss=2.8764
Epoch 26: T2G Loss=2.8979
Epoch 27: T2G Loss=3.0429
Epoch 28: T2G Loss=2.7748
Epoch 29: T2G Loss=2.8568
Epoch 30: T2G Loss=2.7141


In [16]:
torch.save(t2g_model.state_dict(), "t2g_model_weights.pth")