In [None]:
!pip -q install torch torchvision torchaudio matplotlib numpy

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m96.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m96.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m58.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os, re, glob, math, random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

# =========================================================
# 1) Config
# =========================================================
DATA_DIR    = "/content/data_npy"
SEQ_LEN     = 60                    # 2s @ 30fps
F           = 63                    # 21*3 (x,y,z) for one hand
BATCH_SIZE  = 16
EPOCHS      = 60
LR          = 1e-3
HIDDEN      = 256
NUM_LAYERS  = 2
EMB_DIM     = 64
TIME_DIM    = 32
SEED        = 42

random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)

Device: cuda


In [None]:
# =========================================================
# 2) Utilities: resample, masks, normalization
# =========================================================
def temporal_resample(seq, target_len):
    """Linear-resample seq: (T,F)->(target_len,F)"""
    T = seq.shape[0]
    if T == target_len:
        return seq
    idxs = np.linspace(0, T-1, num=target_len)
    lo = np.floor(idxs).astype(int)
    hi = np.clip(lo+1, 0, T-1)
    w  = (idxs - lo)[:, None]
    return (1-w)*seq[lo] + w*seq[hi]

def make_frame_mask(seq):
    """1 if frame has any nonzero coord; else 0."""
    return (np.abs(seq).sum(axis=1) > 0).astype(np.float32)

def per_frame_handscale(seq):
    """Distance wrist(0) <-> index_mcp(5) as size proxy."""
    coords = seq.reshape(-1, 21, 3)
    v = coords[:,0,:] - coords[:,5,:]
    d = np.linalg.norm(v, axis=1)
    if np.any(d > 0):
        med = np.median(d[d>0]); d[d<=1e-6] = med
    else:
        d[:] = 1.0
    return d

def scale_normalize(seq):
    """Divide each frame by its size proxy (per-frame); fallback to per-clip median."""
    s = per_frame_handscale(seq)
    med = np.median(s[s>0]) if np.any(s>0) else 1.0
    s[s<=1e-6] = med
    return seq / s[:, None]

In [None]:
# =========================================================
# 3) (Light) Augmentations for training
# =========================================================
def add_noise(x, std=0.002):
    return x + np.random.normal(0, std, x.shape)

def time_shift(x, max_shift=5, use_zeros=False):
    T = x.shape[0]
    shift = np.random.randint(-max_shift, max_shift+1)
    if shift == 0: return x
    if use_zeros:
        if shift > 0:
            return np.pad(x, ((shift,0),(0,0)), mode='constant')[:T]
        else:
            return np.pad(x, ((0,-shift),(0,0)), mode='constant')[-T:]
    if shift > 0:
        pad = np.repeat(x[:1], shift, axis=0)
        return np.concatenate([pad, x], axis=0)[:T]
    else:
        pad = np.repeat(x[-1:], -shift, axis=0)
        return np.concatenate([x, pad], axis=0)[-T:]

def random_scale(x, min_scale=0.9, max_scale=1.1):
    return x * np.random.uniform(min_scale, max_scale)

def temporal_dropout(x, drop_prob=0.05, mode="hold"):
    T, F_ = x.shape
    keep = np.random.rand(T) > drop_prob
    if keep.all(): return x
    if mode == "zeros":
        out = x.copy(); out[~keep] = 0; return out
    if mode == "hold":
        out = x.copy(); last = out[0]
        for t in range(T):
            if keep[t]: last = out[t]
            else: out[t] = last
        return out
    if mode == "interp":
        out = x.copy()
        idx = np.arange(T)
        vidx = idx[keep]
        if len(vidx) < 2: return x
        for f in range(F_):
            out[:, f] = np.interp(idx, vidx, x[keep, f])
        return out
    return x

def rotate_z(x, max_angle=10):
    T, F_ = x.shape
    pts = x.reshape(T, 21, 3)
    theta = np.radians(np.random.uniform(-max_angle, max_angle))
    c, s = np.cos(theta), np.sin(theta)
    R = np.array([[c,-s,0],[s,c,0],[0,0,1]], dtype=x.dtype)
    pts = pts @ R.T
    return pts.reshape(T, F_)

def augment_sample_train(x):
    x = add_noise(x, std=0.002)
    x = time_shift(x, max_shift=5, use_zeros=False)
    x = random_scale(x, 0.9, 1.1)
    x = temporal_dropout(x, drop_prob=0.05, mode="hold")
    x = rotate_z(x, max_angle=10)
    return x

def augment_sample_val(x):
    return x  # keep validation clean

