# Movenet Training
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LorBordin/bodynet/blob/heatmaps/dev_utils/training.ipynb)

## 1. Setup
## 1.1 Prepare environment

In [None]:
!git clone --branch heatmaps https://github.com/LorBordin/bodynet.git
!pip install -q -U -r bodynet/requirements.txt

In [None]:
# add bodynet to path
import sys
sys.path.append("./bodynet")

# Connect with drive to load the dataset
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import os

DS_PATH='gdrive/MyDrive/bodynet_ds/tfrecords.zip'

# Load the data from gdrive
if not os.path.isdir("/content/dataset"):
  !unzip $DS_PATH -d /content/

## 1.2 Training settings 

In [None]:
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from tensorflow import keras
from imutils import paths
import matplotlib as mpl
import tensorflow as tf
import numpy as np
import cv2
import os

from bodypose.training.architecture.custom_layers import get_max_mask
from bodypose.training.preprocessing import load_TFRecords_dataset
from bodypose.training.preprocessing import augmentations
from bodypose.training.metrics import ClassificationLoss, RegrCoordsLoss, RegrCoordsLossRaw
from bodypose.training.metrics import Accuracy, avgMDE_2D, avgMDE_2D_Raw
from bodypose.training.architecture import MoveNet

from bodypose.demo.graphics import draw_keypoints

import config as cfg

In [7]:
INPUT_SHAPE = (224, 224, 3)
STRIDES = (32, 16, 8, 4)
NUM_KPTS = len(cfg.MPII_KEYPOINT_DICT)

GRID_SIZE = INPUT_SHAPE[0] // STRIDES[-1]

if not os.path.exists("saved_models"):  
  os.mkdir("saved_models")
  
MODEL_PATH = f"./saved_models/movenet_{INPUT_SHAPE[0]}.models"

## 2. Load the Dataset

In [8]:
augs = [
    augmentations.VerticalShift(max_shift_range=.15),
    augmentations.HorizontalShift(max_shift_range=.15),
    augmentations.HorizontalFlip(probability=.5, keypoints_idxs=cfg.MPII_KEYPOINT_IDXS)
    ]

In [None]:
train_paths = list(paths.list_files("./tfrecords/mpii/train/"))
valid_paths = list(paths.list_files("./tfrecords/mpii/validation/"))
np.random.shuffle(train_paths)
np.random.shuffle(valid_paths)

train_ds = load_TFRecords_dataset(
    filePaths=train_paths, 
    batch_size = 32,
    target_size = INPUT_SHAPE[:2],
    grid_dim = GRID_SIZE,
    augmentations = augs,
    roi_thresh = 1.
    )

val_ds = load_TFRecords_dataset(
    filePaths=valid_paths, 
    batch_size = 32,
    target_size = INPUT_SHAPE[:2],
    grid_dim = GRID_SIZE,
    augmentations = [],
    roi_thresh = 1.
    )

In [None]:
for img, (y1, y2) in train_ds.take(1):
    print(img.shape)
    print(y1.shape)
    print(y2.shape)

## 3. Build the model

In [10]:
model = MoveNet(
    input_shape = INPUT_SHAPE, 
    strides = STRIDES, 
    num_joints = NUM_KPTS, 
    alpha = .5, 
    use_depthwise = True,
    use_postproc = True
    )

img = (np.random.uniform(
    0, 255, (1,) + INPUT_SHAPE
    ).astype("uint8") / 255).astype(np.float32)

%timeit model(img)


126 ms ± 15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
adam = keras.optimizers.Adam(5e-5)

def HeatmapWeightingLoss(y_true, y_pred):
  loss = tf.reduce_sum((y_true + 1) * tf.square(y_pred - y_true), axis=[1,2])
  loss = tf.reduce_mean(loss)
  return loss

def heat_mae(y_true, y_pred):
  mae = tf.reduce_mean(tf.abs(y_true - y_pred))
  return mae

def total_loss(y_true, y_pred):
  total_loss = ClassificationLoss(y_true, y_pred) 
  total_loss += 5 * RegrCoordsLoss(y_true, y_pred)
  total_loss += 5 * RegrCoordsLossRaw(y_true, y_pred)
  return total_loss

