[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/akwasnie/SiMa-GUT/blob/main/Realtime_object_tracking.ipynb)

In [None]:
# In Jupyter, you would need to install TF 2 via !pip.
%tensorflow_version 2.x

In [None]:
!pip install onnxruntime

In [None]:
import json
import os
import sys
from base64 import b64decode, b64encode

import cv2
import IPython
import numpy as np
import onnxruntime
import torch
from google.colab import output
from google.colab.output import eval_js
from google.colab.patches import cv2_imshow
from IPython.display import Image, Javascript, clear_output, display
from numpy import asarray
from PIL import Image as pimage

In [None]:
!git clone https://github.com/akwasnie/CenterNet.git && cd CenterNet && git checkout 8ef87b433529ac8f8bd4f95707f6bc05052c55e9

In [None]:
sys.path.append('CenterNet')
from CenterNet import src
from src.lib.utils.image import get_affine_transform
from src.lib.models.decode import ctdet_decode
from src.lib.utils.post_process import ctdet_post_process

In [None]:
DATA_DIR = 'data/'
MODEL_FILE = os.path.join(DATA_DIR, 'ctdet_coco_dlav0_1x.onnx')
CLASS_FILE = os.path.join(DATA_DIR, 'coco.json')
os.makedirs(DATA_DIR, exist_ok=True)
if not os.path.exists(MODEL_FILE):
    !gdown --id 1Nda64Ezeo1yObABzDpJNzr-3n1yn_VOa -O $MODEL_FILE
else:
    print('CSV file ({}) already exists.'.format(MODEL_FILE))

if not os.path.exists(CLASS_FILE):
    !gdown --id 1ddXo-vbPNvNRs4X-faUkcwBtaWSP7ZtO -O $CLASS_FILE
else:
    print('CSV file ({}) already exists.'.format(CLASS_FILE))

In [None]:
IN_SIZE = [512,512]
MEAN = [0.408, 0.447, 0.470]
STD = [0.289, 0.274, 0.278]
MAX_BB_NUM = 10
THRESHOLD = 0.25
FILENAME = 'photo.jpg'

In [None]:
def preprocess(image_path):
  image = cv2.imread(image_path)
  frame = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  h, w = frame.shape[:2]
  c = [np.array([w/2., h/2.], dtype=np.float32)]
  s = [max(h, w) * 1.0]
  trans_input = get_affine_transform(c[0], s[0], 0, IN_SIZE)
  im_data = cv2.warpAffine(
      frame, trans_input,
      (IN_SIZE[0], IN_SIZE[1]),
      flags=cv2.INTER_LINEAR,
  )
  im_data = ((im_data / 255. - MEAN) / STD).astype(np.float32)
  im_data = im_data.transpose(2, 0, 1).reshape(1, 3, IN_SIZE[0], IN_SIZE[1])
  return frame, im_data, c, s

def import_data(model_path, class_mapping_path):
  ort_session = onnxruntime.InferenceSession(model_path)

  with open(class_mapping_path, 'r') as fp:
    category_index = json.load(fp)
  return ort_session, category_index

def centernet2cocodict(out, c, m, category_index):
    hm = torch.from_numpy(out[0])
    wh = torch.from_numpy(out[2])
    reg = torch.from_numpy(out[1])
    hm = hm.sigmoid_()
    dets = ctdet_decode(hm, wh, reg=reg, cat_spec_wh=False, K=MAX_BB_NUM)
    dets = dets.detach().cpu().numpy()
    dets = ctdet_post_process(dets.copy(), c, m, 128, 128, 80)[0]

    results = {}  
    for j in range(1, 80 + 1):
        results[j] = np.array(dets[j], dtype=np.float32).reshape(-1, 5)
    
    scores = np.hstack([results[j][:, 4] for j in range(1, 80 + 1)])

    out_dict = {
        'num_detections': MAX_BB_NUM,
        'detection_boxes': [],
        'detection_scores': [],
        'detection_classes': [],
        }

    cats = list(category_index.values())
    cats.sort(key=lambda x: x['id'])

    for j in range(1, 80 + 1):
        for bbox in results[j]:
            if bbox[4] > THRESHOLD:
                out_dict['detection_boxes'].append(bbox[:4])
                out_dict['detection_scores'].append(bbox[4])
                out_dict['detection_classes'].append(cats[j-1]['id'])

    out_dict['detection_boxes'] = np.asarray(out_dict['detection_boxes'])
    out_dict['detection_scores'] = np.asarray(out_dict['detection_scores'])
    out_dict['detection_classes'] = np.asarray(out_dict['detection_classes'])

    return out_dict