In [None]:
# =========================================================
# 4) Load data with labels
# =========================================================
SPLIT_RE = re.compile(r'^(train|val)\d*$', re.IGNORECASE)

def load_data_with_labels(data_dir, seq_len=SEQ_LEN):
    files = [f for f in os.listdir(data_dir) if f.endswith(".npy")]
    if not files:
        raise RuntimeError("No .npy files found in data_dir")

    # Build label map from filename prefix
    from collections import OrderedDict
    label_map = OrderedDict()
    for f in sorted(files):
        parts = f[:-4].split("_")
        if len(parts) < 2: continue
        label = parts[0]
        if label not in label_map:
            label_map[label] = len(label_map)

    # Pass 1: mean/std from TRAIN frames (after normalization)
    tmp_train = []
    for f in files:
        parts = f[:-4].split("_")
        if len(parts) < 2: continue
        split_raw = parts[1]
        if not SPLIT_RE.match(split_raw): continue
        arr = np.load(os.path.join(data_dir, f)).astype(np.float32)
        arr = temporal_resample(arr, seq_len)
        m   = make_frame_mask(arr)
        arr = scale_normalize(arr)
        if split_raw.lower().startswith("train") and m.sum() > 0:
            tmp_train.append(arr[m>0.5])
    if not tmp_train:
        raise RuntimeError("No valid train frames to compute mean/std.")
    cat = np.vstack(tmp_train)
    mean = cat.mean(axis=0, keepdims=True)
    std  = cat.std(axis=0, keepdims=True) + 1e-6

    # Pass 2: load all, standardize, (augment train before norm)
    Xtr, Mtr, ytr, fntr = [], [], [], []
    Xva, Mva, yva, fnva = [], [], [], []
    for f in sorted(files):
        parts = f[:-4].split("_")
        if len(parts) < 2: continue
        label, split_raw = parts[0], parts[1]
        if not SPLIT_RE.match(split_raw):
            print(f"Skipping {f}: split not recognized");
            continue

        x = np.load(os.path.join(data_dir, f)).astype(np.float32)
        x = temporal_resample(x, seq_len)

        # augment only train BEFORE mask/normalize/standardize
        if split_raw.lower().startswith("train"):
            x = augment_sample_train(x)
        else:
            x = augment_sample_val(x)

        m = make_frame_mask(x)
        x = scale_normalize(x)
        x = (x - mean) / std

        if split_raw.lower().startswith("train"):
            Xtr.append(x); Mtr.append(m); ytr.append(label_map[label]); fntr.append(f)
        else:
            Xva.append(x); Mva.append(m); yva.append(label_map[label]); fnva.append(f)

    return (np.stack(Xtr), np.stack(Mtr), np.array(ytr), fntr,
            np.stack(Xva), np.stack(Mva), np.array(yva), fnva,
            label_map, mean, std)

# ---- Load now
(Xtr, Mtr, ytr, fntr,
 Xva, Mva, yva, fnva,
 LABEL_MAP, TRAIN_MEAN, TRAIN_STD) = load_data_with_labels(DATA_DIR, SEQ_LEN)

print(f"Classes: {list(LABEL_MAP.keys())}")
print("Train:", Xtr.shape, Mtr.shape, ytr.shape)
print("Val:  ", Xva.shape, Mva.shape, yva.shape)

Classes: ['find', 'fuck', 'hello', 'home', 'me', 'need', 'no', 'smart', 'thanks', 'yes']
Train: (70, 60, 63) (70, 60) (70,)
Val:   (20, 60, 63) (20, 60) (20,)


In [None]:
# =========================================================
# 5) Torch datasets
# =========================================================
class SeqLabelDataset(Dataset):
    def __init__(self, X, M, y, fnames):
        self.X = torch.from_numpy(X).float()
        self.M = torch.from_numpy(M).float()
        self.y = torch.from_numpy(y).long()
        self.fnames = fnames
    def __len__(self): return self.X.shape[0]
    def __getitem__(self, i): return self.X[i], self.M[i], self.y[i], self.fnames[i]

train_ds = SeqLabelDataset(Xtr, Mtr, ytr, fntr)
val_ds   = SeqLabelDataset(Xva, Mva, yva, fnva)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

In [None]:
# =========================================================
# 6) Model: Label-conditioned autoregressive LSTM
# =========================================================
class TimeEmbedding(nn.Module):
    def __init__(self, d_model, max_len=1024):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)/d_model))
        pe[:,0::2] = torch.sin(pos*div)
        pe[:,1::2] = torch.cos(pos*div)
        self.register_buffer("pe", pe)  # (max_len, d_model)
    def forward(self, T):
        return self.pe[:T]  # (T,d_model)

