In [None]:
######### SETUP CODE #########
IN_COLAB = False
import os
import pickle
import matplotlib.pyplot as plt
import imageio
import numpy as np
import time
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Conv2D, UpSampling2D, MaxPool2D
from tensorflow.keras import Model
import random
from qmindcolors import cstr
print("TensorFlow version:", tf.__version__)
#NOTE: Good resource. -> https://www.tensorflow.org/tutorials/quickstart/advanced
import cv2 # opencv, for image resizing.
!pip install chumpy 
# NOTE(Noah): Stole this function from Stackoverflow :)
def rgb2gray(rgb):
    return np.expand_dims(np.dot(rgb[...,:3], [0.2989, 0.5870, 0.1140]), axis=2)
def resize(img, size):
    return cv2.resize(img, dsize=(size, size), interpolation=cv2.INTER_CUBIC)

# DATA LOADING

In [2]:
def download_image(path):
  image = imageio.imread(path)
  _image = image.astype('float32')
  if GRAYSCALE:
      _image = rgb2gray(_image / 255)
  else:
      _image = _image / 255
  _image = resize(_image, IMAGE_SIZE)
  return _image

gcs_path = os.path.join("..", "SH_RHD")
train_list = os.listdir(os.path.join(gcs_path, "training/color"))
eval_list = os.listdir(os.path.join(gcs_path, "evaluation/color"))

# Setup some params.
IMAGE_SIZE = 224
GRAYSCALE = False
IMAGE_CHANNELS = 1 if GRAYSCALE else 3
BATCH_SIZE = 32

# Numpy "buckets" that we will use to load things in.
x_train = np.zeros((BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS))
x_test = np.zeros((BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS))


In [None]:
anno_train_path = os.path.join("data", "anno", "anno_training.pickle") if IN_COLAB else \
    os.path.join("..", "RHD_small", "training", "anno_training.pickle")
anno_eval_path = os.path.join("data", "anno", "anno_evaluation.pickle") if IN_COLAB else \
    os.path.join("..", "RHD_small", "evaluation", "anno_evaluation.pickle")

# NOTE: We note that the numbers 41258 and 2728 were retrieved directly from
# https://lmb.informatik.uni-freiburg.de/resources/datasets/RenderedHandposeDataset.en.html
TRAIN_TOTAL_COUNT = 41258
EVALUATION_TOTAL_COUNT = 2728

y_train = np.zeros((TRAIN_TOTAL_COUNT, 21, 3))
y_test = np.zeros((EVALUATION_TOTAL_COUNT, 21, 3))

def load_anno(path, y):
  anno_all = []
  count = 0
  with open(path, 'rb') as f:
    anno_all = pickle.load(f)
  for key, value in anno_all.items():
    kp_visible = (value['uv_vis'][:, 2] == 1)
    case1 = np.sum(kp_visible[0:21])
    case2 = np.sum(kp_visible[21:])
    leftHand = case1 > 0
    # NOTE: We note here that we are not checking if this training or evaluation example is valid.
    # i.e. we want to densely store the annotations.
    if(not leftHand):
        y[count, :, :] = value['xyz'][21:42]
    else: 
        y[count, :, :] = value['xyz'][:21]
    count += 1

print("Loading in training annotations")
time_start = time.time()
load_anno(anno_train_path, y_train)
time_end = time.time()
print(cstr("Training annotations loaded in {} s".format(time_end - time_start)))
print("Loading in evaluation annotations")
time_start = time.time()
load_anno(anno_eval_path, y_test)
time_end = time.time()
print(cstr("Evaluation annotations loaded in {} s".format(time_end - time_start)))

# MODEL LOADING

In [None]:
# TODO(Noah): Get the MANO folders hosted in GCS so that this works again.
#   We note that this cost was tested and is in full working order, so 
#   the only thing not working is the lack of existence of MANO_DIR. 

MANO_DIR = os.path.join("data", "mano_v1_2") if IN_COLAB else os.path.join("..", "mano_v1_2")

from mobilehand import MAKE_MOBILE_HAND
from mobilehand_lfuncs import LOSS_3D

