In [3]:
%pip install -q faiss-cpu mediapipe==0.10.14

Note: you may need to restart the kernel to use updated packages.


In [4]:
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)

In [5]:
import os
import numpy as np
import random
import json
from torch.utils.data import Dataset, Sampler
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
import time
import faiss
import csv
import mediapipe as mp

2026-01-25 14:37:46.795216: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769351866.990113      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769351867.048034      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769351867.519387      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769351867.519431      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769351867.519434      55 computation_placer.cc:177] computation placer alr

In [6]:
class VSLDataset(Dataset):
    def __init__(self, root_dir, label_map_path):
        self.samples = []
        self.labels = []
        self.gloss_to_indices = {}

        with open(label_map_path, "r", encoding="utf-8") as f:
            self.label_map = json.load(f)

        for gloss_name in sorted(os.listdir(root_dir)):
            gloss_path = os.path.join(root_dir, gloss_name)
            if not os.path.isdir(gloss_path):
                continue

            if gloss_name not in self.label_map:
                raise ValueError(f"Gloss '{gloss_name}' not found in label_map.json")

            gloss_id = int(self.label_map[gloss_name])

            for fname in os.listdir(gloss_path):
                if not fname.endswith(".npz"):
                    continue

                fpath = os.path.join(gloss_path, fname)
                idx = len(self.samples)

                self.samples.append(fpath)
                self.labels.append(gloss_id)

                if gloss_id not in self.gloss_to_indices:
                    self.gloss_to_indices[gloss_id] = []
                self.gloss_to_indices[gloss_id].append(idx)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        npz = np.load(self.samples[idx])
        x = npz["sequence"]
        y = self.labels[idx]
        return x, y

In [7]:
class PKSampler(Sampler):
    def __init__(self, gloss_to_indices, P=32, K=8, steps_per_epoch=1000):
        self.gloss_to_indices = gloss_to_indices
        self.P = P
        self.K = K
        self.steps_per_epoch = steps_per_epoch

        self.ptr = {}
        self.buffers = {}

        # Chỉ giữ gloss đủ K sample
        self.gloss_ids = [
            g for g, idxs in gloss_to_indices.items()
            if len(idxs) >= K
        ]

        if len(self.gloss_ids) < P:
            raise ValueError(
                f"Not enough glosses with >=K samples: "
                f"{len(self.gloss_ids)} < P={P}"
            )

        for g in self.gloss_ids:
            idxs = gloss_to_indices[g].copy()
            random.shuffle(idxs)
            self.buffers[g] = idxs
            self.ptr[g] = 0

    def __len__(self):
        # pseudo-length (steps per epoch)
        return self.steps_per_epoch

    def __iter__(self):
        for _ in range(self.steps_per_epoch):
            batch = []

            gloss_batch = random.sample(self.gloss_ids, self.P)

            for g in gloss_batch:
                idxs = self.buffers[g]
                p = self.ptr[g]

                if p + self.K > len(idxs):
                    random.shuffle(idxs)
                    p = 0

                batch.extend(idxs[p:p + self.K])
                self.ptr[g] = p + self.K

            yield batch

In [8]:
from torch.utils.data import DataLoader

label_map_path = "/kaggle/input/vsl-vietnamese-sign-languages/Processed/label_map.json"
train_root = "/kaggle/input/vsl-vietnamese-sign-languages/Processed/train"
val_root = "/kaggle/input/vsl-vietnamese-sign-languages/Processed/val"
test_root = "/kaggle/input/vsl-vietnamese-sign-languages/Processed/test"

train_ds = VSLDataset(train_root, label_map_path)
val_ds = VSLDataset(val_root, label_map_path)
test_ds = VSLDataset(test_root, label_map_path)

sampler = PKSampler(
    gloss_to_indices=train_ds.gloss_to_indices,
    P=32, # Số gloss mỗi step
    K=4, # Số sequence mỗi step
    steps_per_epoch=1000 # Số step mỗi epoch
)

train_loader = DataLoader(
    train_ds,
    batch_sampler=sampler,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_ds,
    batch_size=256,
    num_workers=4,
    pin_memory=True 
)

test_loader = DataLoader(
    val_ds,
    batch_size=256,
    num_workers=4,
    pin_memory=True 
)

In [9]:
mp_holistic = mp.solutions.holistic

N_UPPER_BODY_POSE_LANDMARKS = 25
N_HAND_LANDMARKS = 21
N_TOTAL_LANDMARKS = 67