class CondLSTMGenerator(nn.Module):
    """
    Training (teacher-forcing): input = [prev_frame, label_emb, time_emb] -> predict current frame
    Inference (generate): feed zeros (or seed) as prev, roll out T steps for a label
    """
    def __init__(self, F=63, num_classes=10, emb_dim=64, time_dim=32, hidden=256, layers=2):
        super().__init__()
        self.F = F
        self.label_emb = nn.Embedding(num_classes, emb_dim)
        self.time_emb  = TimeEmbedding(time_dim, max_len=2048)
        self.in_lin    = nn.Linear(F + emb_dim + time_dim, hidden)
        self.lstm      = nn.LSTM(hidden, hidden, num_layers=layers, batch_first=True)
        self.out_lin   = nn.Linear(hidden, F)

    def forward_teacher_forcing(self, x, y_labels):
        """
        x: (B,T,F) ground truth standardized sequence
        y_labels: (B,) label indices
        returns preds: (B,T,F)
        """
        B,T,F = x.shape
        lab = self.label_emb(y_labels)           # (B,emb)
        lab = lab[:,None,:].expand(B,T,-1)       # (B,T,emb)
        te  = self.time_emb(T)[None,:,:].expand(B,-1,-1)  # (B,T,time)

        # teacher forcing: prev = shifted gt, first prev = zeros
        x_prev = torch.zeros_like(x)
        x_prev[:,1:,:] = x[:,:-1,:]

        h_in = torch.cat([x_prev, lab, te], dim=-1)  # (B,T,F+emb+time)
        h = self.in_lin(h_in)
        h,_ = self.lstm(h)
        pred = self.out_lin(h)                        # (B,T,F)
        return pred

    @torch.no_grad()
    def generate(self, y_labels, T, device, start_frame=None):
        """
        y_labels: (B,)
        T: int length
        start_frame: optional (B,F) seed; else zeros
        returns: (B,T,F)
        """
        B = y_labels.shape[0]
        lab = self.label_emb(y_labels).to(device)       # (B,emb)
        te  = self.time_emb(T).to(device)               # (T,time)
        state = None
        if start_frame is None:
            prev = torch.zeros(B, 1, self.F, device=device)
        else:
            prev = start_frame[:,None,:].to(device)

        outs = []
        for t in range(T):
            lab_t = lab[:,None,:]                      # (B,1,emb)
            te_t  = te[t].expand(B,1,-1)               # (B,1,time)
            h_in  = torch.cat([prev, lab_t, te_t], dim=-1)  # (B,1,F+emb+time)
            h0    = self.in_lin(h_in)
            h1, state = self.lstm(h0, state)
            out   = self.out_lin(h1)                   # (B,1,F)
            outs.append(out)
            prev = out                                 # autoregressive
        return torch.cat(outs, dim=1)                  # (B,T,F)

def masked_mse(pred, target, mask):
    # pred/target: (B,T,F), mask: (B,T)
    diff = (pred - target)**2
    diff = diff.mean(dim=-1)      # (B,T)
    diff = diff * mask
    denom = mask.sum(dim=1) + 1e-6
    return (diff.sum(dim=1)/denom).mean()

# Instantiate model
num_classes = len(LABEL_MAP)
model = CondLSTMGenerator(F=F, num_classes=num_classes, emb_dim=EMB_DIM, time_dim=TIME_DIM,
                          hidden=HIDDEN, layers=NUM_LAYERS).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)


In [None]:
# =========================================================
# 7) Training loop (teacher forcing)
# =========================================================
def run_epoch(loader, train=True):
    model.train(train)
    total = 0.0; n = 0
    for x, m, y, _ in loader:
        x = x.to(DEVICE); m = m.to(DEVICE); y = y.to(DEVICE)
        pred = model.forward_teacher_forcing(x, y)
        loss = masked_mse(pred, x, m)
        if train:
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
        total += loss.item() * x.size(0); n += x.size(0)
    return total / max(1, n)

best_val = float("inf")
for epoch in range(1, EPOCHS+1):
    tr = run_epoch(train_loader, True)
    va = run_epoch(val_loader, False)
    print(f"Epoch {epoch:03d} | train {tr:.5f} | val {va:.5f}")
    if va < best_val:
        best_val = va
        torch.save({
            "model_state": model.state_dict(),
            "mean": TRAIN_MEAN,
            "std":  TRAIN_STD,
            "label_map": LABEL_MAP,
            "config": dict(F=F, SEQ_LEN=SEQ_LEN, HIDDEN=HIDDEN, NUM_LAYERS=NUM_LAYERS,
                           EMB_DIM=EMB_DIM, TIME_DIM=TIME_DIM)
        }, "cond_gen.pt")
        print("  ↳ saved cond_gen.pt")