MOBILE_HAND = MAKE_MOBILE_HAND(IMAGE_SIZE, IMAGE_CHANNELS, BATCH_SIZE, MANO_DIR)

# INTEGRATION TEST
input_test = tf.random.uniform(shape = (BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, IMAGE_CHANNELS))
input_test = tf.cast(input_test, tf.float32)
output_test = MOBILE_HAND(input_test)
print(cstr("output_test ="), output_test)

# The lower training loop assumes that the model is set as such.
model = MOBILE_HAND

# The lower training loop also assumes that we have the loss function set like so.
loss_fn = lambda pred, gt : LOSS_3D(pred,gt) 

# TRAINING

In [None]:
class StupidSimpleLossMetric():
    def __init__(self):
        self.losses = [] # empty python array 
    def __call__(self, loss):
        self.losses.append(loss)
    def result(self):
        return sum(self.losses) / len(self.losses)
    def reset_states(self):
        self.losses = []

optimizer = tf.keras.optimizers.Adam() # defaults should work just fine
train_loss = StupidSimpleLossMetric()
test_loss = StupidSimpleLossMetric()

# Loss function unit test
input = tf.zeros([1, 21,3])  # mock pred of all zeros
label = np.expand_dims(y_train[0], axis=0)
loss = loss_fn(input, label) 
print('Loss for pred of all zeros', loss.numpy())
#loss2 = loss_fn(label, label)
#print('Loss for perfect prediction', loss2.numpy())
input2 = tf.ones([1, 21, 3])
loss3 = loss_fn(input2, label)
print('Loss for pred of all ones', loss3.numpy())

@tf.function
def train_step(input, gt):
    with tf.GradientTape() as tape:
        mesh, keypoints = model(input)
        #loss = loss_func(predictions, segmentation_masks)
        #loss = np.dot(tf.reshape(segmentation_masks, [102400], tf.reshape(predictions, [102400])
        loss = loss_fn(keypoints, gt)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss
    #train_accuracy(labels, predictions)
  
@tf.function
def test_step(images, labels):
  # training=False is only needed if there are layers with different
  # behavior during training versus inference (e.g. Dropout).
  mesh, keypoints = model(images, training=False)
  return loss_fn(keypoints, labels)
  #test_accuracy(labels, predictions)

checkpoint_path = os.path.join("data", "checkpoints") if IN_COLAB else os.path.join("..", "checkpoints/")

