In [51]:
# https://www.tensorflow.org/lite/guide/python
# https://www.tensorflow.org/lite/guide/inference#load_and_run_a_model_in_python
# https://www.tensorflow.org/lite/models/pose_estimation/overview
# posenet paper: https://www.cv-foundation.org/openaccess/content_iccv_2015/papers/Kendall_PoseNet_A_Convolutional_ICCV_2015_paper.pdf
import tflite_runtime.interpreter as tflite
import tensorflow as tf

import numpy as np
from matplotlib import pyplot as plt
from skimage.util import img_as_float32, img_as_ubyte, img_as_uint
from skimage.transform import resize
import urllib.request
import json 

# pip install tfjs-graph-converter
import tfjs_graph_converter.api as tfjs

from urllib.request import urlretrieve
import requests
from pathlib import Path

import cv2
import os
import glob
from os.path import isfile, join
from tqdm.auto import tqdm, trange

In [50]:
base_uri = 'https://storage.googleapis.com/tfjs-models/savedmodel/posenet/'
mobilenet_uris = ["mobilenet/float/050/group1-shard1of1.bin",
                  "mobilenet/float/050/model-stride16.json",
                  "mobilenet/float/050/model-stride8.json",
                  "mobilenet/float/075/group1-shard1of2.bin",
                  "mobilenet/float/075/group1-shard2of2.bin",
                  "mobilenet/float/075/model-stride16.json",
                  "mobilenet/float/075/model-stride8.json",
                  "mobilenet/float/100/group1-shard1of4.bin",
                  "mobilenet/float/100/group1-shard2of4.bin",
                  "mobilenet/float/100/group1-shard3of4.bin",
                  "mobilenet/float/100/model-stride16.json",
                  "mobilenet/float/100/model-stride8.json"]
resnet_uris = ['resnet50/float/group1-shard11of23.bin',
               'resnet50/float/group1-shard10of23.bin',
               'resnet50/float/group1-shard12of23.bin',
               'resnet50/float/group1-shard13of23.bin',
               'resnet50/float/group1-shard14of23.bin',
               'resnet50/float/group1-shard15of23.bin',
               'resnet50/float/group1-shard16of23.bin',
               'resnet50/float/group1-shard17of23.bin',
               'resnet50/float/group1-shard18of23.bin',
               'resnet50/float/group1-shard19of23.bin',
               'resnet50/float/group1-shard1of23.bin',
               'resnet50/float/group1-shard20of23.bin',
               'resnet50/float/group1-shard21of23.bin',
               'resnet50/float/group1-shard22of23.bin',
               'resnet50/float/group1-shard23of23.bin',
               'resnet50/float/group1-shard2of23.bin',
               'resnet50/float/group1-shard3of23.bin',
               'resnet50/float/group1-shard4of23.bin',
               'resnet50/float/group1-shard5of23.bin',
               'resnet50/float/group1-shard6of23.bin',
               'resnet50/float/group1-shard7of23.bin',
               'resnet50/float/group1-shard8of23.bin',
               'resnet50/float/group1-shard9of23.bin',
               'resnet50/float/model-stride16.json',
               'resnet50/float/model-stride32.json',
               'resnet50/quant1/group1-shard1of6.bin',
               'resnet50/quant1/group1-shard2of6.bin',
               'resnet50/quant1/group1-shard3of6.bin',
               'resnet50/quant1/group1-shard4of6.bin',
               'resnet50/quant1/group1-shard5of6.bin',
               'resnet50/quant1/group1-shard6of6.bin',
               'resnet50/quant1/model-stride16.json',
               'resnet50/quant1/model-stride32.json',
               'resnet50/quant2/group1-shard10of12.bin',
               'resnet50/quant2/group1-shard11of12.bin',
               'resnet50/quant2/group1-shard12of12.bin',
               'resnet50/quant2/group1-shard1of12.bin',
               'resnet50/quant2/group1-shard2of12.bin',
               'resnet50/quant2/group1-shard3of12.bin',
               'resnet50/quant2/group1-shard4of12.bin',
               'resnet50/quant2/group1-shard5of12.bin',
               'resnet50/quant2/group1-shard6of12.bin',
               'resnet50/quant2/group1-shard7of12.bin',
               'resnet50/quant2/group1-shard8of12.bin',
               'resnet50/quant2/group1-shard9of12.bin',
               'resnet50/quant2/model-stride16.json',
               'resnet50/quant2/model-stride32.json']

# resnet
# stride can be 16 or 32

# mobilenet
# stride can be 8, 16 or 32
# multiplier can be "100", "075", or "050"

def get_model(architecture='resnet'):
    file_uris = ''
    if architecture== 'resnet':
        file_uris = resnet_uris
    else:
        file_uris = mobilenet_uris
    for file_uri in file_uris:
        uri = base_uri + file_uri
        save_path = 'json-models/' + "/".join(file_uri.split("/")[:-1])
        Path(save_path).mkdir(parents=True, exist_ok=True)
        urlretrieve(uri, 'json-models/' + file_uri)
        #print(save_path)

get_model('mobilenet')
get_model('resnet')

#!rm -rf saved_model
#tfjs.graph_model_to_saved_model('json-model', 'saved_model')


# https://github.com/octiapp/KerasPersonLab/blob/master/demo.ipynb


In [58]:
!rm -rf saved_model
tfjs.graph_model_to_saved_model('json-model', 'saved_model', compat_mode=True)

INFO:tensorflow:No assets to save.
INFO:tensorflow:No assets to write.
INFO:tensorflow:SavedModel written to: saved_model/saved_model.pb


b'saved_model/saved_model.pb'

In [57]:
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model')
tflite_model = converter.convert()
with open('model.tflite', 'wb') as f:
  f.write(tflite_model)

