In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, sys, argparse
import tensorflow.keras.backend as K
#from tensorflow.keras.utils import multi_gpu_model
from tensorflow.keras.callbacks import TensorBoard, TerminateOnNaN

from hourglass.model import get_hourglass_model
from hourglass.data import hourglass_dataset
from hourglass.loss import get_loss
from hourglass.callbacks import EvalCallBack, CheckpointCleanCallBack, EvalCallBackNew
from common.utils import get_classes, get_matchpoints, get_model_type, optimize_tf_gpu
from common.model_utils import get_optimizer

# Try to enable Auto Mixed Precision on TF 2.0
# os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
# os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'] = '1'
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
# optimize_tf_gpu(tf, K)

In [2]:
# Arguments parsed from command line can be set here
# Model definition options:
num_stacks=2
mobile=True
tiny=False
model_input_shape="256x256"
weights_path=None

# Data options
dataset_path="data/mpii"
classes_path="configs/mpii_classes.txt"
matchpoint_path="configs/mpii_match_point.txt"

# Training options
batch_size=8
optimizer="RMSProp"
loss_type="mse"
learning_rate=5e-4
decay_type=None
mixed_precision=False
init_epoch=0
total_epoch=100
gpu_num=1

height, width = model_input_shape.split('x')
model_input_shape = (int(height), int(width)) 
orig_img_shape = (1280, 720)    # Height and width in pixels of input images, can be read from file or manually specified

In [3]:
output_shape = (int(model_input_shape[0]/4), int(model_input_shape[1]/4))
output_shape

(64, 64)

In [4]:
# def main(args):
log_dir = 'logs/000'
os.makedirs(log_dir, exist_ok=True)

class_names = get_classes(classes_path)
num_classes = len(class_names)
if matchpoint_path:
    matchpoints = get_matchpoints(matchpoint_path)
else:
    matchpoints = None

# choose model type
if tiny:
    num_channels = 128
else:
    num_channels = 256

if mixed_precision:
    tf_major_version = float(tf.__version__[:3])
    if tf_major_version >= 2.1:
        # apply mixed_precision for valid TF version
        from tensorflow.keras.mixed_precision import experimental as mixed_precision

        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_policy(policy)
    else:
        raise ValueError('Tensorflow {} does not support mixed precision'.format(tf.__version__))


In [5]:
# get train/val dataset
train_generator = hourglass_dataset(dataset_path, batch_size, class_names,
                                    input_shape=model_input_shape,
                                    num_hgstack=num_stacks,
                                    is_train=True,
                                    with_meta=False,
                                    matchpoints=matchpoints)

num_train = train_generator.get_dataset_size()
num_val = len(train_generator.get_val_annotations())

model_type = get_model_type(num_stacks, mobile, tiny, model_input_shape)


In [6]:
import numpy as np
import json

def dataset_from_annotations(annotations, image_path, validation_set=False):

    image_filenames = []
    centers = []
    keypoints = []
    scales = []
    
    for annotation in annotations:
        image_filename = os.path.join(image_path, annotation['img_paths'])
        center = np.array(annotation['objpos'])
        keypoint = np.array(annotation['joint_self'])
        scale = annotation['scale_provided']
        
        # adjust center/scale slightly to avoid cropping limbs

        if center[0] != -1:
            center[1] = center[1] + 15 * scale
            scale = scale * 1.25
        
        if annotation['isValidation'] == validation_set:        
            image_filenames.append(image_filename)
            centers.append(center)
            keypoints.append(keypoint)
            scales.append(scale)
        else:
            pass
    img_filenames = tf.convert_to_tensor(image_filenames)
    img_centers = tf.convert_to_tensor(centers, dtype=tf.float32)
    img_scales = tf.convert_to_tensor(scales, dtype=tf.float32)
    img_keypoints = tf.convert_to_tensor(keypoints, dtype=tf.float32)
    return image_filenames, centers, scales, keypoints
  
                 
            
    
        
# dataset with tf.data.Dataset and training using model.fit()
json_file = "data/mpii/annotations.json"
image_path = "data/mpii/images/"
with open(json_file) as f:
    annotations = json.load(f)
    
    
    
    

In [7]:
img_filenames, img_centers, img_scales, img_keypoints = dataset_from_annotations(annotations, image_path,validation_set=False)   
tfdataset_train= tf.data.Dataset.from_tensor_slices((img_filenames, img_centers, img_scales, img_keypoints))

img_filenames, img_centers, img_scales, img_keypoints = dataset_from_annotations(annotations, image_path,validation_set=True)      
tfdataset_val= tf.data.Dataset.from_tensor_slices((img_filenames, img_centers, img_scales, img_keypoints))


