### Writing custom Training loops from scratch with Keras


Reference: [Writing a training loop from scratch](https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch)

In [10]:
import tensorflow as tf
import keras
from keras import layers
import numpy as np

Calling a model inside a `GradientTape` scope enables you to retrieve the gradients of the `trainable weights` of the layer with respect to a loss value. Using an `optimizer` instance, you can use these gradients to `update` these `variables` (which you can retrieve using `model.trainable_weights`).



In [11]:
BATCH_SIZE = 64
EPOCHS = 2
LEARNING_RATE = 1e-3

In [12]:
# Create basic Model
def get_basic_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = layers.Dense(64, activation="relu")(inputs)
    x2 = layers.Dense(64, activation="relu")(x1)
    outputs = layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    model.summary()
    return model

In [13]:
basic_model = get_basic_model()

Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 digits (InputLayer)         [(None, 784)]             0         
                                                                 
 dense_4 (Dense)             (None, 64)                50240     
                                                                 
 dense_5 (Dense)             (None, 64)                4160      
                                                                 
 predictions (Dense)         (None, 10)                650       
                                                                 
Total params: 55050 (215.04 KB)
Trainable params: 55050 (215.04 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [14]:
# Instantiate an optimizer.
optimizer = keras.optimizers.SGD(learning_rate=LEARNING_RATE)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

### Load and Prepare Dataset

In [15]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# Reserve 10000 samples for validation
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(BATCH_SIZE)

### Training Loop

* We open a for loop that iterates over epochs
* For each epoch, we open a for loop that iterates over the dataset, in batches

In [16]:
for epoch in range(EPOCHS):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            ''' Run the forward pass of the layer.
            The operations that the layer applies to its inputs are going to be
            recorded on the GradientTape.'''
            logits = basic_model(x_batch_train, training=True)  # ts for this minibatch
            # Compute the loss
            loss = loss_fn(y_batch_train, logits)

        # Retrieve the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss, basic_model.trainable_weights)
        # Note that we used trainable_weights, not trainable_variables as used in train_step method of custom models.

        # Run one step of gradient descent by updating the value of the variables
        # to minimize the loss
        optimizer.apply_gradients(zip(grads, basic_model.trainable_weights))

        if step % 200 == 0:
            print("Training Loss at step %d: %.4f" % (step, float(loss)))
            print("Seen so far: %s samples" % ((step + 1) * BATCH_SIZE))


Start of epoch 0
Training Loss at step 0: 133.8576
Seen so far: 64 samples
Training Loss at step 200: 1.2480
Seen so far: 12864 samples
Training Loss at step 400: 0.7482
Seen so far: 25664 samples
Training Loss at step 600: 0.8331
Seen so far: 38464 samples

Start of epoch 1
Training Loss at step 0: 1.2908
Seen so far: 64 samples
Training Loss at step 200: 0.8234
Seen so far: 12864 samples
Training Loss at step 400: 0.8052
Seen so far: 25664 samples
Training Loss at step 600: 0.7437
Seen so far: 38464 samples


### Handling Metrics

* Instantiate the metric at the start of the loop
* Call `metric.update_state()` after each batch
* Call `metric.result()` when you need to display the current value of the metric
* Call `metric.reset_states()` when you need to clear the state of the metric (typically at the end of an epoch)

In [17]:
basic_model_2 = get_basic_model()
# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

Model: "model_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 digits (InputLayer)         [(None, 784)]             0         
                                                                 
 dense_6 (Dense)             (None, 64)                50240     
                                                                 
 dense_7 (Dense)             (None, 64)                4160      
                                                                 
 predictions (Dense)         (None, 10)                650       
                                                                 
Total params: 55050 (215.04 KB)
Trainable params: 55050 (215.04 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


### Prepare metrics

In [18]:
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

In [20]:
import time

for epoch in range(EPOCHS):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            ''' Run the forward pass of the layer.
            The operations that the layer applies to its inputs are going to be
            recorded on the GradientTape.'''
            logits = basic_model_2(x_batch_train, training=True)  # ts for this minibatch
            # Compute the loss
            loss = loss_fn(y_batch_train, logits)

        # Retrieve the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss, basic_model_2.trainable_weights)
        # Note that we used trainable_weights, not trainable_variables as used in train_step method of custom models.

        # Run one step of gradient descent by updating the value of the variables
        # to minimize the loss
        optimizer.apply_gradients(zip(grads, basic_model_2.trainable_weights))

        # Update training metric
        train_acc_metric.update_state(y_batch_train, logits)

        if step % 200 == 0:
            print("Training Loss at step %d: %.4f" % (step, float(loss)))
            print("Seen so far: %s samples" % ((step + 1) * BATCH_SIZE))

    # Display metrics at the end of each epoch
    train_acc = train_acc_metric.result()
    print("Training Accuracy over %d.epoch: %.4f" % (epoch, float(train_acc),))

    # ==== Reset training metrics at the end of the each epoch ====
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = basic_model_2(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)

    val_acc = val_acc_metric.result()
    # ==== Reset training metrics at the end of the each epoch ====
    val_acc_metric.reset_state()

    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))



