## Install

In [1]:
!pip install tensorflow==2.10.0 tensorflow-gpu==2.10.0 tensorflow-hub opencv-python matplotlib



## Imports

In [2]:
import tensorflow as tf
import tensorflow_hub as hub
import cv2
from matplotlib import pyplot as plt
import numpy as np

In [3]:
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')

for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)

## Model

In [4]:
model = hub.load('https://tfhub.dev/google/movenet/multipose/lightning/1')
final_model = model.signatures['serving_default']

## Real-time Pose Estimation

In [5]:
def person_points(frame ,keypoints, score_threshold):

    y, x, c = frame.shape

    shaped_keypoints = np.multiply(keypoints, [y, x, 1])

    for point in shaped_keypoints:
        p_y, p_x, p_s = point

        if p_s >=score_threshold:
            cv2.circle(frame, (int(p_x), int(p_y)), 4, (0, 255, 0), -1)
        

In [6]:
# point to connect
EDGES = {
    (0, 1): (0, 255, 0),
    (0, 2): (0, 0,  255),
    (1, 3): (0, 255, 0),
    (2, 4): (0, 0,  255),
    (0, 5): (0, 255, 0),
    (0, 6): (0, 0,  255),
    (5, 7): (0, 255, 0),
    (7, 9): (0, 255, 0),
    (6, 8): (0, 0,  255),
    (8, 10): (0, 0,  255),
    (5, 6): ( 255, 0,  0),
    (5, 11): (0, 255, 0),
    (6, 12): (0, 0,  255),
    (11, 12): ( 255, 0,  0),
    (11, 13): (0, 255, 0),
    (13, 15): (0, 255, 0),
    (12, 14): (0, 0,  255),
    (14, 16): (0, 0,  255)
}

In [7]:
def person_edges(frame ,keypoints, edges, score_threshold):

    y, x, c = frame.shape

    shaped_keypoints = np.multiply(keypoints, [y, x, 1])

    for points, color in edges.items():

        p1, p2 = points

        p1_y, p1_x, p1_s = shaped_keypoints[p1]
        p2_y, p2_x, p2_s = shaped_keypoints[p2]

        if (p1_s > score_threshold) & (p2_s > score_threshold):
            
            cv2.line(frame, (int(p1_x), int(p1_y)), (int(p2_x), int(p2_y)), color, 2)
    

In [8]:
def persons_points_edges(frame, persons_keypoints_scores, edges, score_threshold):
    
    for person in (persons_keypoints_scores):
        person_points(frame, person, score_threshold)
        person_edges(frame, person, edges, score_threshold)

    

In [9]:
cap = cv2.VideoCapture(0)

while cap.isOpened():
    
    ret, frame = cap.read()

    img = frame.copy()
    img = tf.image.resize_with_pad(tf.expand_dims(img, axis=0), 192, 256)
    input_img = tf.cast(img, dtype=tf.int32)
    
    outputs = final_model(input_img) # outputs is a dictionary

    persons_keypoints_scores = outputs['output_0'].numpy()[:,:,:51].reshape((6,17,3)) #shape of outputs['output_0']: [1, 6, 56]

    persons_points_edges(frame, persons_keypoints_scores, EDGES, 0.2)
   
    cv2.imshow('pose', frame)
    
    if cv2.waitKey(10) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

## Video Pose Estimaton

In [10]:
cap = cv2.VideoCapture('mixkit-healthy-woman-jumping-a-rope-40234-medium.mp4')

while cap.isOpened():
    
    ret, frame = cap.read()

    img = frame.copy()
    img = tf.image.resize_with_pad(tf.expand_dims(img, axis=0), 288, 512)
    input_img = tf.cast(img, dtype=tf.int32)
    
    outputs = final_model(input_img) # outputs is a dictionary

    persons_keypoints_scores = outputs['output_0'].numpy()[:,:,:51].reshape((6,17,3)) #shape of outputs['output_0']: [1, 6, 56]

    persons_points_edges(frame, persons_keypoints_scores, EDGES, 0.2)
   
    cv2.imshow('pose', frame)
    
    if cv2.waitKey(10) & 0xFF == ord('q'):
        break

cap.release()
cv2.destroyAllWindows()

## Image Pose Estimation

In [107]:

img = cv2.imread('sport12.jpg')

input_img = img.copy()
input_img = tf.image.resize_with_pad(tf.expand_dims(input_img, axis=0), 128, 192)
input_img = tf.cast(input_img, dtype=tf.int32)

outputs = final_model(input_img) # outputs is a dictionary

persons_keypoints_scores = outputs['output_0'].numpy()[:,:,:51].reshape((6,17,3)) #shape of outputs['output_0']: [1, 6, 56]

persons_points_edges(img, persons_keypoints_scores, EDGES, 0.18)

cv2.imwrite('pose.jpg', img)

cv2.imshow('sport', img)
cv2.waitKey(0); cv2.destroyAllWindows(); cv2.waitKey(1)


-1