>[Downloading Model Data](#scrollTo=m_Vt_gHBOjxd)

>[Configuring TensorFlow](#scrollTo=5IjH5gusOhy9)

>[Webcam Demo](#scrollTo=L95kdZi1QDFw)



# Downloading Model Data

Run these steps first to download the TensorFlow model data.


In [0]:
!git clone https://github.com/tensorflow/models
checkpoint_name = 'ssdlite_mobilenet_v2_coco_2018_05_09'
!wget http://download.tensorflow.org/models/object_detection/{checkpoint_name}.tar.gz
!tar -xf {checkpoint_name}.tar.gz
checkpoint = '{0}.ckpt'.format(checkpoint_name)
!cd /content/models/research && protoc object_detection/protos/*.proto --python_out=.
print('Setup successful!')

# Configuring TensorFlow

These steps start TensorFlow and read the downloaded model data into memory so we can use them.

In [0]:
# NOTE: Pieces taken from https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb
# The github.com/tensorflow/models is distributed under the Apache 2.0 license.

import numpy
import tensorflow as tf

tf.reset_default_graph()
!rm -f /content/logs/*

with tf.get_default_graph().as_default() as graph:  
  summary_writer = tf.summary.FileWriter("/content/logs/", flush_secs=1)

  # This section builds a "graph" in TensorFlow to explain how to process the data.
  jpeg_input_tensor = tf.placeholder(tf.string, ())  # We will provide a JPEG to TF.

  # First, instruct TF to decode the JPEG string into a matrix.
  image = tf.image.decode_image(jpeg_input_tensor)
  image_tensor = tf.expand_dims(image, 0)

  # Load the Mobilenetv2 + SSDLite graph from disk.
  ssdlite_graph = tf.GraphDef()
  with open('{}/frozen_inference_graph.pb'.format(checkpoint_name), 'rb') as f:
    ssdlite_graph.ParseFromString(f.read())
    
  # Tell TensorFlow we would like to inspect these parts of the network.
  output_names = ['num_detections:0', 
                  'detection_boxes:0', 
                  'detection_scores:0',
                  'detection_classes:0']
  ops = dict(zip(output_names, tf.graph_util.import_graph_def(
      ssdlite_graph, 
      input_map={'image_tensor': image_tensor},
      return_elements=output_names)))
    
  # Also extract the decoded image from the network to draw bounding boxes.
  ops['image'] = image
  
  summary_writer.add_graph(graph)
  summary_writer.flush()


def run_detection(sess, img):
  """Run one detection round."""
  output_dict = sess.run(ops, feed_dict={jpeg_input_tensor: img})

  # all outputs are float32 numpy arrays, so convert types as appropriate
  output_dict['num_detections'] = int(output_dict['num_detections:0'][0])
  output_dict['detection_classes'] = output_dict[
      'detection_classes:0'][0].astype(numpy.int64)
  output_dict['detection_boxes'] = output_dict['detection_boxes:0'][0]
  output_dict['detection_scores'] = output_dict['detection_scores:0'][0]
  
  return output_dict

print('Model configured')

## Optional: Visualize the Graph with TensorBoard

In [0]:
%load_ext tensorboard
!mkdir /content/logs
!ps ax | grep tensorboard | awk '{print $1}' | xargs kill
%tensorboard --logdir /content/logs

# Demo

This section creates the video input element and connects it to TensorFlow.

In [0]:
import base64
import html
import io
import time

# setup path so that mobilenet_v2 can be found.
import sys
sys.path.append('/content/models/research')
from object_detection.utils import label_map_util
from object_detection.utils import visualization_utils

# Taken from https://colab.research.google.com/notebooks/snippets/advanced_outputs.ipynb#scrollTo=SucxddsPhOmj
from IPython.display import display, Javascript
from google.colab.output import eval_js
import numpy
import PIL.Image

def start_input():
  js = Javascript('''
    var video;
    var div = null;
    var stream;
    var captureCanvas;
    var imgElement;
    var labelElement;
    
    var pendingResolve = null;
    var shutdown = false;
    
    function removeDom() {
       stream.getVideoTracks()[0].stop();
       video.remove();
       div.remove();
       video = null;
       div = null;
       stream = null;
       imgElement = null;
       captureCanvas = null;
       labelElement = null;
    }
    
    function onAnimationFrame() {
      if (!shutdown) {
        window.requestAnimationFrame(onAnimationFrame);
      }
      if (pendingResolve) {
        var result = "";
        if (!shutdown) {
          captureCanvas.getContext('2d').drawImage(video, 0, 0, 512, 512);
          result = captureCanvas.toDataURL('image/jpeg', 0.8)
        }
        var lp = pendingResolve;
        pendingResolve = null;
        lp(result);
      }
    }
    
    async function createDom() {
      if (div !== null) {
        return stream;
      }

      div = document.createElement('div');
      div.style.border = '2px solid black';
      div.style.padding = '3px';
      div.style.width = '100%';
      div.style.maxWidth = '600px';
      document.body.appendChild(div);
      
      const modelOut = document.createElement('div');
      modelOut.innerHTML = "<span>Status:</span>";
      labelElement = document.createElement('span');
      labelElement.innerText = 'No data';
      labelElement.style.fontWeight = 'bold';
      modelOut.appendChild(labelElement);
      div.appendChild(modelOut);
           
      video = document.createElement('video');
      video.style.display = 'block';
      video.width = div.clientWidth - 6;
      video.setAttribute('playsinline', '');
      video.onclick = () => { shutdown = true; };
      stream = await navigator.mediaDevices.getUserMedia(
          {video: { facingMode: "environment"}});
      div.appendChild(video);

      imgElement = document.createElement('img');
      imgElement.style.position = 'absolute';
      imgElement.style.zIndex = 1;
      imgElement.onclick = () => { shutdown = true; };
      div.appendChild(imgElement);
      
      const instruction = document.createElement('div');
      instruction.innerHTML = 
          '<span style="color: red; font-weight: bold;">' +
          'When finished, click here or on the video to stop this demo</span>';
      div.appendChild(instruction);
      instruction.onclick = () => { shutdown = true; };
      
      video.srcObject = stream;
      await video.play();

      captureCanvas = document.createElement('canvas');
      captureCanvas.width = 512; //video.videoWidth;
      captureCanvas.height = 512; //video.videoHeight;
      window.requestAnimationFrame(onAnimationFrame);
      
      return stream;
    }
    async function takePhoto(label, imgData) {
      if (shutdown) {
        removeDom();
        shutdown = false;
        return '';
      }

      var preCreate = Date.now();
      stream = await createDom();
      
      var preShow = Date.now();
      if (label != "") {
        labelElement.innerHTML = label;
      }
            
      if (imgData != "") {
        var videoRect = video.getClientRects()[0];
        imgElement.style.top = videoRect.top + "px";
        imgElement.style.left = videoRect.left + "px";
        imgElement.style.width = videoRect.width + "px";
        imgElement.style.height = videoRect.height + "px";
        imgElement.src = imgData;
      }
      
      var preCapture = Date.now();
      var result = await new Promise(function(resolve, reject) {
        pendingResolve = resolve;
      });
      shutdown = false;
      
      return {'create': preShow - preCreate, 
              'show': preCapture - preShow, 
              'capture': Date.now() - preCapture,
              'img': result};
    }
    ''')

  display(js)
  
def take_photo(label, img_data):
  data = eval_js('takePhoto("{}", "{}")'.format(label, img_data))
  return data

category_index = label_map_util.create_category_index_from_labelmap(
    '/content/models/research/object_detection/data/mscoco_label_map.pbtxt',
    use_display_name=True)
    
with tf.Session() as sess:
  start_input()

  label_html = 'Capturing...'
  img_data = ''
  while True:
    capture_start = time.time()
    js_reply = take_photo(label_html, img_data)
    capture_end = time.time()
    if not js_reply:
      break

    # Javascript returns a data URL, like:
    #     data: image/jpeg;base64,<base-64 encoded data>
    # To use the image, decode the base-64 encoded part and treat it as a JPEG.
    jpeg_input = base64.b64decode(js_reply['img'].split(',')[1])
    result = run_detection(sess, jpeg_input)
    detect_end = time.time()
    
    # To reduce transfer sizes, we send just the bounding boxes drawn on a 
    # transparent PNG. Here, we create a blank PNG.
    rgb_shape = result['image'].shape
    rgba_shape = list(rgb_shape)[0:2] + [4]
    image_np = numpy.zeros(rgba_shape, dtype=numpy.uint8)
    
    # Draw the bounding boxes in the RGB channels.
    visualization_utils.visualize_boxes_and_labels_on_image_array(
      image_np[:, :, 0:3],  # sub-select RGB channels only; alpha is done below.
      result['detection_boxes'],
      result['detection_classes'],
      result['detection_scores'],
      category_index,
      instance_masks=result.get('detection_masks'),
      use_normalized_coordinates=True,
      line_thickness=8)
    
    # To be visible, the alpha channel also needs to be edited. Set the alpha
    # channel to 255 (fully opaque) wherever anything was drawn.
    image_t = image_np.transpose()
    max_color = numpy.maximum(numpy.maximum(image_t[0], image_t[1]), image_t[2])
    image_t[3] = numpy.clip(max_color, 0, 1) * 255
    viz_end = time.time()

    # Save the image as a PNG in memory and assemble a data URL.
    im = PIL.Image.fromarray(image_np, 'RGBA')
    iobuf = io.BytesIO()
    im.save(iobuf, format='png')
    img_data = 'data:image/png;base64,{}'.format(
      (str(base64.b64encode(iobuf.getvalue()), 'utf-8')))

    perf_measures = {
        'server': (
            ('take_photo', capture_end - capture_start),
            ('run_detection', detect_end - capture_end),
            ('visualize', viz_end - detect_end)
        ),
        'js': (
            ('create', js_reply['create']),
            ('show', js_reply['show']),
            ('capture', js_reply['capture']),
        ),
    }
    
    label_text = 'img size: {}b\ntime:\n  server: {}\n  js: {}'.format(
        len(js_reply['img']),
        ', '.join('{}: {:2.3f}s'.format(*x) for x in perf_measures['server']),
        ', '.join('{}: {:2.3f}s'.format(x[0], x[1] / 1000) for x in perf_measures['js']),
    )
    
    label_html = html.escape(label_text).replace('\n', '<br/>')

print('Finished')