<a href="https://colab.research.google.com/github/Saroramath/MachineLearning/blob/main/CustomTrainingLoopMNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook illustrates the use of custom training loops for Tensorflow/Keras. While many models can be trained successfully using the high-level `.fit()` method, more advanced architectures such as *generative adversarial networks* (Part 6 of this course) and *deep reinforcement learning* (Part 9 of this course) require custom training loops. We illustrate this here for our simple MNIST model.

In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import time

In [None]:
# Load the raw data
(ds_train, ds_val), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

In [None]:
# Build the training/testing pipelines
def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label #tf.one_hot(label, 10)

ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.experimental.AUTOTUNE)

ds_val = ds_val.map(
    normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
ds_val = ds_val.cache()
ds_val = ds_val.batch(128)
ds_val = ds_val.prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
# Create the model
inp = tf.keras.layers.Input(shape=(28,28,))
b = tf.keras.layers.Flatten()(inp)
b = tf.keras.layers.Dense(128, activation='relu')(b)
out = tf.keras.layers.Dense(10, activation='softmax')(b)

model = tf.keras.models.Model(inp, out)

# Summary of the model
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


In [None]:
# Define the loss function and optimizer (but we do not use the model.compile method)
loss = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.SGD()

# Use accuracy as metrics (both for training and validation)
train_metric = tf.keras.metrics.SparseCategoricalAccuracy()
validation_metric = tf.keras.metrics.SparseCategoricalAccuracy()

In [None]:
# This is a custom training loop
epochs = 10

for epoch in range(epochs):
  print("Working on epoch {} out of {} epochs:\n".format(epoch, epochs))
  start_time = time.time()

  # Now iterate over all batches
  for batch_nr, (x_batch, y_batch) in enumerate(ds_train):

    # Record operations of the forward pass in a GradientTape
    with tf.GradientTape() as tape:

      # Evaluate the model on the current mini-batch
      y_batch_predict = model(x_batch, training=True)

      # Loss over the current mini-batch
      batch_loss = loss(y_batch, y_batch_predict)

    # Compute the gradient of the loss function (using backpropagation)
    gradients = tape.gradient(batch_loss, model.trainable_weights)

    # Take one step of gradient descent
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))

    # Update the metric
    train_metric.update_state(y_batch, y_batch_predict)

    # Record loss value every 100 batches
    if batch_nr % 100 == 0:
      print("Current training loss for batch {} is {: 2.4f}.".format(batch_nr,
                                                              batch_loss))
          
  # Accuracy at the end of the epoch
  current_acc = train_metric.result()
  print("\nTraining accuracy: {: .4f}".format(current_acc))

  # Reset the metric at epoch's end
  train_metric.reset_states()

  # Validation accuracy at the end of the epoch
  for x_batch_val, y_batch_val in ds_val:
    y_batch_val_predict = model(x_batch_val, training=False)
    validation_metric.update_state(y_batch_val, y_batch_val_predict)
  val_acc = validation_metric.result()
  validation_metric.reset_states()
  print("Validation accuracy: {: .4f}".format(val_acc))
  print("Time to complete: {: 2.4f}\n".format(time.time()-start_time))

Working on epoch 0 out of 10 epochs:

Current training loss for batch 0 is  2.4852.
Current training loss for batch 100 is  1.5337.
Current training loss for batch 200 is  1.0508.
Current training loss for batch 300 is  0.8573.
Current training loss for batch 400 is  0.7124.

Training accuracy:  0.7278
Validation accuracy:  0.8584
Time to complete:  8.4231

Working on epoch 1 out of 10 epochs:

Current training loss for batch 0 is  0.6740.
Current training loss for batch 100 is  0.5568.
Current training loss for batch 200 is  0.5856.
Current training loss for batch 300 is  0.4551.
Current training loss for batch 400 is  0.4806.

Training accuracy:  0.8662
Validation accuracy:  0.8882
Time to complete:  2.9368

