# Imports

In [1]:
from tqdm import tqdm
import pandas as pd
from collections import Counter
import os
import json

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

import numpy as np
import cv2
import mediapipe as mp

from datetime import date

# Config

In [2]:

train_index = date.today().toordinal()

DATA_PATH = '../../datasets/train-00000-of-00001.parquet'
OUTPUT_PATH = '../../datasets/Processed/Gloss_feat/processed_data.pt'

G2P_PATH = '../../datasets/archive/'
G2P_PATH_JSON = '../../datasets/archive/WLASL_v0.3.json'

MAX_SEQ_LEN = 50

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

device: cuda


# Dataset

## T2G

### Preprocess

In [3]:
def build_vocab(df):
    
    text_counter = Counter()
    gloss_counter = Counter()
    
    for _, row in df.iterrows():
        full_text = row["text"]
        
        gloss = full_text.split("[INST]")[1].split("[/INST]")[0].strip()
        gloss_tokens = gloss.split()
        gloss_counter.update(gloss_tokens)
        
        text = full_text.split("[/INST]")[1].replace("</s>", "").strip().lower()
        text_tokens = text.split()
        text_counter.update(text_tokens)
        
        base_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
        
        text_vocab = {tok: idx for idx, tok in enumerate(base_tokens)}
        for tok, _ in text_counter.most_common():
            if tok not in text_vocab:
                text_vocab[tok] = len(text_vocab)
                
        gloss_vocab = {tok: idx for idx, tok in enumerate(base_tokens)}
        for tok, _ in gloss_counter.most_common():
            if tok not in gloss_vocab:
                gloss_vocab[tok] = len(gloss_vocab)
                
        return text_vocab, gloss_vocab

def tokenize_and_pad(sequence, vocab, max_len, is_gloss):
    if is_gloss:
        tokens = sequence.split("[INST]")[1].split("[/INST]")[0].strip().split()
    else:
        tokens = sequence.split("[/INST]")[1].replace("</s>", "").strip().lower().split()
    
    idxs = [vocab.get(w, vocab["<unk>"]) for w in tokens]
    idxs = [vocab["<sos>"]] + idxs + [vocab["<eos>"]]
    
    if len(idxs) < max_len:
        idxs = idxs + [vocab["<pad>"]] * (max_len - len(idxs))
    else:
        idxs = idxs[: max_len - 1] + [vocab["<eos>"]]

    return torch.tensor(idxs, dtype=torch.long)

In [4]:
df = pd.read_parquet(DATA_PATH)

text_vocab, gloss_vocab = build_vocab(df)
print(f"> Built text_vocab (size: {len(text_vocab)}) and gloss_vocab (size: {len(gloss_vocab)})")

all_text_tensors  = []
all_gloss_tensors = []

for _, row in tqdm(df.iterrows(), total=len(df), desc="Tokenizing"):
    raw = row["text"]
    gloss_tensor = tokenize_and_pad(raw, gloss_vocab, max_len=MAX_SEQ_LEN, is_gloss=True)
    text_tensor  = tokenize_and_pad(raw, text_vocab,  max_len=MAX_SEQ_LEN, is_gloss=False)

    all_gloss_tensors.append(gloss_tensor.to(device))
    all_text_tensors.append(text_tensor.to(device))

gloss_matrix = torch.stack(all_gloss_tensors)   # (N, MAX_SEQ_LEN)
text_matrix  = torch.stack(all_text_tensors)    # ( N,MAX_SEQ_LEN)

save_dict = {
    "text_vocab": text_vocab,
    "gloss_vocab": gloss_vocab,
    "inv_gloss": {v: k for k, v in gloss_vocab.items()},
    "text_matrix": text_matrix,   # torch.LongTensor
    "gloss_matrix": gloss_matrix, # torch.LongTensor
}
torch.save(save_dict, OUTPUT_PATH)
print(f"> Saved processed data to {OUTPUT_PATH}")


> Built text_vocab (size: 17) and gloss_vocab (size: 12)


Tokenizing: 100%|██████████| 5000/5000 [00:02<00:00, 2301.53it/s]


