In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os, sys, argparse
import tensorflow.keras.backend as K
#from tensorflow.keras.utils import multi_gpu_model
from tensorflow.keras.callbacks import TensorBoard, TerminateOnNaN

from hourglass.model import get_hourglass_model
from hourglass.data import hourglass_dataset
from hourglass.loss import get_loss
from hourglass.callbacks import EvalCallBack, CheckpointCleanCallBack
from common.utils import get_classes, get_matchpoints, get_model_type, optimize_tf_gpu
from common.model_utils import get_optimizer

# Try to enable Auto Mixed Precision on TF 2.0
# os.environ['TF_ENABLE_AUTO_MIXED_PRECISION'] = '1'
# os.environ['TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_IGNORE_PERFORMANCE'] = '1'
# os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
# optimize_tf_gpu(tf, K)

In [2]:
# Arguments parsed from command line can be set here
# Model definition options:
num_stacks=2
mobile=False
tiny=False
model_input_shape="256x256"
weights_path=None

# Data options
dataset_path="data/mpii"
classes_path="configs/mpii_classes.txt"
matchpoint_path="configs/mpii_match_point.txt"

# Training options
batch_size=16
optimizer="RMSProp"
loss_type="mse"
learning_rate=5e-4
decay_type=None
mixed_precision=False
init_epoch=0
total_epoch=100
gpu_num=1

height, width = model_input_shape.split('x')
model_input_shape = (int(height), int(width))


In [3]:
# def main(args):
log_dir = 'logs/000'
os.makedirs(log_dir, exist_ok=True)

class_names = get_classes(classes_path)
num_classes = len(class_names)
if matchpoint_path:
    matchpoints = get_matchpoints(matchpoint_path)
else:
    matchpoints = None

# choose model type
if tiny:
    num_channels = 128
else:
    num_channels = 256

if mixed_precision:
    tf_major_version = float(tf.__version__[:3])
    if tf_major_version >= 2.1:
        # apply mixed_precision for valid TF version
        from tensorflow.keras.mixed_precision import experimental as mixed_precision

        policy = mixed_precision.Policy('mixed_float16')
        mixed_precision.set_policy(policy)
    else:
        raise ValueError('Tensorflow {} does not support mixed precision'.format(tf.__version__))


In [4]:
# get train/val dataset
train_generator = hourglass_dataset(dataset_path, batch_size, class_names,
                                    input_shape=model_input_shape,
                                    num_hgstack=num_stacks,
                                    is_train=True,
                                    with_meta=False,
                                    matchpoints=matchpoints)

num_train = train_generator.get_dataset_size()
num_val = len(train_generator.get_val_annotations())

model_type = get_model_type(num_stacks, mobile, tiny, model_input_shape)


In [5]:
# callbacks for training process
tensorboard = TensorBoard(log_dir=log_dir, histogram_freq=0, write_graph=True, write_grads=False, write_images=False, update_freq='batch')
eval_callback = EvalCallBack(log_dir, dataset_path, class_names, model_input_shape, model_type)
checkpoint_clean = CheckpointCleanCallBack(log_dir, max_val_keep=5)
terminate_on_nan = TerminateOnNaN()

callbacks = [tensorboard, eval_callback, terminate_on_nan, checkpoint_clean]
# callbacks = [tensorboard, terminate_on_nan, checkpoint_clean]