Working on epoch 2 out of 10 epochs:

Current training loss for batch 0 is  0.5966.
Current training loss for batch 100 is  0.5237.
Current training loss for batch 200 is  0.4966.
Current training loss for batch 300 is  0.4819.
Current training loss for batch 400 is  0.4431.

Tra

The above code uses Tensorflow 2.0's default *eager execution*, which does not build a graph of the model, but rather evaluates each expression immediately. While this is great for debugging, it incurs significant overheads as the execution is agnostic to any potential boilerplates that could be reduced were the model to be compiled as a graph.

We can compile the model into a static graph by adding the `@tf.function` decorator on it.

In [None]:
# Define each mini-batch gradient descent step as separate function
@tf.function
def train_step(x_batch, y_batch):

  # Record operations of the forward pass in a GradientTape
  with tf.GradientTape() as tape:

    # Evaluate the model on the current mini-batch
    y_batch_predict = model(x_batch, training=True)

    # Loss over the current mini-batch
    batch_loss = loss(y_batch, y_batch_predict)

  # Compute the gradient of the loss function (using backpropagation)
  gradients = tape.gradient(batch_loss, model.trainable_weights)

  # Take one step of gradient descent
  optimizer.apply_gradients(zip(gradients, model.trainable_weights))

  # Update the metric
  train_metric.update_state(y_batch, y_batch_predict)

  return batch_loss

# Define the validation step as a separate function
@tf.function
def val_step(x_batch_val, y_batch_val):
  y_batch_val_predict = model(x_batch_val, training=False)
  validation_metric.update_state(y_batch_val, y_batch_val_predict)

In [None]:
# Exact same training loop as before, now just using a compiled version of the
# time-critical components
epochs = 10

for epoch in range(epochs):
  print("Working on epoch {} out of {} epochs:\n".format(epoch, epochs))
  start_time = time.time()

  # Now iterate over all batches
  for batch_nr, (x_batch, y_batch) in enumerate(ds_train):

    # Take one step of gradient descent
    batch_loss = train_step(x_batch, y_batch)

    # Record loss value every 100 batches
    if batch_nr % 100 == 0:
      print("Current training loss for batch {} is {: 2.4f}.".format(batch_nr,
                                                              batch_loss))
          
  # Accuracy at the end of the epoch
  current_acc = train_metric.result()
  print("\nTraining accuracy: {: .4f}".format(current_acc))

  # Reset the metric at epoch's end
  train_metric.reset_states()

  # Validation accuracy at the end of the epoch
  for x_batch_val, y_batch_val in ds_val:
    val_step(x_batch_val, y_batch_val)

  val_acc = validation_metric.result()
  validation_metric.reset_states()
  print("Validation accuracy: {: .4f}".format(val_acc))
  print("Time to complete: {: 2.4f}\n".format(time.time()-start_time))

Working on epoch 0 out of 10 epochs:

Current training loss for batch 0 is  0.2973.
Current training loss for batch 100 is  0.2498.
Current training loss for batch 200 is  0.1861.
Current training loss for batch 300 is  0.3251.
Current training loss for batch 400 is  0.3829.

Training accuracy:  0.9204
Validation accuracy:  0.9255
Time to complete:  0.9808

Working on epoch 1 out of 10 epochs:

Current training loss for batch 0 is  0.3378.
Current training loss for batch 100 is  0.2113.
Current training loss for batch 200 is  0.2382.
Current training loss for batch 300 is  0.2174.
Current training loss for batch 400 is  0.2888.

Training accuracy:  0.9226
Validation accuracy:  0.9272
Time to complete:  0.6724

Working on epoch 2 out of 10 epochs:

Current training loss for batch 0 is  0.2366.
Current training loss for batch 100 is  0.1104.
Current training loss for batch 200 is  0.3008.
Current training loss for batch 300 is  0.2348.
Current training loss for batch 400 is  0.4214.

Tra