# Define Variables

In [None]:
# Define the FOV of the camera
FOV = 160  # CSI Camera = 160 degrees

# KeyPoints TRT Model Path
OPTIMIZED_MODEL = 'resnet18_baseline_att_224x224_A_epoch_249_trt.pth'

# Reset the camera

In [None]:
import getpass
import os

# Reset Camera
password = 'jetbot'
command = "sudo -S systemctl restart nvargus-daemon" #can be any command but don't forget -S as it enables input from stdin
os.system('echo %s | %s' % (password, command))

# Load pre-Trained keyPointsRCNN

In [None]:
from threading import Thread
from torch2trt import TRTModule
import trt_pose.models
import torch2trt
import os
import json
import trt_pose.coco
import torch
import time

# Topology
with open('human_pose.json', 'r') as f:
    human_pose = json.load(f)
topology = trt_pose.coco.coco_category_to_topology(human_pose)

# load Model
WIDTH = 224
HEIGHT = 224
data = torch.zeros((1, 3, HEIGHT, WIDTH)).cuda()
model_trt = TRTModule()
model_trt.load_state_dict(torch.load(OPTIMIZED_MODEL))
print("KeyPoints RCNN Successfully Loaded. ")

# Prepare the Camera

In [None]:
import cv2
import torchvision.transforms as transforms
import PIL.Image
import time

mean = torch.Tensor([0.485, 0.456, 0.406]).cuda()
std = torch.Tensor([0.229, 0.224, 0.225]).cuda()
device = torch.device('cuda')

def preprocess(image):
    global device
    device = torch.device('cuda')
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = PIL.Image.fromarray(image)
    image = transforms.functional.to_tensor(image).to(device)
    image.sub_(mean[:, None, None]).div_(std[:, None, None])
    return image[None, ...]

In [None]:
from jetcam.csi_camera import CSICamera

camera = CSICamera(width=WIDTH, height=HEIGHT, capture_fps=15)
camera.running = True

# Prepare KeyPoints drawing on Frames

In [None]:
from trt_pose.draw_objects import DrawObjects
from trt_pose.parse_objects import ParseObjects

parse_objects = ParseObjects(topology)
draw_objects = DrawObjects(topology)

# Function to obtain keypoints (for drawing & calculating stage)

In [None]:
def get_keypoints(image, human_pose, topology, object_counts, objects, normalized_peaks):
    """Get the keypoints from torch data and put into a dictionary where keys are keypoints
    and values the x,y coordinates. The coordinates will be interpreted on the image given.

    Args:
        image: cv2 image
        human_pose: json formatted file about the keypoints

    Returns:
        dictionary: dictionary where keys are keypoints and values are the x,y coordinates
    """
    height = image.shape[0]
    width = image.shape[1]
    keypoints = {}
    K = topology.shape[0]
    count = int(object_counts[0])

    for i in range(count):
        obj = objects[0][i]
        C = obj.shape[0]
        for j in range(C):
            k = int(obj[j])
            if k >= 0:
                peak = normalized_peaks[0][j][k]
                x = round(float(peak[1]) * width)
                y = round(float(peak[0]) * height)
                keypoints[human_pose["keypoints"][j]] = (x, y)

    return keypoints

# Function to Calculate Stage

In [None]:
def calculate_stage(keypoints):
    '''
    Estimates the phase based on the skeleton parts detected:
    
    Phase 1: Knees and ankles are detected.
    Phase 2: Skeleton is detected but one of the knees or ankles is missing (obstructed by obstacle).
    Phase 3: No skeleton is detected.
    
    '''
    
    
    # If camera detects skeleton
    if keypoints:
        
        leg_parts = ["left_knee", "right_knee", "left_ankle", "right_ankle"] 
        obstructed = [x for x in leg_parts if x not in keypoints.keys()]
        
        # If legs are obstructed
        if obstructed:
            return [0,1,0]  # Phase 2
        
        else:
            return [0,0,1]   # Phase 1
            
    # No skeleton detected
    else:
        return [1,0,0]   # Phase 3

# Thread to print stage

In [None]:
from IPython.display import clear_output as cls

def print_stage_thread():
    
    global print_stage_running
    print_stage_running = True
    
    while True:
        cls(wait=True)
        keypoints, image, counts, objects, peaks = execute2({'new': camera.value})
        stage = calculate_stage(keypoints)
        print("Stage: ", stage)

# Define Execute (For Visualization)

In [None]:
def execute(change):
    
    global image
    
    image = change['new']
    data = preprocess(image)
    cmap, paf = model_trt(data)
    cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
    counts, objects, peaks = parse_objects(cmap, paf)#, cmap_threshold=0.15, link_threshold=0.15)
    keypoints = get_keypoints(image, human_pose, topology, counts, objects, peaks)
    draw_objects(image, counts, objects, peaks)
    image_w.value = bgr8_to_jpeg(image[:, ::-1, :])
    
    if print_stage_running == False:
        thread2 = Thread(target = print_stage_thread)
        thread2.start()

# Define Execute (Used for stage calculation)

In [None]:
def execute2(change, visualize=True):
    image = change['new']
    data = preprocess(image)
    cmap, paf = model_trt(data)
    cmap, paf = cmap.detach().cpu(), paf.detach().cpu()
    counts, objects, peaks = parse_objects(cmap, paf)#, cmap_threshold=0.15, link_threshold=0.15)
    keypoints = get_keypoints(image, human_pose, topology, counts, objects, peaks)

    return keypoints, image, counts, objects, peaks

###### 2 Different execute functions are defined because if only one is used and the variables are made global to be shared by both threads, the delay goes from 2 seconds to 20~30 seconds 

In [None]:
print_stage_running = False  # Thread Lock (prevents from crashes)

# Run Everything

In [None]:
import ipywidgets
from IPython.display import display
from jetcam.utils import bgr8_to_jpeg

image_w = ipywidgets.Image(format='jpeg', width=WIDTH, height=HEIGHT)
display(image_w)

In [None]:
execute({'new': camera.value})
camera.observe(execute, names='value')