In [9]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, sys, argparse
import tensorflow.keras.backend as K
#from tensorflow.keras.utils import multi_gpu_model
from tensorflow.keras.callbacks import TensorBoard, TerminateOnNaN

from hourglass.model import get_hourglass_model
from hourglass.data import hourglass_dataset
from hourglass.loss import get_loss
from hourglass.callbacks import EvalCallBack, CheckpointCleanCallBack, EvalCallBackNew
from common.utils import get_classes, get_matchpoints, get_model_type, optimize_tf_gpu
from common.model_utils import get_optimizer
from common.data_utils import generate_gt_heatmap
# Try to enable Auto Mixed Precision on TF 2.0
# os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
# os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'] = '1'
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
import numpy as np
# optimize_tf_gpu(tf, K)

In [10]:
# Arguments parsed from command line can be set here
# Model definition options:
num_stacks=2
mobile=True
tiny=True
model_input_shape="256x256"
weights_path=None

# Data options
dataset_path="data/mpii"
classes_path="configs/mpii_classes.txt"
matchpoint_path="configs/mpii_match_point.txt"

# Training options
batch_size=16
optimizer="RMSProp"
loss_type="mse"
learning_rate=5e-4
decay_type=None
mixed_precision=False
init_epoch=0
total_epoch=100
gpu_num=1

height, width = model_input_shape.split('x')
model_input_shape = (int(height), int(width)) 
orig_img_shape = (1280, 720)    # Height and width in pixels of input images, can be read from file or manually specified
output_shape = (int(model_input_shape[0]/4), int(model_input_shape[1]/4))

In [11]:
# def main(args):
log_dir = 'logs/000'
os.makedirs(log_dir, exist_ok=True)

class_names = get_classes(classes_path)
num_classes = len(class_names)
if matchpoint_path:
    matchpoints = get_matchpoints(matchpoint_path)
else:
    matchpoints = None

# choose model type
if tiny:
    num_channels = 128
else:
    num_channels = 256

if mixed_precision:
    tf_major_version = float(tf.__version__[:3])
    if tf_major_version >= 2.1:
        # apply mixed_precision for valid TF version
        from tensorflow.keras.mixed_precision import experimental as mixed_precision

        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_policy(policy)
    else:
        raise ValueError('Tensorflow {} does not support mixed precision'.format(tf.__version__))


In [12]:
# # get train/val dataset
# train_generator = hourglass_dataset(dataset_path, batch_size, class_names,
#                                     input_shape=model_input_shape,
#                                     num_hgstack=num_stacks,
#                                     is_train=True,
#                                     with_meta=False,
#                                     matchpoints=matchpoints)

# num_train = train_generator.get_dataset_size()
# num_val = len(train_generator.get_val_annotations())

# model_type = get_model_type(num_stacks, mobile, tiny, model_input_shape)


In [13]:
import numpy as np
import json
resized_scale = np.divide(output_shape, orig_img_shape)
resized_scale = np.append(resized_scale, [1.0])

def dataset_from_annotations(annotations, image_path, validation_set=False):

    image_filenames = []
    centers = []
    keypoints = []
    scales = []
    
    for annotation in annotations:
        image_filename = os.path.join(image_path, annotation['img_paths'])
        center = np.array(annotation['objpos'])
        keypoint = np.array(annotation['joint_self'])
        scale = annotation['scale_provided']
        
        # adjust center/scale slightly to avoid cropping limbs

        if center[0] != -1:
            center[1] = center[1] + 15 * scale
            scale = scale * 1.25
        
        if annotation['isValidation'] == validation_set:        
            image_filenames.append(image_filename)
            centers.append(center)
            keypoints.append(keypoint*resized_scale)
            scales.append(scale)
        else:
            pass
    img_filenames = tf.convert_to_tensor(image_filenames)
    img_centers = tf.convert_to_tensor(centers, dtype=tf.float32)
    img_scales = tf.convert_to_tensor(scales, dtype=tf.float32)
    img_keypoints = tf.convert_to_tensor(keypoints, dtype=tf.float32)
#     return image_filenames, centers, scales, keypoints
    return image_filenames, keypoints, centers, scales
  
                 
            
    
        
