# Loop from scratch

In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np

## Using the GradientTape: a first end-to-end example

In [2]:
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28*28)
x_test = x_test.reshape(-1, 28*28)

In [3]:
x_val, y_val = x_train[:10000], y_train[:10000]
x_train, y_train = x_train[10000:], y_train[10000:]

In [5]:
inputs = keras.Input(shape=(784,), name='digits')
x1 = layers.Dense(64, activation='relu')(inputs)
x2 = layers.Dense(64, activation='relu')(inputs)
outputs =layers.Dense(10, name='predictions')(x2)
model = keras.Model(inputs=inputs, outputs=outputs)

In [6]:
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_f = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

batch_size = 64

In [7]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)

In [8]:
epochs = 2
for epoch in range(epochs):
    print('Start epoch %d' % epoch)
    
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            
            loss_value = loss_f(y_batch_train, logits)
            
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))
        
        if step % 200:
            print("Training loss (for one batch) at step %d: %.4f" % (step, float(loss_value)))
            print("Seen so far: %s samples" % ((step + 1) * batch_size))
            

Start epoch 0
Training loss (for one batch) at step 1: 232.9317
Seen so far: 128 samples
Training loss (for one batch) at step 2: 155.1023
Seen so far: 192 samples
Training loss (for one batch) at step 3: 95.5584
Seen so far: 256 samples
Training loss (for one batch) at step 4: 39.2264
Seen so far: 320 samples
Training loss (for one batch) at step 5: 18.6083
Seen so far: 384 samples
Training loss (for one batch) at step 6: 10.5849
Seen so far: 448 samples
Training loss (for one batch) at step 7: 13.4401
Seen so far: 512 samples
Training loss (for one batch) at step 8: 7.9432
Seen so far: 576 samples
Training loss (for one batch) at step 9: 7.5258
Seen so far: 640 samples
Training loss (for one batch) at step 10: 6.6949
Seen so far: 704 samples
Training loss (for one batch) at step 11: 6.2490
Seen so far: 768 samples
Training loss (for one batch) at step 12: 3.3532
Seen so far: 832 samples
Training loss (for one batch) at step 13: 3.1668
Seen so far: 896 samples
Training loss (for one b

## Low-level handling of metrics

Let's add metrics monitoring to this basic loop.

You can readily reuse the built-in metrics (or custom ones you wrote) in such training loops written from scratch. Here's the flow:

Instantiate the metric at the start of the loop
- Call metric.update_state() after each batch
- Call metric.result() when you need to display the current value of the metric
- Call metric.reset_states() when you need to clear the state of the metric (typically at the end of an epoch)

Let's use this knowledge to compute SparseCategoricalAccuracy on validation data at the end of each epoch:

In [9]:
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu", name="dense_1")(inputs)
x = layers.Dense(64, activation="relu", name="dense_2")(x)
outputs = layers.Dense(10, name="predictions")(x)
model = keras.Model(inputs=inputs, outputs=outputs)

optimizer = keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

In [10]:
import time

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))


Start of epoch 0
Training loss (for one batch) at step 0: 129.7590
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.3273
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.9911
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.7046
Seen so far: 38464 samples
Training acc over epoch: 0.7297
Validation acc: 0.8133
Time taken: 5.11s

Start of epoch 1
Training loss (for one batch) at step 0: 0.7668
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.6095
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.2927
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.4104
Seen so far: 38464 samples
Training acc over epoch: 0.8377
Validation acc: 0.8635
Time taken: 4.45s


## Speeding-up your training step with tf.function

In [11]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

In [12]:
@tf.function
def test_step(x, y):
    val_logits = model(x, training=False)
    val_acc_metric.update_state(y, val_logits)

In [13]:
import time

epochs = 2
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))


Start of epoch 0
Training loss (for one batch) at step 0: 0.2368
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.5075
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.4994
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.4527
Seen so far: 38464 samples
Training acc over epoch: 0.8669
Validation acc: 0.8790
Time taken: 1.25s

Start of epoch 1
Training loss (for one batch) at step 0: 0.2381
Seen so far: 64 samples
Training loss (for one batch) at step 200: 0.4733
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 0.4445
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 0.5088
Seen so far: 38464 samples
Training acc over epoch: 0.8843
Validation acc: 0.8901
Time taken: 0.91s


# Low-level handling of losses tracked by the model

In [14]:
class ActivityRegularizationLayer(layers.Layer):
    def call(self, inputs):
        self.add_loss(1e-2 * tf.reduce_sum(inputs))
        return inputs

In [15]:
inputs = keras.Input(shape=(784,), name="digits")
x = layers.Dense(64, activation="relu")(inputs)
# Insert activity regularization as a layer
x = ActivityRegularizationLayer()(x)
x = layers.Dense(64, activation="relu")(x)
outputs = layers.Dense(10, name="predictions")(x)

model = keras.Model(inputs=inputs, outputs=outputs)

In [16]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        # Add any extra losses created during the forward pass.
        loss_value += sum(model.losses)
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value