def run_inference(ort_session, category_index, image_data, c, s):
  output = ort_session.run(None, {ort_session.get_inputs()[0].name: image_data})
  output_dict = centernet2cocodict(output, c, s, category_index)
  return output_dict

ort_session, category_index = import_data(MODEL_FILE, CLASS_FILE)

In [None]:
!wget -O $FILENAME https://m.media-amazon.com/images/I/71lkKY9oGWL._AC_SX450_.jpg

In [None]:
frame, image_data, c, s = preprocess(FILENAME)
ort_session, category_index = import_data(MODEL_FILE, CLASS_FILE)
output_dict = run_inference(ort_session, category_index, image_data, c, s)
print(output_dict)
for i in range(len(output_dict['detection_boxes'])):
  bbox = output_dict['detection_boxes'][i]
  class_name = category_index[str(output_dict['detection_classes'][i])]['name']
  cv2.rectangle(frame,
                (int(bbox[0]), int(bbox[1])),
                (int(bbox[2]), int(bbox[3])),
                (0,255,0),
                2)
  cv2.putText(frame,
              class_name,
              (int(bbox[0]), int(bbox[1]-10)),
              0,
              0.45,
              (255,0,0),
              0)
result = np.asarray(frame)
result = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
cv2_imshow(result)

In [None]:

def ndarray_to_b64(ndarray):
    # Converts a np ndarray to a b64 string readable by html-img tags. 
    img = cv2.cvtColor(ndarray, cv2.COLOR_RGB2BGR)
    _, buffer = cv2.imencode('.png', img)
    return b64encode(buffer).decode('utf-8')

def predict(img_64):
  binary = b64decode(img_64.split(',')[1])
  with open('photo.jpg', 'wb') as f:
    f.write(binary)

  frame, image_data, c, s = preprocess('photo.jpg')
  out_dict = run_inference(ort_session, category_index, image_data, c, s)
  print(out_dict)

  for i in range(len(out_dict['detection_boxes'])):
    bbox = out_dict['detection_boxes'][i]
    class_name = category_index[str(out_dict['detection_classes'][i])]['name']
    cv2.rectangle(frame,
                  (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])),
                  (0, 255, 0),
                  2)
    cv2.putText(frame,
                class_name,
                (int(bbox[0]), int(bbox[1]-10)),
                0,
                0.45,
                (0, 0, 0),
                0)
  data = np.asarray(frame)
  result = 'data:image/jpeg;base64,' + ndarray_to_b64(data)
  return IPython.display.JSON({'result': result})

output.register_callback('amld.predict', predict)

In [None]:
%%html
<meta name="viewport" content="width=device-width, initial-scale=1, user-scalable=no">
<button id="start">start</button><button id="clear">clear</button><br />
<canvas width="320" height="180" id="canvas" style="border:1px solid black"></canvas><br />
<video id="myVideo" width="320" height="180"></video><br />
<image id="image"></image>
<script>
  let canvas = document.getElementById('canvas')
  let output = document.getElementById('output')
  let ctx = canvas.getContext('2d')
  let img_64
  let dragging = false
  let timeout
  let stream 
  let video = document.getElementById('myVideo')


  let predict = () => {
    google.colab.kernel.invokeFunction('amld.predict', [img_64], {}).then(
        obj => document.getElementById('image').src = obj.data['application/json'].result)
  }

  async function startvideo(){
    stream = await navigator.mediaDevices.getUserMedia({ video: true, audio: false })
    video.srcObject = stream
    await video.play();
  }
  
  async function sendforprediction() {
    ctx.drawImage(video, 0, 0, canvas.width, canvas.height)
   	img_64 = canvas.toDataURL('image/jpeg', 0.9)
    clearTimeout(timeout)
    timeout = setTimeout(predict, 500)
  }

  const handler = e => {
    sendforprediction()
  }
  canvas.addEventListener('touchstart', e => {dragging=true; handler(e)})
  canvas.addEventListener('touchmove', e => {e.preventDefault(); dragging && handler(e)})
  canvas.addEventListener('touchend', () => dragging=false)
  canvas.addEventListener('mousedown', e => {dragging=true; handler(e)})
  canvas.addEventListener('mousemove', e => {dragging && handler(e)})
  canvas.addEventListener('mouseup', () => dragging=false)
  canvas.addEventListener('mouseleave', () => dragging=false)
  document.getElementById('clear').addEventListener('click', () => {
    ctx.fillStyle = 'white'
    ctx.fillRect(0, 0, 320, 180)
    stream.getTracks().forEach(function(track) {
        track.stop();
    });
    clear_output();
    video.srcObject = null;
  })
  document.getElementById('start').addEventListener('click', () => {
    startvideo()
  })
</script>