In [1]:
import os
import cv2
import mediapipe as mp
import numpy as np
import pandas as pd
from tqdm import tqdm
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [5]:
def keypoints_transform(keypoints):
    # 1. 将一维数组 (63,) 重塑为 (21, 3) 的矩阵
    kp_matrix = torch.tensor(keypoints, dtype=torch.float32).reshape(21, 3)
    # 2. 获取手腕坐标 (第0行)
    wrist = kp_matrix[0].clone()  # Shape: (3,)
    # 3. 所有点减去手腕坐标 (广播机制，自动应用到每一行)
    kp_matrix -= wrist
    # 4. 归一化 (保持原有逻辑，缩放到 -1 到 1)
    # 注意：这里建议用所有相对坐标的最大值来缩放，保持长宽比
    max_val = torch.max(torch.abs(kp_matrix))
    if max_val > 0:
        kp_matrix /= max_val
    # 5. 展平回 (63,)
    return kp_matrix.flatten()

In [6]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(63, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 3)
        self.bn = nn.BatchNorm1d(64)
        self.dropout = nn.Dropout(0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

In [7]:
# torch.save(model.state_dict(), "mlp_model_weights.pth")

model = MLP().to(device)
model.load_state_dict(torch.load("mlp_model_weights.pth"))
model.eval()

MLP(
  (fc1): Linear(in_features=63, out_features=64, bias=True)
  (fc2): Linear(in_features=64, out_features=64, bias=True)
  (fc3): Linear(in_features=64, out_features=3, bias=True)
  (bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.5, inplace=False)
  (relu): ReLU()
)

In [9]:
cap = cv2.VideoCapture(0)

# ---------- 加载 PNG ----------
imgs = {
    0: cv2.imread("images\\rock.png", cv2.IMREAD_UNCHANGED),
    1: cv2.imread("images\\paper.png", cv2.IMREAD_UNCHANGED), 
    2: cv2.imread("images\\scissor.png", cv2.IMREAD_UNCHANGED)
}

# ---------- 反手势规则 ----------
opposite = {0: 1, 1: 2, 2: 0}

# ---------- Mediapipe ----------
mp_draw = mp.solutions.drawing_utils
mp_hands = mp.solutions.hands
hands = mp_hands.Hands(
    static_image_mode=False,
    max_num_hands=1,
    min_detection_confidence=0.5,
    min_tracking_confidence=0.5
)


# ---------- PNG 叠加函数 ----------
def overlay_png(background, png, x, y):
    h, w = png.shape[:2]
    bg_h, bg_w = background.shape[:2]
    if x + w > bg_w or y + h > bg_h:
        return background
    bgr = png[:, :, :3]
    alpha = png[:, :, 3] / 255.0
    for c in range(3):
        background[y:y+h, x:x+w, c] = (
            bgr[:, :, c] * alpha
            + background[y:y+h, x:x+w, c] * (1 - alpha)
        )
    return background


while True:
    ret, frame = cap.read()
    if not ret:
        break

    # frame = cv2.flip(frame, 1)
    img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    results = hands.process(img_rgb)

    if results.multi_hand_landmarks:
        hand_landmarks = results.multi_hand_landmarks[0]
        keypoints = []
        for lm in hand_landmarks.landmark:
            keypoints.extend([lm.x, lm.y, lm.z])
            
        keypoints = keypoints_transform(keypoints).unsqueeze(0).to(device)
        res = model(keypoints)
        predicted_class = torch.argmax(res, dim=1).item()
        class_names = ["Rock", "Paper", "Scissors"]
        cv2.putText(frame, f'Prediction: {class_names[predicted_class]}', (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
        frame = overlay_png(frame, imgs[opposite[predicted_class]], 10, 50)
        
        for hand_landmarks in results.multi_hand_landmarks:
            h, w, c = frame.shape
            mp_draw.draw_landmarks(frame, hand_landmarks, mp_hands.HAND_CONNECTIONS)
            x_list = []
            y_list = []
            for lm in hand_landmarks.landmark:
                cx, cy = int(lm.x * w), int(lm.y * h)
                x_list.append(cx)
                y_list.append(cy)

            x_min, x_max = min(x_list), max(x_list)
            y_min, y_max = min(y_list), max(y_list)
            cv2.rectangle(frame, (x_min, y_min), (x_max, y_max), (0,255,0), 2)

    
    cv2.imshow("Hand Tracking", frame)

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()