In [1]:
import os
# import argparse
import cv2
import numpy as np
import sys
import time
from threading import Thread
# import importlib.util
from PIL import Image
import math
from pkg_resources import parse_version

from edgetpu import __version__ as edgetpu_version
assert parse_version(edgetpu_version) >= parse_version('2.11.1'), \
        'This demo requires Edge TPU version >= 2.11.1'
from edgetpu.basic.basic_engine import BasicEngine
# from edgetpu.utils import image_processing

In [2]:
class VideoStream:
    """Camera object that controls video streaming from the Picamera"""
    def __init__(self,resolution=(640,480),framerate=30):
        # Initialize the PiCamera and the camera image stream
        self.stream = cv2.VideoCapture(0)
        ret = self.stream.set(cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'MJPG'))
        ret = self.stream.set(3,resolution[0])
        ret = self.stream.set(4,resolution[1])
            
        # Read first frame from the stream
        (self.grabbed, self.frame) = self.stream.read()

	# Variable to control when the camera is stopped
        self.stopped = False

    def start(self):
	# Start the thread that reads frames from the video stream
        Thread(target=self.update,args=()).start()
        return self

    def update(self):
        # Keep looping indefinitely until the thread is stopped
        while True:
            # If the camera is stopped, stop the thread
            if self.stopped:
                # Close camera resources
                self.stream.release()
                return

            # Otherwise, grab the next frame from the stream
            (self.grabbed, self.frame) = self.stream.read()

    def read(self):
	# Return the most recent frame
        return self.frame

    def stop(self):
	# Indicate that the camera and thread should be stopped
        self.stopped = True

In [3]:
MODEL_NAME = "pose_TFLite_model"
LABELMAP_NAME = 'labelmap.txt'

#18 fps
resW, resH = '1280x720'.split('x')
#60 fps
# resW, resH = '640x480'.split('x')
#110 fps
# resW, resH = '480x352'.split('x')

imW, imH = int(resW), int(resH)
min_thresh = 0.2
use_TPU = False
_mirror = False

In [4]:
def draw_arrows(frame):
    """Show the direction vector output in the cv2 window"""
    #cv2.putText(frame,"Color:", (0, 35), cv2.FONT_HERSHEY_SIMPLEX, 1, 255, thickness=2)
    cv2.arrowedLine(frame, (int(imW/2), int(imH/2)),
                    (int(imW/2) + xoff, int(imH/2) - yoff),
                    (0, 0, 255), 1)
    return frame

def DetectPosesInImage(img):
    """Detects poses in a given image.

       For ideal results make sure the image fed to this function is close to the
       expected input size - it is the caller's responsibility to resize the
       image accordingly.

    Args:
      img: numpy array containing image
    """

    # Extend or crop the input to match the input shape of the network.
    if img.shape[0] < image_height or img.shape[1] < image_width:
        img = np.pad(img, [[0, max(0, image_height - img.shape[0])],
                           [0, max(0, image_width - img.shape[1])], [0, 0]],
                     mode='constant')
    img = img[0:image_height, 0:image_width]
    assert (img.shape == tuple(_input_tensor_shape[1:]))

    # Run the inference (API expects the data to be flattened)
    return img

def ParseOutput(output=None,KEYPOINTS=None):
    global kpt_dict
    inference_time, output = output
    outputs = [output[i:j] for i, j in zip(_output_offsets, _output_offsets[1:])]
    keypoints = outputs[0].reshape(-1, len(KEYPOINTS), 2)
    keypoints = outputs[0].reshape(-1, len(KEYPOINTS), 2)
    keypoint_scores = outputs[1].reshape(-1, len(KEYPOINTS))
    pose_scores = outputs[2]
    nposes = int(outputs[3][0])
    pose_scores

    # Convert the poses to a friendlier format of keypoints with associated
    # scores.
    poses = []
#     for pose_i in range(nposes):
    pose_i = 0
#     print(pose_i)
    kpt_dict = {}
    for point_i, point in enumerate(keypoints[pose_i]):
        keypoint = Keypoint(KEYPOINTS[point_i], list(point),
                            keypoint_scores[pose_i, point_i])
        if _mirror: keypoint.yx[1] = image_width - keypoint.yx[1]
        kpt_dict[KEYPOINTS[point_i]] = keypoint
        
    return kpt_dict

def Keypoint(k,yx,score):

    return [k,yx,score]