if not IN_COLAB:
  import open3d as o3d
  import open3d.visualization.rendering as rendering
  from mano_layer import MANO_Model

  # ckpt_index is an index for the current checkpoint that the model param is loaded
  # with weights.  
  def render_checkpoint_image(ckpt_index, model, eval_image, annot, template_override=False):

      # Step 1 is to use the eval_image in a forward pass w/ the model to generate a chkpt_image.
      T_posed, keypoints3D = model(np.repeat(np.expand_dims(eval_image, 0), 32, axis=0))

      render = rendering.OffscreenRenderer(1080, 1080)
      # TODO(Noah): Reloading MANO here is sort of redundant. We should expose the manu params on the
      # model or something like that.
      mano_dir = os.path.join("..", "mano_v1_2")
      mpi_model = MANO_Model(mano_dir) 

      #print(cstr("shoot!"))
      #print(cstr("T_posed"), T_posed)
      #print(cstr("keypoints3D"), keypoints3D)  

      if template_override:
        #print(cstr("template_override!"))
        batch_size = 1
        beta = tf.zeros([batch_size, 10])
        pose = tf.repeat(tf.constant([[
            [1.57/2,0,0], # Root
            [0,1.57/2,0], # 
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0],
            [0,0,0]
        ]], dtype=tf.float32), repeats=[batch_size], axis=0)

        #T_posed, keypoints3D = mpi_model(beta, pose, 
        #    tf.constant([[0,0,annot[0][2]]]))
        T_posed, keypoints3D = mpi_model(beta, pose, 
            tf.constant([[0,0,0]]))

      #print(cstr("T_posed"), T_posed)  
      #rint(cstr("keypoints3D"), keypoints3D)

      green = rendering.MaterialRecord()
      green.base_color = [0.0, 0.5, 0.0, 1.0]
      green.shader = "defaultLit"
      red = rendering.MaterialRecord()
      red.base_color = [0.5, 0.0, 0.0, 1.0]
      red.shader = "defaultLit"  
      yellow = rendering.MaterialRecord()
      yellow.base_color = [1.0, 0.75, 0.0, 1.0]
      yellow.shader = "defaultLit"  

      # k_y_batched = np.repeat(np.expand_dims(annot, axis=0), 21, axis=0)
      # build the lines
      lines = [ ]
      for i in range(1, 21):
          lines.append( 
              [
                  i, 
                  mpi_model.RHD_K.numpy()[i]
              ]
          )
      colors = [[1, 0, 0] for i in range(len(lines))]
      line_set = o3d.geometry.LineSet()
      line_set.points = o3d.utility.Vector3dVector(annot)
      line_set.lines = o3d.utility.Vector2iVector(lines)
      line_set.colors = o3d.utility.Vector3dVector(colors)
      render.scene.add_geometry("line_set_RHD", line_set, red)
      i = 0
      for i in range(21):
          keypoint = annot[i]
          msphere = o3d.geometry.TriangleMesh.create_sphere(0.005)
          # msphere.paint_uniform_color([0.75, 1 - (j+1) / 5, (j+1) / 5])
          msphere.compute_vertex_normals()
          msphere.translate(keypoint)
          render.scene.add_geometry("sphere_RHD{}".format(i), msphere, yellow)

      # [bs, 16, 3]
      keypoints3D_pylist = tf.unstack( keypoints3D, axis=1 )
      i = 0
      for keypoint in keypoints3D_pylist:
          keypoint = keypoint[0, :]
          msphere = o3d.geometry.TriangleMesh.create_sphere(0.005)
          msphere.paint_uniform_color([0, 0.75, 0])
          msphere.compute_vertex_normals()
          msphere.translate(keypoint.numpy())   
          render.scene.add_geometry("sphere{}".format(i), msphere, green)
          i += 1
      # build the lines
      lines = [ ]
      for i in range(1, 21):
          lines.append( 
              [
                  i, 
                  mpi_model.RHD_K.numpy()[i]
              ]
          )
      colors = [[1, 0, 0] for i in range(len(lines))]
      line_set = o3d.geometry.LineSet()
      line_set.points = o3d.utility.Vector3dVector(keypoints3D[0, :, :].numpy())
      line_set.lines = o3d.utility.Vector2iVector(lines)
      line_set.colors = o3d.utility.Vector3dVector(colors)
      render.scene.add_geometry("line_set", line_set, red)

      # add the MANO mesh as well.
      T_posed_scaled = T_posed
      mesh = o3d.geometry.TriangleMesh()
      mesh.triangles = o3d.utility.Vector3iVector(mpi_model.F.numpy())
      mesh.vertices = o3d.utility.Vector3dVector(T_posed_scaled[0, :, :].numpy()) 
      mesh.compute_vertex_normals()
      pcd = mesh.sample_points_uniformly(number_of_points=1000)
      render.scene.add_geometry("pcd", pcd, red)  
      
      #cyl = o3d.geometry.TriangleMesh.create_cylinder(.05, 3)
      #cyl.compute_vertex_normals()
      #cyl.translate([-2, 0, 1.5])
      #render.scene.add_geometry("cyl", cyl, green)
        
      center = [0,0, annot[0][2]] # select the root annotation location
      z_dist = 0.5  

      # if we want to understand the params of the setup_camera func, it goes like this.
      # (FOV, center, eye, up)
      # when we setup the camera like we are doing here, it's the same as the openGL gl.lookAt()
      # So for the explanation of the center, eye, and up params, see here 
      # https://stackoverflow.com/questions/21830340/understanding-glmlookat  
      render.setup_camera(60.0, center, [center[0], center[1], center[2]-z_dist], [0, -1, 0])
      render.scene.scene.set_sun_light([0.707, 0.0, -.707], [1.0, 1.0, 1.0],
                                      75000)
      render.scene.scene.enable_sun_light(True)
      render.scene.show_axes(False)

      img1 = render.render_to_image()
      img_filepath = os.path.join(checkpoint_path, "image-{:0d}.png".format(ckpt_index))
      print(cstr("Saving image(s) at"), img_filepath)
      #o3d.io.write_image(img_filepath, img, 9)

      # TODO(Noah): Make it so that we render all of these images onto one big image in a grid
      # for ease of viewing. AND, we can include the original image (the one passed into the model) 
      # into the grid.
      # below, we are going to grab TWO more viewpoints of the same image
      #img_filepath = os.path.join(checkpoint_path, "image-left-{:0d}.png".format(ckpt_index))
      render.setup_camera(60.0, center, [center[0]-z_dist * 0.7071, center[1], center[2]-z_dist * 0.7071], [0, -1, 0])  
      img2 = render.render_to_image()
      #o3d.io.write_image(img_filepath, img, 9)
      #img_filepath = os.path.join(checkpoint_path, "image-right-{:0d}.png".format(ckpt_index))
      render.setup_camera(60.0, center, [center[0]+z_dist * 0.7071, center[1], center[2]-z_dist * 0.7071], [0, -1, 0])  
      img3 = render.render_to_image()  
      #o3d.io.write_image(img_filepath, img, 9)
      
      w = 10
      h = 10
      fig = plt.figure(figsize=(15, 10))
      columns = 2
      rows = 1
      fig.add_subplot(rows, columns, 1)
      plt.imshow(eval_image)
      #fig.add_subplot(rows, columns, 2)
      #plt.imshow(img2)
      fig.add_subplot(rows, columns, 2)
      plt.imshow(img1)
      #fig.add_subplot(rows, columns, 4)
      #plt.imshow(img3)

      plt.savefig(img_filepath)


      #send_email(img_filepath)
      #print(cstr("Sent email of"), img_filepath)