> Saved processed data to ../../datasets/Processed/Gloss_feat/processed_data.pt


### txtGlossDataset CLass

In [5]:
class TextGlossDataset(Dataset):
    def __init__(self):
        self.text_data = ["hello world", "good morning"]
        self.gloss_data = [["HELLO", "WORLD"], ["GOOD", "MORNING"]]
        self.text_vocab = {"<pad>":0, "<sos>":1, "<eos>":2, "hello":3, "world":4, "good":5, "morning":6}
        self.gloss_vocab = {"<pad>":0, "<sos>":1, "<eos>":2, "HELLO":3, "WORLD":4, "GOOD":5, "MORNING":6}
        self.inv_gloss = {v:k for k, v in self.gloss_vocab.items()}

    def __len__(self): return len(self.text_data)

    def __getitem__(self, idx):
        text = [self.text_vocab[w] for w in self.text_data[idx].split()]
        gloss = [self.gloss_vocab[g] for g in self.gloss_data[idx]]
        return torch.tensor(text), torch.tensor(gloss)

In [6]:
class TextGlossDataset2(Dataset):
    def __init__(self, parquet_path, max_seq_length=50):
        
        self.df = pd.read_parquet(parquet_path)
        
        self.text_vocab = {"<pad>":0, "<sos>":1, "<eos>":2, "<unk>":3}
        self.gloss_vocab = {"<pad>":0, "<sos>": 1, "<eos>": 2, "<unk>":3}
        self.max_seq_length = max_seq_length
        
        self._build_vocabs()
        
        self.inv_gloss = {v: k for k, v in self.gloss_vocab.items()}
        
    def _build_vocabs(self):
        
        text_words = []
        gloss_words = []
        
        for _, row in self.df.iterrows():
            
            gloss = row['text'].split('[INST]')[1].split('[/INST]')[0].strip()
            gloss_words.extend(gloss.split())
            
            text = row['text'].split('[/INST]')[1].replace('</s>', '').strip()
            text_words.extend(text.lower().split())
            
            # build text vocab
            text_counter = Counter(text_words)
            for word, _ in text_counter.most_common():
                if word not in self.text_vocab:
                    self.text_vocab[word] = len(self.text_vocab)
            
            #build gloss vocab
            gloss_counter = Counter(gloss_words)
            for gloss, _ in gloss_counter.most_common():
                if gloss not in self.gloss_vocab:
                    self.gloss_vocab[gloss] = len(self.gloss_vocab)
                    
    def __len__(self):
        return len(self.df)
    
    def _process_sequence(self, sequence, vocab, is_gloss=False):
        
        if is_gloss:
            seq = sequence.split('[INST]')[1].split('[/INST]')[0].strip().split()
        else:
            seq = sequence.split('[/INST]')[1].replace('</s>', '').strip().lower().split()
        
        indices = [vocab.get(word, vocab["<unk>"]) for word in seq]
        indices = [vocab["<sos>"]] + indices + [vocab["<eos>"]]
        
        if len(indices) < self.max_seq_length:
            indices = indices + [vocab["<pad>"]] * (self.max_seq_length - len(indices))
        else:
            indices = indices[:self.max_seq_length-1] + [vocab["<eos>"]]
        
        return torch.tensor(indices).to(device)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]['text']
        
        gloss = self._process_sequence(row, self.gloss_vocab, is_gloss=True)
        text = self._process_sequence(row, self.text_vocab, is_gloss=False)
        
        return text, gloss
    
    def decode_gloss(self, indices):
        return ' '.join([self.inv_gloss.get(idx, '<unk>') for idx in indices if idx not in {0, 1, 2}])