# dataset with tf.data.Dataset and training using model.fit()
json_file = "data/mpii/annotations.json"
image_path = "data/mpii/images/"
with open(json_file) as f:
    annotations = json.load(f)
    
    
    
    

In [14]:
AUTOTUNE = tf.data.AUTOTUNE
img_filenames_train, img_keypoints_train, img_centers_train, img_scales_train = dataset_from_annotations(annotations, image_path,validation_set=False)   
tfdataset_train= tf.data.Dataset.from_tensor_slices((img_filenames_train, img_keypoints_train))

img_filenames_test, img_keypoints_test, img_centers_test, img_scales_test = dataset_from_annotations(annotations, image_path,validation_set=True)      
tfdataset_val= tf.data.Dataset.from_tensor_slices((img_filenames_test, img_keypoints_test))


In [19]:
IS_TRAIN= True

def map_image_open(image_filename):
    image = tf.io.decode_png(tf.io.read_file(image_filename))
    img_size = tf.shape(image)[0:2]
       


    return (tf.image.resize(image, model_input_shape), img_size)

def map_rescale_keypoints(image_data, keypoint_data): 
    # Simply rescale the keypoints by the factor of model_output_shape to original image shape
    scale_factor = tf.divide(output_shape, image_data[1])
    scale_factor = tf.concat([scale_factor, [1]], 0)
    rescaled_keypoints = tf.cast(tf.multiply(keypoint_data, scale_factor), dtype=tf.float32)
    
    gt_heatmap = generate_gt_heatmap(rescaled_keypoints, output_shape)

    out_heatmaps = []
    for m in range(num_stacks):
        out_heatmaps.append(gt_heatmap)
    if IS_TRAIN:
        # Data Augmentation
        #image, keypoints = crop_single_object(image, keypoints, center, scale, model_input_shape)
        seed = tf.random.uniform(shape=[2], maxval=3, dtype=tf.int32)
        image = tf.image.stateless_random_brightness(image_data[0], max_delta=0.95, seed=seed)
        image = tf.image.stateless_random_contrast(image_data[0], lower=0.1, upper=0.9, seed=seed)
        image = tf.image.stateless_random_hue(image_data[0], 0.2, seed)
        image = tf.image.stateless_random_jpeg_quality(image_data[0], 75, 95, seed)
        image = tf.image.stateless_random_saturation(image_data[0], lower=0.5, upper=1.0, seed=seed)

        return (image_data[0], tf.stack(out_heatmaps, axis=-1))   # Return the image tensor and rescaled keypoints
    else:
        return (image_data[0], rescaled_keypoints)

# train_image_dataset = (tf.data.Dataset.from_tensor_slices(img_filenames_train)
#            .map(lambda x: tf.image.resize(tf.io.decode_png(tf.io.read_file(x)), model_input_shape),
#                 num_parallel_calls=AUTOTUNE)
#            .prefetch(AUTOTUNE))

# train_image_dataset = (tf.data.Dataset.from_tensor_slices(img_filenames_train)
#            .map(lambda x: tf.io.decode_png(tf.io.read_file(x)),
#                 num_parallel_calls=AUTOTUNE)
#            .prefetch(AUTOTUNE))


In [20]:
IS_TRAIN = True
train_image_dataset = (tf.data.Dataset.from_tensor_slices(img_filenames_train)
           .map(map_image_open,
                num_parallel_calls=AUTOTUNE)
           .prefetch(AUTOTUNE))
train_keypoints_dataset = (tf.data.Dataset.from_tensor_slices(img_keypoints_train).prefetch(AUTOTUNE))
tfdataset_mapped_train = tf.data.Dataset.zip((train_image_dataset, train_keypoints_dataset)).map(map_rescale_keypoints).batch(batch_size).prefetch(AUTOTUNE)
IS_TRAIN = False
test_image_dataset = (tf.data.Dataset.from_tensor_slices(img_filenames_test)
           .map(map_image_open,
                num_parallel_calls=AUTOTUNE)
           .prefetch(AUTOTUNE))
test_keypoints_dataset = (tf.data.Dataset.from_tensor_slices(img_keypoints_test).prefetch(AUTOTUNE))
tfdataset_mapped_val = tf.data.Dataset.zip((test_image_dataset, test_keypoints_dataset)).map(map_rescale_keypoints).batch(batch_size).prefetch(AUTOTUNE)