def build_adjacency(self_loop=True):
    A = torch.zeros(N_TOTAL_LANDMARKS, N_TOTAL_LANDMARKS)
    for i, j in mp_holistic.POSE_CONNECTIONS:
        if i < N_UPPER_BODY_POSE_LANDMARKS and j < N_UPPER_BODY_POSE_LANDMARKS:
            A[i, j] = 1
            A[j, i] = 1

    LEFT_HAND_OFFSET = N_UPPER_BODY_POSE_LANDMARKS
    for i, j in mp_holistic.HAND_CONNECTIONS:
        A[LEFT_HAND_OFFSET + i, LEFT_HAND_OFFSET + j] = 1
        A[LEFT_HAND_OFFSET + j, LEFT_HAND_OFFSET + i] = 1

    RIGHT_HAND_OFFSET = N_UPPER_BODY_POSE_LANDMARKS + N_HAND_LANDMARKS
    for i, j in mp_holistic.HAND_CONNECTIONS:
        A[RIGHT_HAND_OFFSET + i, RIGHT_HAND_OFFSET + j] = 1
        A[RIGHT_HAND_OFFSET + j, RIGHT_HAND_OFFSET + i] = 1

    POSE_LEFT_WRIST = 15
    POSE_RIGHT_WRIST = 16

    LEFT_HAND_WRIST = LEFT_HAND_OFFSET + 0
    RIGHT_HAND_WRIST = RIGHT_HAND_OFFSET + 0

    A[POSE_LEFT_WRIST, LEFT_HAND_WRIST] = 1
    A[LEFT_HAND_WRIST, POSE_LEFT_WRIST] = 1

    A[POSE_RIGHT_WRIST, RIGHT_HAND_WRIST] = 1
    A[RIGHT_HAND_WRIST, POSE_RIGHT_WRIST] = 1

    if self_loop:
        A += torch.eye(N_TOTAL_LANDMARKS)

    A = A / A.sum(dim=1, keepdim=True)

    return A

class MS_GCN(nn.Module):
    def __init__(self, in_channels, out_channels, A, scales=(1, 2, 3)):
        super().__init__()
        self.As = build_multiscale_adjacency(A, scales)
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
            for _ in scales
        ])

    def forward(self, x):
        """
        x: (B, C, T, V)
        """
        out = 0
        for A, conv in zip(self.As, self.convs):
            A = A.to(x.device)
            z = torch.einsum("vw,bctw->bctv", A, x)
            out = out + conv(z)
        return out

def build_multiscale_adjacency(A, scales=(1, 2, 3)):
    """
    A: (V, V)
    return: list of A^k
    """
    As = []
    A_k = A.clone()
    for k in scales:
        if k == 1:
            As.append(A)
        else:
            A_k = torch.matmul(A_k, A)
            As.append(A_k)
    return As


class MS_G3D_Block(nn.Module):
    def __init__(self, in_channels, out_channels, A, stride=1):
        super().__init__()

        self.msgcn = MS_GCN(in_channels, out_channels, A)

        self.temporal_conv = nn.Sequential(
            nn.Conv3d(
                out_channels,
                out_channels,
                kernel_size=(3, 1, 1),
                stride=(stride, 1, 1),
                padding=(1, 0, 0)
            ),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

        if in_channels != out_channels or stride != 1:
            self.residual = nn.Conv3d(
                in_channels,
                out_channels,
                kernel_size=1,
                stride=(stride, 1, 1)
            )
        else:
            self.residual = None

    def forward(self, x):
        """
        x: (B, C, T, V)
        return: (B, C, T, V)
        """

        # Residual path
        if self.residual is not None:
            res = self.residual(x.unsqueeze(-1))
        else:
            res = x.unsqueeze(-1)

        # MS-GCN (4D)
        x = self.msgcn(x)                  # (B, C, T, V)

        # Temporal Conv (5D)
        x = self.temporal_conv(x.unsqueeze(-1))  # (B, C, T', V, 1)

        # Add & REMOVE person dim
        x = x + res
        x = x.squeeze(-1)                  # ✅ BACK TO (B, C, T, V)

        return x



class VSL_MS_G3D(nn.Module):
    def __init__(
        self,
        in_dim=3,
        hidden_dim=256,
        emb_dim=256
    ):
        super().__init__()

        A = build_adjacency()

        self.data_bn = nn.BatchNorm1d(N_TOTAL_LANDMARKS * in_dim)

        self.block1 = MS_G3D_Block(in_dim, 128, A)
        self.block2 = MS_G3D_Block(128, 256, A, stride=2)
        self.block3 = MS_G3D_Block(256, hidden_dim, A)

        self.fc = nn.Linear(hidden_dim, emb_dim)

    def forward(self, x):
        """
        x: (B, T, V*C)
        """
        B, T, _ = x.shape
        x = x.view(B, T, N_TOTAL_LANDMARKS * 3)
        x = self.data_bn(x.view(B * T, -1)).view(B, T, N_TOTAL_LANDMARKS, 3)

        # (B, C, T, V)
        x = x.permute(0, 3, 1, 2)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)

        # Global Pooling
        x = x.mean(dim=[2, 3])  # (B, C)

        emb = self.fc(x)
        emb = F.normalize(emb, dim=1)

        return emb