In [None]:
!mkdir $checkpoint_path

last_checkpoint = -1
if (last_checkpoint > -1):
  file_path = os.path.join(checkpoint_path, "cp-{:04d}.ckpt".format(last_checkpoint))
  model.load_weights(file_path)
  print(cstr("Loaded weights from {}".format(file_path)))

# load the crap
y = np.zeros([BATCH_SIZE, 21, 3], dtype=np.float32)
for j in range(BATCH_SIZE):
  filename = train_list[3]
  train_image = download_image(os.path.join(gcs_path, "training", "color", filename))
  x_train[j,:,:,:] = train_image
  y_index = int(filename[0:5])
  y[j, :, :] = y_train[y_index]

EPOCHS = 10 # sure...

for epoch in range(EPOCHS):
  # Reset the metrics at the start of the next epoch
  print("Begin epoch", epoch)
  start = time.time()
  train_loss.reset_states()
  test_loss.reset_states()
  
  x_train = x_train.astype('float32')
  loss = train_step(x_train, y)
  train_loss(loss.numpy())
 
  '''
  for i in range(1):
    for j in range(BATCH_SIZE):
      filename = eval_list[j + i * BATCH_SIZE]
      eval_image = download_image(os.path.join(gcs_path, "evaluation", "color", filename))
      x_test[j,:,:,:] = eval_image
      y_index = int(filename[0:5])
      y[j, :, :] = y_test[y_index]
    x_test = x_test.astype('float32')
    loss_test = test_step(x_test, y)
    test_loss(loss.numpy())
  '''

  end = time.time()

  print(
    f'Epoch {epoch}, '
    f'Time {end-start} s'
    f'Loss: {train_loss.result()}, '
    #f'Test Loss: {test_loss.result()}, '
  )

  # Save the model parameters
  #if (epoch % 5 == 0) or (epoch == EPOCHS - 1):
  checkpoint_filepath = os.path.join(checkpoint_path, "cp-{:04d}.ckpt".format(epoch))
  model.save_weights(checkpoint_filepath)
  print(cstr("Saved weights to {}".format(checkpoint_filepath)))

  if not IN_COLAB:
    # Run the model on image 19 of the evaluation images.
    test_img = 26
    eval_image = download_image(os.path.join(gcs_path, "training", "color", "000{}.png".format(test_img)))
    eval_image = eval_image.astype('float32')
    render_checkpoint_image(epoch, model, eval_image, y_train[test_img])