callbacks = [
    keras.callbacks.ModelCheckpoint(
        MODEL_PATH,
        monitor = "val_output_1_avgMDE_2D",
        save_best_only = True,
        save_weights_only = True,
        initial_value_threshold=None,
        ),
]

model.compile(
    optimizer = adam,
    loss = {'output_1': total_loss,  'output_2': HeatmapWeightingLoss},
    loss_weights = [5, 1],
    metrics = {'output_1': [Accuracy, avgMDE_2D, avgMDE_2D_Raw], 'output_2': heat_mae}
)

In [None]:
model.evaluate(val_ds)

## 4. Model Training

In [None]:
model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks)

In [None]:
model.load_weights("./saved_models/movenet_224.models")
model.evaluate(val_ds)

In [None]:
%cp -r ./saved_models/* /content/gdrive/MyDrive/bodynet_ds/saved_models

## 5. Model evaluation

In [None]:
for img, (y1, y2) in val_ds.take(1):
    print(img.shape)
    print(y1.shape)
    print(y2.shape)

In [None]:
def draw_sample(i):
    sample = ((img[i].numpy() + 1) * 127.5).astype('uint8').copy()

    coords, heatmaps = model(img[i:i+1])
    #coords, heatmaps = postproc(y[i:i+1])
    coords = coords[0].numpy()[:, [1,2,0]]

    sample_pred = draw_keypoints(sample, coords, .01, cfg.MPII_KEYPOINT_DICT)
    sample_orig = draw_keypoints(sample, y1[i].numpy()[:, [1,2,0]], .5, cfg.MPII_KEYPOINT_DICT)

    colormap=cv2.COLORMAP_VIRIDIS
    alpha = .5
    
    # Heatmaps
    kptsmap = heatmaps[0, :, :, 1:].numpy().sum(axis=-1)
    kptsmap = cv2.resize(kptsmap, INPUT_SHAPE[:2])
    kptsmap =(kptsmap * 255).astype('uint8')
    kptsmap = cv2.applyColorMap(kptsmap, colormap)

    kptsmask = get_max_mask(heatmaps[:,:,:,1:])
    kptsmask = kptsmask[0].numpy().sum(axis=-1)
    kptsmask = cv2.resize(kptsmask, INPUT_SHAPE[:2])

    centremap = heatmaps[0, :, :, 0].numpy()
    centremap = cv2.resize(centremap, INPUT_SHAPE[:2])
    centremap =(centremap * 255).astype('uint8')
    centremap = cv2.applyColorMap(centremap, colormap)

    # Labels
    center_label = y2[i, :, :, 0].numpy()
    center_label = cv2.resize(center_label, INPUT_SHAPE[:2])
    center_label =(center_label * 255).astype('uint8')
    center_label = cv2.applyColorMap(center_label, colormap)

    kpts_label = y2[i, :, :, 1:].numpy().sum(axis=-1)
    kpts_label = cv2.resize(kpts_label, INPUT_SHAPE[:2])
    kpts_label =(kpts_label * 255).astype('uint8')
    kpts_label = cv2.applyColorMap(kpts_label, colormap)
    

    fig, axs = plt.subplots(2, 2)
    fig.set_figheight(20)
    fig.set_figwidth(20)
    
    output = cv2.addWeighted(sample_orig, alpha, center_label, 1 - alpha, 0)
    axs[0,0].imshow(output)
    axs[0,0].axis('off')
    axs[0,0].set_title("Label - CentreMap")

    output = cv2.addWeighted(sample_pred, alpha, centremap, 1 - alpha, 0)
    axs[0,1].imshow(output)
    axs[0,1].axis('off')
    axs[0,1].set_title("Predicted - CentreMap")

    output = cv2.addWeighted(sample_orig, alpha, kpts_label, 1 - alpha, 0)
    axs[1,0].imshow(output)
    axs[1,0].axis('off')
    axs[1,0].set_title("Label - KeypointsMap")

    output = cv2.addWeighted(sample_pred, alpha, kptsmap, 1 - alpha, 0)
    axs[1,1].imshow(sample_pred, alpha=.5)
    axs[1,1].imshow(kptsmask, alpha=.5)
    axs[1,1].axis('off')
    axs[1,1].set_title("Predicted - KeypointsMap")

In [None]:
for i  in range(10):
    draw_sample(i)