In [27]:
# custom training loop
# no custom error
# with and without metric

# https://www.tensorflow.org/guide/keras/writing_a_training_loop_from_scratch/

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
from tensorflow.keras.utils import to_categorical
import time


batch_size = 64
num_classes = 10
epochs = 10

# Prepare the training dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

# convert class vectors to binary class matrices
# Needed for categorical cross entropy
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
print("after to_categorical", np.shape(y_train))

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=len(x_train)).batch(batch_size)
print(train_dataset)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size)
print(train_dataset)


# simple mnist model
# 784 = 28x28
inputs = keras.Input(shape=(28,28), name="digits")
x0 = layers.Flatten()(inputs)
x1 = layers.Dense(64, activation="relu")(x0)
x2 = layers.Dense(64, activation="relu")(x1)
outputs = layers.Dense(10, activation='softmax', name="predictions")(x2)
model = keras.Model(inputs=inputs, outputs=outputs)


after to_categorical (60000, 10)
<BatchDataset shapes: ((None, 28, 28), (None, 10)), types: (tf.float32, tf.float32)>
<BatchDataset shapes: ((None, 28, 28), (None, 10)), types: (tf.float32, tf.float32)>


In [28]:
# Instantiate an optimizer.
optimizer = keras.optimizers.Adam(learning_rate=1e-4)
# Instantiate a loss function.
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)


# training loop
"""
We open a for loop that iterates over epochs
For each epoch, we open a for loop that iterates over the dataset, in batches
For each batch, we open a GradientTape() scope
Inside this scope, we call the model (forward pass) and compute the loss
Outside the scope, we retrieve the gradients of the weights of the model with regard to the loss
Finally, we use the optimizer to update the weights of the model based on the gradients
"""
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            logits = model(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * batch_size))


Start of epoch 0
Training loss (for one batch) at step 0: 2.3050
Seen so far: 64 samples
Training loss (for one batch) at step 200: 2.0354
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.8303
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.6855
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 1.6320
Seen so far: 51264 samples

Start of epoch 1
Training loss (for one batch) at step 0: 1.6408
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.5265
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.6103
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.5461
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 1.6101
Seen so far: 51264 samples

Start of epoch 2
Training loss (for one batch) at step 0: 1.5641
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.5694
Seen so far: 12864 samples
Training loss (for one batch) at step

KeyboardInterrupt: 

In [None]:
# training loop with metric

"""
Instantiate the metric at the start of the loop
Call metric.update_state() after each batch
Call metric.result() when you need to display the current value of the metric
Call metric.reset_states() when you need to clear the state of the metric (typically at the end of an epoch)
"""

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

import time

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))

In [None]:
""" with metric and seperate train_step & @tf.function """

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)

import time

for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 200 batches.
        if step % 200 == 0:
            print("Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))