In [21]:
# iterator = iter(train_dataset_zip)
# batch = next(iterator)
# for i in range(5):
#     batch = next(iterator)
#     cropped_img = crop_image(batch[0], img_centers_train[i], img_scales_train[i], model_input_shape, 0)

#     print("Image shapes: ", batch[1].shape)


In [22]:
# # Previous version slower mapping functions
# from PIL import Image
   
# # Rescale keypoints from original_img_shape to output_shape
# # resized_scale = tf.divide(output_shape, orig_img_shape)
# # resized_scale = tf.concat([resized_scale, [1]], 0)  # adding third dimension for visibility dimension in the keypoints 


# def map_dataset_to_image_heatmaps(imagefile,  keypoints):
            
#     img = tf.io.read_file(imagefile)
#     decoded_img = tf.io.decode_png(img, channels=3, dtype=tf.dtypes.uint8)
# #     orig_img_shape = decoded_img.shape
#     resized_img = tf.image.resize(decoded_img, model_input_shape)
#     image = resized_img
    
#     # Data Augmentation
# #     image, keypoints = crop_single_object(image, keypoints, center, scale, model_input_shape)
#     seed = tf.random.uniform(shape=[2], maxval=3, dtype=tf.int32)
#     image = tf.image.stateless_random_brightness(image, max_delta=0.95, seed=seed)
#     image = tf.image.stateless_random_contrast(image, lower=0.1, upper=0.9, seed=seed)
#     image = tf.image.stateless_random_hue(image, 0.2, seed)
#     image = tf.image.stateless_random_jpeg_quality(image, 75, 95, seed)
#     image = tf.image.stateless_random_saturation(image, lower=0.5, upper=1.0, seed=seed)


#     # Rescale keypoints from original_img_shape to output_shape -- Done in previous step
# #     resized_scale = tf.divide(orig_img_shape, model_input_shape)
# #     resized_scale = tf.concat([resized_scale, [1]], 0)  # adding third dimension for visibility dimension in the keypoints 
# #     keypoints = tf.multiply(keypoints, resized_scale)

#     # generate ground truth keypoint heatmap
#     gt_heatmap = generate_gt_heatmap(keypoints, output_shape)

#     out_heatmaps = []
#     for m in range(num_stacks):
#         out_heatmaps.append(gt_heatmap)
        
#     return (image, tf.stack(out_heatmaps, axis=-1))


# def map_dataset_to_image_heatmaps_val(imagefile,  keypoints):
            
#     img = tf.io.read_file(imagefile)
#     decoded_img = tf.io.decode_png(img, channels=3)
#     resized_img = tf.image.resize(decoded_img, model_input_shape)
#     image = resized_img
    




 
#     # Rescale keypoints from original_img_shape to output_shape   - Done in previous step already
# #     resized_scale = tf.divide(orig_img_shape, output_shape)
# #     resized_scale = tf.concat([resized_scale, [1]], 0)  # adding third dimension for visibility dimension in the keypoints 
# #     keypoints = tf.cast(tf.multiply(keypoints, resized_scale), dtype=tf.float32)
# #     keypoints = tf.multiply(keypoints, resized_scale)


#     return (image, keypoints)




In [23]:
# AUTOTUNE = tf.data.AUTOTUNE
# tfdataset_mapped_train = tfdataset_train.map(map_dataset_to_image_heatmaps, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)
# tfdataset_mapped_val = tfdataset_val.map(map_dataset_to_image_heatmaps_val, num_parallel_calls=AUTOTUNE).batch(batch_size, drop_remainder=True).prefetch(AUTOTUNE)

In [24]:
item = next(iter(tfdataset_mapped_val))
item[1].shape

TensorShape([16, 16, 3])

In [25]:
num_train = tfdataset_mapped_train.cardinality()*batch_size
num_val = tfdataset_mapped_val.cardinality()*batch_size

model_type = get_model_type(num_stacks, mobile, tiny, model_input_shape)

In [26]:
# callbacks for training process
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_grads=False, write_images=False, update_freq='batch')
# eval_callback = EvalCallBack(log_dir, dataset_path, class_names, model_input_shape, model_type)
eval_callback = EvalCallBackNew(log_dir, tfdataset_mapped_val, class_names, model_input_shape, model_type)
checkpoint_clean = CheckpointCleanCallBack(log_dir, max_val_keep=5)
terminate_on_nan = TerminateOnNaN()

