# Imports

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import numpy as np
import cv2

from datetime import date

In [5]:
import tqdm
import pandas as pd

In [26]:
from collections import Counter

# Config

In [None]:
train_index = date.today().toordinal()

DATA_PATH = '../../datasets/train-00000-of-00001.parquet'
OUTPUT_PATH = '../../datasets/Processed/Gloss_feat/processed_data.pt'

MAX_SEQ_LEN = 50

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

# Dataset

## Preprocess

In [None]:
def build_vocab(df):
    
    text_counter = Counter()
    gloss_counter = Counter()
    
    for _, row in df.iterrows():
        full_text = row["text"]
        
        gloss = full_text.split("[INST]")[1].split("[/INST]")[0].strip()
        gloss_tokens = gloss.split()
        gloss_counter.update(gloss_tokens)
        
        text = full_text.split("[/INST]")[1].replace("</s>", "").strip().lower()
        text_tokens = text.split()
        text_counter.update(text_tokens)
        
        base_tokens = ["<pad>", "<sos>", "<eos>", "<unk>"]
        
        text_vocab = {tok: idx for idx, tok in enumerate(base_tokens)}
        for tok, _ in text_counter.most_common():
            if tok not in text_vocab:
                text_vocab[tok] = len(text_vocab)
                
        gloss_vocab = {tok: idx for idx, tok in enumerate(base_tokens)}
        for tok, _ in gloss_counter.most_common():
            if tok not in gloss_vocab:
                gloss_vocab[tok] = len(gloss_vocab)
                
        return text_vocab, gloss_vocab

def tokenize_and_pad(sequence, vocab, max_len, is_gloss):
    if is_gloss:
        tokens = sequence.split("[INST]")[1].split("[/INST]")[0].strip().split()
    else:
        tokens = sequence.split("[/INST]")[1].replace("</s>", "").strip().lower().split()
    
    idxs = [vocab.get(w, vocab["<unk>"]) for w in tokens]
    idxs = [vocab["<sos>"]] + idxs + [vocab["<eos>"]]
    
    if len(idxs) < max_len:
        idxs = idxs + [vocab["<pad>"]] * (max_len - len(idxs))
    else:
        idxs = idxs[: max_len - 1] + [vocab["<eos>"]]

    return torch.tensor(idxs, dtype=torch.long)

In [None]:
df = pd.read_parquet(PARQUET_PATH)

text_vocab, gloss_vocab = build_vocabularies(df)
print(f"> Built text_vocab (size: {len(text_vocab)}) and gloss_vocab (size: {len(gloss_vocab)})")

all_text_tensors  = []
all_gloss_tensors = []

for _, row in tqdm(df.iterrows(), total=len(df), desc="Tokenizing"):
    raw = row["text"]
    gloss_tensor = tokenize_and_pad(raw, gloss_vocab, max_len=MAX_SEQ_LEN, is_gloss=True)
    text_tensor  = tokenize_and_pad(raw, text_vocab,  max_len=MAX_SEQ_LEN, is_gloss=False)

    all_gloss_tensors.append(gloss_tensor.to(device))
    all_text_tensors.append(text_tensor.to(device))

gloss_matrix = torch.stack(all_gloss_tensors)   # (N, MAX_SEQ_LEN)
text_matrix  = torch.stack(all_text_tensors)    # ( N,MAX_SEQ_LEN)

save_dict = {
    "text_vocab": text_vocab,
    "gloss_vocab": gloss_vocab,
    "inv_gloss": {v: k for k, v in gloss_vocab.items()},
    "text_matrix": text_matrix,   # torch.LongTensor
    "gloss_matrix": gloss_matrix, # torch.LongTensor
}
torch.save(save_dict, OUTPUT_PATH)
print(f"> Saved processed data to {OUTPUT_PATH}")


## txtGlossDataset CLass

In [49]:
class TextGlossDataset(Dataset):
    def __init__(self):
        self.text_data = ["hello world", "good morning"]
        self.gloss_data = [["HELLO", "WORLD"], ["GOOD", "MORNING"]]
        self.text_vocab = {"<pad>":0, "<sos>":1, "<eos>":2, "hello":3, "world":4, "good":5, "morning":6}
        self.gloss_vocab = {"<pad>":0, "<sos>":1, "<eos>":2, "HELLO":3, "WORLD":4, "GOOD":5, "MORNING":6}
        self.inv_gloss = {v:k for k, v in self.gloss_vocab.items()}

    def __len__(self): return len(self.text_data)

    def __getitem__(self, idx):
        text = [self.text_vocab[w] for w in self.text_data[idx].split()]
        gloss = [self.gloss_vocab[g] for g in self.gloss_data[idx]]
        return torch.tensor(text), torch.tensor(gloss)

In [50]:
class TextGlossDataset2(Dataset):
    def __init__(self, parquet_path, max_seq_length=50):
        
        self.df = pd.read_parquet(parquet_path)
        
        self.text_vocab = {"<pad>":0, "<sos>":1, "<eos>":2, "<unk>":3}
        self.gloss_vocab = {"<pad>":0, "<sos>": 1, "<eos>": 2, "<unk>":3}
        self.max_seq_length = max_seq_length
        
        self._build_vocabs()
        
        self.inv_gloss = {v: k for k, v in self.gloss_vocab.items()}
        
    def _build_vocabs(self):
        
        text_words = []
        gloss_words = []
        
        for _, row in self.df.iterrows():
            
            gloss = row['text'].split('[INST]')[1].split('[/INST]')[0].strip()
            gloss_words.extend(gloss.split())
            
            text = row['text'].split('[/INST]')[1].replace('</s>', '').strip()
            text_words.extend(text.lower().split())
            
            # build text vocab
            text_counter = Counter(text_words)
            for word, _ in text_counter.most_common():
                if word not in self.text_vocab:
                    self.text_vocab[word] = len(self.text_vocab)
            
            #build gloss vocab
            gloss_counter = Counter(gloss_words)
            for gloss, _ in gloss_counter.most_common():
                if gloss not in self.gloss_vocab:
                    self.gloss_vocab[gloss] = len(self.gloss_vocab)
                    
    def __len__(self):
        return len(self.df)
    
    def _process_sequence(self, sequence, vocab, is_gloss=False):
        
        if is_gloss:
            seq = sequence.split('[INST]')[1].split('[/INST]')[0].strip().split()
        else:
            seq = sequence.split('[/INST]')[1].replace('</s>', '').strip().lower().split()
        
        indices = [vocab.get(word, vocab["<unk>"]) for word in seq]
        indices = [vocab["<sos>"]] + indices + [vocab["<eos>"]]
        
        if len(indices) < self.max_seq_length:
            indices = indices + [vocab["<pad>"]] * (self.max_seq_length - len(indices))
        else:
            indices = indices[:self.max_seq_length-1] + [vocab["<eos>"]]
        
        return torch.tensor(indices).to(device)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]['text']
        
        gloss = self._process_sequence(row, self.gloss_vocab, is_gloss=True)
        text = self._process_sequence(row, self.text_vocab, is_gloss=False)
        
        return text, gloss
    
    def decode_gloss(self, indices):
        return ' '.join([self.inv_gloss.get(idx, '<unk>') for idx in indices if idx not in {0, 1, 2}])

# Text-to-Gloss

In [51]:
class Text2GlossTransformer(nn.Module):
    def __init__(self, text_vocab_size, gloss_vocab_size):
        super().__init__()
        self.text_embed = nn.Embedding(text_vocab_size, 256)
        self.gloss_embed = nn.Embedding(gloss_vocab_size, 256)
        self.transformer = nn.Transformer(
            d_model=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3
        ).to(device)
        self.fc = nn.Linear(256, gloss_vocab_size)

    def forward(self, src, tgt):
        src = self.text_embed(src).permute(1,0,2) # S, B, E
        tgt = self.gloss_embed(tgt).permute(1,0,2)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(0)).to(device)
        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        return self.fc(output).permute(1,0,2) # B, S, V

# Gloss-to-Pose

In [52]:
class Gloss2Pose(nn.Module):
    def __init__(self, gloss_vocab_size, pose_dim=51): # 17 key points * 3 (x, y, conf)
        super().__init__()
        self.embed = nn.Embedding(gloss_vocab_size, 128)
        self.conv = nn.Sequential(
            nn.Conv1d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Conv1d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(512, pose_dim, 3, padding=1)
        )

    def forward(self, gloss_seq):
        x = self.embed(gloss_seq).permute(0,2,1) # B, C, S
        return self.conv(x).permute(0,2,1) # B, S, D

# Render Pose

In [53]:
def render_pose(pose, connections, frame_size=(512, 512)):
    frame = np.zeros((*frame_size, 3), dtype=np.uint8)
    keypoints = pose.reshape(-1,3)
    keypoints[:, :2] = keypoints[:, :2] * frame_size[0]

    # draw connections
    for i, j in connections:
        if keypoints[i, 2] > 0.2 and keypoints[j, 2] > 0.2: # conf thresh
            cv2.line(
                frame,
                (int(keypoints[i,0]), int(keypoints[i, 1])),
                (int(keypoints[j, 0]), int(keypoints[j, 1])),
                (255, 166, 2), # orange
                2
            )


    # draw points
    for i, (x, y, conf) in enumerate(keypoints):
        if conf > 0.2:
            cv2.circle(frame, (int(x), int(y)), 5, (0, 255, 255), -1) # yellow

    return frame

# Train

In [54]:
train_index = train_index + 1

dataset = TextGlossDataset2('../../datasets/train-00000-of-00001.parquet')

t2g_model = Text2GlossTransformer(
    len(dataset.text_vocab),
    len(dataset.gloss_vocab)
).to(device)

g2p_model = Gloss2Pose(len(dataset.gloss_vocab)).to(device)

loader = DataLoader(dataset, batch_size=2, shuffle=True)
optimizer = torch.optim.Adam(
    list(t2g_model.parameters()) + list(g2p_model.parameters()),
    lr = 1e-4
)

losses = []

for epoch in tqdm.tqdm(range(30)):
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)

        optimizer.zero_grad()

        decoder_input = torch.cat([
            torch.ones_like(tgt[:, :1]) * dataset.gloss_vocab["<sos>"],
            tgt[:, :-1]
        ], dim=1)

        gloss_logits = t2g_model(src, decoder_input)
        # gloss_loss = nn.CrossEntropyLoss()(
        #     gloss_logits.view(-1, gloss_logits.size(-1)),
        #     tgt.view(-1)
        # )

        gloss_logits = gloss_logits.contiguous()

        loss = nn.CrossEntropyLoss()(
            gloss_logits.view(-1, gloss_logits.size(-1)),
            tgt.contiguous().view(-1)
        )

        # optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(loss.item())
    print(f"Epoch {epoch+1}: Loss={loss.item():.4f}")


  3%|▎         | 1/30 [02:27<1:11:07, 147.15s/it]

Epoch 1: Loss=1.0963


  7%|▋         | 2/30 [04:53<1:08:28, 146.72s/it]

Epoch 2: Loss=1.5963


 10%|█         | 3/30 [07:19<1:05:56, 146.53s/it]

Epoch 3: Loss=1.7186


 13%|█▎        | 4/30 [10:11<1:07:47, 156.44s/it]

Epoch 4: Loss=0.9698


 17%|█▋        | 5/30 [13:11<1:08:41, 164.86s/it]

Epoch 5: Loss=0.8943


 20%|██        | 6/30 [16:12<1:08:11, 170.48s/it]

Epoch 6: Loss=0.8216


 23%|██▎       | 7/30 [19:04<1:05:32, 170.99s/it]

Epoch 7: Loss=0.8037


 27%|██▋       | 8/30 [21:46<1:01:40, 168.21s/it]

Epoch 8: Loss=0.7741


 30%|███       | 9/30 [24:45<1:00:00, 171.44s/it]

Epoch 9: Loss=0.4984


 33%|███▎      | 10/30 [27:42<57:42, 173.11s/it] 

Epoch 10: Loss=0.5031


 37%|███▋      | 11/30 [30:33<54:37, 172.47s/it]

Epoch 11: Loss=0.3680


 40%|████      | 12/30 [33:31<52:13, 174.07s/it]

Epoch 12: Loss=0.3595


 43%|████▎     | 13/30 [36:11<48:07, 169.83s/it]

Epoch 13: Loss=0.1871


 47%|████▋     | 14/30 [38:40<43:36, 163.51s/it]

Epoch 14: Loss=0.1747


 50%|█████     | 15/30 [41:35<41:47, 167.17s/it]

Epoch 15: Loss=0.1260


 53%|█████▎    | 16/30 [44:37<40:01, 171.50s/it]

Epoch 16: Loss=0.1181


 57%|█████▋    | 17/30 [47:42<38:04, 175.73s/it]

Epoch 17: Loss=0.0306


 60%|██████    | 18/30 [50:45<35:33, 177.78s/it]

Epoch 18: Loss=0.1596


 63%|██████▎   | 19/30 [53:46<32:46, 178.78s/it]

Epoch 19: Loss=0.0515


 67%|██████▋   | 20/30 [56:36<29:22, 176.28s/it]

Epoch 20: Loss=0.0323


 70%|███████   | 21/30 [59:08<25:18, 168.72s/it]

Epoch 21: Loss=0.0514


 73%|███████▎  | 22/30 [1:01:42<21:55, 164.49s/it]

Epoch 22: Loss=0.0365


 77%|███████▋  | 23/30 [1:04:43<19:45, 169.42s/it]

Epoch 23: Loss=0.0945


 80%|████████  | 24/30 [1:07:41<17:12, 172.03s/it]

Epoch 24: Loss=0.0284


 83%|████████▎ | 25/30 [1:10:40<14:30, 174.18s/it]

Epoch 25: Loss=0.0170


 87%|████████▋ | 26/30 [1:13:37<11:39, 174.94s/it]

Epoch 26: Loss=0.0494


 90%|█████████ | 27/30 [1:16:12<08:27, 169.02s/it]

Epoch 27: Loss=0.0439


 93%|█████████▎| 28/30 [1:18:45<05:27, 163.99s/it]

Epoch 28: Loss=0.0448


 97%|█████████▋| 29/30 [1:21:22<02:41, 161.98s/it]

Epoch 29: Loss=0.0313


100%|██████████| 30/30 [1:23:56<00:00, 167.88s/it]

Epoch 30: Loss=0.0098





# Save

In [55]:
df = pd.DataFrame({
    'Loss' : losses
})

df.to_csv(f'train{train_index}.csv', index=True)

# Text-to-Sign testing

In [56]:
def text_to_sign(text):
    tokens = [dataset.text_vocab.get(w, 0) for w in text.split()]
    src_tokens = torch.tensor([tokens]).to(device)

    gloss_seq = [dataset.gloss_vocab["<sos>"]]
    for i in tqdm.tqdm(range(20)):
        gloss_decoder_input = torch.tensor([gloss_seq]).to(device)
        with torch.no_grad():
            logits = t2g_model(src_tokens, gloss_decoder_input)
        next_id = logits[0, -1].argmax().item()
        if next_id == dataset.gloss_vocab["<eos>"]:
            break
        gloss_seq.append(next_id)

    glosses = [dataset.inv_gloss[idx] for idx in gloss_seq[1:]]
    print(f"Glosses: {' '.join(glosses)}")

    gloss_tensor = torch.tensor([gloss_seq[1:]]).to(device)  # Exclude <sos>
    with torch.no_grad():
        poses = g2p_model(gloss_tensor).cpu().numpy()[0]  # [S, D]

    connections = [
        (0,1), (0,2), (1,3), (2,4),        # Head
        (5,6), (5,7), (7,9), (6,8), (8,10), # Arms
        (11,12), (11,13), (13,15), (12,14), (14,16) # Legs
    ]

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(f'sign_output{train_index}.mp4', fourcc, 5.0, (512, 512))

    for pose in poses:
        pose_norm = (pose - np.min(pose)) / (np.max(pose) - np.min(pose) + 1e-8)
        frame = render_pose(pose_norm, connections)
        out.write(frame)

    out.release()
    return glosses, f'sign_output{train_index}.mp4'


In [57]:
text_to_sign("hello world")

100%|██████████| 20/20 [00:00<00:00, 48.68it/s]

Glosses: <sos> SON RETURN HOME, HOLD ROAD, THEN NOW GOURD LOWER VILLAGE, THEN NOW NOT THEN NOW GOURD FILL, NO LEAK.





(['<sos>',
  'SON',
  'RETURN',
  'HOME,',
  'HOLD',
  'ROAD,',
  'THEN',
  'NOW',
  'GOURD',
  'LOWER',
  'VILLAGE,',
  'THEN',
  'NOW',
  'NOT',
  'THEN',
  'NOW',
  'GOURD',
  'FILL,',
  'NO',
  'LEAK.'],
 'sign_output739415.mp4')