In [None]:
import json
import cv2
import matplotlib.pyplot as plt
import matplotlib.image as mpimg 
import trt_pose.coco
import math
import os
import numpy as np
import traitlets


In [None]:
with open('hand_pose.json', 'r') as f:
    hand_pose = json.load(f)

topology = trt_pose.coco.coco_category_to_topology(hand_pose)
import trt_pose.models

num_parts = len(hand_pose['keypoints'])
num_links = len(hand_pose['skeleton'])

model = trt_pose.models.resnet18_baseline_att(num_parts, 2 * num_links).cuda().eval()
import torch


WIDTH = 256
HEIGHT = 256
data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda()

if not os.path.exists('resnet18_244x224_epoch_4150_trt.pth'):
    MODEL_WEIGHTS = 'resnet18_244x224_epoch_4150.pth'
    model.load_state_dict(torch.load(MODEL_WEIGHTS))
    import torch2trt
    model_trt = torch2trt.torch2trt(model, [data], fp16_mode=True, max_workspace_size=1<<25)
    OPTIMIZED_MODEL = 'resnet18_244x224_epoch_4150_trt.pth'
    torch.save(model_trt.state_dict(), OPTIMIZED_MODEL)


OPTIMIZED_MODEL = 'resnet18_244x224_epoch_4150_trt.pth'
from torch2trt import TRTModule

model_trt = TRTModule()
model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))

In [None]:
from trt_pose.draw_objects import DrawObjects
from trt_pose.parse_objects import ParseObjects

parse_objects = ParseObjects(topology,cmap_threshold=0.15, link_threshold=0.15)
draw_objects = DrawObjects(topology)

In [None]:

import torchvision.transforms as transforms
import PIL.Image

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
device = torch.device('cuda')

def preprocess(image):
    global device
    device = torch.device('cuda')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device)
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

Next, let's define a function that will preprocess the image, which is originally in BGR8 / HWC format.

In [None]:
from jetcam.usb_camera import USBCamera
from jetcam.csi_camera import CSICamera
from jetcam.utils import bgr8_to_jpeg

camera = USBCamera(width=WIDTH, height=HEIGHT, capture_fps=30, capture_device=1)
#camera = CSICamera(width=WIDTH, height=HEIGHT, capture_fps=30)

camera.running = True

In [None]:
from preprocessdata import preprocessdata
preprocessdata = preprocessdata(topology, num_parts)

In [None]:
import ipywidgets
from IPython.display import display


image_w = ipywidgets.Image(format='jpeg', width=256, height=256)
display(image_w)

In [None]:
pen = []
def draw(image, joints):
    cv2.circle(image, (joints[17][0], joints[17][1]), 1,(255,0,255), 2)
    cv2.circle(image, (joints[9][0], joints[9][1]), 1,(0,255,0), 2)
    cv2.circle(image, (joints[5][0], joints[5][1]), 1,(255,255,255), 2)
    cv2.circle(image, (joints[1][0], joints[1][1]), 1,(0,0,0), 2)
    dist_between_j17_j1 = math.sqrt((joints[17][0]-joints[1][0])**2+(joints[17][1]-joints[1][1])**2)
    dist_between_j9_j1 = math.sqrt((joints[9][0]-joints[1][0])**2+(joints[9][1]-joints[1][1])**2)
    global pen
    if dist_between_j9_j1<30:
        pen.append((joints[5][0], joints[5][1]))
    for i in range(len(pen)):
        if i > 0:
            cv2.line(image,pen[i-1], pen[i], (0,0,0), 2)
            #cv2.circle(image, pen[i], 1,(0,0,0), 2)
    if dist_between_j17_j1<5:
        pen.clear()

In [None]:
def draw_pose(image, joints):
    for i in range (len(joints)):
        cv2.circle(image, (joints[i][0], joints[i][1]), 1,(0,0,255), 2)

In [None]:
def execute(change):
    image = change['new']
    data = preprocess(image)
    cmap, paf = model_trt(data)
    cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
    counts, objects, peaks = parse_objects(cmap, paf)#, cmap_threshold=0.15, link_threshold=0.15)
    draw_objects(image, counts, objects, peaks)  
    joints = preprocessdata.joints_inference(image, counts, objects, peaks)
    dist_bn_joints = preprocessdata.find_distance(joints)
    #draw(image, joints)
    #draw_pose(image, joints)
    image_w.value = bgr8_to_jpeg(image[:, ::-1, :])

In [None]:
execute({'new': camera.value})

In [None]:
camera.observe(execute, names='value')

In [None]:
camera.unobserve_all()

In [None]:
#camera.running = False