In [5]:
# print in terminal
import sys
sys.stdout = open('/dev/stdout', 'w')

In [6]:
# Get path to current working directory
CWD_PATH = os.getcwd()

# Path to .tflite file, which contains the model that is used for object detection
# PATH_TO_CKPT = os.path.join(CWD_PATH,MODEL_NAME,GRAPH_NAME)

# Path to label map file
PATH_TO_LABELS = os.path.join(CWD_PATH,MODEL_NAME,LABELMAP_NAME)

# Load the label map
with open(PATH_TO_LABELS, 'r') as f:
    labels = [line.strip() for line in f.readlines()]

# Initialize frame rate calculation
frame_rate_calc = 1
freq = cv2.getTickFrequency()

xoff, yoff = 0,0
# Initialize video stream
videostream = VideoStream(resolution=(imW,imH),framerate=30).start()
time.sleep(1)


model = "models/posenet_mobilenet_v1_075_{}_{}_quant_decoder_edgetpu.tflite".format(int(resH)+1, int(resW)+1)
KEYPOINTS = labels
engine = BasicEngine(model)#PoseEngine(model)
input_shape = engine.get_input_tensor_shape()
inference_size = (input_shape[2], input_shape[1])

_input_tensor_shape = engine.get_input_tensor_shape()
if (_input_tensor_shape.size != 4 or
        _input_tensor_shape[3] != 3 or
        _input_tensor_shape[0] != 1):
    raise ValueError(
        ('Image model should have input shape [1, height, width, 3]!'
         ' This model has {}.'.format(self._input_tensor_shape)))
_, image_height, image_width, image_depth = engine.get_input_tensor_shape()
# The API returns all the output tensors flattened and concatenated. We
# have to figure out the boundaries from the tensor shapes & sizes.
offset = 0
print("offset",offset)
_output_offsets = [0]
for size in engine.get_all_output_tensors_sizes():
    offset += size
    _output_offsets.append(offset)
    
while True:

    # Start timer (for calculating frame rate)
    t1 = cv2.getTickCount()
    
    # Grab frame from video stream
    frame = videostream.read()
    width, height = 640,480
    frame_resized = cv2.resize(frame, (width, height))
    frame_resized = DetectPosesInImage(frame_resized)
    keypoint_dict = ParseOutput(engine.run_inference(frame_resized.flatten()),KEYPOINTS)
    # 5 draw keypoints
    for idx,label in enumerate(labels):
        if keypoint_dict[label][-1] >= min_thresh:
            x = round((keypoint_dict[label][1][1]))
            y = round((keypoint_dict[label][1][0]))
            x = round((keypoint_dict[label][1][1]/width)*imW)
            y = round((keypoint_dict[label][1][0]/height)*imH)
            if 'right' in labels[idx]:
                cv2.circle(frame,(int(x),int(y)), 5, (0,255,0), -1)
            elif 'left' in labels[idx]:
                cv2.circle(frame,(int(x),int(y)), 5, (0,0,255), -1)
            else:
                xoff, yoff = int(x-(imW/2)),int((imH/2)-y)
                cv2.circle(frame,(int(x),int(y)), 5, (255,0,0), -1)
    

    draw_arrows(frame)
    # Draw framerate in corner of frame
    cv2.putText(frame,
                'FPS: {0:.2f}'.format(frame_rate_calc),
                (imW-200,30),
                cv2.FONT_HERSHEY_SIMPLEX,
                1,
                (255,255,0),
                1,
                cv2.LINE_AA)

    
    distance = 100
    cmd = ""
    # print(yoff)
    tracking = True
    if tracking:
        if xoff < -distance and xoff>-imW/2:
            cmd = "counter_clockwise"
        elif xoff > distance and xoff<imW/2:
            cmd = "clockwise"
        elif yoff < -distance and yoff>-imH/2:
            print("DOWNNNNN",yoff)
            cmd = "down"
        elif yoff > distance and yoff<imH/2:
            print("UPPPPPPPPPPPPPPP",yoff)
            # cmd = "up"
        elif xoff==0 and yoff == 0:
            print("ignore")
    
    # All the results have been drawn on the frame, so it's time to display it.
    cv2.imshow('posenet', frame)
                    
    # Calculate framerate
    t2 = cv2.getTickCount()
    time1 = (t2-t1)/freq
    frame_rate_calc= 1/time1

    # Press 'q' to quit
    if cv2.waitKey(1) == ord('q'):
        cv2.destroyAllWindows()
        break