In [10]:
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        features = F.normalize(features, dim=1)
        labels = labels.view(-1, 1)

        mask = torch.eq(labels, labels.T).float().to(features.device)

        logits = torch.matmul(features, features.T) / self.temperature
        logits = logits - logits.max(dim=1, keepdim=True)[0].detach()

        logits_mask = torch.ones_like(mask)
        logits_mask.fill_diagonal_(0)
        mask = mask * logits_mask

        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-8)

        mean_log_prob_pos = (mask * log_prob).sum(1) / (mask.sum(1) + 1e-8)
        loss = -mean_log_prob_pos.mean()
        return loss

In [11]:
def eval_recall_at_k_faiss(
    model,
    val_loader,
    device,
    Ks=(1, 5)
):
    model.eval()

    # ===== 1. Extract embeddings =====
    all_embs = []
    all_labels = []

    with torch.no_grad():
        for x, y in tqdm(val_loader, desc="Extract val emb", leave=False):
            x = x.to(device).float()
            emb = model(x)                 # (B, D)
            emb = F.normalize(emb, dim=1)  # cosine

            all_embs.append(emb.cpu())
            all_labels.append(y.cpu())

    all_embs = torch.cat(all_embs, dim=0)      # (N, D)
    all_labels = torch.cat(all_labels, dim=0)  # (N,)

    # ===== 2. FAISS index =====
    emb_np = all_embs.numpy().astype("float32")
    labels_np = all_labels.numpy()

    dim = emb_np.shape[1]

    index = faiss.IndexFlatIP(dim)  # Inner Product
    index.add(emb_np)               # N vectors

    # ===== 3. Search =====
    max_k = max(Ks) + 1  # +1 để bỏ self-match
    D, I = index.search(emb_np, max_k)

    recalls = {}
    for K in Ks:
        correct = 0
        for i in range(len(I)):
            # bỏ chính nó
            neighbors = I[i][I[i] != i][:K]
            if labels_np[i] in labels_np[neighbors]:
                correct += 1

        recalls[K] = correct / len(I)

    return recalls

In [18]:
model = VSL_MS_G3D(
    in_dim=3,
    hidden_dim=256,
    emb_dim=256
)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

criterion = SupConLoss(temperature=0.07)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=1e-3, weight_decay=1e-4
)

use_amp = (device == "cuda")
scaler = GradScaler(enabled=use_amp)

epochs = 5
eval_interval = 1

save_dir = "./checkpoints"
os.makedirs(save_dir, exist_ok=True)

best_recall1 = 0.0

  scaler = GradScaler(enabled=use_amp)


In [13]:
lr_patience = 2
early_stop_patience = 6
best_recall1 = 0.0
no_improve_count = 0

log_path = os.path.join(save_dir, "train_log.csv")

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="max",
    factor=0.3,
    patience=lr_patience,
    min_lr=1e-6,
)


if not os.path.exists(log_path):
    with open(log_path, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            "epoch",
            "train_loss",
            "recall@1",
            "lr",
            "epoch_time_sec"
        ])