In [8]:
from common.data_utils import generate_gt_heatmap, label_heatmap
import inspect

# converted_f = tf.autograph.to_graph(label_heatmap.python_function)
# print(inspect.getsource(converted_f))
# img_keypointsT[0].shape
# generate_gt_heatmap(img_keypoints[0], (64,64))

In [9]:
from common.data_utils import random_horizontal_flip, random_vertical_flip, random_brightness
from common.data_utils import random_grayscale, random_chroma, random_contrast, random_sharpness, random_blur, random_histeq, random_rotate_angle
from common.data_utils import crop_single_object, rotate_single_object, crop_image, normalize_image, transform_keypoints, generate_gt_heatmap
from PIL import Image


def map_dataset_to_image_heatmaps(imagefile, center, scale, keypoints):
            
    img = tf.io.read_file(imagefile)
    decoded_img = tf.io.decode_png(img, channels=3, dtype=tf.dtypes.uint8)
#     orig_img_shape = decoded_img.shape
    resized_img = tf.image.resize(decoded_img, model_input_shape)
    image = resized_img
    
#     image = tf.expand_dims(resized_img, axis=0)


#     img = tf.io.read_file(imagefile)
#     if img.mode != 'RGB':
#         img = img.convert('RGB')
#     image = np.array(img)
#     img.close()
    
#     image_shape = image.shape

    rotate_angle = 0
#     image = crop_image(image, center, scale, model_input_shape, rotate_angle)
    
    # transform keypoints to cropped image reference
#     transformed_keypoints = transform_keypoints(keypoints, center, scale, output_shape, rotate_angle)

        # in case we got an empty image, bypass the sample
#     if image is None:
#         return None, None, None
    
    # normalize image
#     image = normalize_image(image, self.get_color_mean())



    # Data Augmentation
#     image, keypoints = crop_single_object(image, keypoints, center, scale, model_input_shape)
    seed = tf.random.uniform(shape=[2], maxval=3, dtype=tf.int32)
    image = tf.image.stateless_random_brightness(image, max_delta=0.95, seed=seed)
    image = tf.image.stateless_random_contrast(image, lower=0.1, upper=0.9, seed=seed)
    image = tf.image.stateless_random_hue(image, 0.2, seed)
    image = tf.image.stateless_random_jpeg_quality(image, 75, 95, seed)
    image = tf.image.stateless_random_saturation(image, lower=0.5, upper=1.0, seed=seed)


    # Rescale keypoints from original_img_shape to output_shape
    resized_scale = tf.divide(orig_img_shape, model_input_shape)
    resized_scale = tf.concat([resized_scale, [1]], 0)  # adding third dimension for visibility dimension in the keypoints 
    keypoints = tf.multiply(keypoints, resized_scale)

    # generate ground truth keypoint heatmap
    gt_heatmap = generate_gt_heatmap(keypoints, output_shape)

    out_heatmaps = []
    for m in range(num_stacks):
        out_heatmaps.append(gt_heatmap)
        
    return (image, tf.stack(out_heatmaps, axis=-1))


def map_dataset_to_image_heatmaps_val(imagefile, center, scale, keypoints):
            
    img = tf.io.read_file(imagefile)
    decoded_img = tf.io.decode_png(img, channels=3)
#     orig_img_shape = decoded_img.shape
    resized_img = tf.image.resize(decoded_img, model_input_shape)
    image = resized_img
    
#     image = tf.expand_dims(resized_img, axis=0)


#     img = tf.io.read_file(imagefile)
#     if img.mode != 'RGB':
#         img = img.convert('RGB')
#     image = np.array(img)
#     img.close()
    
#     image_shape = image.shape

    rotate_angle = 0
#     image = crop_image(image, center, scale, model_input_shape, rotate_angle)
    
    # transform keypoints to cropped image reference
#     transformed_keypoints = transform_keypoints(keypoints, center, scale, output_shape, rotate_angle)

        # in case we got an empty image, bypass the sample
#     if image is None:
#         return None, None, None
    
    # normalize image
#     image = normalize_image(image, self.get_color_mean())



    # Data Augmentation