# prepare optimizer
steps_per_epoch = max(1, num_train//batch_size)
decay_steps = steps_per_epoch * (total_epoch - init_epoch)
optimizer = get_optimizer(optimizer, learning_rate, decay_type=decay_type, decay_steps=decay_steps)
#optimizer = RMSprop(lr=5e-4)

# prepare loss function
loss_func = get_loss(loss_type)


In [6]:
# support multi-gpu training
if gpu_num >= 2:
    # devices_list=["/gpu:0", "/gpu:1"]
    devices_list=["/gpu:{}".format(n) for n in range(gpu_num)]
    strategy = tf.distribute.MirroredStrategy(devices=devices_list)
    print ('Number of devices: {}'.format(strategy.num_replicas_in_sync))
    with strategy.scope():
        # get multi-gpu train model. you can also use "model_input_shape=None" to create a dynamic input shape model,
        # but multiscale train/inference doesn't work for it
        model = get_hourglass_model(num_classes, num_stacks, num_channels, model_input_shape=model_input_shape, mobile=mobile)
        # compile model
        model.compile(optimizer=optimizer, loss=loss_func)
else:
    # get normal train model. you can also use "model_input_shape=None" to create a dynamic input shape model,
    # but multiscale train/inference doesn't work for it
    model = get_hourglass_model(num_classes, num_stacks, num_channels, model_input_shape=model_input_shape, mobile=mobile)
    # compile model
    model.compile(optimizer=optimizer, loss=loss_func)

print('Create {} Stacked Hourglass model with stack number {}, channel number {}. train input shape {}'.format('Mobile' if mobile else '', num_stacks, num_channels, model_input_shape))
model.summary()



Create  Stacked Hourglass model with stack number 2, channel number 256. train input shape (256, 256)
Model: "model_1"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 image_input (InputLayer)       [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 front_conv_1x1_1 (Conv2D)      (None, 128, 128, 64  9472        ['image_input[0][0]']            
                                )                                                                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 128, 128, 64  256        ['front_conv_1x1_1[0][0]

 hg0_downsample_1_conv_1x1_1 (C  (None, 64, 64, 128)  32896      ['front_residual_3_add[0][0]']   
 onv2D)                                                                                           
                                                                                                  
 batch_normalization_10 (BatchN  (None, 64, 64, 128)  512        ['hg0_downsample_1_conv_1x1_1[0][
 ormalization)                                                   0]']                             
                                                                                                  
 hg0_downsample_1_conv_3x3_2 (C  (None, 64, 64, 128)  147584     ['batch_normalization_10[0][0]'] 
 onv2D)                                                                                           
                                                                                                  
 batch_normalization_11 (BatchN  (None, 64, 64, 128)  512        ['hg0_downsample_1_conv_3x3_2[0][
 ormalizat

                                                                                                  
 batch_normalization_21 (BatchN  (None, 8, 8, 256)   1024        ['hg0_downsample_8_conv_1x1_3[0][
 ormalization)                                                   0]']                             
                                                                                                  
 hg0_downsample_8_add (Add)     (None, 8, 8, 256)    0           ['max_pooling2d_3[0][0]',        
                                                                  'batch_normalization_21[0][0]'] 
                                                                                                  
 hg0_downsample_f8_1_conv_1x1_1  (None, 8, 8, 128)   32896       ['hg0_downsample_8_add[0][0]']   
  (Conv2D)                                                                                        
                                                                                                  
 batch_nor

  (Conv2D)                                                                                        
                                                                                                  
 hg0_downsample_f8_short_conv_1  (None, 8, 8, 256)   33024       ['batch_normalization_23[0][0]'] 
 x1_3 (Conv2D)                                                                                    
                                                                                                  
 batch_normalization_35 (BatchN  (None, 16, 16, 128)  512        ['hg0_upsample_f4_short_conv_3x3_
 ormalization)                                                   2[0][0]']                        
                                                                                                  
 batch_normalization_33 (BatchN  (None, 8, 8, 256)   1024        ['hg0_downsample_f8_3_conv_1x1_3[
 ormalization)                                                   0][0]']                          
          

                                                                                                  
 add_2 (Add)                    (None, 32, 32, 256)  0           ['hg0_upsample_f2_short_add[0][0]
                                                                 ',                               
                                                                  'up_sampling2d_1[0][0]']        
                                                                                                  
 hg0_upsample_f2_merged_conv_1x  (None, 32, 32, 128)  32896      ['add_2[0][0]']                  
 1_1 (Conv2D)                                                                                     
                                                                                                  
 hg0_upsample_f1_short_conv_1x1  (None, 64, 64, 128)  32896      ['hg0_downsample_1_add[0][0]']   
 _1 (Conv2D)                                                                                      
          

 hg0_conv_1x1_2 (Conv2D)        (None, 64, 64, 256)  65792       ['batch_normalization_52[0][0]'] 
                                                                                                  
 hg0_conv_1x1_3 (Conv2D)        (None, 64, 64, 256)  4352        ['hg0_conv_1x1_predict[0][0]']   
                                                                                                  
 add_4 (Add)                    (None, 64, 64, 256)  0           ['hg0_conv_1x1_2[0][0]',         
                                                                  'hg0_conv_1x1_3[0][0]',         
                                                                  'front_residual_3_add[0][0]']   
                                                                                                  
 hg1_downsample_1_conv_1x1_1 (C  (None, 64, 64, 128)  32896      ['add_4[0][0]']                  
 onv2D)                                                                                           
          

 hg1_downsample_8_conv_3x3_2 (C  (None, 8, 8, 128)   147584      ['batch_normalization_62[0][0]'] 
 onv2D)                                                                                           
                                                                                                  
 batch_normalization_63 (BatchN  (None, 8, 8, 128)   512         ['hg1_downsample_8_conv_3x3_2[0][
 ormalization)                                                   0]']                             
                                                                                                  
 hg1_downsample_8_conv_1x1_3 (C  (None, 8, 8, 256)   33024       ['batch_normalization_63[0][0]'] 
 onv2D)                                                                                           
                                                                                                  
 batch_normalization_64 (BatchN  (None, 8, 8, 256)   1024        ['hg1_downsample_8_conv_1x1_3[0][
 ormalizat

                                                                                                  
 batch_normalization_66 (BatchN  (None, 8, 8, 128)   512         ['hg1_downsample_f8_short_conv_3x
 ormalization)                                                   3_2[0][0]']                      
                                                                                                  
 hg1_upsample_f4_short_conv_3x3  (None, 16, 16, 128)  147584     ['batch_normalization_77[0][0]'] 
 _2 (Conv2D)                                                                                      
                                                                                                  
 hg1_downsample_f8_3_conv_1x1_3  (None, 8, 8, 256)   33024       ['batch_normalization_75[0][0]'] 
  (Conv2D)                                                                                        
                                                                                                  
 hg1_downs

 hg1_upsample_f4_merged_add (Ad  (None, 16, 16, 256)  0          ['add_6[0][0]',                  
 d)                                                               'batch_normalization_82[0][0]'] 
                                                                                                  
 hg1_upsample_f2_short_add (Add  (None, 32, 32, 256)  0          ['hg1_downsample_2_add[0][0]',   
 )                                                                'batch_normalization_85[0][0]'] 
                                                                                                  
 up_sampling2d_4 (UpSampling2D)  (None, 32, 32, 256)  0          ['hg1_upsample_f4_merged_add[0][0
                                                                 ]']                              
                                                                                                  
 add_7 (Add)                    (None, 32, 32, 256)  0           ['hg1_upsample_f2_short_add[0][0]
          

 hg1_conv_1x1_1 (Conv2D)        (None, 64, 64, 256)  65792       ['hg1_upsample_f1_merged_add[0][0
                                                                 ]']                              
                                                                                                  
 batch_normalization_95 (BatchN  (None, 64, 64, 256)  1024       ['hg1_conv_1x1_1[0][0]']         
 ormalization)                                                                                    
                                                                                                  
 hg1_conv_1x1_predict (Conv2D)  (None, 64, 64, 16)   4112        ['batch_normalization_95[0][0]'] 
                                                                                                  
 tf.stack (TFOpLambda)          (None, 64, 64, 16,   0           ['hg0_conv_1x1_predict[0][0]',   
                                2)                                'hg1_conv_1x1_predict[0][0]']   
          

In [7]:
# if weights_path:
#     model.load_weights(weights_path, by_name=True)#, skip_mismatch=True)
#     print('Load weights {}.'.format(weights_path))

# # start training
# print('Train on {} samples, val on {} samples, with batch size {}, model input shape {}.'.format(num_train, num_val, batch_size, model_input_shape))
# model.fit_generator(generator=train_generator,
#                     steps_per_epoch=num_train // batch_size,
#                     epochs=total_epoch,
#                     initial_epoch=init_epoch,
#                     workers=1,
#                     use_multiprocessing=False,
#                     max_queue_size=10,
                    
#                     callbacks=callbacks)

# model.save(os.path.join(log_dir, 'trained_final.h5'))


In [8]:
output_shape = (int(model_input_shape[0]/4), int(model_input_shape[1]/4))
output_shape

(64, 64)

In [39]:
from common.data_utils import random_horizontal_flip, random_vertical_flip, random_brightness
from common.data_utils import random_grayscale, random_chroma, random_contrast, random_sharpness, random_blur, random_histeq, random_rotate_angle
from common.data_utils import crop_single_object, rotate_single_object, crop_image, normalize_image, transform_keypoints, generate_gt_heatmap
from PIL import Image


def map_dataset_to_image_heatmaps(imagefile, center, scale, keypoints):
            
    img = tf.io.read_file(imagefile)
    decoded_img = tf.io.decode_png(img, channels=3)
    resized_img = tf.image.resize(decoded_img, model_input_shape)
    image = resized_img
#     image = tf.expand_dims(resized_img, axis=0)


#     img = tf.io.read_file(imagefile)
#     if img.mode != 'RGB':
#         img = img.convert('RGB')
#     image = np.array(img)
#     img.close()
    
#     image_shape = image.shape

    rotate_angle = 0
#     image = crop_image(image, center, scale, model_input_shape, rotate_angle)
    
    # transform keypoints to cropped image reference
#     transformed_keypoints = transform_keypoints(keypoints, center, scale, output_shape, rotate_angle)

        # in case we got an empty image, bypass the sample
#     if image is None:
#         return None, None, None
    
    # normalize image
#     image = normalize_image(image, self.get_color_mean())



    # Data Augmentation
#     image, keypoints = crop_single_object(image, keypoints, center, scale, model_input_shape)
    seed = tf.random.uniform(shape=[2], maxval=3, dtype=tf.int32)
    image = tf.image.stateless_random_brightness(image, max_delta=0.95, seed=seed)
    image = tf.image.stateless_random_contrast(image, lower=0.1, upper=0.9, seed=seed)
    image = tf.image.stateless_random_hue(image, 0.2, seed)
    image = tf.image.stateless_random_jpeg_quality(image, 75, 95, seed)
    image = tf.image.stateless_random_saturation(image, lower=0.5, upper=1.0, seed=seed)


    # generate ground truth keypoint heatmap
    gt_heatmap = generate_gt_heatmap(keypoints, output_shape)

    out_heatmaps = []
    for m in range(num_stacks):
        out_heatmaps.append(gt_heatmap)
        
    return (image, tf.stack(out_heatmaps, axis=-1))




In [40]:
import numpy as np
import json

def dataset_from_annotations(annotations, image_path, validation_set=False):

    image_filenames = []
    centers = []
    keypoints = []
    scales = []
    
    for annotation in annotations:
        image_filename = os.path.join(image_path, annotation['img_paths'])
        center = np.array(annotation['objpos'])
        keypoint = np.array(annotation['joint_self'])
        scale = annotation['scale_provided']
        
        # adjust center/scale slightly to avoid cropping limbs

        if center[0] != -1:
            center[1] = center[1] + 15 * scale
            scale = scale * 1.25
        
        if annotation['isValidation'] == validation_set:        
            image_filenames.append(image_filename)
            centers.append(center)
            keypoints.append(keypoint)
            scales.append(scale)
        else:
            pass
    img_filenames = tf.convert_to_tensor(image_filenames)
    img_centers = tf.convert_to_tensor(centers)
    img_scales = tf.convert_to_tensor(scales)
    img_keypoints = tf.convert_to_tensor(keypoints)
    return image_filenames, centers, scales, keypoints
  
                 
            
    
        
# dataset with tf.data.Dataset and training using model.fit()
json_file = "data/mpii/annotations.json"
image_path = "data/mpii/images/"
with open(json_file) as f:
    annotations = json.load(f)
    
    
    
    

In [41]:
img_filenames, img_centers, img_scales, img_keypoints = dataset_from_annotations(annotations, image_path,validation_set=False)   
tfdataset_train= tf.data.Dataset.from_tensor_slices((img_filenames, img_centers, img_scales, img_keypoints))

img_filenames, img_centers, img_scales, img_keypoints = dataset_from_annotations(annotations, image_path,validation_set=True)      
tfdataset_val= tf.data.Dataset.from_tensor_slices((img_filenames, img_centers, img_scales, img_keypoints))


In [42]:
from common.data_utils import generate_gt_heatmap, label_heatmap
import inspect

# converted_f = tf.autograph.to_graph(label_heatmap.python_function)
# print(inspect.getsource(converted_f))
# img_keypointsT[0].shape
# generate_gt_heatmap(img_keypoints[0], (64,64))

In [43]:
AUTOTUNE = tf.data.AUTOTUNE
tfdataset_mapped_train = tfdataset_train.map(map_dataset_to_image_heatmaps, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)
tfdataset_mapped_val = tfdataset_val.map(map_dataset_to_image_heatmaps, num_parallel_calls=AUTOTUNE).batch(batch_size).prefetch(AUTOTUNE)

In [44]:
data_item = next(iter(tfdataset_mapped_train))
data_item[0][0].shape

TensorShape([256, 256, 3])

In [None]:
from common.data_utils import crop_single_object

new_image, new_keypoints = crop_single_object(data_item[0][1], img_keypoints[1], img_centers[1], img_scales[1], model_input_shape)
new_image.shape, new_keypoints.shape

In [None]:
train_generator.__getitem__(1)[1][1].shape

In [None]:
# changing GPU device configuration
# import tensorflow as tf
# physical_devices = tf.config.list_physical_devices('GPU')

# tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
#      tf.config.LogicalDeviceConfiguration(memory_limit=100)])
# logical_devices = tf.config.list_logical_devices('GPU')
# logical_devicesphysical_devices = tf.config.list_physical_devices('GPU')
# try:
#   tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=100),
#      tf.config.LogicalDeviceConfiguration(memory_limit=100)])

#   logical_devices = tf.config.list_logical_devices('GPU')
#   assert len(logical_devices) == len(physical_devices) + 1

#   tf.config.set_logical_device_configuration(
#     physical_devices[0],
#     [tf.config.LogicalDeviceConfiguration(memory_limit=10),
#      tf.config.LogicalDeviceConfiguration(memory_limit=10)])
# except:
#   # Invalid device or cannot modify logical devices once initialized.
#   pass

In [50]:
model.fit(tfdataset_mapped_train, validation_data=tfdataset_mapped_val,
#                     steps_per_epoch=num_train // batch_size,
                    epochs=total_epoch,
                    initial_epoch=init_epoch,
                    workers=1,
                    use_multiprocessing=False,
                    max_queue_size=10,                    
                    callbacks=callbacks)


Epoch 1/100


ResourceExhaustedError: Graph execution error:

Detected at node 'gradient_tape/model_1/hg1_conv_1x1_predict/Conv2D/Conv2DBackpropInput' defined at (most recent call last):
    File "c:\users\aryan\appdata\local\programs\python\python37\lib\runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "c:\users\aryan\appdata\local\programs\python\python37\lib\runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\ipykernel_launcher.py", line 16, in <module>
      app.launch_new_instance()
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\traitlets\config\application.py", line 845, in launch_instance
      app.start()
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\ipykernel\kernelapp.py", line 677, in start
      self.io_loop.start()
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\tornado\platform\asyncio.py", line 199, in start
      self.asyncio_loop.run_forever()
    File "c:\users\aryan\appdata\local\programs\python\python37\lib\asyncio\base_events.py", line 523, in run_forever
      self._run_once()
    File "c:\users\aryan\appdata\local\programs\python\python37\lib\asyncio\base_events.py", line 1758, in _run_once
      handle._run()
    File "c:\users\aryan\appdata\local\programs\python\python37\lib\asyncio\events.py", line 88, in _run
      self._context.run(self._callback, *self._args)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\ipykernel\kernelbase.py", line 457, in dispatch_queue
      await self.process_one()
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\ipykernel\kernelbase.py", line 446, in process_one
      await dispatch(*args)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\ipykernel\kernelbase.py", line 353, in dispatch_shell
      await result
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\ipykernel\kernelbase.py", line 648, in execute_request
      reply_content = await reply_content
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\ipykernel\ipkernel.py", line 345, in do_execute
      res = shell.run_cell(code, store_history=store_history, silent=silent)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\ipykernel\zmqshell.py", line 532, in run_cell
      return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\IPython\core\interactiveshell.py", line 2899, in run_cell
      raw_cell, store_history, silent, shell_futures)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\IPython\core\interactiveshell.py", line 2944, in _run_cell
      return runner(coro)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\IPython\core\async_helpers.py", line 68, in _pseudo_sync_runner
      coro.send(None)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\IPython\core\interactiveshell.py", line 3170, in run_cell_async
      interactivity=interactivity, compiler=compiler, result=result)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\IPython\core\interactiveshell.py", line 3361, in run_ast_nodes
      if (await self.run_code(code, result,  async_=asy)):
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\IPython\core\interactiveshell.py", line 3441, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\aryan\AppData\Local\Temp/ipykernel_25808/3950785756.py", line 8, in <module>
      callbacks=callbacks)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\keras\utils\traceback_utils.py", line 64, in error_handler
      return fn(*args, **kwargs)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\keras\engine\training.py", line 1384, in fit
      tmp_logs = self.train_function(iterator)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\keras\engine\training.py", line 1021, in train_function
      return step_function(self, iterator)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\keras\engine\training.py", line 1010, in step_function
      outputs = model.distribute_strategy.run(run_step, args=(data,))
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\keras\engine\training.py", line 1000, in run_step
      outputs = model.train_step(data)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\keras\engine\training.py", line 863, in train_step
      self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\keras\optimizer_v2\optimizer_v2.py", line 531, in minimize
      loss, var_list=var_list, grad_loss=grad_loss, tape=tape)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\keras\optimizer_v2\optimizer_v2.py", line 583, in _compute_gradients
      grads_and_vars = self._get_gradients(tape, loss, var_list, grad_loss)
    File "C:\Users\aryan\AppData\Local\Programs\Python\Python37\Lib\site-packages\keras\optimizer_v2\optimizer_v2.py", line 464, in _get_gradients
      grads = tape.gradient(loss, var_list, grad_loss)
Node: 'gradient_tape/model_1/hg1_conv_1x1_predict/Conv2D/Conv2DBackpropInput'
OOM when allocating tensor with shape[16,256,64,64] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node gradient_tape/model_1/hg1_conv_1x1_predict/Conv2D/Conv2DBackpropInput}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_train_function_32195]

In [None]:
eval_callback.model = model
eval_callback.on_epoch_end(0)

In [None]:
from eval import hourglass_predict_keras, post_process_heatmap_simple
heatmap = hourglass_predict_keras(model, data_item[0])

In [None]:
heatmap.shape[0:2]

In [None]:
post_process_heatmap_simple(heatmap, 0.01)

In [None]:
_map = heatmap[:, :,:, 0]
np.where(_map == _map.max())

In [None]:
num_train