# Hand Gesture Transformer – Training notebook

이 노트북은 Mediapipe로 추출한 21개 손 랜드마크(3‑D)를 Transformer Encoder에 넣어 4‑클래스 손동작을 분류하는 모델을 학습합니다.

* 데이터 구성: `dataset/<gesture_label>/*.jpg`
* Train : Val : Test = 70 : 20 : 10
* 결과: 에포크별 loss 그래프, Attention heat‑map 시각화


In [None]:
# (Colab 환경이라면) 필요 패키지 설치
# !pip install mediapipe nbformat torch torchvision matplotlib seaborn tqdm

In [None]:
import os, random, glob, json, math, itertools, time, copy
import numpy as np
import cv2
import mediapipe as mp
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from models.HandGestureTransformer import HandGestureTransformer

In [None]:
# ------- 데이터셋 정의 -------
GESTURES = ['gesture0', 'gesture1', 'gesture2', 'gesture3']  # 수정 가능

class HandDataset(Dataset):
    def __init__(self, root_dir, split='train', val_ratio=0.2, test_ratio=0.1, transform=None):
        self.samples = []
        self.transform = transform
        self.mp_hands = mp.solutions.hands.Hands(static_image_mode=True,
                                                 max_num_hands=1,
                                                 min_detection_confidence=0.5)
        random.seed(42)
        # Gather all (img_path, label)
        all_items = []
        for idx, g in enumerate(GESTURES):
            for p in glob.glob(os.path.join(root_dir, g, '*')):
                all_items.append((p, idx))
        random.shuffle(all_items)
        n = len(all_items)
        n_val = int(n * val_ratio)
        n_test = int(n * test_ratio)
        if split == 'train':
            self.items = all_items[: n - n_val - n_test]
        elif split == 'val':
            self.items = all_items[n - n_val - n_test : n - n_test]
        else:
            self.items = all_items[-n_test:]

    def _extract_landmarks(self, img_bgr):
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        res = self.mp_hands.process(img_rgb)
        if res.multi_hand_landmarks:
            lm = res.multi_hand_landmarks[0]
            h, w, _ = img_bgr.shape
            xyz = [(pt.x, pt.y, pt.z) for pt in lm.landmark]
            return np.array(xyz, dtype=np.float32)
        else:
            return np.zeros((21,3), dtype=np.float32)

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        path, label = self.items[idx]
        img = cv2.imread(path)
        xyz = self._extract_landmarks(img)
        if self.transform: xyz = self.transform(xyz)
        return torch.tensor(xyz), torch.tensor(label)


In [None]:
# -------- 학습 루프 --------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'{device}\n')
root_dir = './dataset'  # 데이터 위치 수정
batch = 256
ds_train = HandDataset(root_dir, 'train')
ds_val   = HandDataset(root_dir, 'val')
train_loader = DataLoader(ds_train, batch_size=batch, shuffle=True, num_workers=4)
val_loader   = DataLoader(ds_val, batch_size=batch, shuffle=False, num_workers=4)

model = HandGestureTransformer(return_attn=True).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
crit = nn.CrossEntropyLoss()
EPOCHS = 40
train_losses, val_losses = [], []

best_val = float('inf')
best_path = 'ckpt_best.pt'

for epoch in range(1, EPOCHS+1):
    model.train()
    tloss = 0
    for xyz, label in train_loader:
        xyz, label = xyz.to(device), label.to(device)
        opt.zero_grad()
        out = model(xyz)
        loss = crit(out, label)
        loss.backward()
        opt.step()
        tloss += loss.item()*xyz.size(0)
    train_losses.append(tloss/len(ds_train))

    model.eval(); vloss=0
    with torch.no_grad():
        for xyz,label in val_loader:
            xyz, label = xyz.to(device), label.to(device)
            vloss += crit(model(xyz), label).item()*xyz.size(0)
    val_loss = vloss / len(ds_val)
    val_losses.append(val_loss)
    print(f"Epoch {epoch}: train {train_losses[-1]:.4f}  val {val_losses[-1]:.4f}")
    torch.save({'epoch':epoch,'model':model.state_dict()}, f'ckpt_{epoch}.pt')

    if val_loss < best_val:
        best_val = val_loss
        torch.save({'epoch': epoch,
                    'model': model.state_dict()},
                    best_path)
        print(f'★ New best -> epoch: {epoch}, val loss: {val_loss:.4f}')

# -------- Loss Plot --------
plt.figure()
plt.plot(train_losses, label='train'); plt.plot(val_losses, label='val')
plt.xlabel('epoch'); plt.ylabel('loss'); plt.legend(); plt.show()

In [None]:
# --------- Attention Heat‑map ---------
sample_xyz, _ = ds_val[0]
sample_logits = model(sample_xyz.unsqueeze(0).to(device))
attn_weights = torch.stack(model.attn_maps)  # [L,B,nH,Len,Len]
layer0_head0 = attn_weights[0,0,0].cpu().numpy()
sns.heatmap(layer0_head0, cmap='viridis')
plt.title('Layer0-Head0 Attention'); plt.show()