In [0]:
#  UTILS

def display_9_images_from_dataset(dataset):
  plt.figure(figsize=(13,13))
  subplot=331
  for i, (image, label) in enumerate(dataset):
    plt.subplot(subplot)
    plt.axis('off')
    plt.imshow(image.numpy().astype(np.uint8))
    subplot += 1
    if i==2:
      break
  plt.tight_layout()
  plt.subplots_adjust(wspace=0.1, hspace=0.1)
  plt.show()

def resize_and_crop_image(image, label):
    w = tf.shape(image)[0]
    h = tf.shape(image)[1]
    tw = TARGET_SIZE[1]
    th = TARGET_SIZE[0]
    resize_crit = (w * th) / (h * tw)
    image = tf.cond(resize_crit < 1,
                    lambda: tf.image.resize(image, [w*tw/w, h*tw/w]), # if true
                    lambda: tf.image.resize(image, [w*th/h, h*th/h])  # if false
                   )
    nw = tf.shape(image)[0]
    nh = tf.shape(image)[1]
    image = tf.image.crop_to_bounding_box(image, (nw - tw) // 2, (nh - th) // 2, tw, th)
    return image, label

def normalize(image, label):
    image = tf.image.per_image_standardization(image)
    return image, label


def augmentation(image, label):
    crop_w = int(TARGET_SIZE[0]*0.95) 
    crop_h = int(TARGET_SIZE[1]*0.95)
    crop_or_pad_s = int(TARGET_SIZE[0]*0.1) + TARGET_SIZE[0]
    crop_or_pad_p = int(TARGET_SIZE[1]*0.1) + TARGET_SIZE[1]
    image = tf.image.resize_with_crop_or_pad(image, crop_or_pad_s, crop_or_pad_p)
    image = tf.image.random_crop(image, [crop_w, crop_h, 3])
    image = tf.image.random_flip_left_right(image)
    image = tf.image.random_flip_up_down(image)
    image = tf.image.rot90(image, k=random.randrange(4))
    image = tf.image.random_saturation(image, 0.6, 1.6)
    image = tf.image.random_hue(image, 0.08)
    image = tf.image.random_contrast(image, 0.7, 1.3)
    image = tf.image.random_brightness(image, 0.05)
    return image, label

def count_data_items(filenames):
    return np.sum([int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames])

def force_image_sizes(dataset):
    reshape_images = lambda image, label: (tf.reshape(image, SHAPE), label)   
    dataset = dataset.map(reshape_images, num_parallel_calls=AUTO)
    return dataset

## READ_RECORD
def read_tfrecord(example):
    features = {
      "image":tf.io.FixedLenFeature([], tf.string), 
      "label":tf.io.FixedLenFeature([], tf.string),          
      "head_root_hot":tf.io.VarLenFeature(tf.float32),
      "head_vowel_hot":tf.io.VarLenFeature(tf.float32),
      "head_consonant_hot": tf.io.VarLenFeature(tf.float32)
     }
    # decode the TFRecord
    example        = tf.io.parse_single_example(example, features)
#     image          = tf.image.decode_jpeg(example['image'], channels=3)
#     image          = tf.cast(image, tf.float32)
#     image = tf.reshape(image, [*TARGET_SIZE, 3])
    image = tf.image.decode_jpeg(example['image'], channels=3)
    label          = example['label']
    head_root      = tf.sparse.to_dense(example['head_root_hot'])
    head_vowel     = tf.sparse.to_dense(example['head_vowel_hot'])
    head_consonant = tf.sparse.to_dense(example['head_consonant_hot'])
  
    return image,{"head_root": head_root, "head_vowel": head_vowel, "head_consonant": head_consonant }

# LOAD DATASETS
option_no_order = tf.data.Options()
option_no_order.experimental_deterministic = False


def load_dataset(filenames):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) 
    dataset = dataset.with_options(ignore_order)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO)
#     dataset = force_image_sizes(dataset)
    return dataset

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES)
    dataset = dataset.shuffle(BATCH_SIZE)
    dataset = dataset.map(augmentation)
    dataset = dataset.map(normalize)
    dataset = dataset.map(resize_and_crop_image, num_parallel_calls=AUTO) 
    dataset = dataset.repeat(3)
    dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(AUTO)
    return dataset

def get_validation_dataset():
    dataset = load_dataset(VALIDATION_FILENAMES)
    dataset = dataset.map(resize_and_crop_image, num_parallel_calls=AUTO) 
    dataset = dataset.batch(VALIDATION_BATCH_SIZE, drop_remainder=True)
    dataset = dataset.prefetch(AUTO) 
    return dataset



dataset = tf.data.Dataset.from_tensor_slices((placeholder_X, placeholder_y))
dataset = dataset.batch(batch_size)
iterator = dataset.make_initializable_iterator()

data_X, data_y = iterator.get_next()
data_y = tf.cast(data_y, tf.int32)
model = Model(data_X, data_y)