<a href="https://colab.research.google.com/github/Victoooooor/ErgoChairML/blob/main/Demo_Vid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup Environment

In [None]:
#@title
!apt-get install libmagic-dev
!pip install git+https://github.com/Victoooooor/ErgoChairML.git
!pip install -q imageio
!pip install tensorflow-io
!echo "deb http://packages.cloud.google.com/apt gcsfuse-`lsb_release -c -s` main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
!sudo apt-get -y -q update
!sudo apt-get -y -q install gcsfuse

# LOGIN

In [None]:
from google.colab import auth
auth.authenticate_user()

# Load trained model

In [None]:
bucket_name = 'ergo_chair_ml'
!gsutil -m cp -r gs://{bucket_name}/* ./

# Init

In [None]:
from gen_chair import pix2pix

from gen_chair import coco
from gen_chair.gen_multi import Preprocess
from gen_chair import pix2pix

from google.colab import files
import shutil
import os
import tqdm
import cv2
import tensorflow as tf
from PIL import ImageSequence
import PIL
import numpy as np
import tensorflow_io as tfio 

from IPython.display import display, Javascript, Image
from google.colab.output import eval_js
from base64 import b64decode, b64encode

import io
import html
import time

class InferenceConfig(coco.CocoConfig):
  # Set batch size to 1 since we'll be running inference on
  # one image at a time. Batch size = GPU_COUNT * IMAGES_PER_GPU
  GPU_COUNT = 1
  IMAGES_PER_GPU = 1

cpdir_ske = 'skeleton_checkpoints'
cpdir_mask = 'masked_checkpoints'

class colab_ref(object):
  def __init__(self):

    self.prep = Preprocess(InferenceConfig)

    
    self.masked = pix2pix(cpdir_mask)
    self.masked.loadcp()

    
    self.skeleton = pix2pix(cpdir_ske)
    self.skeleton.loadcp()

  def Infer(self, preproc, generator, origin, masked = True):

    pre = preproc(origin)
    if type(pre) is tuple:
        if masked:
          seg = pre[0]
        else:
          seg = pre[1]
    else:
        seg = None
    origin = tf.cast(tf.image.resize_with_pad(origin, 256, 256), dtype=tf.int32)

    if seg is not None:
      seg  = tf.keras.preprocessing.image.img_to_array(seg)

    else:
      seg = origin

    seg = tf.expand_dims(seg, axis=0)
    seg = tf.cast(tf.image.resize_with_pad(seg, 256, 256), dtype=tf.int32)

    gen = generator(seg ,training=True)
    gen = tf.keras.utils.array_to_img(gen[0])

    seg = tf.keras.utils.array_to_img(seg[0])

    origin = tf.keras.utils.array_to_img(origin)
    dst = PIL.Image.new('RGB', (origin.width + seg.width + gen.width, origin.height))
    dst.paste(origin, (0, 0))
    dst.paste(seg, (origin.width, 0))
    dst.paste(gen, (origin.width+seg.width, 0))
    return dst
    
# function to convert the JavaScript object into an OpenCV image
def js_to_image(js_reply):
  """
  Params:
          js_reply: JavaScript object containing image from webcam
  Returns:
          img: OpenCV BGR image
  """
  # decode base64 image
  image_bytes = b64decode(js_reply.split(',')[1])
  # convert bytes to numpy array
  jpg_as_np = np.frombuffer(image_bytes, dtype=np.uint8)
  # decode numpy array into OpenCV BGR image
  img = cv2.imdecode(jpg_as_np, flags=1)

  return img

# function to convert OpenCV Rectangle bounding box image into base64 byte string to be overlayed on video stream
def bbox_to_bytes(bbox_array):
  """
  Params:
          bbox_array: Numpy array (pixels) containing rectangle to overlay on video stream.
  Returns:
        bytes: Base64 image byte string
  """
  # convert array into PIL image
  bbox_PIL = PIL.Image.fromarray(bbox_array, 'RGBA')
  iobuf = io.BytesIO()
  # format bbox into png for return
  bbox_PIL.save(iobuf, format='png')
  # format return string
  bbox_bytes = 'data:image/png;base64,{}'.format((str(b64encode(iobuf.getvalue()), 'utf-8')))

  return bbox_bytes

def take_photo(filename='photo.jpg', quality=0.8, wid = 768, hei = 512):
  js = Javascript('''
    async function takePhoto(quality) {
      const div = document.createElement('div');
      const capture = document.createElement('button');
      capture.textContent = 'Capture';
      div.appendChild(capture);

      const video = document.createElement('video');
      video.style.display = 'block';
      const stream = await navigator.mediaDevices.getUserMedia({video: { width: {ideal: wid}, height: {ideal: hei}, facingMode: 'user'}});
      document.body.appendChild(div);
      div.appendChild(video);
      video.srcObject = stream;
      await video.play();

      // Resize the output to fit the video element.
      google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);

      // Wait for Capture to be clicked.
      await new Promise((resolve) => capture.onclick = resolve);

      const canvas = document.createElement('canvas');
      canvas.width = video.videoWidth;
      canvas.height = video.videoHeight;
      canvas.getContext('2d').drawImage(video, 0, 0);
      stream.getVideoTracks()[0].stop();
      div.remove();
      return canvas.toDataURL('image/jpeg', quality);
    }
    ''')
  display(js)

  # get photo data
  data = eval_js('takePhoto({})'.format(quality))
  # get OpenCV format image
  img = js_to_image(data) 
  cv2.imwrite(filename, img)
  return filename