#     image, keypoints = crop_single_object(image, keypoints, center, scale, model_input_shape)
#     seed = tf.random.uniform(shape=[2], maxval=3, dtype=tf.int32)
#     image = tf.image.stateless_random_brightness(image, max_delta=0.95, seed=seed)
#     image = tf.image.stateless_random_contrast(image, lower=0.1, upper=0.9, seed=seed)
#     image = tf.image.stateless_random_hue(image, 0.2, seed)
#     image = tf.image.stateless_random_jpeg_quality(image, 75, 95, seed)
#     image = tf.image.stateless_random_saturation(image, lower=0.5, upper=1.0, seed=seed)


    # Rescale keypoints from original_img_shape to output_shape
    resized_scale = tf.divide(orig_img_shape, output_shape)
    resized_scale = tf.concat([resized_scale, [1]], 0)  # adding third dimension for visibility dimension in the keypoints 
    keypoints = tf.cast(tf.multiply(keypoints, resized_scale), dtype=tf.float32)

    # generate ground truth keypoint heatmap
#     gt_heatmap = generate_gt_heatmap(keypoints, output_shape)

#     out_heatmaps = []
#     for m in range(num_stacks):
#         out_heatmaps.append(gt_heatmap)
            
#     return (image, tf.stack(out_heatmaps, axis=-1))

    return (image, keypoints)




In [10]:
AUTOTUNE = tf.data.AUTOTUNE
tfdataset_mapped_train = tfdataset_train.map(map_dataset_to_image_heatmaps, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)
tfdataset_mapped_val = tfdataset_val.map(map_dataset_to_image_heatmaps_val, num_parallel_calls=AUTOTUNE).batch(batch_size, drop_remainder=True).prefetch(AUTOTUNE)

[<tf.Tensor 'cond/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond/Cast_3:0' shape=() dtype=int32>]
[<tf.Tensor 'cond_1/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond_1/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond_1/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond_1/Cast_3:0' shape=() dtype=int32>]
[<tf.Tensor 'cond_2/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond_2/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond_2/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond_2/Cast_3:0' shape=() dtype=int32>]
[<tf.Tensor 'cond_3/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond_3/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond_3/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond_3/Cast_3:0' shape=() dtype=int32>]
[<tf.Tensor 'cond_4/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond_4/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond_4/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond_4/C

In [11]:
# callbacks for training process
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_grads=False, write_images=False, update_freq='batch')
# eval_callback = EvalCallBack(log_dir, dataset_path, class_names, model_input_shape, model_type)
eval_callback = EvalCallBackNew(log_dir, tfdataset_mapped_val, class_names, model_input_shape, model_type)
checkpoint_clean = CheckpointCleanCallBack(log_dir, max_val_keep=5)
terminate_on_nan = TerminateOnNaN()

callbacks = [tensorboard, eval_callback, terminate_on_nan, checkpoint_clean]
# callbacks = [tensorboard, terminate_on_nan, checkpoint_clean]