In [7]:
class TextGlossDataset3(Dataset): # takes from .pth
    def __init__(self, processed_path):
        
        data = torch.load(processed_path, map_location=torch.device("cpu"))
        self.text_vocab  = data["text_vocab"]
        self.gloss_vocab = data["gloss_vocab"]
        self.inv_gloss   = data["inv_gloss"]

        # Pre‐tokenized (N, max_seq_len)
        self.text_matrix  = data["text_matrix"]
        self.gloss_matrix = data["gloss_matrix"]

        assert self.text_matrix.size(0) == self.gloss_matrix.size(0), "Mismatch in example count"

    def __len__(self):
        return self.text_matrix.size(0)

    def __getitem__(self, idx):
        text_indices  = self.text_matrix[idx]
        gloss_indices = self.gloss_matrix[idx]
        return text_indices, gloss_indices

    def decode_gloss(self, indices):
        return " ".join(
            [self.inv_gloss.get(int(idx), "<unk>") for idx in indices if idx not in {0, 1, 2}]
        )

## G2P

In [8]:
class WLASLGlossPoseDataset(Dataset):
    def __init__(self, json_path, video_dir, gloss_vocab=None, max_samples=100, cache_dir="pose_cache", force_reprocess=False):
        """
        WLASL Gloss to Pose Dataset
        Args:
            json_path: Path to WLASL_v0.3.json
            video_dir: Directory containing video files
            gloss_vocab: Existing gloss vocabulary (or create new)
            max_samples: Maximum samples to process
            cache_dir: Directory to store extracted poses
            force_reprocess: Reprocess even if cached exists
        """
        self.video_dir = video_dir
        self.cache_dir = cache_dir
        self.force_reprocess = force_reprocess
        os.makedirs(cache_dir, exist_ok=True)
        
        with open(json_path, 'r') as f:
            data = json.load(f)
        
        if gloss_vocab is None:
            self.gloss_vocab = {}
            for i, entry in enumerate(data):
                self.gloss_vocab[entry['gloss']] = i
        else:
            self.gloss_vocab = gloss_vocab
            
        self.samples = []
        for entry in data[:max_samples]:
            gloss = entry['gloss']
            for instance in entry['instances']:
                video_id = instance['video_id']
                video_path = os.path.join(video_dir, f"{video_id}.mp4")
                if os.path.exists(video_path):
                    self.samples.append((gloss, video_path, video_id))
        
        # Initialize MediaPipe Pose
        self.mp_pose = mp.solutions.pose.Pose(
            static_image_mode=False,
            model_complexity=2,
            enable_segmentation=False,
            min_detection_confidence=0.5
        )
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        gloss, video_path, video_id = self.samples[idx]
        cache_path = os.path.join(self.cache_dir, f"{video_id}.pt")
        
        # Load from cache or process video
        if os.path.exists(cache_path) and not self.force_reprocess:
            pose_seq = torch.load(cache_path)
        else:
            pose_seq = self.process_video(video_path)
            torch.save(pose_seq, cache_path)
        
        return {
            'gloss': self.gloss_vocab[gloss],
            'pose': pose_seq,
            'video_id': video_id
        }
    
    def process_video(self, video_path):
        # extract pose sequence from video using
        cap = cv2.VideoCapture(video_path)
        pose_sequence = []
        
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
                
            # Process frame with MediaPipe
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            results = self.mp_pose.process(frame_rgb)
            
            if results.pose_landmarks:
                # Extract 33 pose landmarks (x, y, visibility)
                landmarks = []
                for landmark in results.pose_landmarks.landmark:
                    landmarks.extend([landmark.x, landmark.y, landmark.visibility])
                pose_sequence.append(landmarks)
            else:
                # Pad with zeros if no detection
                pose_sequence.append([0.0]*99)
        
        cap.release()
        return torch.tensor(pose_sequence, dtype=torch.float32)


In [9]:
dataset = WLASLGlossPoseDataset(
    json_path="WLASL_v0.3.json",
    video_dir="videos",
    max_samples=100  
)

sample = dataset[0]
print(f"Gloss ID: {sample['gloss']}")
print(f"Pose sequence shape: {sample['pose'].shape}")  # (num_frames, 33*3=99)
print(f"Video ID: {sample['video_id']}")