Epoch 001 | train 0.96582 | val 0.77291
  ↳ saved cond_gen.pt
Epoch 002 | train 0.61597 | val 0.47398
  ↳ saved cond_gen.pt
Epoch 003 | train 0.46940 | val 0.37039
  ↳ saved cond_gen.pt
Epoch 004 | train 0.37960 | val 0.32178
  ↳ saved cond_gen.pt
Epoch 005 | train 0.32477 | val 0.27312
  ↳ saved cond_gen.pt
Epoch 006 | train 0.27776 | val 0.24335
  ↳ saved cond_gen.pt
Epoch 007 | train 0.24320 | val 0.22809
  ↳ saved cond_gen.pt
Epoch 008 | train 0.21918 | val 0.20538
  ↳ saved cond_gen.pt
Epoch 009 | train 0.19882 | val 0.18622
  ↳ saved cond_gen.pt
Epoch 010 | train 0.18312 | val 0.17671
  ↳ saved cond_gen.pt
Epoch 011 | train 0.16901 | val 0.16218
  ↳ saved cond_gen.pt
Epoch 012 | train 0.15653 | val 0.15251
  ↳ saved cond_gen.pt
Epoch 013 | train 0.15060 | val 0.15236
  ↳ saved cond_gen.pt
Epoch 014 | train 0.14913 | val 0.13865
  ↳ saved cond_gen.pt
Epoch 015 | train 0.13979 | val 0.12966
  ↳ saved cond_gen.pt
Epoch 016 | train 0.13653 | val 0.12522
  ↳ saved cond_gen.pt
Epoch 01

In [None]:
# =========================================================
# 8) Stick-figure animation helpers
# =========================================================
HAND_CONNECTIONS = [
    (0,1),(1,2),(2,3),(3,4),
    (0,5),(5,6),(6,7),(7,8),
    (0,9),(9,10),(10,11),(11,12),
    (0,13),(13,14),(14,15),(15,16),
    (0,17),(17,18),(18,19),(19,20)
]

def unstandardize(seq_norm, mean, std):
    return seq_norm * std + mean

def _bounds_from_seq(seq_unstd):
    coords = seq_unstd.reshape(-1, 21, 3)
    xy_all = coords[:,:,:2].reshape(-1,2)
    good = ~np.isnan(xy_all).any(axis=1) & ~np.isinf(xy_all).any(axis=1)
    xy_all = xy_all[good]
    if xy_all.size == 0:
        return -1,1,-1,1
    xmin, ymin = xy_all.min(axis=0)
    xmax, ymax = xy_all.max(axis=0)
    pad = 0.1 * max(1e-6, xmax-xmin, ymax-ymin)
    return xmin-pad, xmax+pad, ymin-pad, ymax+pad

def animate_sequence(seq_unstd, title="Hand animation", mask=None, fps=30, save_path=None):
    """
    seq_unstd: (T,63) UNstandardized
    mask: optional (T,) 1=valid 0=invalid; invalid frames are 'held'
    """
    T = seq_unstd.shape[0]
    coords = seq_unstd.reshape(T, 21, 3)[:, :, :2]
    if mask is not None and len(mask) == T:
        coords_filled = coords.copy()
        last = coords_filled[0]
        for t in range(T):
            if mask[t] > 0.5 and np.isfinite(coords_filled[t]).all():
                last = coords_filled[t]
            else:
                coords_filled[t] = last
        coords = coords_filled

    xmin, xmax, ymin, ymax = _bounds_from_seq(seq_unstd)
    fig, ax = plt.subplots()
    scat = ax.scatter([], [])
    lines = [ax.plot([], [])[0] for _ in HAND_CONNECTIONS]
    ax.set_xlim(xmin, xmax); ax.set_ylim(-ymax, -ymin); ax.set_aspect('equal', adjustable='box')

    def init():
        scat.set_offsets(np.empty((0,2)))
        for ln in lines: ln.set_data([], [])
        ax.set_title(title)
        return (scat, *lines)

    def update(t):
        xy = coords[t]
        scat.set_offsets(np.c_[xy[:,0], -xy[:,1]])
        for ln, (a,b) in zip(lines, HAND_CONNECTIONS):
            ln.set_data([xy[a,0], xy[b,0]], [-xy[a,1], -xy[b,1]])
        ax.set_title(f"{title}  t={t+1}/{T}")
        return (scat, *lines)

    anim = FuncAnimation(fig, update, init_func=init, frames=T, interval=1000/max(1,fps), blit=True)
    if save_path:
        try:
            anim.save(save_path, fps=fps, dpi=120)
            plt.close(fig)
            print(f"Saved animation to {save_path}")
        except Exception as e:
            print(f"Saving failed ({e}). Showing inline instead.")
            plt.close(fig); display(HTML(anim.to_jshtml()))
    else:
        plt.close(fig); display(HTML(anim.to_jshtml()))