# prepare optimizer
steps_per_epoch = max(1, num_train//batch_size)
decay_steps = steps_per_epoch * (total_epoch - init_epoch)
optimizer = get_optimizer(optimizer, learning_rate, decay_type=decay_type, decay_steps=decay_steps)
#optimizer = RMSprop(lr=5e-4)

# prepare loss function
loss_func = get_loss(loss_type)


In [12]:
# support multi-gpu training
if gpu_num >= 2:
    # devices_list=["/gpu:0", "/gpu:1"]
    devices_list=["/gpu:{}".format(n) for n in range(gpu_num)]
    strategy = tf.distribute.MirroredStrategy(devices=devices_list)
    print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
    with strategy.scope():
        # get multi-gpu train model. you can also use "model_input_shape=None" to create a dynamic input shape model,
        # but multiscale train/inference doesn't work for it
        model = get_hourglass_model(num_classes, num_stacks, num_channels, model_input_shape=model_input_shape, mobile=mobile)
        # compile model
        model.compile(optimizer=optimizer, loss=loss_func)
else:
    # get normal train model. you can also use "model_input_shape=None" to create a dynamic input shape model,
    # but multiscale train/inference doesn't work for it
    model = get_hourglass_model(num_classes, num_stacks, num_channels, model_input_shape=model_input_shape, mobile=mobile)
    # compile model
    model.compile(optimizer=optimizer, loss=loss_func)

print('Create {} Stacked Hourglass model with stack number {}, channel number {}. train input shape {}'.format('Mobile' if mobile else '', num_stacks, num_channels, model_input_shape))
model.summary()



Create Mobile Stacked Hourglass model with stack number 2, channel number 256. train input shape (256, 256)
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 image_input (InputLayer)       [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 front_conv_1x1_1 (Conv2D)      (None, 128, 128, 64  9472        ['image_input[0][0]']            
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 128, 128, 64  256        ['front_conv_1x1_1

                                                                                                  
 hg0_downsample_1_conv_1x1_1 (S  (None, 64, 64, 128)  33152      ['front_residual_3_add[0][0]']   
 eparableConv2D)                                                                                  
                                                                                                  
 batch_normalization_10 (BatchN  (None, 64, 64, 128)  512        ['hg0_downsample_1_conv_1x1_1[0][
 ormalization)                                                   0]']                             
                                                                                                  
 hg0_downsample_1_conv_3x3_2 (S  (None, 64, 64, 128)  17664      ['batch_normalization_10[0][0]'] 
 eparableConv2D)                                                                                  
                                                                                                  
 batch_nor

 eparableConv2D)                                                                                  
                                                                                                  
 batch_normalization_21 (BatchN  (None, 8, 8, 256)   1024        ['hg0_downsample_8_conv_1x1_3[0][
 ormalization)                                                   0]']                             
                                                                                                  
 hg0_downsample_8_add (Add)     (None, 8, 8, 256)    0           ['max_pooling2d_3[0][0]',        
                                                                  'batch_normalization_21[0][0]'] 
                                                                                                  
 hg0_downsample_f8_1_conv_1x1_1  (None, 8, 8, 128)   33152       ['hg0_downsample_8_add[0][0]']   
  (SeparableConv2D)                                                                               
          

 hg0_downsample_f8_3_conv_1x1_3  (None, 8, 8, 256)   33152       ['batch_normalization_32[0][0]'] 
  (SeparableConv2D)                                                                               
                                                                                                  
 hg0_downsample_f8_short_conv_1  (None, 8, 8, 256)   33152       ['batch_normalization_23[0][0]'] 
 x1_3 (SeparableConv2D)                                                                           
                                                                                                  
 batch_normalization_35 (BatchN  (None, 16, 16, 128)  512        ['hg0_upsample_f4_short_conv_3x3_
 ormalization)                                                   2[0][0]']                        
                                                                                                  
 batch_normalization_33 (BatchN  (None, 8, 8, 256)   1024        ['hg0_downsample_f8_3_conv_1x1_3[
 ormalizat

                                                                 ]']                              
                                                                                                  
 add_2 (Add)                    (None, 32, 32, 256)  0           ['hg0_upsample_f2_short_add[0][0]
                                                                 ',                               
                                                                  'up_sampling2d_1[0][0]']        
                                                                                                  
 hg0_upsample_f2_merged_conv_1x  (None, 32, 32, 128)  33152      ['add_2[0][0]']                  
 1_1 (SeparableConv2D)                                                                            
                                                                                                  
 hg0_upsample_f1_short_conv_1x1  (None, 64, 64, 128)  33152      ['hg0_downsample_1_add[0][0]']   
 _1 (Separ

                                                                                                  
 hg0_conv_1x1_2 (Conv2D)        (None, 64, 64, 256)  65792       ['batch_normalization_52[0][0]'] 
                                                                                                  
 hg0_conv_1x1_3 (Conv2D)        (None, 64, 64, 256)  4352        ['hg0_conv_1x1_predict[0][0]']   
                                                                                                  
 add_4 (Add)                    (None, 64, 64, 256)  0           ['hg0_conv_1x1_2[0][0]',         
                                                                  'hg0_conv_1x1_3[0][0]',         
                                                                  'front_residual_3_add[0][0]']   
                                                                                                  
 hg1_downsample_1_conv_1x1_1 (S  (None, 64, 64, 128)  33152      ['add_4[0][0]']                  
 eparableC

                                                                                                  
 hg1_downsample_8_conv_3x3_2 (S  (None, 8, 8, 128)   17664       ['batch_normalization_62[0][0]'] 
 eparableConv2D)                                                                                  
                                                                                                  
 batch_normalization_63 (BatchN  (None, 8, 8, 128)   512         ['hg1_downsample_8_conv_3x3_2[0][
 ormalization)                                                   0]']                             
                                                                                                  
 hg1_downsample_8_conv_1x1_3 (S  (None, 8, 8, 256)   33152       ['batch_normalization_63[0][0]'] 
 eparableConv2D)                                                                                  
                                                                                                  
 batch_nor

 ormalization)                                                   0][0]']                          
                                                                                                  
 batch_normalization_66 (BatchN  (None, 8, 8, 128)   512         ['hg1_downsample_f8_short_conv_3x
 ormalization)                                                   3_2[0][0]']                      
                                                                                                  
 hg1_upsample_f4_short_conv_3x3  (None, 16, 16, 128)  17664      ['batch_normalization_77[0][0]'] 
 _2 (SeparableConv2D)                                                                             
                                                                                                  
 hg1_downsample_f8_3_conv_1x1_3  (None, 8, 8, 256)   33152       ['batch_normalization_75[0][0]'] 
  (SeparableConv2D)                                                                               
          

                                                                                                  
 hg1_upsample_f4_merged_add (Ad  (None, 16, 16, 256)  0          ['add_6[0][0]',                  
 d)                                                               'batch_normalization_82[0][0]'] 
                                                                                                  
 hg1_upsample_f2_short_add (Add  (None, 32, 32, 256)  0          ['hg1_downsample_2_add[0][0]',   
 )                                                                'batch_normalization_85[0][0]'] 
                                                                                                  
 up_sampling2d_4 (UpSampling2D)  (None, 32, 32, 256)  0          ['hg1_upsample_f4_merged_add[0][0
                                                                 ]']                              
                                                                                                  
 add_7 (Ad

                                                                                                  
 hg1_conv_1x1_1 (Conv2D)        (None, 64, 64, 256)  65792       ['hg1_upsample_f1_merged_add[0][0
                                                                 ]']                              
                                                                                                  
 batch_normalization_95 (BatchN  (None, 64, 64, 256)  1024       ['hg1_conv_1x1_1[0][0]']         
 ormalization)                                                                                    
                                                                                                  
 hg1_conv_1x1_predict (Conv2D)  (None, 64, 64, 16)   4112        ['batch_normalization_95[0][0]'] 
                                                                                                  
 tf.stack (TFOpLambda)          (None, 64, 64, 16,   0           ['hg0_conv_1x1_predict[0][0]',   
          

In [13]:
# if weights_path:
#     model.load_weights(weights_path, by_name=True)#, skip_mismatch=True)
#     print('Load weights {}.'.format(weights_path))

# # start training
# print('Train on {} samples, val on {} samples, with batch size {}, model input shape {}.'.format(num_train, num_val, batch_size, model_input_shape))
# model.fit_generator(generator=train_generator,
#                     steps_per_epoch=num_train // batch_size,
#                     epochs=total_epoch,
#                     initial_epoch=init_epoch,
#                     workers=1,
#                     use_multiprocessing=False,
#                     max_queue_size=10,
                    
#                     callbacks=callbacks)

# model.save(os.path.join(log_dir, 'trained_final.h5'))


In [14]:
data_item = next(iter(tfdataset_mapped_train))
data_item[0][0].shape

TensorShape([256, 256, 3])

In [15]:
# changing GPU device configuration
# import tensorflow as tf
# physical_devices = tf.config.list_physical_devices('GPU')

# tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
#      tf.config.LogicalDeviceConfiguration(memory_limit=100)])
# logical_devices = tf.config.list_logical_devices('GPU')
# logical_devicesphysical_devices = tf.config.list_physical_devices('GPU')
# try:
#   tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
#      tf.config.LogicalDeviceConfiguration(memory_limit=100)])

#   logical_devices = tf.config.list_logical_devices('GPU')
#   assert len(logical_devices) == len(physical_devices) + 1

#   tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=10),
#      tf.config.LogicalDeviceConfiguration(memory_limit=10)])
# except:
#   # Invalid device or cannot modify logical devices once initialized.
#   pass

In [16]:
weights_path=None
weights_path="logs/000/ep035-loss0.001-val_acc0.759.h5"
if weights_path:
    model.load_weights(weights_path, by_name=True)#, skip_mismatch=True)
    print('Load weights {}.'.format(weights_path))

Load weights logs/000/ep035-loss0.001-val_acc0.759.h5.


In [17]:
model.fit(tfdataset_mapped_train, 
#                     validation_data=tfdataset_mapped_val,
#                   steps_per_epoch=num_train // batch_size,
                    epochs=total_epoch,
                    initial_epoch=init_epoch,
                    workers=1,
                    use_multiprocessing=True,
                    max_queue_size=10,                    
                    callbacks=callbacks)


Epoch 1/100


KeyboardInterrupt: 

In [None]:
eval_callback.model = model
eval_callback.on_epoch_end(0)

In [None]:
from eval import hourglass_predict_keras, post_process_heatmap_simple
heatmap = hourglass_predict_keras(model, data_item[0])

In [None]:
heatmap.shape[0:2]

In [None]:
post_process_heatmap_simple(heatmap, 0.01)

In [None]:
_map = heatmap[:, :,:, 0]
np.where(_map == _map.max())

In [None]:
num_train

In [None]:
tfdataset_mapped_val.element_spec[0].shape[0]