for epoch in range(epochs):
    model.train()
    epoch_loss = 0.0
    start_time = time.time()

    pbar = tqdm(
        train_loader,
        desc=f"Epoch {epoch+1}/{epochs}",
        ncols=120
    )

    for step, (x, y) in enumerate(pbar):
        x = x.to(device).float()
        y = y.to(device).long()

        optimizer.zero_grad(set_to_none=True)

        with autocast(enabled=use_amp):
            emb = model(x)
            loss = criterion(emb, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()

        pbar.set_postfix({
            "loss": f"{loss.item():.4f}",
            "lr": optimizer.param_groups[0]["lr"]
        })

    avg_loss = epoch_loss / (step + 1)
    epoch_time = time.time() - start_time
    current_lr = optimizer.param_groups[0]["lr"]

    print(
        f"\nEpoch {epoch+1}: "
        f"train_loss={avg_loss:.4f}, "
        f"lr={current_lr:.2e}, "
        f"time={epoch_time:.1f}s"
    )

    # -------- SAVE LAST --------
    torch.save(
        {
            "epoch": epoch + 1,
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        f"{save_dir}/last_encoder.pt"
    )

    recall1 = None

    if (epoch + 1) % eval_interval == 0:
        model.eval()
        with torch.no_grad():
            recalls = eval_recall_at_k_faiss(
                model,
                val_loader,
                device,
                Ks=(1, 5)
            )

        recall1 = recalls[1]

        print(
            f"\n Valid Epoch {epoch+1} | "
            f"R@1: {recall1*100:.2f}% | "
        )

        scheduler.step(recall1)

        if recall1 > best_recall1:
            best_recall1 = recall1
            no_improve_count = 0

            torch.save(
                model.state_dict(),
                f"{save_dir}/best_encoder.pt"
            )
            print("Saved BEST encoder (Recall@1)")

        else:
            no_improve_count += 1
            print(
                f"⏸ No Recall@1 improvement "
                f"({no_improve_count}/{early_stop_patience})"
            )

    with open(log_path, "a", newline="") as f:
        writer = csv.writer(f)
        writer.writerow([
            epoch + 1,
            round(avg_loss, 6),
            None if recall1 is None else round(recall1, 6),
            f"{current_lr:.2e}",
            round(epoch_time, 2)
        ])

    if no_improve_count >= early_stop_patience:
        print(
            f"\nEarly stopping at epoch {epoch+1} "
            f"(no Recall@1 improvement for "
            f"{early_stop_patience} evals)"
        )
        break

print("Encoder training finished")


  with autocast(enabled=use_amp):
Epoch 1/5: 100%|█████████████████████████████████████████████| 1000/1000 [10:24<00:00,  1.60it/s, loss=1.4466, lr=0.001]



Epoch 1: train_loss=1.7749, lr=1.00e-03, time=624.2s


                                                                


 Valid Epoch 1 | R@1: 99.17% | 
Saved BEST encoder (Recall@1)


Epoch 2/5: 100%|█████████████████████████████████████████████| 1000/1000 [10:22<00:00,  1.61it/s, loss=1.5192, lr=0.001]



Epoch 2: train_loss=1.5390, lr=1.00e-03, time=622.3s


                                                                


 Valid Epoch 2 | R@1: 99.75% | 
Saved BEST encoder (Recall@1)


Epoch 3/5: 100%|█████████████████████████████████████████████| 1000/1000 [10:22<00:00,  1.61it/s, loss=1.3844, lr=0.001]



Epoch 3: train_loss=1.4775, lr=1.00e-03, time=622.5s


                                                                


 Valid Epoch 3 | R@1: 99.65% | 
⏸ No Recall@1 improvement (1/6)


Epoch 4/5: 100%|█████████████████████████████████████████████| 1000/1000 [10:22<00:00,  1.61it/s, loss=1.6228, lr=0.001]



Epoch 4: train_loss=1.4540, lr=1.00e-03, time=622.2s


                                                                


 Valid Epoch 4 | R@1: 99.74% | 
⏸ No Recall@1 improvement (2/6)


Epoch 5/5: 100%|█████████████████████████████████████████████| 1000/1000 [10:22<00:00,  1.61it/s, loss=1.4208, lr=0.001]



Epoch 5: train_loss=1.4268, lr=1.00e-03, time=622.0s


                                                                


 Valid Epoch 5 | R@1: 99.76% | 
Saved BEST encoder (Recall@1)
Encoder training finished


In [20]:
model_dir = "./checkpoints"

model = VSL_MS_G3D(
    in_dim=3,
    hidden_dim=256,
    emb_dim=256
)

model.to("cuda")

ckpt_path = f"{model_dir}/best_encoder.pt"
model.load_state_dict(torch.load(ckpt_path, map_location=device))

recalls = eval_recall_at_k_faiss(
            model,
            test_loader,
            device,
            Ks=(1, 5)
        )

recall1 = recalls[1]
recall5 = recalls[5]

print(
    f"\n Valid | "
    f"R@1: {recall1*100:.2f}% | "
    f"R@5: {recall5*100:.2f}%"
)

                                                                


 Valid | R@1: 99.76% | R@5: 99.82%


In [22]:
!zip best.zip ./checkpoints/best_encoder.pt

updating: checkpoints/best_encoder.pt (deflated 8%)
