# Movenet Training
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/LorBordin/bodynet/blob/master/dev_utils/training.ipynb)

## 1. Setup
## 1.1 Prepare environment

In [None]:
!git clone https://github.com/LorBordin/bodynet
!pip install -q -U -r bodynet/requirements.txt

In [None]:
# add bodynet to path
import sys
sys.path.append("./bodynet")

# Connect with drive to load the dataset
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
import os

DS_PATH='gdrive/MyDrive/bodynet_ds/tfrecords.zip'

# Load the data from gdrive
if not os.path.isdir("/content/dataset"):
  !unzip $DS_PATH -d /content/

## 1.2 Training settings 

In [None]:
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from tensorflow import keras
from imutils import paths
import tensorflow as tf
import numpy as np
import cv2
import os

from bodypose.training.metrics import avgMDE_2D, avgMDE_2D_RAW, Accuracy
from bodypose.training.metrics import RegressionLoss2D, AuxiliaryLoss  
from bodypose.training.architecture.custom_layers import get_max_mask
from bodypose.training.preprocessing import load_TFRecords_dataset
from bodypose.training.preprocessing import augmentations
from bodypose.training.architecture import MoveNet

from bodypose.demo.graphics import draw_keypoints

import config as cfg

In [7]:
INPUT_SHAPE = (224, 224, 3)
STRIDES = (32, 16, 8, 4)
NUM_KPTS = len(cfg.MPII_KEYPOINT_DICT)

GRID_SIZE = INPUT_SHAPE[0] // STRIDES[-1]

if not os.path.exists("saved_models"):  
  os.mkdir("saved_models")
  
MODEL_PATH = f"./saved_models/movenet_{INPUT_SHAPE[0]}.models"

## 2. Load the Dataset

In [8]:
augs = [
    augmentations.VerticalShift(max_shift_range=.15),
    augmentations.HorizontalShift(max_shift_range=.15),
    augmentations.HorizontalFlip(probability=.5, keypoints_idxs=cfg.MPII_KEYPOINT_IDXS)
    ]

In [None]:
train_paths = list(paths.list_files("./tfrecords/mpii/train/"))
valid_paths = list(paths.list_files("./tfrecords/mpii/validation/"))
np.random.shuffle(train_paths)
np.random.shuffle(valid_paths)

train_ds = load_TFRecords_dataset(
    filePaths=train_paths, 
    batch_size = 32,
    target_size = INPUT_SHAPE[:2],
    grid_dim = GRID_SIZE,
    augmentations = augs,
    roi_thresh = 0.9
    )

val_ds = load_TFRecords_dataset(
    filePaths=valid_paths, 
    batch_size = 32,
    target_size = INPUT_SHAPE[:2],
    grid_dim = GRID_SIZE,
    augmentations = [],
    roi_thresh = 0.9
    )

In [None]:
for img, (y1, y2) in train_ds.take(1):
    print(img.shape)
    print(y1.shape)
    print(y2.shape)

## 3. Build the model

In [10]:
model = MoveNet(
    input_shape = INPUT_SHAPE, 
    strides = STRIDES, 
    num_joints = NUM_KPTS, 
    alpha = .5, 
    use_depthwise = True
    )

img = (np.random.uniform(
    0, 255, (1,) + INPUT_SHAPE
    ).astype("uint8") / 255).astype(np.float32)

%timeit model(img)


126 ms ± 15 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
adam = keras.optimizers.Adam()
moving_avg_opt = tfa.optimizers.MovingAverage(adam)

def scheduler(epoch, lr):
    return  lr
    #if epoch < 10:
    #    return lr
    #else:
    #    return lr * tf.math.exp(-0.1)

callbacks = [
    keras.callbacks.ModelCheckpoint(
        MODEL_PATH,
        monitor = "output_1_val_avgMDE_2D",
        save_best_only = True,
        save_weights_only = True,
        initial_value_threshold=None,
        ),
    keras.callbacks.LearningRateScheduler(
        scheduler
        ),
    tfa.callbacks.AverageModelCheckpoint(
        filepath=MODEL_PATH, 
        update_weights=True
        )
]