Start of epoch 0
Training Loss at step 0: 0.5148
Seen so far: 64 samples
Training Loss at step 200: 0.5658
Seen so far: 12864 samples
Training Loss at step 400: 0.4478
Seen so far: 25664 samples
Training Loss at step 600: 0.9974
Seen so far: 38464 samples
Training Accuracy over 0.epoch: 0.8645
Validation acc: 0.8775
Time taken: 21.73s

Start of epoch 1
Training Loss at step 0: 0.6932
Seen so far: 64 samples
Training Loss at step 200: 0.3110
Seen so far: 12864 samples
Training Loss at step 400: 0.5195
Seen so far: 25664 samples
Training Loss at step 600: 0.2024
Seen so far: 38464 samples
Training Accuracy over 1.epoch: 0.8946
Validation acc: 0.9015
Time taken: 22.21s


### Speeding-up the training step with tf.function

The default runtime in TensorFlow 2 is `eager execution`. As such, our training loop above executes eagerly.

This is great for debugging, but `graph compilation` has a definite `performance` advantage. Describing your computation as a `static graph` enables the framework to apply `global performance optimizations`. This is impossible when the framework is constrained to greedily execute one operation after another, with no knowledge of what comes next.

You can compile into a `static graph` any function that `takes tensors as input`. Just add a `@tf.function` decorator on it, like this:

In [23]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = basic_model_2(x, training=True)
        loss = loss_fn(y, logits)
    grads = tape.gradient(loss, basic_model_2.trainable_weights)
    optimizer.apply_gradients(zip(grads, basic_model_2.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss

In [24]:
# for evaluation step
@tf.function
def test_step(x, y):
    val_logits = basic_model_2(x, training=False)
    val_acc_metric.update_state(y, val_logits)

In [25]:
import time

for epoch in range(EPOCHS):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # Train step
        loss = train_step(x_batch_train, y_batch_train)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss))
            )
            print("Seen so far: %d samples" % ((step + 1) * BATCH_SIZE))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Evaluation step
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))


Start of epoch 0
Training loss (for one batch) at step 0: 0.4023
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.7137
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.1519
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.3010
Seen so far: 38464 samples
Training acc over epoch: 0.9085
Validation acc: 0.9124
Time taken: 2.99s

Start of epoch 1
Training loss (for one batch) at step 0: 0.4876
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.2370
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.1396
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.1095
Seen so far: 38464 samples
Training acc over epoch: 0.9194
Validation acc: 0.9136
Time taken: 2.31s


### Low-level handling of losses tracked by the model

Layers & models recursively `track` any `losses` created `during` the `forward pass` by layers that call `self.add_loss(value)`. The resulting list of scalar loss values are available via the property model.losses at the end of the forward pass.

