In [1]:
%pip install -q faiss-cpu mediapipe==0.10.14

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.7/35.7 MB[0m [31m53.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.8/23.8 MB[0m [31m91.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m294.9/294.9 kB[0m [31m19.2 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
google-adk 1.22.1 requires google-cloud-bigquery-storage>=2.0.0, which is not installed.
bigframes 2.26.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
a2a-sdk 0.3.22 requires protobuf>=5.29.5, but you have protobuf 4.25.8 which is incompatible.
opentelemetry-proto 1.37.0 requires protobuf<7.0,>=5.0, but you have protobuf 4.25.8 which is incompatible.
ydf 0.13.0 requires protobuf<7

In [2]:
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

In [3]:
import os
import numpy as np
import random
import json
from torch.utils.data import Dataset, Sampler
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import time
import faiss
import csv
import mediapipe as mp

2026-01-25 13:56:15.155732: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769349375.360212      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769349375.422121      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769349375.918334      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769349375.918369      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769349375.918372      55 computation_placer.cc:177] computation placer alr

In [5]:
class VSLDataset(Dataset):
    def __init__(self, root_dir, label_map_path):
        self.samples = []
        self.labels = []
        self.gloss_to_indices = {}

        with open(label_map_path, "r", encoding="utf-8") as f:
            self.label_map = json.load(f)

        for gloss_name in sorted(os.listdir(root_dir)):
            gloss_path = os.path.join(root_dir, gloss_name)
            if not os.path.isdir(gloss_path):
                continue

            if gloss_name not in self.label_map:
                raise ValueError(f"Gloss '{gloss_name}' not found in label_map.json")

            gloss_id = int(self.label_map[gloss_name])

            for fname in os.listdir(gloss_path):
                if not fname.endswith(".npz"):
                    continue

                fpath = os.path.join(gloss_path, fname)
                idx = len(self.samples)

                self.samples.append(fpath)
                self.labels.append(gloss_id)

                if gloss_id not in self.gloss_to_indices:
                    self.gloss_to_indices[gloss_id] = []
                self.gloss_to_indices[gloss_id].append(idx)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        npz = np.load(self.samples[idx])
        x = npz["sequence"]
        y = self.labels[idx]
        return x, y

In [6]:
class PKSampler(Sampler):
    def __init__(self, gloss_to_indices, P=32, K=8, steps_per_epoch=1000):
        self.gloss_to_indices = gloss_to_indices
        self.P = P
        self.K = K
        self.steps_per_epoch = steps_per_epoch

        self.ptr = {}
        self.buffers = {}

        # Chỉ giữ gloss đủ K sample
        self.gloss_ids = [
            g for g, idxs in gloss_to_indices.items()
            if len(idxs) >= K
        ]

        if len(self.gloss_ids) < P:
            raise ValueError(
                f"Not enough glosses with >=K samples: "
                f"{len(self.gloss_ids)} < P={P}"
            )

        for g in self.gloss_ids:
            idxs = gloss_to_indices[g].copy()
            random.shuffle(idxs)
            self.buffers[g] = idxs
            self.ptr[g] = 0

    def __len__(self):
        # pseudo-length (steps per epoch)
        return self.steps_per_epoch

    def __iter__(self):
        for _ in range(self.steps_per_epoch):
            batch = []

            gloss_batch = random.sample(self.gloss_ids, self.P)

            for g in gloss_batch:
                idxs = self.buffers[g]
                p = self.ptr[g]

                if p + self.K > len(idxs):
                    random.shuffle(idxs)
                    p = 0

                batch.extend(idxs[p:p + self.K])
                self.ptr[g] = p + self.K

            yield batch

In [8]:
from torch.utils.data import DataLoader

label_map_path = "/kaggle/input/vsl-vietnamese-sign-languages/Processed/label_map.json"
train_root = "/kaggle/input/vsl-vietnamese-sign-languages/Processed/train"
val_root = "/kaggle/input/vsl-vietnamese-sign-languages/Processed/val"
test_root = "/kaggle/input/vsl-vietnamese-sign-languages/Processed/test"

train_ds = VSLDataset(train_root, label_map_path)
val_ds = VSLDataset(val_root, label_map_path)
test_ds = VSLDataset(test_root, label_map_path)

sampler = PKSampler(
    gloss_to_indices=train_ds.gloss_to_indices,
    P=32, # Số gloss mỗi step
    K=4, # Số sequence mỗi step
    steps_per_epoch=1000 # Số step mỗi epoch
)

train_loader = DataLoader(
    train_ds,
    batch_sampler=sampler,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=256,
    num_workers=4,
    pin_memory=True 
)

test_loader = DataLoader(
    val_ds,
    batch_size=256,
    num_workers=4,
    pin_memory=True 
)

In [10]:
mp_holistic = mp.solutions.holistic

N_UPPER_BODY_POSE_LANDMARKS = 25
N_HAND_LANDMARKS = 21
N_TOTAL_LANDMARKS = 67


def build_adjacency(self_loop=True):
    A = torch.zeros(N_TOTAL_LANDMARKS, N_TOTAL_LANDMARKS)
    for i, j in mp_holistic.POSE_CONNECTIONS:
        if i < N_UPPER_BODY_POSE_LANDMARKS and j < N_UPPER_BODY_POSE_LANDMARKS:
            A[i, j] = 1
            A[j, i] = 1

    LEFT_HAND_OFFSET = N_UPPER_BODY_POSE_LANDMARKS
    for i, j in mp_holistic.HAND_CONNECTIONS:
        A[LEFT_HAND_OFFSET + i, LEFT_HAND_OFFSET + j] = 1
        A[LEFT_HAND_OFFSET + j, LEFT_HAND_OFFSET + i] = 1

    RIGHT_HAND_OFFSET = N_UPPER_BODY_POSE_LANDMARKS + N_HAND_LANDMARKS
    for i, j in mp_holistic.HAND_CONNECTIONS:
        A[RIGHT_HAND_OFFSET + i, RIGHT_HAND_OFFSET + j] = 1
        A[RIGHT_HAND_OFFSET + j, RIGHT_HAND_OFFSET + i] = 1

    POSE_LEFT_WRIST = 15
    POSE_RIGHT_WRIST = 16

    LEFT_HAND_WRIST = LEFT_HAND_OFFSET + 0
    RIGHT_HAND_WRIST = RIGHT_HAND_OFFSET + 0

    A[POSE_LEFT_WRIST, LEFT_HAND_WRIST] = 1
    A[LEFT_HAND_WRIST, POSE_LEFT_WRIST] = 1

    A[POSE_RIGHT_WRIST, RIGHT_HAND_WRIST] = 1
    A[RIGHT_HAND_WRIST, POSE_RIGHT_WRIST] = 1

    if self_loop:
        A += torch.eye(N_TOTAL_LANDMARKS)

    A = A / A.sum(dim=1, keepdim=True)

    return A


class STGCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, A, kernel_size=9, stride=1):
        super().__init__()
        self.register_buffer("A", A)

        # Spatial GCN
        self.gcn = nn.Conv2d(in_channels, out_channels, kernel_size=1)

        # Temporal convolution
        padding = ((kernel_size - 1) // 2, 0)
        self.tcn = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=(kernel_size, 1),
            stride=(stride, 1),
            padding=padding
        )

        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        # x: (B, C, T, V)

        # spatial graph conv
        x = torch.einsum("vw,bctw->bctv", self.A, x)
        x = self.gcn(x)

        # temporal conv
        x = self.tcn(x)
        x = self.bn(x)
        x = self.relu(x)

        return x

class VSL_STGCN(nn.Module):
    def __init__(self, in_dim=3, hidden_dim=256, emb_dim=256):
        super().__init__()

        A = build_adjacency()

        self.data_bn = nn.BatchNorm1d(N_TOTAL_LANDMARKS * in_dim)

        self.stgcn1 = STGCNBlock(in_dim, hidden_dim, A)
        self.stgcn2 = STGCNBlock(hidden_dim, hidden_dim, A)

        self.fc = nn.Linear(hidden_dim, emb_dim)

    def forward(self, x):
        """
        x: (B, T, V*C)
        """
        B, T, _ = x.shape
        x = x.view(B, T, N_TOTAL_LANDMARKS, 3)
        x = x.permute(0, 3, 1, 2)  # (B, C, T, V)

        x = self.stgcn1(x)
        x = self.stgcn2(x)

        # global pooling
        x = x.mean(dim=[2, 3])    # (B, hidden_dim)

        emb = self.fc(x)
        emb = F.normalize(emb, dim=1)
        return emb

In [11]:
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        features = F.normalize(features, dim=1)
        labels = labels.view(-1, 1)

        mask = torch.eq(labels, labels.T).float().to(features.device)

        logits = torch.matmul(features, features.T) / self.temperature
        logits = logits - logits.max(dim=1, keepdim=True)[0].detach()

        logits_mask = torch.ones_like(mask)
        logits_mask.fill_diagonal_(0)
        mask = mask * logits_mask

        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-8)

        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)
        loss = -mean_log_prob_pos.mean()
        return loss

In [14]:
def eval_recall_at_k_faiss(
    model,
    val_loader,
    device,
    Ks=(1, 5)
):
    model.eval()

    # ===== 1. Extract embeddings =====
    all_embs = []
    all_labels = []

    with torch.no_grad():
        for x, y in tqdm(val_loader, desc="Extract val emb", leave=False):
            x = x.to(device).float()
            emb = model(x)                 # (B, D)
            emb = F.normalize(emb, dim=1)  # cosine

            all_embs.append(emb.cpu())
            all_labels.append(y.cpu())

    all_embs = torch.cat(all_embs, dim=0)      # (N, D)
    all_labels = torch.cat(all_labels, dim=0)  # (N,)

    # ===== 2. FAISS index =====
    emb_np = all_embs.numpy().astype("float32")
    labels_np = all_labels.numpy()

    dim = emb_np.shape[1]

    index = faiss.IndexFlatIP(dim)  # Inner Product
    index.add(emb_np)               # N vectors

    # ===== 3. Search =====
    max_k = max(Ks) + 1  # +1 để bỏ self-match
    D, I = index.search(emb_np, max_k)

    recalls = {}
    for K in Ks:
        correct = 0
        for i in range(len(I)):
            # bỏ chính nó
            neighbors = I[i][I[i] != i][:K]
            if labels_np[i] in labels_np[neighbors]:
                correct += 1

        recalls[K] = correct / len(I)

    return recalls

In [16]:
model = VSL_STGCN(
    in_dim=3,
    hidden_dim=256,
    emb_dim=256
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

criterion = SupConLoss(temperature=0.07)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=1e-3, weight_decay=1e-4
)

use_amp = (device == "cuda")
scaler = GradScaler(enabled=use_amp)

epochs = 5
eval_interval = 1

save_dir = "./checkpoints"
os.makedirs(save_dir, exist_ok=True)

best_recall1 = 0.0

  scaler = GradScaler(enabled=use_amp)


In [17]:
lr_patience = 2
early_stop_patience = 6
best_recall1 = 0.0
no_improve_count = 0

log_path = os.path.join(save_dir, "train_log.csv")

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",
    factor=0.3,
    patience=lr_patience,
    min_lr=1e-6,
)


if not os.path.exists(log_path):
    with open(log_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch",
            "train_loss",
            "recall@1",
            "lr",
            "epoch_time_sec"
        ])