callbacks = [tensorboard, eval_callback, terminate_on_nan, checkpoint_clean]
# callbacks = [tensorboard, terminate_on_nan, checkpoint_clean]

# prepare optimizer
steps_per_epoch = max(1, num_train//batch_size)
decay_steps = steps_per_epoch * (total_epoch - init_epoch)
optimizer = get_optimizer(optimizer, learning_rate, decay_type=decay_type, decay_steps=decay_steps)
#optimizer = RMSprop(lr=5e-4)

# prepare loss function
loss_func = get_loss(loss_type)


In [27]:
# support multi-gpu training
if gpu_num >= 2:
    # devices_list=["/gpu:0", "/gpu:1"]
    devices_list=["/gpu:{}".format(n) for n in range(gpu_num)]
    strategy = tf.distribute.MirroredStrategy(devices=devices_list)
    print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
    with strategy.scope():
        # get multi-gpu train model. you can also use "model_input_shape=None" to create a dynamic input shape model,
        # but multiscale train/inference doesn't work for it
        model = get_hourglass_model(num_classes, num_stacks, num_channels, model_input_shape=model_input_shape, mobile=mobile)
        # compile model
        model.compile(optimizer=optimizer, loss=loss_func)
else:
    # get normal train model. you can also use "model_input_shape=None" to create a dynamic input shape model,
    # but multiscale train/inference doesn't work for it
    model = get_hourglass_model(num_classes, num_stacks, num_channels, model_input_shape=model_input_shape, mobile=mobile)
    # compile model
    model.compile(optimizer=optimizer, loss=loss_func)

print('Create {} Stacked Hourglass model with stack number {}, channel number {}. train input shape {}'.format('Mobile' if mobile else '', num_stacks, num_channels, model_input_shape))
# model.summary()



Create Mobile Stacked Hourglass model with stack number 2, channel number 128. train input shape (256, 256)


In [28]:
data_item = next(iter(tfdataset_mapped_val))
tf.math.reduce_max(data_item[1][:,:,0:1])

<tf.Tensor: shape=(), dtype=float32, numpy=5.7025642>

In [29]:
# changing GPU device configuration
# import tensorflow as tf
# physical_devices = tf.config.list_physical_devices('GPU')

# tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
#      tf.config.LogicalDeviceConfiguration(memory_limit=100)])
# logical_devices = tf.config.list_logical_devices('GPU')
# logical_devicesphysical_devices = tf.config.list_physical_devices('GPU')
# try:
#   tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
#      tf.config.LogicalDeviceConfiguration(memory_limit=100)])

#   logical_devices = tf.config.list_logical_devices('GPU')
#   assert len(logical_devices) == len(physical_devices) + 1

#   tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=10),
#      tf.config.LogicalDeviceConfiguration(memory_limit=10)])
# except:
#   # Invalid device or cannot modify logical devices once initialized.
#   pass

In [30]:
weights_path=None
# weights_path="logs/000/ep035-loss0.001-val_acc0.759.h5"
if weights_path:
    model.load_weights(weights_path, by_name=True)#, skip_mismatch=True)
    print('Load weights {}.'.format(weights_path))

In [None]:
model.fit(tfdataset_mapped_train, 
#                     validation_data=tfdataset_mapped_val,
#                   steps_per_epoch=num_train // batch_size,
                    epochs=total_epoch,
                    initial_epoch=init_epoch,
#                     workers=6,
#                     use_multiprocessing=True,
#                     max_queue_size=10,                    
                    callbacks=callbacks
         )


Epoch 1/100
  44/1391 [..............................] - ETA: 10:19 - loss: 0.5167

In [None]:
eval_callback.model = model
eval_callback.on_epoch_end(0)

In [None]:
from eval import hourglass_predict_keras, post_process_heatmap_simple
heatmap = hourglass_predict_keras(model, data_item[0])

In [None]:
heatmap.shape[0:2]

In [None]:
post_process_heatmap_simple(heatmap, 0.01)

In [None]:
_map = heatmap[:, :,:, 0]
np.where(_map == _map.max())

In [None]:
num_train

In [None]:
tfdataset_mapped_val.element_spec[0].shape