If you want to be using these loss components, you should `sum` them and `add` them to the `main loss` in your training step.

Consider this layer, that creates an activity regularization loss:

In [26]:
@keras.saving.register_keras_serializable()
class ActivityRegularizationLayer(layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2*tf.reduce_sum(inputs))
        return inputs

### Build the basic model again with ActivityRegularizationLayer

In [27]:
def get_basic_model_with_activity_regularization_layer():
    inputs = keras.Input(shape=(784,), name="digits")
    x = layers.Dense(64, activation="relu")(inputs)
    # Insert activity regularization as a layer
    x = ActivityRegularizationLayer()(x)
    x = layers.Dense(64, activation="relu")(x)
    outputs = layers.Dense(10, name="predictions")(x)
    model = keras.Model(inputs=inputs, outputs=outputs)
    model.summary()
    return model

In [28]:
model_with_activity_regularization_layer = get_basic_model_with_activity_regularization_layer()

Model: "model_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 digits (InputLayer)         [(None, 784)]             0         
                                                                 
 dense_8 (Dense)             (None, 64)                50240     
                                                                 
 activity_regularization_la  (None, 64)                0         
 yer (ActivityRegularizatio                                      
 nLayer)                                                         
                                                                 
 dense_9 (Dense)             (None, 64)                4160      
                                                                 
 predictions (Dense)         (None, 10)                650       
                                                                 
Total params: 55050 (215.04 KB)
Trainable params: 55050 (21

### Training step with ActivityRegularizationLayer

In [29]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model_with_activity_regularization_layer(x, training=True)
        loss_value = loss_fn(y, logits)

        # Add any extra losses created during the forward pass.
        loss_value += sum(model_with_activity_regularization_layer.losses)

    grads = tape.gradient(loss_value, model_with_activity_regularization_layer.trainable_weights)
    optimizer.apply_gradients(zip(grads, model_with_activity_regularization_layer.trainable_weights))
    train_acc_metric.update_state(y, logits)

    return loss_value

### End-to-end GAN (Generative adversarial networks) example with training loop from scratch

A GAN training loop

1. **Train the discriminator**. - Sample a `batch of random points` in the `latent space`. - Turn the points into `fake images` via the "`generator`" model. - Get a `batch of real images` and `combine` them `with` the `generated images`. - Train the "discriminator" model to `classify generated` vs. `real images`.

2. **Train the generator**. - `Sample random points` in the `latent space`. - Turn the points into fake images via the "generator" network. - Get a `batch of real images` and `combine` them with the `generated images`. - Train the "generator" model to "fool" the discriminator and `classify` the `fake images as real`.

In [32]:
INPUT_SHAPE = (28, 28, 1)
LATENT_DIM = 128

In [33]:
def get_discriminator_model(input_shape):
    discriminator = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(64, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.Conv2D(128, (3, 3), strides=(2, 2), padding="same"),
        layers.LeakyReLU(alpha=0.2),
        layers.GlobalMaxPooling2D(),
        layers.Dense(1),
    ],
    name="discriminator",
    )
    discriminator.summary()
    return discriminator

In [34]:
def get_generator_model(latent_dim):
    generator = keras.Sequential(
        [
            keras.Input(shape=(latent_dim,)),
            # We want to generate 128 coefficients to reshape into a 7x7x128 map
            layers.Dense(7 * 7 * 128),
            layers.LeakyReLU(alpha=0.2),
            layers.Reshape((7, 7, 128)),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.Conv2DTranspose(128, (4, 4), strides=(2, 2), padding="same"),
            layers.LeakyReLU(alpha=0.2),
            layers.Conv2D(1, (7, 7), padding="same", activation="sigmoid"),
        ],
        name="generator",
    )
    generator.summary()
    return generator

In [35]:
discriminator_model = get_discriminator_model(INPUT_SHAPE)
generator_model = get_generator_model(LATENT_DIM)

