# Imports

In [31]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import numpy as np
import cv2

In [32]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"device: {device}")

device: cuda


In [33]:
train_index = 6

In [34]:
import tqdm
import pandas as pd

# text2gloss

In [35]:
class TextGlossDataset(Dataset):
    def __init__(self):
        self.text_data = ["hello world", "good morning"]
        self.gloss_data = [["HELLO", "WORLD"], ["GOOD", "MORNING"]]
        self.text_vocab = {"<pad>":0, "<sos>":1, "<eos>":2, "hello":3, "world":4, "good":5, "morning":6}
        self.gloss_vocab = {"<pad>":0, "<sos>":1, "<eos>":2, "HELLO":3, "WORLD":4, "GOOD":5, "MORNING":6}
        self.inv_gloss = {v:k for k, v in self.gloss_vocab.items()}

    def __len__(self): return len(self.text_data)

    def __getitem__(self, idx):
        text = [self.text_vocab[w] for w in self.text_data[idx].split()]
        gloss = [self.gloss_vocab[g] for g in self.gloss_data[idx]]
        return torch.tensor(text), torch.tensor(gloss)

In [36]:
class Text2GlossTransformer(nn.Module):
    def __init__(self, text_vocab_size, gloss_vocab_size):
        super().__init__()
        self.text_embed = nn.Embedding(text_vocab_size, 256)
        self.gloss_embed = nn.Embedding(gloss_vocab_size, 256)
        self.transformer = nn.Transformer(
            d_model=256, nhead=8, num_encoder_layers=3, num_decoder_layers=3
        ).to(device)
        self.fc = nn.Linear(256, gloss_vocab_size)

    def forward(self, src, tgt):
        src = self.text_embed(src).permute(1,0,2) # S, B, E
        tgt = self.gloss_embed(tgt).permute(1,0,2)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(0)).to(device)
        output = self.transformer(src, tgt, tgt_mask=tgt_mask)
        return self.fc(output).permute(1,0,2) # B, S, V

In [37]:
class Gloss2Pose(nn.Module):
    def __init__(self, gloss_vocab_size, pose_dim=51): # 17 key points * 3 (x, y, conf)
        super().__init__()
        self.embed = nn.Embedding(gloss_vocab_size, 128)
        self.conv = nn.Sequential(
            nn.Conv1d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.BatchNorm1d(256),
            nn.Conv1d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv1d(512, pose_dim, 3, padding=1)
        )

    def forward(self, gloss_seq):
        x = self.embed(gloss_seq).permute(0,2,1) # B, C, S
        return self.conv(x).permute(0,2,1) # B, S, D

In [38]:
def render_pose(pose, connections, frame_size=(512, 512)):
    frame = np.zeros((*frame_size, 3), dtype=np.uint8)
    keypoints = pose.reshape(-1,3)
    keypoints[:, :2] = keypoints[:, :2] * frame_size[0]

    # draw connections
    for i, j in connections:
        if keypoints[i, 2] > 0.2 and keypoints[j, 2] > 0.2: # conf thresh
            cv2.line(
                frame,
                (int(keypoints[i,0]), int(keypoints[i, 1])),
                (int(keypoints[j, 0]), int(keypoints[j, 1])),
                (255, 166, 2), # orange
                2
            )


    # draw points
    for i, (x, y, conf) in enumerate(keypoints):
        if conf > 0.2:
            cv2.circle(frame, (int(x), int(y)), 5, (0, 255, 255), -1) # yellow

    return frame

In [39]:
train_index = train_index + 1

dataset = TextGlossDataset()

t2g_model = Text2GlossTransformer(
    len(dataset.text_vocab),
    len(dataset.gloss_vocab)
).to(device)

g2p_model = Gloss2Pose(len(dataset.gloss_vocab)).to(device)

loader = DataLoader(dataset, batch_size=2, shuffle=True)
optimizer = torch.optim.Adam(
    list(t2g_model.parameters()) + list(g2p_model.parameters()),
    lr = 1e-4
)

losses = []

for epoch in tqdm.tqdm(range(50)):
    for src, tgt in loader:
        src, tgt = src.to(device), tgt.to(device)

        optimizer.zero_grad()

        decoder_input = torch.cat([
            torch.ones_like(tgt[:, :1]) * dataset.gloss_vocab["<sos>"],
            tgt[:, :-1]
        ], dim=1)

        gloss_logits = t2g_model(src, decoder_input)
        # gloss_loss = nn.CrossEntropyLoss()(
        #     gloss_logits.view(-1, gloss_logits.size(-1)),
        #     tgt.view(-1)
        # )

        gloss_logits = gloss_logits.contiguous()

        loss = nn.CrossEntropyLoss()(
            gloss_logits.view(-1, gloss_logits.size(-1)),
            tgt.contiguous().view(-1)
        )

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    losses.append(loss.item())
    print(f"Epoch {epoch+1}: Loss={loss.item():.4f}")


  6%|▌         | 3/50 [00:00<00:03, 12.23it/s]

Epoch 1: Loss=2.1472
Epoch 2: Loss=1.3307
Epoch 3: Loss=0.9652
Epoch 4: Loss=0.6635


 14%|█▍        | 7/50 [00:00<00:02, 14.80it/s]

Epoch 5: Loss=0.7357
Epoch 6: Loss=0.6038
Epoch 7: Loss=0.5030
Epoch 8: Loss=0.3529


 22%|██▏       | 11/50 [00:00<00:02, 15.06it/s]

Epoch 9: Loss=0.3299
Epoch 10: Loss=0.3349
Epoch 11: Loss=0.2360
Epoch 12: Loss=0.2328


 30%|███       | 15/50 [00:01<00:02, 16.28it/s]

Epoch 13: Loss=0.1673
Epoch 14: Loss=0.0940
Epoch 15: Loss=0.1292
Epoch 16: Loss=0.0900


 38%|███▊      | 19/50 [00:01<00:01, 16.29it/s]

Epoch 17: Loss=0.1401
Epoch 18: Loss=0.0979
Epoch 19: Loss=0.0631
Epoch 20: Loss=0.0542


 46%|████▌     | 23/50 [00:01<00:01, 16.13it/s]

Epoch 21: Loss=0.0345
Epoch 22: Loss=0.0336
Epoch 23: Loss=0.0308
Epoch 24: Loss=0.0255


 54%|█████▍    | 27/50 [00:01<00:01, 16.43it/s]

Epoch 25: Loss=0.0189
Epoch 26: Loss=0.0202
Epoch 27: Loss=0.0169
Epoch 28: Loss=0.0157


 62%|██████▏   | 31/50 [00:02<00:01, 15.88it/s]

Epoch 29: Loss=0.0128
Epoch 30: Loss=0.0183
Epoch 31: Loss=0.0171
Epoch 32: Loss=0.0113


 70%|███████   | 35/50 [00:02<00:00, 15.02it/s]

Epoch 33: Loss=0.0153
Epoch 34: Loss=0.0153
Epoch 35: Loss=0.0119


 74%|███████▍  | 37/50 [00:02<00:00, 14.39it/s]

Epoch 36: Loss=0.0110
Epoch 37: Loss=0.0101
Epoch 38: Loss=0.0100


 82%|████████▏ | 41/50 [00:02<00:00, 13.85it/s]

Epoch 39: Loss=0.0078
Epoch 40: Loss=0.0091
Epoch 41: Loss=0.0071
Epoch 42: Loss=0.0077


 90%|█████████ | 45/50 [00:03<00:00, 14.52it/s]

Epoch 43: Loss=0.0071
Epoch 44: Loss=0.0065
Epoch 45: Loss=0.0055
Epoch 46: Loss=0.0051


100%|██████████| 50/50 [00:03<00:00, 15.15it/s]

Epoch 47: Loss=0.0059
Epoch 48: Loss=0.0075
Epoch 49: Loss=0.0060
Epoch 50: Loss=0.0064





In [40]:
df = pd.DataFrame({
    'Loss' : losses
})

df.to_csv(f'train{train_index}.csv', index=True)

In [41]:
def text_to_sign(text):
    tokens = [dataset.text_vocab.get(w, 0) for w in text.split()]
    src_tokens = torch.tensor([tokens]).to(device)

    gloss_seq = [dataset.gloss_vocab["<sos>"]]
    for i in tqdm.tqdm(range(20)):
        gloss_decoder_input = torch.tensor([gloss_seq]).to(device)
        with torch.no_grad():
            logits = t2g_model(src_tokens, gloss_decoder_input)
        next_id = logits[0, -1].argmax().item()
        if next_id == dataset.gloss_vocab["<eos>"]:
            break
        gloss_seq.append(next_id)

    glosses = [dataset.inv_gloss[idx] for idx in gloss_seq[1:]]
    print(f"Glosses: {' '.join(glosses)}")

    gloss_tensor = torch.tensor([gloss_seq[1:]]).to(device)  # Exclude <sos>
    with torch.no_grad():
        poses = g2p_model(gloss_tensor).cpu().numpy()[0]  # [S, D]

    connections = [
        (0,1), (0,2), (1,3), (2,4),        # Head
        (5,6), (5,7), (7,9), (6,8), (8,10), # Arms
        (11,12), (11,13), (13,15), (12,14), (14,16) # Legs
    ]

    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(f'sign_output{train_index}.mp4', fourcc, 5.0, (512, 512))

    for pose in poses:
        pose_norm = (pose - np.min(pose)) / (np.max(pose) - np.min(pose) + 1e-8)
        frame = render_pose(pose_norm, connections)
        out.write(frame)

    out.release()
    return glosses, f'sign_output{train_index}.mp4'


In [42]:
text_to_sign("hello world")

100%|██████████| 20/20 [00:00<00:00, 70.92it/s]


Glosses: HELLO WORLD HELLO WORLD WORLD WORLD WORLD WORLD WORLD WORLD WORLD WORLD WORLD WORLD WORLD WORLD HELLO WORLD WORLD WORLD


(['HELLO',
  'WORLD',
  'HELLO',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'WORLD',
  'HELLO',
  'WORLD',
  'WORLD',
  'WORLD'],
 'sign_output7.mp4')