# Hand Gesture Transformer – Training notebook

이 노트북은 Mediapipe로 추출한 21개 손 랜드마크(3‑D)를 Transformer Encoder에 넣어 4‑클래스 손동작을 분류하는 모델을 학습합니다.

* Train : Val = 80 : 20
* 결과: 에포크별 loss & accuracy 그래프, action별 Attention heat‑map 시각화


In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
from tqdm import tqdm
import json
import matplotlib.pyplot as plt
from models.transformer_only import LandmarkTransformer
import seaborn as sns
from dataset_xyzva import GestureDatasetCSV
from config import GESTURE

In [None]:
class Settings:
    window_size = 30
    batch_size = 64
    epochs = 30
    learning_rate = 0.0001
    dataset_dir = './data/csv/train_data'
    model_save_dir = './checkpoint/'
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    num_classes = len(GESTURE)
    seed = 42

os.makedirs(Settings.model_save_dir, exist_ok=True)
ckpt_path = os.path.join(Settings.model_save_dir, "best_model.pth")

In [3]:
dataset = GestureDatasetCSV(Settings.dataset_dir, Settings.window_size)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=Settings.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=Settings.batch_size, shuffle=False)

In [4]:
# dataset = GestureDatasetCSV(Settings.dataset_dir, Settings.window_size)

# n_total = len(dataset)
# n_train = int(n_total * 0.7)
# n_val = int(n_total * 0.2)
# n_test  = n_total - n_train - n_val

# train_set, val_set, test_set = random_split(
#     dataset, [n_train, n_val, n_test],
#     generator=torch.Generator().manual_seed(Settings.seed)
# )


# train_loader = DataLoader(train_set, batch_size=Settings.batch_size,
#                           shuffle=True,  num_workers=4, pin_memory=True)
# val_loader   = DataLoader(test_set,   batch_size=Settings.batch_size,
#                           shuffle=False, num_workers=4, pin_memory=True)
# test_loader = DataLoader(val_set, batch_size=Settings.batch_size,
#                           shuffle=True,  num_workers=4, pin_memory=True)

In [5]:
model = LandmarkTransformer().to(Settings.device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=Settings.learning_rate)

In [6]:
def run(loader, train=True):
    model.train(train)
    tot, correct, loss_sum = 0,0,0.0
    for x,y in tqdm(loader, leave=False):
        x,y = x.to(Settings.device), y.to(Settings.device)
        if train: optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        if train:
            loss.backward()
            optimizer.step()
        pred = out.argmax(1)
        tot += y.size(0); correct += (pred==y).sum().item()
        loss_sum += loss.item()*y.size(0)
    return loss_sum/tot, correct/tot

In [7]:
best = 0.0
best_epoch = 0
for ep in range(1, Settings.epochs+1):
    train_loss, train_acc = run(train_loader,True)
    val_loss ,val_acc = run(val_loader,False)
    print(f"[{ep}] train loss:{train_loss:.4f} | accuracy:{train_acc:.3f}\nval loss:{val_loss:.4f} | accuracy:{val_acc:.3f}")
    if val_acc >= best:
        best = val_acc
        best_epoch = ep
        torch.save(model.state_dict(), ckpt_path)
        print(f"  ↳ best saved ({best:.3f})")

print(f"\nBest accuracy = {best:.3f} (epoch {best_epoch})")

  0%|          | 0/103 [00:00<?, ?it/s]

                                                 

[1] train loss:1.4481 | accuracy:0.541
val loss:0.6487 | accuracy:0.911
  ↳ best saved (0.911)


                                                  

[2] train loss:0.5737 | accuracy:0.916
val loss:0.2277 | accuracy:0.989
  ↳ best saved (0.989)


                                                 

[3] train loss:0.2778 | accuracy:0.985
val loss:0.0989 | accuracy:0.996
  ↳ best saved (0.996)


                                                 

[4] train loss:0.1656 | accuracy:0.995
val loss:0.0583 | accuracy:0.998
  ↳ best saved (0.998)


                                                  

[5] train loss:0.1112 | accuracy:0.998
val loss:0.0390 | accuracy:0.999
  ↳ best saved (0.999)


                                                 

[6] train loss:0.0837 | accuracy:0.999
val loss:0.0285 | accuracy:0.999
  ↳ best saved (0.999)


                                                 

[7] train loss:0.0646 | accuracy:0.999
val loss:0.0220 | accuracy:0.999
  ↳ best saved (0.999)


                                                  

[8] train loss:0.0534 | accuracy:1.000
val loss:0.0179 | accuracy:0.999
  ↳ best saved (0.999)


                                                 

[9] train loss:0.0439 | accuracy:1.000
val loss:0.0142 | accuracy:0.999
  ↳ best saved (0.999)


                                                  

[10] train loss:0.0387 | accuracy:1.000
val loss:0.0119 | accuracy:0.999
  ↳ best saved (0.999)


                                                 

[11] train loss:0.0328 | accuracy:1.000
val loss:0.0100 | accuracy:0.999
  ↳ best saved (0.999)


                                                  

[12] train loss:0.0280 | accuracy:1.000
val loss:0.0085 | accuracy:0.999
  ↳ best saved (0.999)


                                                 

[13] train loss:0.0249 | accuracy:1.000
val loss:0.0071 | accuracy:0.999
  ↳ best saved (0.999)


                                                 

[14] train loss:0.0218 | accuracy:1.000
val loss:0.0062 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[15] train loss:0.0194 | accuracy:1.000
val loss:0.0056 | accuracy:0.999


                                                 

[16] train loss:0.0172 | accuracy:1.000
val loss:0.0049 | accuracy:1.000
  ↳ best saved (1.000)


                                                  

[17] train loss:0.0157 | accuracy:1.000
val loss:0.0041 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[18] train loss:0.0134 | accuracy:1.000
val loss:0.0037 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[19] train loss:0.0125 | accuracy:1.000
val loss:0.0032 | accuracy:1.000
  ↳ best saved (1.000)


                                                  

[20] train loss:0.0112 | accuracy:1.000
val loss:0.0029 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[21] train loss:0.0104 | accuracy:1.000
val loss:0.0026 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[22] train loss:0.0093 | accuracy:1.000
val loss:0.0023 | accuracy:1.000
  ↳ best saved (1.000)


                                                  

[23] train loss:0.0089 | accuracy:1.000
val loss:0.0021 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[24] train loss:0.0079 | accuracy:1.000
val loss:0.0019 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[25] train loss:0.0074 | accuracy:1.000
val loss:0.0017 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[26] train loss:0.0069 | accuracy:1.000
val loss:0.0015 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[27] train loss:0.0065 | accuracy:1.000
val loss:0.0014 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[28] train loss:0.0060 | accuracy:1.000
val loss:0.0013 | accuracy:1.000
  ↳ best saved (1.000)


                                                 

[29] train loss:0.0054 | accuracy:1.000
val loss:0.0011 | accuracy:1.000
  ↳ best saved (1.000)


                                                  

[30] train loss:0.0050 | accuracy:1.000
val loss:0.0010 | accuracy:1.000
  ↳ best saved (1.000)

Best accuracy = 1.000 (epoch 30)




In [8]:
# model.load_state_dict(torch.load(ckpt_path))
# test_loss, test_acc = run(test_loader, False)
# print(f"Test loss:{test_loss:.4f} | accuracy:{test_acc:.3f}")