### Installing Dependencies

In [1]:
import os
import sys
import shutil
from typing import List
import requests
from tqdm import tqdm
from datetime import timedelta

from enum import Enum
import time
import cv2
import numpy as np
import matplotlib.pyplot as plt

from dataclasses import dataclass
from torchvision import transforms
import torch
import torch.nn.functional as F
from huggingface_hub import hf_hub_download, hf_hub_url

### Pose Estimation

In [6]:
from ultralytics import YOLO

@dataclass
class DetectorConfig:
    model_path: str = "models/yolov8m.pt"
    person_id: int = 0
    conf_thres: float = 0.25


def draw_boxes(img, boxes, color=(0, 255, 0), thickness=2):
    draw_img = img.copy()
    for box in boxes:
        x1, y1, x2, y2 = box
        draw_img = cv2.rectangle(draw_img, (x1, y1), (x2, y2), color, thickness)
    return draw_img


class Detector:
    def __init__(self, config: DetectorConfig = DetectorConfig()):
        model_path = config.model_path
        if not model_path.endswith(".pt"):
            model_path = model_path.split(".")[0] + ".pt"
        self.model = YOLO(model_path, verbose=False)
        self.person_id = config.person_id
        self.conf_thres = config.conf_thres

    def __call__(self, img: np.ndarray) -> np.ndarray:
        return self.detect(img)

    def detect(self, img: np.ndarray) -> np.ndarray:
        start = time.perf_counter()
        results = self.model(img, conf=self.conf_thres)
        detections = results[0].boxes.data.cpu().numpy()  # (x1, y1, x2, y2, conf, cls)

        # Filter out only person
        person_detections = detections[detections[:, -1] == self.person_id]
        boxes = person_detections[:, :-2].astype(int)

        # print(f"Detection inference took: {time.perf_counter() - start:.4f} seconds")
        return boxes


Creating new Ultralytics Settings v0.0.6 file  
View Ultralytics Settings with 'yolo settings' or at 'C:\Users\dangh\AppData\Roaming\Ultralytics\settings.json'
Update Settings with 'yolo settings key=value', i.e. 'yolo settings runs_dir=path/to/dir'. For help see https://docs.ultralytics.com/quickstart/#ultralytics-settings.


In [7]:
from classes_and_palettes import (
    COCO_KPTS_COLORS,
    COCO_WHOLEBODY_KPTS_COLORS,
    GOLIATH_KPTS_COLORS,
    GOLIATH_SKELETON_INFO,
    GOLIATH_KEYPOINTS
)

In [None]:
class SapiensPoseEstimation:
    def __init__(self,
                 path='sapiens_1b_goliath_best_goliath_AP_639_torchscript.pt2',
                 device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
                 dtype: torch.dtype = torch.float32):
        # Load the model
        self.device = device
        self.dtype = dtype
        self.model = torch.jit.load(path).eval().to(device).to(dtype)
        self.preprocessor = transforms.Compose([transforms.ToPILImage(),
                               transforms.Resize((1024,768)),
                               transforms.ToTensor(),
                               transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                               ])
        # Initialize the YOLO-based detector
        self.detector = Detector()


    def __call__(self, img: np.ndarray) -> np.ndarray:
        start = time.perf_counter()

        # Detect persons in the image
        bboxes = self.detector.detect(img)

        # Process the image and estimate the pose
        pose_result_image, keypoints = self.estimate_pose(img, bboxes)

        # print(f"Pose estimation inference took: {time.perf_counter() - start:.4f} seconds")
        return pose_result_image, keypoints


    @torch.inference_mode()
    def estimate_pose(self, img: np.ndarray, bboxes: List[List[float]]) -> (np.ndarray, List[dict]):
        all_keypoints = []
        result_img = img.copy()

        for bbox in bboxes:
            cropped_img = self.crop_image(img, bbox)
            tensor = self.preprocessor(cropped_img).unsqueeze(0).to(self.device).to(self.dtype)

            heatmaps = self.model(tensor)
            keypoints = self.heatmaps_to_keypoints(heatmaps[0].cpu().numpy())
            all_keypoints.append(keypoints)

            # Draw the keypoints on the original image
            result_img = self.draw_keypoints(result_img, keypoints, bbox)

        return result_img, all_keypoints

    def crop_image(self, img: np.ndarray, bbox: List[float]) -> np.ndarray:
        x1, y1, x2, y2 = map(int, bbox[:4])
        return img[y1:y2, x1:x2]


    def heatmaps_to_keypoints(self, heatmaps: np.ndarray) -> dict:
        keypoints = {}
        for i, name in enumerate(GOLIATH_KEYPOINTS):
            if i < heatmaps.shape[0]:
                y, x = np.unravel_index(np.argmax(heatmaps[i]), heatmaps[i].shape)
                conf = heatmaps[i, y, x]
                keypoints[name] = (float(x), float(y), float(conf))
        return keypoints


    def draw_keypoints(self, img: np.ndarray, keypoints: dict, bbox: List[float]) -> np.ndarray:
        x1, y1, x2, y2 = map(int, bbox[:4])
        bbox_width, bbox_height = x2 - x1, y2 - y1
        img_copy = img.copy()

        # Draw keypoints on the image
        for i, (name, (x, y, conf)) in enumerate(keypoints.items()):
            if conf > 0.3:  # Only draw confident keypoints
                x_coord = int(x * bbox_width / 192) + x1
                y_coord = int(y * bbox_height / 256) + y1
                cv2.circle(img_copy, (x_coord, y_coord), 3, GOLIATH_KPTS_COLORS[i], -1)

        # Optionally draw skeleton
        for _, link_info in GOLIATH_SKELETON_INFO.items():
            pt1_name, pt2_name = link_info['link']
            if pt1_name in keypoints and pt2_name in keypoints:
                pt1 = keypoints[pt1_name]
                pt2 = keypoints[pt2_name]
                if pt1[2] > 0.3 and pt2[2] > 0.3:
                    x1_coord = int(pt1[0] * bbox_width / 192) + x1
                    y1_coord = int(pt1[1] * bbox_height / 256) + y1
                    x2_coord = int(pt2[0] * bbox_width / 192) + x1
                    y2_coord = int(pt2[1] * bbox_height / 256) + y1
                    cv2.line(img_copy, (x1_coord, y1_coord), (x2_coord, y2_coord), GOLIATH_KPTS_COLORS[i], 2)

        return img_copy

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pose_estimator = SapiensPoseEstimation()

img_path = "football.jpg"
img = cv2.imread(img_path)
start_time = time.perf_counter()
result_img, keypoints = pose_estimator(img)

height, width, _ = result_img.shape
fig = plt.figure(figsize = (width/100, height/100), dpi=100)
print(f"Time taken: {time.perf_counter() - start_time:.4f} seconds")

result_img_rgb = result_img[:,:,::-1]
plt.imshow(result_img_rgb)
plt.axis('off')
plt.show()

del pose_estimator
torch.cuda.empty_cache()

In [None]:
from google.colab.patches import cv2_imshow
pose_estimator = SapiensPoseEstimation(dtype=torch.float16)
cap = cv2.VideoCapture("test_images/Sapiens-video-test.mp4")

while True:
    ret, frame = cap.read()
    if not ret: break

    result_img, _, _  = pose_estimator(frame)
    cv2_imshow(result_img)

    if cv2.waitKey(1) & 0xFF == ord('q'): break
    torch.cuda.empty_cache()