# DataLoader with padding
def collate_fn(batch):
    glosses = torch.tensor([item['gloss'] for item in batch])
    poses = [item['pose'] for item in batch]
    video_ids = [item['video_id'] for item in batch]
    
    # Pad pose sequences to same length
    poses_padded = torch.nn.utils.rnn.pad_sequence(
        poses, batch_first=True, padding_value=0.0
    )
    
    return {
        'gloss': glosses,
        'pose': poses_padded,
        'pose_lengths': torch.tensor([len(p) for p in poses]),
        'video_id': video_ids
    }

from torch.utils.data import DataLoader
loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

for batch in loader:
    print("\nBatch:")
    print(f"Glosses: {batch['gloss']}")
    print(f"Poses shape: {batch['pose'].shape}")  # (batch, max_frames, 99)
    print(f"Pose lengths: {batch['pose_lengths']}")
    break

FileNotFoundError: [Errno 2] No such file or directory: 'WLASL_v0.3.json'

# Text-to-Gloss

In [9]:
class Text2GlossTransformer(nn.Module):
    def __init__(self, text_vocab_size, gloss_vocab_size):
        super().__init__()
        self.text_embed = nn.Embedding(text_vocab_size, 256)
        self.gloss_embed = nn.Embedding(gloss_vocab_size, 256)
        self.transformer = nn.Transformer(
            d_model=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3
        ).to(device)
        self.fc = nn.Linear(256, gloss_vocab_size)

    def forward(self, src, tgt):
        src = self.text_embed(src).permute(1,0,2) # S, B, E
        tgt = self.gloss_embed(tgt).permute(1,0,2)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(0)).to(device)
        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        return self.fc(output).permute(1,0,2) # B, S, V

# Gloss-to-Pose

In [10]:
class Gloss2Pose(nn.Module):
    def __init__(self, gloss_vocab_size, pose_dim=51): # 17 key points * 3 (x, y, conf)
        super().__init__()
        self.embed = nn.Embedding(gloss_vocab_size, 128)
        self.conv = nn.Sequential(
            nn.Conv1d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Conv1d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(512, pose_dim, 3, padding=1)
        )

    def forward(self, gloss_seq):
        x = self.embed(gloss_seq).permute(0,2,1) # B, C, S
        return self.conv(x).permute(0,2,1) # B, S, D

In [11]:

class PoseVQVAE(nn.Module):
    def __init__(self, pose_dim=51, num_tokens=1024, embedding_dim=64, hidden_dim=256):
        super().__init__()
            
        self.encoder = nn.Sequential(
            nn.Conv1d(pose_dim, hidden_dim, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, embedding_dim, 1)
        )
        
        self.codebook = nn.Embedding(num_tokens, embedding_dim)
        
        self.decoder = nn.Sequential(
            nn.Conv1d(embedding_dim, hidden_dim, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(hidden_dim, pose_dim, 1)
        )
        
    def forward(self, poses):
        # (B, T, D)
        x = poses.permute(0,2,1) # (B, D, T)
        z = self.encoder(x) # (B, E, T)
        z = z.permute(0,2,1) # (B,T,E)
        
        #quanntization
        distance = torch.cdist(z, self.codebook.weight) #(B, T, K)
        tokens = distance.argmin(-1) # (B, T)
        quantized = self.codebook(tokens)
        
        # decode
        quantized = quantized.permute(0,2,1) # (B,E, T)
        recon = self.decoder(quantized).permute(0,2,1) # (B,T,D)
        
        #losses
        recon_loss = F.mse_loss(recon, poses)
        commit_loss = F.mse_loss(z.detach(), quantized)
        return recon, recon_loss + 0.25 * commit_loss, tokens

In [12]:
class G2PWithVQ(nn.Module):
    def __init__(self, gloss_vocab_size,pose_dim = 51, num_tokens = 1024, k=5):
        super().__init__()
        self.k = k
        self.vqvae = PoseVQVAE(pose_dim, num_tokens)
        
        self.embed = nn.Embedding(gloss_vocab_size, 128)
        self.conv = nn.Sequential(
            nn.Conv1d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(256, num_tokens, 3, padding=1)
        )
        
    def forward(self, gloss_seq):
        
        x = self.embed(gloss_seq)
        
        #unsampling
        x = x.repeat_interleave(self.k, dim=1) # (B, S*k, 128)
        x = x.permute(0,2,1) # (V, T, num_tokens)
        
        logits = self.conv(x).permute(0,2,1) # (B, T, num_tokens)
        return logits
    
    def decode(self, logits):
        # logits -> Pose seqs
        pred_tokens = logits.argmax(-1)
        return self.vqvae.decoder(
            self.vqvae.codebook(pred_tokens).permute(0,2,1)
        ).permute(0,2,1)
        
        

# Render Pose

In [13]:
def render_pose(pose, connections, frame_size=(512, 512)):
    frame = np.zeros((*frame_size, 3), dtype=np.uint8)
    keypoints = pose.reshape(-1,3)
    keypoints[:, :2] = keypoints[:, :2] * frame_size[0]

    # draw connections
    for i, j in connections:
        if keypoints[i, 2] > 0.2 and keypoints[j, 2] > 0.2: # conf thresh
            cv2.line(
                frame,
                (int(keypoints[i,0]), int(keypoints[i, 1])),
                (int(keypoints[j, 0]), int(keypoints[j, 1])),
                (255, 166, 2), # orange
                2
            )


    # draw points
    for i, (x, y, conf) in enumerate(keypoints):
        if conf > 0.2:
            cv2.circle(frame, (int(x), int(y)), 5, (0, 255, 255), -1) # yellow

    return frame

# Train

In [14]:
train_index = train_index + 1

# dataset = TextGlossDataset2('../../datasets/train-00000-of-00001.parquet')
dataset = TextGlossDataset3(OUTPUT_PATH)

t2g_model = Text2GlossTransformer(
    len(dataset.text_vocab),
    len(dataset.gloss_vocab)
).to(device)

g2p_model = Gloss2Pose(len(dataset.gloss_vocab)).to(device)

loader = DataLoader(dataset, batch_size=2, shuffle=True)
optimizer = torch.optim.Adam(
    list(t2g_model.parameters()) + list(g2p_model.parameters()),
    lr = 1e-4
)

losses = []

for epoch in tqdm(range(30)):
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)

        optimizer.zero_grad()

        decoder_input = torch.cat([
            torch.ones_like(tgt[:, :1]) * dataset.gloss_vocab["<sos>"],
            tgt[:, :-1]
        ], dim=1)

        gloss_logits = t2g_model(src, decoder_input)
        # gloss_loss = nn.CrossEntropyLoss()(
        #     gloss_logits.view(-1, gloss_logits.size(-1)),
        #     tgt.view(-1)
        # )

        gloss_logits = gloss_logits.contiguous()

        loss = nn.CrossEntropyLoss()(
            gloss_logits.view(-1, gloss_logits.size(-1)),
            tgt.contiguous().view(-1)
        )

        # optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(loss.item())
    print(f"Epoch {epoch+1}: Loss={loss.item():.4f}")


  3%|▎         | 1/30 [02:30<1:12:40, 150.35s/it]

Epoch 1: Loss=0.0759


  7%|▋         | 2/30 [05:10<1:12:52, 156.16s/it]

Epoch 2: Loss=0.1692


 10%|█         | 3/30 [07:39<1:08:44, 152.76s/it]

Epoch 3: Loss=0.0529


 13%|█▎        | 4/30 [10:15<1:06:52, 154.31s/it]

Epoch 4: Loss=0.0378


 17%|█▋        | 5/30 [12:48<1:04:01, 153.66s/it]

Epoch 5: Loss=0.0435


 20%|██        | 6/30 [15:09<59:46, 149.43s/it]  

Epoch 6: Loss=0.0527


 23%|██▎       | 7/30 [17:54<59:15, 154.57s/it]

Epoch 7: Loss=0.0357


 27%|██▋       | 8/30 [24:22<1:23:53, 228.81s/it]

Epoch 8: Loss=0.0419


 30%|███       | 9/30 [27:18<1:14:18, 212.30s/it]

Epoch 9: Loss=0.0417


 33%|███▎      | 10/30 [30:13<1:06:53, 200.68s/it]

Epoch 10: Loss=0.0669


 37%|███▋      | 11/30 [33:05<1:00:49, 192.06s/it]

Epoch 11: Loss=0.0659


 40%|████      | 12/30 [35:56<55:41, 185.62s/it]  

Epoch 12: Loss=0.0401


 43%|████▎     | 13/30 [38:48<51:22, 181.32s/it]

Epoch 13: Loss=0.0560


 47%|████▋     | 14/30 [41:38<47:29, 178.11s/it]

Epoch 14: Loss=0.0284


 50%|█████     | 15/30 [44:30<44:02, 176.15s/it]

Epoch 15: Loss=0.0276


 53%|█████▎    | 16/30 [47:20<40:40, 174.29s/it]

Epoch 16: Loss=0.0418


 57%|█████▋    | 17/30 [50:20<38:10, 176.17s/it]

Epoch 17: Loss=0.0342


 60%|██████    | 18/30 [52:54<33:52, 169.35s/it]

Epoch 18: Loss=0.1051


 63%|██████▎   | 19/30 [55:52<31:32, 172.04s/it]

Epoch 19: Loss=0.0576


 67%|██████▋   | 20/30 [58:48<28:52, 173.20s/it]

Epoch 20: Loss=0.0308


 70%|███████   | 21/30 [1:01:43<26:03, 173.71s/it]

Epoch 21: Loss=0.1748


 73%|███████▎  | 22/30 [1:04:34<23:03, 173.00s/it]

Epoch 22: Loss=0.0313


 77%|███████▋  | 23/30 [1:07:24<20:04, 172.12s/it]

Epoch 23: Loss=0.0486


 80%|████████  | 24/30 [1:10:14<17:08, 171.36s/it]

Epoch 24: Loss=0.0345


 83%|████████▎ | 25/30 [1:13:11<14:24, 172.98s/it]

Epoch 25: Loss=0.0995


 87%|████████▋ | 26/30 [1:16:13<11:42, 175.72s/it]

Epoch 26: Loss=0.0513


 90%|█████████ | 27/30 [1:18:49<08:29, 169.82s/it]

Epoch 27: Loss=0.1861


 93%|█████████▎| 28/30 [1:21:40<05:40, 170.24s/it]

Epoch 28: Loss=0.0507


 97%|█████████▋| 29/30 [1:24:33<02:51, 171.06s/it]

Epoch 29: Loss=0.0959


100%|██████████| 30/30 [1:27:19<00:00, 174.66s/it]

Epoch 30: Loss=0.0528





In [16]:
t2g_losses = []
g2p_losses = []

In [15]:
torch.save(t2g_model.state_dict(), 't2g_model_weights.pth')
torch.save(g2p_model.state_dict(), 'g2p_model_weights.pth')

In [16]:
torch.save(t2g_model, 't2g_model.pth')
torch.save(g2p_model, 'g2p_model.pth')

