In [9]:
from tensorflow import keras
from tensorflow.keras import layers, Model
import tensorflow as tf

a supervised-learning training step ends up
looking like this:

In [2]:
def train_step(inputs, targets):
 with tf.GradientTape() as tape:
  predictions = model(inputs, training=True)
  loss = loss_fn(targets, predictions)
 gradients = tape.gradients(loss, model.trainable_weights)
 optimizer.apply_gradients(zip(model.trainable_weights, gradients))

In [3]:
from tensorflow.keras.datasets import mnist
def get_mnist_model():
 inputs = keras.Input(shape=(28 * 28,))
 features = layers.Dense(512, activation="relu")(inputs)
 features = layers.Dropout(0.5)(features)
 outputs = layers.Dense(10, activation="softmax")(features)
 model = keras.Model(inputs, outputs)
 return model

In [4]:
(images, labels), (test_images, test_labels) = mnist.load_data()
images = images.reshape((60000, 28 * 28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [5]:
model = get_mnist_model()

loss_fn = keras.losses.SparseCategoricalCrossentropy()

optimizer = keras.optimizers.RMSprop()

metrics = [keras.metrics.SparseCategoricalAccuracy()]

loss_tracking_metric = keras.metrics.Mean() # Prepare a Mean metric tracker to keep track of the loss average

In [6]:
def train_step(inputs, targets):
 with tf.GradientTape() as tape:
  predictions = model(inputs, training=True) # Run the forward pass. Note that we pass training=True.
  loss = loss_fn(targets, predictions)
 gradients = tape.gradient(loss, model.trainable_weights) # Run the backward pass. Note that we use model.trainable_weights.
 optimizer.apply_gradients(zip(gradients, model.trainable_weights))
 logs = {}
 for metric in metrics: # Keep track of metrics.
  metric.update_state(targets, predictions)
  logs[metric.name] = metric.result()
 loss_tracking_metric.update_state(loss)   # Keep track of the loss average.
 logs["loss"] = loss_tracking_metric.result()
 return logs

We will need to reset the state of our metrics at the start of each epoch and before running evaluation. Here’s a utility function to do it.

In [7]:
def reset_metrics():
 for metric in metrics:
  metric.reset_state()
 loss_tracking_metric.reset_state()

We can now lay out our complete training loop. Note that we use a tf.data.Dataset
object to turn our NumPy data into an iterator that iterates over the data in batches of
size 32.

In [10]:
training_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 3
for epoch in range(epochs):
 reset_metrics()
 for inputs_batch, targets_batch in training_dataset:
  logs = train_step(inputs_batch, targets_batch)
 print(f"Results at the end of epoch {epoch}")
 for key, value in logs.items():
  print(f"...{key}: {value:.4f}")


Results at the end of epoch 0
...sparse_categorical_accuracy: 0.9132
...loss: 0.2899
Results at the end of epoch 1
...sparse_categorical_accuracy: 0.9532
...loss: 0.1602
Results at the end of epoch 2
...sparse_categorical_accuracy: 0.9633
...loss: 0.1274


And here’s the evaluation loop: a simple for loop that repeatedly calls a test_step()
function, which processes a single batch of data. The test_step() function is just a subset of the logic of train_step(). It omits the code that deals with updating the weights
of the model—that is to say, everything involving the GradientTape and the optimizer.

In [13]:
@tf.function
def test_step(inputs, targets):
 predictions = model(inputs, training=False) #Note that we pass training=False.
 loss = loss_fn(targets, predictions)
 logs = {}
 for metric in metrics:
  metric.update_state(targets, predictions)
  logs["val_" + metric.name] = metric.result()
 loss_tracking_metric.update_state(loss)
 logs["val_loss"] = loss_tracking_metric.result()
 return logs

In [14]:
val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
val_dataset = val_dataset.batch(32)
reset_metrics()
for inputs_batch, targets_batch in val_dataset:
 logs = test_step(inputs_batch, targets_batch)
print("Evaluation results:")
for key, value in logs.items():
 print(f"...{key}: {value:.4f}")

Evaluation results:
...val_sparse_categorical_accuracy: 0.9688
...val_loss: 0.1194


What if you need a custom training algorithm, but you still want to leverage the
power of the built-in Keras training logic? There’s actually a middle ground between
fit() and a training loop written from scratch: you can provide a custom training
step function and let the framework do the rest.

 You can do this by overriding the train_step() method of the Model class. This is
the function that is called by fit() for every batch of data. You will then be able to call
fit() as usual, and it will be running your own learning algorithm under the hood.

Here’s a simple example:

 We create a new class that subclasses keras.Model.

 We override the method train_step(self, data). Its contents are nearly identical to what we used in the previous section. It returns a dictionary mapping
metric names (including the loss) to their current values.

 We implement a metrics property that tracks the model’s Metric instances.
This enables the model to automatically call reset_state() on the model’s
metrics at the start of each epoch and at the start of a call to evaluate(), so you
don’t have to do it by hand.

In [15]:
# This metric object will be used to track the average of per-batch losses during training and evaluation.
loss_fn = keras.losses.SparseCategoricalCrossentropy()
loss_tracker = keras.metrics.Mean(name="loss")


In [28]:
class CustomModel(keras.Model):
 def train_step(self, data): # We override the train_step method.
  inputs, targets = data
  with tf.GradientTape() as tape:
    predictions = self(inputs, training=True) # We use self(inputs, training=True) instead of model(inputs,training=True), since our model is the class itself.
    loss = loss_fn(targets, predictions)
  gradients = tape.gradient(loss, model.trainable_weights)
  self.optimizer.apply_gradients(zip(gradients, model.trainable_weights))
  loss_tracker.update_state(loss)
  return {"loss": loss_tracker.result()}
 @property
 def metrics(self):  # Any metric you would like to reset across epochs should be listed here.
  return [loss_tracker]

We can now instantiate our custom model, compile it (we only pass the optimizer, since
the loss is already defined outside of the model), and train it using fit() as usual

In [20]:
model = get_mnist_model()

In [29]:
inputs = keras.Input(shape=(28 * 28,))
features = layers.Dense(512, activation="relu")(inputs)
features = layers.Dropout(0.5)(features)
outputs = layers.Dense(10, activation="softmax")(features)
model = CustomModel(inputs, outputs)

In [30]:
model.compile(optimizer=keras.optimizers.RMSprop())
model.fit(train_images, train_labels, epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x79dab46929e0>

 After
you’ve called compile(), you get access to the following:

 self.compiled_loss—The loss function you passed to compile().

 self.compiled_metrics—A wrapper for the list of metrics you passed, which
allows you to call self.compiled_metrics.update_state() to update all of
your metrics at once.

 self.metrics—The actual list of metrics you passed to compile(). Note that it
also includes a metric that tracks the loss, similar to what we did manually with
our loss_tracking_metric earlier.

We can thus write

In [26]:
class CustomModel(keras.Model):
 def train_step(self, data):
  inputs, targets = data
  with tf.GradientTape() as tape:
    predictions = self(inputs, training=True)
    loss = self.compiled_loss(targets, predictions) # Compute the loss via self.compiled_loss.
  gradients = tape.gradient(loss, model.trainable_weights)
  self.optimizer.apply_gradients(zip(gradients, model.trainable_weights)) # By using self.optimizer instead of the global optimizer variable, you ensure that the optimizer is aware of the variables within your custom model.
  self.compiled_metrics.update_state(targets, predictions)  # Update the model’s metrics via self.compiled_metrics.
  return {m.name: m.result() for m in self.metrics} # Return a dict mapping metric names to their current value.

In [27]:
inputs = keras.Input(shape=(28 * 28,))
features = layers.Dense(512, activation="relu")(inputs)
features = layers.Dropout(0.5)(features)
outputs = layers.Dense(10, activation="softmax")(features)
model = CustomModel(inputs, outputs)
model.compile(optimizer=keras.optimizers.RMSprop(),
 loss=keras.losses.SparseCategoricalCrossentropy(),
 metrics=[keras.metrics.SparseCategoricalAccuracy()])
model.fit(train_images, train_labels, epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.src.callbacks.History at 0x79dab473d3c0>