for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    start_time = time.time()

    pbar = tqdm(
        train_loader,
        desc=f"Epoch {epoch+1}/{epochs}",
        ncols=120
    )

    for step, (x, y) in enumerate(pbar):
        x = x.to(device).float()
        y = y.to(device).long()

        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=use_amp):
            emb = model(x)
            loss = criterion(emb, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "lr": optimizer.param_groups[0]["lr"]
        })

    avg_loss = epoch_loss / (step + 1)
    epoch_time = time.time() - start_time
    current_lr = optimizer.param_groups[0]["lr"]

    print(
        f"\nEpoch {epoch+1}: "
        f"train_loss={avg_loss:.4f}, "
        f"lr={current_lr:.2e}, "
        f"time={epoch_time:.1f}s"
    )

    # -------- SAVE LAST --------
    torch.save(
        {
            "epoch": epoch + 1,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        f"{save_dir}/last_encoder.pt"
    )

    recall1 = None

    if (epoch + 1) % eval_interval == 0:
        model.eval()
        with torch.no_grad():
            recalls = eval_recall_at_k_faiss(
                model,
                val_loader,
                device,
                Ks=(1, 5)
            )

        recall1 = recalls[1]

        print(
            f"\n Valid Epoch {epoch+1} | "
            f"R@1: {recall1*100:.2f}% | "
        )

        scheduler.step(recall1)

        if recall1 > best_recall1:
            best_recall1 = recall1
            no_improve_count = 0

            torch.save(
                model.state_dict(),
                f"{save_dir}/best_encoder.pt"
            )
            print("Saved BEST encoder (Recall@1)")

        else:
            no_improve_count += 1
            print(
                f"⏸ No Recall@1 improvement "
                f"({no_improve_count}/{early_stop_patience})"
            )

    with open(log_path, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            epoch + 1,
            round(avg_loss, 6),
            None if recall1 is None else round(recall1, 6),
            f"{current_lr:.2e}",
            round(epoch_time, 2)
        ])

    if no_improve_count >= early_stop_patience:
        print(
            f"\nEarly stopping at epoch {epoch+1} "
            f"(no Recall@1 improvement for "
            f"{early_stop_patience} evals)"
        )
        break