In [17]:
def pretrain_vqvae(model, dataset, epochs=10):
    loader = DataLoader(dataset, batch_size=2, shuffle=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    
    for epoch in range(epochs):
        for batch in loader:
            _, _, pose = batch 
            pose = pose.to(device)
            
            optimizer.zero_grad()
            _, loss, _ = model(pose)
            loss.backward()
            optimizer.step()
        
        print(f"VQ-VAE Pre-train Epoch {epoch+1}: Loss={loss.item():.4f}")

In [18]:
def train2():
    dataset = TextGlossDataset3(OUTPUT_PATH)
    
    t2g_model = Text2GlossTransformer(
        len(dataset.text_vocab),
        len(dataset.gloss_vocab)
    ).to(device)
    
    g2p_model = G2PWithVQ(len(dataset.gloss_vocab)).to(device)
    
    
    
    t2g_optim = torch.optim.Adam(t2g_model.parameters(), lr=1e-4)
    g2p_optim = torch.optim.Adam(g2p_model.parameters(), lr=1e-4)
    
    # pre-train the VQ-VAE component
    print("Pre-training VQ-VAE...")
    pretrain_vqvae(g2p_model.vqvae, dataset)  
    
    loader = DataLoader(dataset, batch_size=2, shuffle=True)
    
    for epoch in range(30):
        for batch in loader:
            # batch: (text, gloss, pose)
            text, gloss, pose = batch
            text, gloss, pose = text.to(device), gloss.to(device), pose.to(device)
            
            t2g_optim.zero_grad()
            decoder_input = torch.cat([
                torch.ones_like(gloss[:, :1]) * dataset.gloss_vocab["<sos>"],
                gloss[:, :-1]
            ], dim=1)
            
            gloss_logits = t2g_model(text, decoder_input)
            t2g_loss = F.cross_entropy(
                gloss_logits.view(-1, gloss_logits.size(-1)),
                gloss.contiguous().view(-1)
            )
            t2g_loss.backward()
            t2g_optim.step()
            
            # Gloss to Pose
            g2p_optim.zero_grad()
            pred_tokens = g2p_model(gloss)
            
            # target tokens from VQ-VAE
            with torch.no_grad():
                _, _, target_tokens = g2p_model.vqvae(pose)
            
            g2p_loss = F.cross_entropy(
                pred_tokens.view(-1, pred_tokens.size(-1)),
                target_tokens.view(-1)
            )
            g2p_loss.backward()
            g2p_optim.step()
            
            t2g_losses.append(t2g_loss.item())
            g2p_losses.append(g2p_loss.item())
        print(f"Epoch {epoch+1}: T2G Loss={t2g_loss.item():.4f}, G2P Loss={g2p_loss.item():.4f}")

In [19]:
train2()



Pre-training VQ-VAE...


ValueError: not enough values to unpack (expected 3, got 2)

# Save

In [20]:
df = pd.DataFrame({
    'Loss' : losses
})

df.to_csv(f'train{train_index}.csv', index=True)

# Text-to-Sign testing

In [23]:
def text_to_sign(text):
    tokens = [dataset.text_vocab.get(w, 0) for w in text.split()]
    src_tokens = torch.tensor([tokens]).to(device)

    gloss_seq = [dataset.gloss_vocab["<sos>"]]
    for i in tqdm(range(20)):
        gloss_decoder_input = torch.tensor([gloss_seq]).to(device)
        with torch.no_grad():
            logits = t2g_model(src_tokens, gloss_decoder_input)
        next_id = logits[0, -1].argmax().item()
        if next_id == dataset.gloss_vocab["<eos>"]:
            break
        gloss_seq.append(next_id)

    glosses = [dataset.inv_gloss[idx] for idx in gloss_seq[1:]]
    print(f"Glosses: {' '.join(glosses)}")

    gloss_tensor = torch.tensor([gloss_seq[1:]]).to(device)  # Exclude <sos>
    with torch.no_grad():
        poses = g2p_model(gloss_tensor).cpu().numpy()[0]  # [S, D]

    connections = [
        (0,1), (0,2), (1,3), (2,4),        # Head
        (5,6), (5,7), (7,9), (6,8), (8,10), # Arms
        (11,12), (11,13), (13,15), (12,14), (14,16) # Legs
    ]

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(f'sign_output{train_index}.mp4', fourcc, 5.0, (512, 512))

    for pose in poses:
        pose_norm = (pose - np.min(pose)) / (np.max(pose) - np.min(pose) + 1e-8)
        frame = render_pose(pose_norm, connections)
        out.write(frame)

    out.release()
    return glosses, f'sign_output{train_index}.mp4'


In [24]:
text_to_sign("hello world")

100%|██████████| 20/20 [00:00<00:00, 37.66it/s]


Glosses: <sos> <sos> <sos> <sos> <sos> <sos> <sos> <sos> <sos> <sos> <sos> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk> <unk>


(['<sos>',
  '<sos>',
  '<sos>',
  '<sos>',
  '<sos>',
  '<sos>',
  '<sos>',
  '<sos>',
  '<sos>',
  '<sos>',
  '<sos>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>',
  '<unk>'],
 'sign_output739415.mp4')