Model: "discriminator"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 14, 14, 64)        640       
                                                                 
 leaky_re_lu (LeakyReLU)     (None, 14, 14, 64)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 7, 7, 128)         73856     
                                                                 
 leaky_re_lu_1 (LeakyReLU)   (None, 7, 7, 128)         0         
                                                                 
 global_max_pooling2d (Glob  (None, 128)               0         
 alMaxPooling2D)                                                 
                                                                 
 dense_10 (Dense)            (None, 1)                 129       
                                                     

In [68]:
# Instantiate one optimizer for the discriminator and another for the generator.
d_optimizer = keras.optimizers.Adam(learning_rate=0.0003)
g_optimizer = keras.optimizers.Adam(learning_rate=0.0004)

# Instantiate a loss function.
loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)

In [69]:
@tf.function
def train_step(real_images):
    # Sample random points in the Latent space
    random_latent_vectors = tf.random.normal(shape=(BATCH_SIZE, LATENT_DIM))

    # Decode samples to fake images
    generated_images = generator_model(random_latent_vectors)

    # Combine generated images with real images
    combined_images = tf.concat([generated_images, real_images], axis=0)

    # Assemble labels discriminating real from fake images
    labels = tf.concat(
        [tf.ones((BATCH_SIZE, 1)), tf.zeros((real_images.shape[0], 1))], axis=0
    )

    # Add random noise to the labels !!!
    labels += 0.05 * tf.random.uniform(labels.shape)

    # Train the discriminator
    with tf.GradientTape() as tape:
        predictions = discriminator_model(combined_images)
        d_loss = loss_fn(labels, predictions)

    grads = tape.gradient(d_loss, discriminator_model.trainable_weights)
    d_optimizer.apply_gradients(zip(grads, discriminator_model.trainable_weights))

    # Sample random points in the latent space as we did before
    random_latent_vectors = tf.random.normal(shape=(BATCH_SIZE, LATENT_DIM))

    # Assemble labels that say "all real images"
    misleading_labels = tf.zeros((BATCH_SIZE, 1))

    # Train the generator
    # Do not update the weights of the discriminator
    with tf.GradientTape() as tape:
        predictions = discriminator_model(generator_model(random_latent_vectors))  # Not combined images, only generated
        g_loss = loss_fn(misleading_labels, predictions)

    grads = tape.gradient(g_loss, generator_model.trainable_weights)
    g_optimizer.apply_gradients(zip(grads, generator_model.trainable_weights))

    return d_loss, g_loss, generated_images

### Load and Prepare Dataset

In [70]:
(x_train, _), (x_test, _) = keras.datasets.mnist.load_data()
all_digits = np.concatenate([x_train, x_test])
all_digits = all_digits.astype("float32") / 255.0
all_digits = np.reshape(all_digits, (-1, 28, 28, 1))
dataset = tf.data.Dataset.from_tensor_slices(all_digits)
dataset = dataset.shuffle(buffer_size=1024).batch(BATCH_SIZE)

### Train the GAN model

In [71]:
import os

save_dir = "./"

for epoch in range(EPOCHS):
    print("\nStart epoch", epoch)

    for step, real_images in enumerate(dataset):
        # Train one bacth of real images
        d_loss, g_loss, generated_images = train_step(real_images)

        if step % 200 == 0:
            # Print metrics
            print("Discriminator loss at step %d: %.2f" % (step, d_loss))
            print("Adversarial loss at step %d: %.2f" % (step, g_loss))

            # Save one generated image
            img = keras.utils.array_to_img(generated_images[0] * 255.0, scale=False)
            img.save(os.path.join(save_dir, "generated_img" + str(step) + ".png"))

        # To limit execution time we stop after 10 steps.
        # Remove the lines below to actually train the model!
        if step > 10:
            break


Start epoch 0
Discriminator loss at step 0: 0.13
Adversarial loss at step 0: 1.88

Start epoch 1
Discriminator loss at step 0: 0.10
Adversarial loss at step 0: 2.28


### Reference

https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch