In [1]:
import os
from time import time
import tensorflow as tf

def data_generator(features,labels,batch_size):
  dataset = tf.data.Dataset.from_tensor_slices((tf.cast((features/255),tf.float32),labels))
  dataset = dataset.shuffle(buffer_size=len(labels)+1)
  dataset = dataset.batch(batch_size=batch_size,
                          drop_remainder=True)
  dataset = dataset.repeat()
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset

def model(num_classes):  
  model_input = tf.keras.layers.Input(shape=(32,32,3),dtype=tf.float32)
  model = tf.keras.applications.ResNet50(include_top=False,
                                            pooling='avg',
                                         input_tensor=model_input)
  model.trainable = False
  predictions = tf.keras.layers.Dense(num_classes,activation="softmax")(model.output)
  return tf.keras.Model(inputs=model.input, outputs=predictions)

if __name__ == "__main__":
  batch_size = 32
  epochs = 2
  num_classes = 10

  (x_train,y_train),(x_test,y_test) = tf.keras.datasets.cifar10.load_data()

  train_dataset = data_generator(x_train, y_train, batch_size)
  test_dataset = data_generator(x_test, y_test, batch_size)

  xception = model(num_classes)
  optimizers = tf.keras.optimizers.Adam()
  xception.compile(loss='sparse_categorical_crossentropy',optimizer=optimizers)

  start_time = time()
  xception.fit(train_dataset, epochs=epochs, steps_per_epoch=len(x_train)//batch_size,
              validation_data=test_dataset,validation_steps=len(x_test)//batch_size)
  end_time = time()

  print("Time without xla and mpt with fp32",end_time-start_time)

Epoch 1/2
Epoch 2/2
Time without xla and mpt with fp32 55.788514852523804


In [4]:
import os
from time import time
import tensorflow as tf

os.environ['TF_XLA_FLAGS'] = "--tf_xla_auto_jit=2 --tf_xla_cpu_global_jit"
tf.config.optimizer.set_jit(True)

policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
tf.keras.mixed_precision.experimental.set_policy(policy)

def data_generator(features,labels,batch_size):
  dataset = tf.data.Dataset.from_tensor_slices(((features/255),labels))
  dataset = dataset.shuffle(buffer_size=len(labels)+1)
  dataset = dataset.batch(batch_size=batch_size,
                          drop_remainder=True)
  dataset = dataset.repeat()
  dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
  return dataset

def model(num_classes):  
  model_input = tf.keras.layers.Input(shape=(32,32,3),dtype=tf.float16)
  model = tf.keras.applications.ResNet50(include_top=False,
                                          pooling='avg',
                                         input_tensor=model_input)
  model.trainable = False
  x = tf.keras.layers.Dense(num_classes)(model.output)
  predictions = tf.keras.layers.Activation('softmax', dtype=tf.float32)(x)
  return tf.keras.Model(inputs=model.input, outputs=predictions)

if __name__ == "__main__":
  batch_size = 32
  epochs = 2
  num_classes = 10

  (x_train,y_train),(x_test,y_test) = tf.keras.datasets.cifar10.load_data()

  train_dataset = data_generator(x_train, y_train, batch_size)
  test_dataset = data_generator(x_test, y_test, batch_size)

  optimizer = tf.optimizers.Adam()
  optimizer = tf.keras.mixed_precision.LossScaleOptimizer(optimizer)

  xception = model(num_classes)
  xception.compile(loss='sparse_categorical_crossentropy',optimizer=optimizer)

  start_time = time()
  xception.fit(train_dataset, epochs=epochs, steps_per_epoch=len(x_train)//batch_size,
              validation_data=test_dataset,validation_steps=len(x_test)//batch_size)
  end_time = time()
  print("Time with xla and mpt with fp16",end_time-start_time)

Epoch 1/2
Epoch 2/2
Time with xla and mpt with fp16 44.07524275779724