In [None]:
# =========================================================
# 9) Quick sanity viz: ground truth vs teacher-forced reconstruction
#    (optional; useful to see if the model learned something)
# =========================================================
model.eval()
with torch.no_grad():
    if len(val_ds) > 0:
        x, m, y, fname = val_ds[0]
        x = x.unsqueeze(0).to(DEVICE)
        pred_tf = model.forward_teacher_forcing(x, y.unsqueeze(0).to(DEVICE)).cpu().numpy()[0]
        x_np = x.cpu().numpy()[0]

        gt_seq   = unstandardize(x_np, TRAIN_MEAN, TRAIN_STD)
        pred_seq = unstandardize(pred_tf, TRAIN_MEAN, TRAIN_STD)

        sign_name = fname.split('_')[0]
        animate_sequence(gt_seq,   title=f"GT – {sign_name}",   mask=m.numpy(), fps=30)
        animate_sequence(pred_seq, title=f"TF Recon – {sign_name}", mask=m.numpy(), fps=30)

Output hidden; open in https://colab.research.google.com to view.

In [None]:
# =========================================================
# 10) Pure generation from a LABEL (no keypoint input)
# =========================================================
inv_label = {v:k for k,v in LABEL_MAP.items()}

def generate_sign(sign_name, length=SEQ_LEN, seed_frame=None, fps=30, save_path=None):
    """Generate a sequence given the sign label, then animate."""
    if sign_name not in LABEL_MAP:
        raise ValueError(f"Unknown sign '{sign_name}'. Known: {list(LABEL_MAP.keys())}")
    y_lbl = torch.tensor([LABEL_MAP[sign_name]], device=DEVICE)
    model.eval()
    with torch.no_grad():
        if seed_frame is not None:
            seed = torch.from_numpy(seed_frame).float().to(DEVICE)
        else:
            seed = None
        gen = model.generate(y_labels=y_lbl, T=length, device=DEVICE, start_frame=seed).cpu().numpy()[0]
    gen_unstd = unstandardize(gen, TRAIN_MEAN, TRAIN_STD)
    animate_sequence(gen_unstd, title=f"Generated – {sign_name}", fps=fps, save_path=save_path)


# generate_sign(sign_name=list(LABEL_MAP.keys())[0], length=SEQ_LEN, fps=30)

In [None]:
# =========================================================
# 11) Save & Load helpers
# =========================================================
def save_checkpoint(path="cond_gen.pt"):
    torch.save({
        "model_state": model.state_dict(),
        "mean": TRAIN_MEAN,
        "std":  TRAIN_STD,
        "label_map": LABEL_MAP,
        "config": dict(F=F, SEQ_LEN=SEQ_LEN, HIDDEN=HIDDEN, NUM_LAYERS=NUM_LAYERS,
                       EMB_DIM=EMB_DIM, TIME_DIM=TIME_DIM)
    }, path)
    print(f"Saved to {path}")

def load_checkpoint(path="cond_gen.pt"):
    ckpt = torch.load(path, map_location=DEVICE)
    cfg = ckpt.get("config", {})
    mdl = CondLSTMGenerator(F=cfg.get("F", F),
                            num_classes=len(ckpt["label_map"]),
                            emb_dim=cfg.get("EMB_DIM", EMB_DIM),
                            time_dim=cfg.get("TIME_DIM", TIME_DIM),
                            hidden=cfg.get("HIDDEN", HIDDEN),
                            layers=cfg.get("NUM_LAYERS", NUM_LAYERS)).to(DEVICE)
    mdl.load_state_dict(ckpt["model_state"])
    mdl.eval()
    print("Loaded model.")
    return mdl, ckpt["mean"], ckpt["std"], ckpt["label_map"]

# save_checkpoint("cond_gen.pt")
# model, TRAIN_MEAN, TRAIN_STD, LABEL_MAP = load_checkpoint("cond_gen.pt")
# generate_sign("HELLO", length=SEQ_LEN)