In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, sys, argparse
import tensorflow.keras.backend as K
#from tensorflow.keras.utils import multi_gpu_model
from tensorflow.keras.callbacks import TensorBoard, TerminateOnNaN

from hourglass.model import get_hourglass_model
from hourglass.data import hourglass_dataset
from hourglass.loss import get_loss
from hourglass.callbacks import EvalCallBack, CheckpointCleanCallBack
from common.utils import get_classes, get_matchpoints, get_model_type, optimize_tf_gpu
from common.model_utils import get_optimizer

# Try to enable Auto Mixed Precision on TF 2.0
# os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
# os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'] = '1'
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
# optimize_tf_gpu(tf, K)

In [2]:
# Arguments parsed from command line can be set here
# Model definition options:
num_stacks=2
mobile=False
tiny=False
model_input_shape="256x256"
weights_path=None

# Data options
dataset_path="data/mpii"
classes_path="configs/mpii_classes.txt"
matchpoint_path="configs/mpii_match_point.txt"

# Training options
batch_size=64
optimizer="RMSProp"
loss_type="mse"
learning_rate=5e-4
decay_type=None
mixed_precision=False
init_epoch=0
total_epoch=100
gpu_num=2

height, width = model_input_shape.split('x')
model_input_shape = (int(height), int(width)) 
orig_img_shape = (1280, 720)    # Height and width in pixels of input images, can be read from file or manually specified

In [3]:
# def main(args):
log_dir = 'logs/000'
os.makedirs(log_dir, exist_ok=True)

class_names = get_classes(classes_path)
num_classes = len(class_names)
if matchpoint_path:
    matchpoints = get_matchpoints(matchpoint_path)
else:
    matchpoints = None

# choose model type
if tiny:
    num_channels = 128
else:
    num_channels = 256

if mixed_precision:
    tf_major_version = float(tf.__version__[:3])
    if tf_major_version >= 2.1:
        # apply mixed_precision for valid TF version
        from tensorflow.keras.mixed_precision import experimental as mixed_precision

        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_policy(policy)
    else:
        raise ValueError('Tensorflow {} does not support mixed precision'.format(tf.__version__))


In [4]:
# get train/val dataset
train_generator = hourglass_dataset(dataset_path, batch_size, class_names,
                                    input_shape=model_input_shape,
                                    num_hgstack=num_stacks,
                                    is_train=True,
                                    with_meta=False,
                                    matchpoints=matchpoints)

num_train = train_generator.get_dataset_size()
num_val = len(train_generator.get_val_annotations())

model_type = get_model_type(num_stacks, mobile, tiny, model_input_shape)


In [5]:
# callbacks for training process
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_grads=False, write_images=False, update_freq='batch')
eval_callback = EvalCallBack(log_dir, dataset_path, class_names, model_input_shape, model_type)
checkpoint_clean = CheckpointCleanCallBack(log_dir, max_val_keep=5)
terminate_on_nan = TerminateOnNaN()

callbacks = [tensorboard, eval_callback, terminate_on_nan, checkpoint_clean]
# callbacks = [tensorboard, terminate_on_nan, checkpoint_clean]

# prepare optimizer
steps_per_epoch = max(1, num_train//batch_size)
decay_steps = steps_per_epoch * (total_epoch - init_epoch)
optimizer = get_optimizer(optimizer, learning_rate, decay_type=decay_type, decay_steps=decay_steps)
#optimizer = RMSprop(lr=5e-4)

# prepare loss function
loss_func = get_loss(loss_type)


In [6]:
# support multi-gpu training
if gpu_num >= 2:
    # devices_list=["/gpu:0", "/gpu:1"]
    devices_list=["/gpu:{}".format(n) for n in range(gpu_num)]
    strategy = tf.distribute.MirroredStrategy(devices=devices_list)
    print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
    with strategy.scope():
        # get multi-gpu train model. you can also use "model_input_shape=None" to create a dynamic input shape model,
        # but multiscale train/inference doesn't work for it
        model = get_hourglass_model(num_classes, num_stacks, num_channels, model_input_shape=model_input_shape, mobile=mobile)
        # compile model
        model.compile(optimizer=optimizer, loss=loss_func)
else:
    # get normal train model. you can also use "model_input_shape=None" to create a dynamic input shape model,
    # but multiscale train/inference doesn't work for it
    model = get_hourglass_model(num_classes, num_stacks, num_channels, model_input_shape=model_input_shape, mobile=mobile)
    # compile model
    model.compile(optimizer=optimizer, loss=loss_func)

print('Create {} Stacked Hourglass model with stack number {}, channel number {}. train input shape {}'.format('Mobile' if mobile else '', num_stacks, num_channels, model_input_shape))
model.summary()



INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
Number of devices: 2
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:

In [7]:
# if weights_path:
#     model.load_weights(weights_path, by_name=True)#, skip_mismatch=True)
#     print('Load weights {}.'.format(weights_path))

# # start training
# print('Train on {} samples, val on {} samples, with batch size {}, model input shape {}.'.format(num_train, num_val, batch_size, model_input_shape))
# model.fit_generator(generator=train_generator,
#                     steps_per_epoch=num_train // batch_size,
#                     epochs=total_epoch,
#                     initial_epoch=init_epoch,
#                     workers=1,
#                     use_multiprocessing=False,
#                     max_queue_size=10,
                    
#                     callbacks=callbacks)

# model.save(os.path.join(log_dir, 'trained_final.h5'))


In [8]:
output_shape = (int(model_input_shape[0]/4), int(model_input_shape[1]/4))
output_shape

(64, 64)

In [9]:
from common.data_utils import random_horizontal_flip, random_vertical_flip, random_brightness
from common.data_utils import random_grayscale, random_chroma, random_contrast, random_sharpness, random_blur, random_histeq, random_rotate_angle
from common.data_utils import crop_single_object, rotate_single_object, crop_image, normalize_image, transform_keypoints, generate_gt_heatmap
from PIL import Image


def map_dataset_to_image_heatmaps(imagefile, center, scale, keypoints):
            
    img = tf.io.read_file(imagefile)
    decoded_img = tf.io.decode_png(img, channels=3)
#     orig_img_shape = decoded_img.shape
    resized_img = tf.image.resize(decoded_img, model_input_shape)
    image = resized_img
    
#     image = tf.expand_dims(resized_img, axis=0)


#     img = tf.io.read_file(imagefile)
#     if img.mode != 'RGB':
#         img = img.convert('RGB')
#     image = np.array(img)
#     img.close()
    
#     image_shape = image.shape

    rotate_angle = 0
#     image = crop_image(image, center, scale, model_input_shape, rotate_angle)
    
    # transform keypoints to cropped image reference
#     transformed_keypoints = transform_keypoints(keypoints, center, scale, output_shape, rotate_angle)

        # in case we got an empty image, bypass the sample
#     if image is None:
#         return None, None, None
    
    # normalize image
#     image = normalize_image(image, self.get_color_mean())



    # Data Augmentation
#     image, keypoints = crop_single_object(image, keypoints, center, scale, model_input_shape)
    seed = tf.random.uniform(shape=[2], maxval=3, dtype=tf.int32)
    image = tf.image.stateless_random_brightness(image, max_delta=0.95, seed=seed)
    image = tf.image.stateless_random_contrast(image, lower=0.1, upper=0.9, seed=seed)
    image = tf.image.stateless_random_hue(image, 0.2, seed)
    image = tf.image.stateless_random_jpeg_quality(image, 75, 95, seed)
    image = tf.image.stateless_random_saturation(image, lower=0.5, upper=1.0, seed=seed)


    # Rescale keypoints from original_img_shape to output_shape
    resized_scale = tf.divide(orig_img_shape, model_input_shape)
    resized_scale = tf.concat([resized_scale, [1]], 0)  # adding third dimension for visibility dimension in the keypoints 
    keypoints = tf.multiply(keypoints, resized_scale)

    # generate ground truth keypoint heatmap
    gt_heatmap = generate_gt_heatmap(keypoints, output_shape)

    out_heatmaps = []
    for m in range(num_stacks):
        out_heatmaps.append(gt_heatmap)
        
    return (image, tf.stack(out_heatmaps, axis=-1))




In [10]:
import numpy as np
import json

def dataset_from_annotations(annotations, image_path, validation_set=False):

    image_filenames = []
    centers = []
    keypoints = []
    scales = []
    
    for annotation in annotations:
        image_filename = os.path.join(image_path, annotation['img_paths'])
        center = np.array(annotation['objpos'])
        keypoint = np.array(annotation['joint_self'])
        scale = annotation['scale_provided']
        
        # adjust center/scale slightly to avoid cropping limbs

        if center[0] != -1:
            center[1] = center[1] + 15 * scale
            scale = scale * 1.25
        
        if annotation['isValidation'] == validation_set:        
            image_filenames.append(image_filename)
            centers.append(center)
            keypoints.append(keypoint)
            scales.append(scale)
        else:
            pass
    img_filenames = tf.convert_to_tensor(image_filenames)
    img_centers = tf.convert_to_tensor(centers)
    img_scales = tf.convert_to_tensor(scales)
    img_keypoints = tf.convert_to_tensor(keypoints)
    return image_filenames, centers, scales, keypoints
  
                 
            
    
        
# dataset with tf.data.Dataset and training using model.fit()
json_file = "data/mpii/annotations.json"
image_path = "data/mpii/images/"
with open(json_file) as f:
    annotations = json.load(f)
    
    
    
    

In [11]:
img_filenames, img_centers, img_scales, img_keypoints = dataset_from_annotations(annotations, image_path,validation_set=False)   
tfdataset_train= tf.data.Dataset.from_tensor_slices((img_filenames, img_centers, img_scales, img_keypoints))

img_filenames, img_centers, img_scales, img_keypoints = dataset_from_annotations(annotations, image_path,validation_set=True)      
tfdataset_val= tf.data.Dataset.from_tensor_slices((img_filenames, img_centers, img_scales, img_keypoints))


In [12]:
from common.data_utils import generate_gt_heatmap, label_heatmap
import inspect

# converted_f = tf.autograph.to_graph(label_heatmap.python_function)
# print(inspect.getsource(converted_f))
# img_keypointsT[0].shape
# generate_gt_heatmap(img_keypoints[0], (64,64))

In [13]:
AUTOTUNE = tf.data.AUTOTUNE
tfdataset_mapped_train = tfdataset_train.map(map_dataset_to_image_heatmaps, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)
tfdataset_mapped_val = tfdataset_val.map(map_dataset_to_image_heatmaps, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)

[<tf.Tensor 'cond/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond/Cast_3:0' shape=() dtype=int32>]
[<tf.Tensor 'cond_1/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond_1/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond_1/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond_1/Cast_3:0' shape=() dtype=int32>]
[<tf.Tensor 'cond_2/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond_2/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond_2/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond_2/Cast_3:0' shape=() dtype=int32>]
[<tf.Tensor 'cond_3/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond_3/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond_3/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond_3/Cast_3:0' shape=() dtype=int32>]
[<tf.Tensor 'cond_4/Cast:0' shape=() dtype=int32>, <tf.Tensor 'cond_4/Cast_1:0' shape=() dtype=int32>] [<tf.Tensor 'cond_4/Cast_2:0' shape=() dtype=int32>, <tf.Tensor 'cond_4/C

In [14]:
# changing GPU device configuration
# import tensorflow as tf
# physical_devices = tf.config.list_physical_devices('GPU')

# tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
#      tf.config.LogicalDeviceConfiguration(memory_limit=100)])
# logical_devices = tf.config.list_logical_devices('GPU')
# logical_devicesphysical_devices = tf.config.list_physical_devices('GPU')
# try:
#   tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
#      tf.config.LogicalDeviceConfiguration(memory_limit=100)])

#   logical_devices = tf.config.list_logical_devices('GPU')
#   assert len(logical_devices) == len(physical_devices) + 1

#   tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=10),
#      tf.config.LogicalDeviceConfiguration(memory_limit=10)])
# except:
#   # Invalid device or cannot modify logical devices once initialized.
#   pass

In [15]:
# callbacks = [tensorboard, eval_callback, terminate_on_nan, checkpoint_clean]
callbacks = [tensorboard, terminate_on_nan, checkpoint_clean]

In [16]:
model.fit(tfdataset_mapped_train, 
          validation_data=tfdataset_mapped_val,
                    # steps_per_epoch=num_train // batch_size,
                    epochs=total_epoch,
                    initial_epoch=init_epoch,
                    workers=1,
                    use_multiprocessing=True,
                    max_queue_size=10,                    
                    callbacks=callbacks)


Epoch 1/100
INFO:tensorflow:batch_all_reduce: 396 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 396 all-reduces with algorithm = nccl, num_packs = 1
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epo

KeyboardInterrupt: 

In [19]:
# checkpoint_dir = os.path.join(log_dir, 'ep{epoch:03d}-loss{loss:.3f}-val_acc{val_acc:.3f}.h5'.format(epoch=(epoch+1), loss=logs.get('loss'), val_acc=val_acc))
epoch = 97
checkpoint_dir = os.path.join(log_dir, 'ep{epoch:03d}-loss{loss:.3f}.h5'.format(epoch=(epoch+1), loss=9.7319e-04))
model.save(checkpoint_dir)


In [None]:
eval_callback.model = model
eval_callback.on_epoch_end(0)

In [None]:
from eval import hourglass_predict_keras, post_process_heatmap_simple
heatmap = hourglass_predict_keras(model, data_item[0])

In [None]:
heatmap.shape[0:2]

In [None]:
post_process_heatmap_simple(heatmap, 0.01)

In [None]:
_map = heatmap[:, :,:, 0]
np.where(_map == _map.max())

In [None]:
num_train