model.compile(
    optimizer = moving_avg_opt,
    loss = {'output_1': RegressionLoss2D, 'output_2': AuxiliaryLoss},
    loss_weights = [1., 1.],
    metrics = {'output_1': [Accuracy, avgMDE_2D_RAW, avgMDE_2D]},
)

## 4. Model Training

In [None]:
model.fit(train_ds, validation_data=val_ds, epochs=10, callbacks=callbacks)

In [None]:
%cp -r ./saved_models/* /content/gdrive/MyDrive/bodynet_ds/

## 5. Model evaluation

In [None]:
for img, (y1, y2) in val_ds.take(1):
    print(img.shape)
    print(y1.shape)
    print(y2.shape)

In [None]:
i = 2
for i in range(10):
    pred_img = ((img[i].numpy() + 1) * 127.5).astype('uint8').copy()
    true_img = ((img[i].numpy() + 1) * 127.5).astype('uint8').copy()

    preds, heatmaps = model.predict(img[i:i+1])
    preds = preds[0, :, :3]
    preds = preds[:, [1,2,0]]

    # Heatmaps
    kptsmask = get_max_mask(heatmaps.reshape(-1, GRID_SIZE, GRID_SIZE, NUM_KPTS+1))
    kptsmap = kptsmask[0, :, :, 1:].numpy().sum(axis=-1)
    kptsmap = cv2.resize(kptsmap, INPUT_SHAPE[:2])

    centremap = heatmaps[0, :, 0].reshape(GRID_SIZE, GRID_SIZE)
    centremap = cv2.resize(centremap, INPUT_SHAPE[:2])

    # Weighted heatmaps
    #w_kptsmask = get_max_mask(w_heatmaps.reshape(-1, GRID_SIZE, GRID_SIZE, NUM_KPTS+1))
    #w_kptsmap = w_kptsmask[0, :, :, 1:].numpy().sum(axis=-1)
    #w_kptsmap = cv2.resize(w_kptsmap, INPUT_SHAPE[:2])

    #w_centremap = w_heatmaps[0, :, 0].reshape(56, 56)
    #w_centremap = cv2.resize(w_centremap, INPUT_SHAPE[:2])


    labels = y1[i, :, :3].numpy()
    labels = labels[:, [1,2,0]]

    pred_img = draw_keypoints(pred_img, preds, .5, cfg.MPII_KEYPOINT_DICT)
    true_img = draw_keypoints(true_img, labels, .5, cfg.MPII_KEYPOINT_DICT)

    fig, axs = plt.subplots(2, 2)
    fig.set_figheight(20)
    fig.set_figwidth(20)

    axs[0,0].imshow(true_img)
    axs[0,0].axis('off')
    axs[0,0].set_title("Original")

    axs[0,1].imshow(pred_img)
    axs[0,1].axis('off')
    axs[0,1].set_title("Predicted")
    
    axs[1,0].imshow(pred_img, alpha=.5)
    axs[1,0].imshow(centremap,  alpha=.5)
    axs[1,0].axis('off')
    axs[1,0].set_title("Centremap")
    
    axs[1,1].imshow(pred_img, alpha=.5)
    axs[1,1].imshow(kptsmap,  alpha=.5)
    axs[1,1].axis('off')
    axs[1,1].set_title("Keypointsmap")
    
    #axs[2,0].imshow(pred_img, alpha=.5)
    #axs[2,0].imshow(w_centremap,  alpha=.5)
    #axs[2,0].axis('off')
    #axs[2,0].set_title("Weighted Centremap")
    
    #axs[2,1].imshow(pred_img, alpha=.5)
    #axs[2,1].imshow(w_kptsmap,  alpha=.5)
    #axs[2,1].axis('off')
    #axs[2,1].set_title("Weighted Keypointsmap")