# JavaScript to properly create our live video stream using our webcam as input
def video_stream():
  js = Javascript('''
    var video;
    var div = null;
    var stream;
    var captureCanvas;
    var imgElement;
    var labelElement;
    
    var pendingResolve = null;
    var shutdown = false;
    
    function removeDom() {
       stream.getVideoTracks()[0].stop();
       video.remove();
       div.remove();
       video = null;
       div = null;
       stream = null;
       imgElement = null;
       captureCanvas = null;
       labelElement = null;
    }
    
    function onAnimationFrame() {
      if (!shutdown) {
        window.requestAnimationFrame(onAnimationFrame);
      }
      if (pendingResolve) {
        var result = "";
        if (!shutdown) {
          captureCanvas.getContext('2d').drawImage(video, 0, 0, 768, 512);
          result = captureCanvas.toDataURL('image/jpeg', 0.8)
        }
        var lp = pendingResolve;
        pendingResolve = null;
        lp(result);
      }
    }
    
    async function createDom() {
      if (div !== null) {
        return stream;
      }

      div = document.createElement('div');
      div.style.border = '2px solid black';
      div.style.padding = '3px';
      div.style.width = 774;
      // div.style.maxWidth = '774px';
      document.body.appendChild(div);
      
      const modelOut = document.createElement('div');
      modelOut.innerHTML = "<span>Status:</span>";
      labelElement = document.createElement('span');
      labelElement.innerText = 'No data';
      labelElement.style.fontWeight = 'bold';
      modelOut.appendChild(labelElement);
      div.appendChild(modelOut);
           
      video = document.createElement('video');
      video.style.display = 'block';
      video.width = 768;
      video.height = 512;
      video.setAttribute('playsinline', '');
      video.onclick = () => { shutdown = true; };
      stream = await navigator.mediaDevices.getUserMedia(
          {video: { facingMode: "environment"}});
      div.appendChild(video);

      imgElement = document.createElement('img');
      // imgElement.style.position = 'absolute';
      // imgElement.style.zIndex = 1;
      imgElement.width = 768;
      imgElement.height = 512;
      imgElement.onclick = () => { shutdown = true; };
      div.appendChild(imgElement);
      
      const instruction = document.createElement('div');
      instruction.innerHTML = 
          '<span style="color: red; font-weight: bold;">' +
          'When finished, click here or on the video to stop this demo</span>';
      div.appendChild(instruction);
      instruction.onclick = () => { shutdown = true; };
      
      video.srcObject = stream;
      await video.play();

      captureCanvas = document.createElement('canvas');
      captureCanvas.width = 768; //video.videoWidth;
      captureCanvas.height = 512; //video.videoHeight;
      window.requestAnimationFrame(onAnimationFrame);
      
      return stream;
    }
    async function stream_frame(label, imgData) {
      if (shutdown) {
        removeDom();
        shutdown = false;
        return '';
      }

      var preCreate = Date.now();
      stream = await createDom();
      
      var preShow = Date.now();
      if (label != "") {
        labelElement.innerHTML = label;
      }
            
      if (imgData != "") {
        var videoRect = video.getClientRects()[0];
        // imgElement.style.top = videoRect.top + "px";
        // imgElement.style.left = videoRect.left + "px";
        // imgElement.style.width = videoRect.width + "px";
        // imgElement.style.height = videoRect.height + "px";
        imgElement.src = imgData;
      }
      
      var preCapture = Date.now();
      var result = await new Promise(function(resolve, reject) {
        pendingResolve = resolve;
      });
      shutdown = false;
      
      return {'create': preShow - preCreate, 
              'show': preCapture - preShow, 
              'capture': Date.now() - preCapture,
              'img': result};
    }
    ''')

  display(js)
  
def video_frame(label, bbox):
  data = eval_js('stream_frame("{}", "{}")'.format(label, bbox))
  return data

# Video Input

In [None]:
import warnings
warnings.filterwarnings('ignore')

cl = colab_ref()
# start streaming video from webcam
video_stream()
# label for video
label_html = 'Capturing...'
# initialze bounding box to empty
bbox = ''
count = 0 
while True:
    js_reply = video_frame(label_html, bbox)
    if not js_reply:
        break

    # convert JS response to OpenCV Image
    img = js_to_image(js_reply["img"])
    img = tf.convert_to_tensor(img)
    img= tfio.experimental.color.bgr_to_rgb(img)
    mframe = cl.Infer(cl.prep.img_seg, cl.masked.generator, img, True)
    sframe = cl.Infer(cl.prep.img_seg, cl.skeleton.generator, img, False)
    # create transparent overlay for bounding box
    bbox_array = np.full([512, 768, 4], 255, dtype=np.uint8)
    bbox_array[:256, :,:3] = mframe
    bbox_array[256:, : , :3] = sframe

    # convert overlay of bbox into bytes
    bbox_bytes = bbox_to_bytes(bbox_array)
    # update bbox so next frame gets new overlay
    bbox = bbox_bytes