print("Encoder training finished")


  with autocast(enabled=use_amp):
Epoch 1/5: 100%|█████████████████████████████████████████████| 1000/1000 [11:40<00:00,  1.43it/s, loss=1.7058, lr=0.001]



Epoch 1: train_loss=1.8482, lr=1.00e-03, time=700.3s


                                                                


 Valid Epoch 1 | R@1: 96.88% | 
Saved BEST encoder (Recall@1)


Epoch 2/5: 100%|█████████████████████████████████████████████| 1000/1000 [11:38<00:00,  1.43it/s, loss=1.5281, lr=0.001]



Epoch 2: train_loss=1.6287, lr=1.00e-03, time=698.7s


                                                                


 Valid Epoch 2 | R@1: 98.23% | 
Saved BEST encoder (Recall@1)


Epoch 3/5: 100%|█████████████████████████████████████████████| 1000/1000 [11:38<00:00,  1.43it/s, loss=1.9454, lr=0.001]



Epoch 3: train_loss=1.5509, lr=1.00e-03, time=698.6s


                                                                


 Valid Epoch 3 | R@1: 99.16% | 
Saved BEST encoder (Recall@1)


Epoch 4/5: 100%|█████████████████████████████████████████████| 1000/1000 [11:38<00:00,  1.43it/s, loss=1.3279, lr=0.001]



