In [1]:
import time
import math
import cv2
import mediapipe as mp
import numpy as np
from handUnits import HandDetector
from gestureUnits import GestureDetector
import os

In [2]:
import torch
import torch.optim
import torch.nn as nn
from torch.autograd import Variable
from torch.utils.data import TensorDataset,DataLoader
from sklearn.model_selection import train_test_split
import pandas as pd

%matplotlib inline

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layers1 = nn.Sequential(
            nn.Conv2d(1,16,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True)
        )
        self.layers2 = nn.Sequential(
            nn.Conv2d(16,32,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2,stride=2)
        )
        self.layers3 = nn.Sequential(
            nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layers4 = nn.Sequential(
            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),

        )
        self.fc = nn.Sequential(
            nn.Linear(7*7*128,1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024,100),
            nn.ReLU(inplace=True),
            nn.Linear(100,26)
        )
    def forward(self, x):
        x = self.layers1(x)
        x = self.layers2(x)
        x = self.layers3(x)
        x = self.layers4(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)

        return x

In [4]:
cnn = CNN()
cnn.load_state_dict(torch.load('E:/MyProject/GR/dataset/archive/cnn_1.pth')) # 导入网络的参数

<All keys matched successfully>

In [26]:
def Frame(pTime):
    cTime = time.time()
    fps = 1 / (cTime - pTime)
    pTime = cTime
    cv2.putText(
        img, str(int(fps)), (10, 70), cv2.FONT_HERSHEY_PLAIN, 3, (255, 255, 255), 2
    )
    return pTime



# 获取坐标
def get_lms(position, direction):
    lms = []
    if any(position[direction]):
        for i in range(21):
            pos_x = position[direction].get(i)[0]
            pos_y = position[direction].get(i)[1]
            lms.append([int(pos_x), int(pos_y)])
    return lms



# 获取感兴趣的区域图像
def get_roi(lms, draw=True, show=False):

    x_0 = 100
    y_0 = 200
    width = 160
    height = 160

    if draw == True:
        cv2.rectangle(img, (x_0, y_0), (x_0 + width, y_0 + height), (0, 255, 255), 2)
    if show == True:
        cv2.imshow('Roi', roi)

    roi = img[
        y_0 + 2 : y_0 + height - 2, x_0 + 2 : x_0 + width - 2
    ]  # 避免如果绘制矩形把矩形也加入roi

    return roi



In [32]:
def get_data_img(roi, size=28):
    img = cv2.resize(roi, (size, size))  # 默认重整为28*28尺寸
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)  # 灰度处理
    return img

In [36]:
def cnnOut(img):
    img = get_data_img(img)
    img = img/255
    img = torch.from_numpy(img)
    img = img.to(torch.float32)
    img = img.view(1,1,28,28)
    img = Variable(img)
    out = cnn(img)
    pre = torch.max(out.data,1)[1]
    return pre.numpy()[0]

In [None]:
def modelOut(model, img):
    

In [37]:
camera = cv2.VideoCapture(0)
hand_detector = HandDetector()


# 帧率显示
pTime = 0

while True:
    time_1 = time.time()
    success, img = camera.read()

    if success:
        img = cv2.flip(img, 1)  # 水平翻转


        hand_detector.process(img, draw=False)
        position = hand_detector.find_position(img)

        # 获取右手坐标点
        right_lms = get_lms(position, 'Right')

        # 获取左手坐标点
        left_lms = get_lms(position, 'Left')

        if left_lms:
            roi = get_roi(left_lms)
            res = cnnOut(roi)
            cv2.putText(img, str(res), (200, 200), cv2.FONT_HERSHEY_PLAIN, 3, (0, 255, 255), 2)

        # 统计屏幕帧率
        pTime = Frame(pTime)

        cv2.imshow('Video', img)

        time_2 = time.time()

    else:
        print("获取失败")
        break

    if cv2.waitKey(1) & 0xFF == ord('q'):
        break


camera.release()
cv2.destroyAllWindows()

In [9]:
img = cv2.imread('E:/MyProject/GR/dataset/test_0.jpg',cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img,(28,28))
img = img/255
img = torch.from_numpy(img)
x = img
x = x.to(torch.float32)
x = x.view(1,1,28,28)
x = Variable(x)
out = cnn(x)
pre = torch.max(out.data,1)[1]

In [23]:
pre_ = pre.numpy()
str(pre.numpy())

'0'