INFO:tensorflow:Saver not created because there are no variables in the graph to restore
INFO:tensorflow:Saver not created because there are no variables in the graph to restore


In [25]:
#model_file = 'posenet_mobilenet_v1_100_257x257_multi_kpt_stripped.tflite'
model_file = 'posenet_resnet_50_928_672_16_quant_cpu_decoder.tflite'
#model_file = 'posenet_resnet_50_928_672_16_quant_edgetpu_decoder.tflite'
#model_file = 'model.tflite'
img_file = 'elon-musk.jpg'

no_kps = 17

def pose_est(interpreter, img):
    # Get input and output tensors.
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    #print(input_details)

    input_shape = input_details[0]['shape']
    input_data = img 
    input_data = input_data[np.newaxis]
    input_data = img_as_ubyte(resize(input_data, input_shape))
    print(input_data.dtype)
    
    
    interpreter.set_tensor(input_details[0]['index'], input_data)

    interpreter.invoke()
    
    # The function `get_tensor()` returns a copy of the tensor data.
    # Use `tensor()` in order to get a pointer to the tensor.
    heatmap = interpreter.get_tensor(output_details[0]['index'])[0]
    offset = interpreter.get_tensor(output_details[1]['index'])[0]
    
    # Infer keypoint coordinates from heatmap and offset
    coords = np.empty((no_kps,2))
    for i in range(no_kps):
        y = heatmap[:,:,i].max(axis=1).argmax()
        x = heatmap[:,:,i].max(axis=0).argmax()
        coords[i,0] = (y / (heatmap.shape[0]-1)) * input_data[0].shape[0] + offset[y, x, i]
        coords[i,1] = (x / (heatmap.shape[1]-1)) * input_data[0].shape[1] + offset[y, x, i+no_kps]
    coords[:,0] = (coords[:,0] / input_data[0].shape[0]) * img.shape[0]
    coords[:,1] = (coords[:,1] / input_data[0].shape[1]) * img.shape[1]
    return coords


def pose_est_mult(imgs):
    interpreter = tflite.Interpreter(model_path=model_file)
    interpreter.allocate_tensors()
    coords = []
    for i in imgs:
        coords.append(pose_est(interpreter, i))
    
    return np.concatenate(coords)


def plot_img_pose(img, coords, include_oob=False):
    if not include_oob:
        coords[:,0] = coords[:,0].clip(min=0, max=img.shape[0])
        coords[:,1] = coords[:,1].clip(min=0, max=img.shape[1])
    
    fig = plt.figure(figsize=(img.shape[0]/100, img.shape[1]/100))
    plt.imshow(img)
    plt.scatter(coords[:,1], coords[:,0])
    plt.axis('off')
    fig.tight_layout(pad=0)
    plt.margins(0)
    fig.canvas.draw()
    
    image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    image_from_plot = image_from_plot.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    walls = (image_from_plot == 255).all(axis=2).all(axis=0)
    image_from_plot = image_from_plot[:, np.logical_not(walls),:]
    plt.close()
    return image_from_plot

imgs = [plt.imread(img_file)]
coords = pose_est_mult(imgs)

img = plot_img_pose(imgs[0], coords)
plt.imshow(imgs[0])
plt.axis('off')
plt.show()

RuntimeError: Encountered unresolved custom op: PosenetDecoderOp.Node number 89 (PosenetDecoderOp) failed to prepare.


In [3]:
def video2Frames(filename, outpath, fps=2.0):
    vidcap = cv2.VideoCapture(filename)
    sec = 0
    count=1
    
    vidcap.set(cv2.CAP_PROP_POS_MSEC,sec*1000)
    hasFrames,image = vidcap.read()
    if not os.path.isdir(outpath):
        os.mkdir(outpath)
    files = glob.glob(outpath+'*')
    for f in files:
        os.remove(f)
    while hasFrames:# and sec <= 30:
        cv2.imwrite(f'{outpath}{count}.jpg', image)  
        
        count += 1
        sec = sec + (1/fps)
        sec = round(sec, 2)
        
        vidcap.set(cv2.CAP_PROP_POS_MSEC,sec*1000)
        hasFrames,image = vidcap.read()

In [None]:
video2Frames('videos/IMG_2139.MOV', 'frames/IMG_2139/', fps=20.0)

In [4]:
def get_frame_num(name):
    name = name.split('/')[-1]
    name = name.split('.')[0]
    return int(name)

def pose_est_frames(framepath, outname, fps=2.0):
    framenames = [framepath+f for f in os.listdir(framepath) if isfile(join(framepath, f))]
    framenames.sort(key=get_frame_num)
    
    frame = cv2.imread(framenames[0]) 
    height, width, layers = frame.shape
    downsize = 0.5
    size = (int(width*downsize), int(height*downsize))
    
    if os.path.isfile(outname):
        os.remove(outname)
    
    fourcc = cv2.VideoWriter_fourcc(*'MJPG')
    out = cv2.VideoWriter(outname, fourcc, fps, size)
    for framename in tqdm(framenames[:1000]):
        frame = cv2.imread(framename)
        # writing to a image array
        img = img_as_float32(frame)
        coords = pose_est_mult([img])
        
        img = plot_img_pose(img, coords, include_oob=False)
        img = cv2.resize(img, size)
        out.write(img)
        #cv2.imshow('frame',img)
        #if cv2.waitKey(1) and 0xFF == ord('q'):
        #    break
        del img, frame
    
    # Release everything if job is finished
    out.release()
    cv2.destroyAllWindows()
    del out

In [6]:
pose_est_frames('frames/IMG_2139/', 'posevideos/IMG_2139.mp4', fps=20.0)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))


