In [1]:
!pip install mediapipe



In [2]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import scipy.io as sio
import time
import argparse
import os.path as osp
import torch
from tqdm.notebook import tqdm as tqdm
from IPython.display import HTML, Audio, display, Javascript, Image
from base64 import b64decode, b64encode
import numpy as np
import html
import time
import io
import PIL
from PIL import Image
import cv2
import matplotlib.pyplot as plt
from torch import nn as nn
import os
import mediapipe as mp

In [3]:
class handDetector():
    def __init__(self, mode=True, maxHands=2, detectionCon=0.5, trackCon=0.5):
        self.mode = mode
        self.maxHands = maxHands
        self.detectionCon = detectionCon
        self.trackCon = trackCon

        self.mpHands = mp.solutions.hands
        self.hands = self.mpHands.Hands(self.mode, self.maxHands,
                                        self.detectionCon, self.trackCon)
        self.mpDraw = mp.solutions.drawing_utils

    def find_hands(self, img, draw=True):
        imgRGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        self.results = self.hands.process(imgRGB)

        if self.results.multi_hand_landmarks:
            for handLms in self.results.multi_hand_landmarks:
                if draw:
                    self.mpDraw.draw_landmarks(img, handLms, self.mpHands.HAND_CONNECTIONS)
        return img
    
    def get_num(self, num):
        temp = num.split(".")
        num = temp[0] + "." + temp[1][:3]
        return num
    
    def find_pos(self, img, hand_num=0, draw=False):
        global count
        global prev_angels
        global prev_xz_angels
        
        lmList = []
        angle = 0
        xz_angle = 0
        
        if self.results.multi_hand_landmarks:
            my_hand = self.results.multi_hand_landmarks[hand_num]

            id_5 = [0, 0, 0]
            id_17 = [0, 0, 0]
            
            for id, lm in enumerate(my_hand.landmark):
                if id == 5:
                    id_5[0] = lm.x
                    id_5[1] = lm.y
                    id_5[2] = lm.z

                if id == 17:
                    id_17[0] = lm.x
                    id_17[1] = lm.y
                    id_17[2] = lm.z

                h, w, c = img.shape
                cx, cy = int(lm.x * w), int(lm.y * h)
                lmList.append([id, cx, cy])
                if draw:
                    cv2.circle(img, (cx, cy), 5, (255, 0, 255), cv2.FILLED)


            difference = np.array(id_5) - np.array(id_17)
            angle = np.arctan2(difference[0],difference[2]) + np.pi

            zplane = np.array([0,1,0]).astype(np.float)
            vec_normed = np.linalg.norm(difference)
            zplane_normed = np.linalg.norm(zplane)

            result = (difference*zplane)/(vec_normed*zplane_normed)
            xz_angle = np.arcsin(np.sum(result))
            
            xz_angle = (1-np.abs((xz_angle/ np.pi) * (2)))

            avg_angles = angle
            avg_prev_xz_angels = xz_angle

            if draw:
              x = int(np.cos(avg_angles)* avg_prev_xz_angels*512) + 128
              y = int(np.sin(avg_angles)* avg_prev_xz_angels*512) + 128
              
              cv2.circle(img, (x,y), 5,(255, 0, 0), cv2.FILLED)
              cv2.line(img, (128,128),(x,y),(255,255,0),2)
              
              cv2.putText(img, f'xz_angels {str(avg_prev_xz_angels)}', (2, 30), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 255), 2)
              cv2.putText(img, f'angles {avg_angles}', (2, 50), cv2.FONT_HERSHEY_PLAIN, 1, (0, 0, 255), 2)
            
        else:
            return None, None
        return avg_angles, avg_prev_xz_angels 

In [4]:
def compute_zplane_angle(posmat, joints=(17,5), radians=True):
  vec = posmat[:,joints[0]] - posmat[:,joints[1]].type(torch.float)
  zplane = torch.tensor([0,1,0]).type(torch.float)

  vec_normed = vec.norm()
  zplane_normed = zplane.norm()

  result = (vec*zplane)/(vec_normed*zplane_normed)
  angle = torch.asin(torch.sum(result))

  if radians:
    return angle

  return (angle/torch.pi)*180


def compute_x_z_angle(posmat, joints=(17,5), radians=True):
  vec = posmat[:,joints[0]] - posmat[:,joints[1]]
  vec = vec.squeeze()
  angle = torch.atan2(vec[0],vec[2])

  if radians:
    return angle

  return (angle/torch.pi)*180

In [5]:
class PoseToVecModel(torch.nn.Module):
    def __init__(self, inputSize=2, outputSize=2, depth=0, width=8):
        super(PoseToVecModel, self).__init__()
        mlist = [nn.Linear(inputSize, width)]
        rlist = [torch.nn.PReLU()]
        for i in range(depth):
          mlist.append(nn.Linear(width, width))
          rlist.append(torch.nn.PReLU())

        self.last = nn.Linear(width, outputSize)

        self.linear = torch.nn.ModuleList(mlist)
        self.rels = torch.nn.ModuleList(rlist)

    def forward(self, x):
      for i in range(len(self.linear)):
        x = self.linear[i](x)
        
        x = self.rels[i](x)
        
      x = self.last(x)

      return x

In [7]:
model_checkpoint_path = "C:/Users/Nir Ben Dor/Documents/Afeka/Handy Project/Midterm Report/21.03.21/checkpoint_6000_gtloss_1.0657_shouldbegood_new_model.pth" 
model = torch.load(model_checkpoint_path)
model.eval()

detector = handDetector()
cap = cv2.VideoCapture(1)

prev_angles = [0, 0, 0, 0, 0]
prev_xz_angels = [0, 0, 0, 0, 0]
count = 0


while True:
    success, img = cap.read()

    with torch.no_grad():
        
        detector.find_hands(img, draw=True)
        angle, xz_angle = detector.find_pos(img, draw=False)

        if angle is not None:
            input = torch.tensor([angle, xz_angle]).type(torch.float32)
            angle, xz_angle = model(input)
            
            count += 1
            prev_angles[count%4] = angle
            prev_xz_angels[count%4] = xz_angle

            avg_angle = sum(prev_angles)/len(prev_angles)
            avg_xz_angels = sum(prev_xz_angels)/len(prev_xz_angels)
            
            y = int(np.cos(avg_angle)* avg_xz_angels*128) + 128
            x = int(np.sin(avg_angle)* avg_xz_angels*128) + 128

            cv2.circle(img, (x,y), 5,(255, 0, 0), cv2.FILLED)
            cv2.line(img, (128,128),(x,y),(255,255,0),2)

    cv2.imshow("Image", img)
    
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break
        
cap.release()
cv2.destroyAllWindows()