Epoch 4: train_loss=1.5264, lr=1.00e-03, time=698.4s


                                                                


 Valid Epoch 4 | R@1: 99.33% | 
Saved BEST encoder (Recall@1)


Epoch 5/5: 100%|█████████████████████████████████████████████| 1000/1000 [11:38<00:00,  1.43it/s, loss=1.5164, lr=0.001]



Epoch 5: train_loss=1.4960, lr=1.00e-03, time=698.4s


                                                                


 Valid Epoch 5 | R@1: 99.30% | 
⏸ No Recall@1 improvement (1/6)
Encoder training finished


In [18]:
model_dir = "./checkpoints"

model = VSL_STGCN(
    in_dim=3,
    hidden_dim=256,
    emb_dim=256
)

model.to("cuda")

ckpt_path = f"{model_dir}/best_encoder.pt"
model.load_state_dict(torch.load(ckpt_path, map_location=device))

recalls = eval_recall_at_k_faiss(
            model,
            test_loader,
            device,
            Ks=(1, 5)
        )

recall1 = recalls[1]
recall5 = recalls[5]

print(
    f"\n Valid | "
    f"R@1: {recall1*100:.2f}% | "
    f"R@5: {recall5*100:.2f}%"
)

                                                                


 Valid | R@1: 99.33% | R@5: 99.76%


In [19]:
!zip best.zip ./checkpoints/best_encoder.pt

  adding: checkpoints/best_encoder.pt (deflated 8%)
