In [2]:
import cv2

import torchvision.models.detection as torch_detect
import torchvision
from torchvision import transforms
import torch

from utils.unet import UNet

In [3]:
DETECTION_PATH = 'detection/m_keypoints.pth'
ESTIMATION_PATH = 'model/m_keypoints.pth'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

KEYPOINTS_NUM = 21
NUM_CLASSES = 2

RESIZED_IMG_SIZE = 128
MEANS = [0.3950, 0.4323, 0.2954]
STDS = [0.1966, 0.1734, 0.1836]

In [4]:
detection = torch_detect.fasterrcnn_resnet50_fpn()

in_features = detection.roi_heads.box_predictor.cls_score.in_features
detection.roi_heads.box_predictor = torch_detect.faster_rcnn.FastRCNNPredictor(in_features, NUM_CLASSES)

detection.load_state_dict(torch.load(DETECTION_PATH))
detection = detection.to(DEVICE)
detection.eval()

detection_transform = transforms.ToTensor()

  detection.load_state_dict(torch.load(DETECTION_PATH))


In [5]:
estimation = UNet(KEYPOINTS_NUM)
estimation.load_state_dict(torch.load(ESTIMATION_PATH))
estimation = estimation.to(DEVICE)
estimation.eval()

estimation_transform = transforms.Compose([
    transforms.Resize((RESIZED_IMG_SIZE, RESIZED_IMG_SIZE)),
    transforms.Normalize(mean=MEANS, std=STDS)
])

  estimation.load_state_dict(torch.load(ESTIMATION_PATH))


In [6]:
connections = [
    (0, 1), (1, 2), (2, 3), (3, 4),  # Thumb
    (0, 5), (5, 6), (6, 7), (7, 8),  # Index finger
    (0, 9), (9, 10), (10, 11), (11, 12),  # Middle finger
    (0, 13), (13, 14), (14, 15), (15, 16),  # Ring finger
    (0, 17), (17, 18), (18, 19), (19, 20)  # Pinky
]

In [7]:
import numpy as np
from utils.utils import keypoints_from_heatmaps
import torchvision.transforms.functional as F

cap = cv2.VideoCapture(0)
out = cv2.VideoWriter('output.avi', cv2.VideoWriter_fourcc(*'MJPG'), 10, (640,480))
if not cap.isOpened():
    print("Cannot open camera")
    exit()

while True:
    ret, frame = cap.read()

    if not ret:
        print("Can't receive frame (stream end?). Exiting ...")
        break
    
    frame_model = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    frame_model = detection_transform(frame_model).unsqueeze(0).to(DEVICE)
    
    with torch.no_grad():
        detection_res = detection(frame_model)

    boxes, scores = detection_res[0]['boxes'], detection_res[0]['scores']
    boxes = boxes[scores > 0.9]

    padding = 100

    for i, box in enumerate(boxes):
        x1, y1, x2, y2 = box.cpu().detach().numpy().astype(int)

        crop_x1, crop_y1 = x1 - padding, y1 - padding
        width, height = x2 - crop_x1 + padding, y2 - crop_y1 + padding

        hand_image = F.crop(frame_model.squeeze(0), crop_y1, crop_x1, height, width)

        flip = x1 > (frame_model.shape[3] / 2)
        if flip:
            hand_image = F.vflip(hand_image)
        hand_image = estimation_transform(hand_image.unsqueeze(0)).to(DEVICE)
        
        with torch.no_grad():
            heatmap = estimation(hand_image)
        if flip:
            heatmap = F.vflip(heatmap)

        keypoints = keypoints_from_heatmaps(heatmap[0], RESIZED_IMG_SIZE).cpu().detach().numpy()
        
        keypoints[:, 0] = (keypoints[:, 0] * width + crop_x1)
        keypoints[:, 1] = (keypoints[:, 1] * height + crop_y1)

        keypoints = keypoints.astype(int)

        cv2.rectangle(frame, (x1,y1), (x2,y2), (0,0,255), 2)
        for point in keypoints:
            cv2.circle(frame, (point[0], point[1]), radius=2, color=(255, 0, 0), thickness=3)

        for connection in connections:
            x1, y1 = keypoints[connection[0]]
            x2, y2 = keypoints[connection[1]]
            
            # Draw each line between two connected keypoints
            cv2.line(frame, (x1, y1), (x2, y2), color=(255, 0, 0), thickness=1) 

    out.write(frame)
    cv2.imshow('frame', frame)
    if cv2.waitKey(1) == ord('q'):
        break
out.release
cap.release()
cv2.destroyAllWindows()