# Hand Gesture Transformer – Inference notebook

훈련된 checkpoint를 불러와 실제 이미지를 Mediapipe로 전처리한 뒤 제스처를 분류합니다.


In [None]:
import cv2
import torch
import torch.nn as nn
import mediapipe as mp
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
from models.HandGestureTransformer import HandGestureTransformer
from models.img_to_landmarks import img_to_landmarks

# 모델 클래스 (학습 노트북과 동일)
class HandGestureTransformer(nn.Module):
    def __init__(self, d_model=128, num_layers=4, num_heads=8, n_gestures=4):
        super().__init__()
        self.proj = nn.Linear(3, d_model)
        self.cls_token = nn.Parameter(torch.zeros(1,1,d_model))
        self.pos = nn.Parameter(torch.randn(1,22,d_model))
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(d_model, num_heads, 4*d_model, dropout=0.1, batch_first=True)
            for _ in range(num_layers)])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, n_gestures)

    def forward(self, xyz):
        B = xyz.shape[0]
        x = torch.cat([self.cls_token.expand(B,-1,-1), self.proj(xyz)],1) + self.pos
        for layer in self.layers: x = layer(x)
        x = self.norm(x)
        return self.head(x[:,0])


In [None]:
checkpoint_path = 'checkpoint/ckpt_best.pt'  # 경로 수정
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = HandGestureTransformer().to(device)
model.load_state_dict(torch.load(checkpoint_path, map_location=device)['model'])
model.eval()
mp_hands = mp.solutions.hands.Hands(static_image_mode=True, max_num_hands=1)
GESTURES = ['gesture0','gesture1','gesture2','gesture3']


In [None]:
def img_to_landmarks(img_path, mp_hands):
    img_bgr = cv2.imread(img_path)
    res = mp_hands.process(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB))
    if not res.multi_hand_landmarks:
        raise ValueError('No hand detected')
    lm = res.multi_hand_landmarks[0]
    xyz = np.array([[pt.x, pt.y, pt.z] for pt in lm.landmark], dtype=np.float32)
    return torch.tensor(xyz).unsqueeze(0)

In [None]:
img_path = 'assets/example1.png'

xyz = img_to_landmarks(img_path).to(device)
with torch.no_grad():
    logits = model(xyz)
prob = torch.softmax(logits, -1).cpu().numpy()[0]
for i,p in enumerate(prob):
    print(f'{GESTURES[i]}: {p:.4f}')
print('Predicted:', GESTURES[int(prob.argmax())])