With the TensorFlow 2.0 release, we now have the GradientTape
function, which makes it easier than ever to write custom training loops for both TensorFlow and Keras models, thanks to automatic differentiation.

Whether you’re a deep learning practitioner or a seasoned researcher, you should learn how to use the GradientTape
function — it allows you to create custom training loops for models implemented in Keras’ easy-to-use API, giving you the best of both worlds. You just can’t beat that combination.

Automatic differentiation (also called computational differentiation) refers to a set of techniques that can automatically compute the derivative of a function by repeatedly applying the chain rule.



Automatic differentiation exploits the fact that every computer program, no matter how complicated, executes a sequence of elementary arithmetic operations (addition, subtraction, multiplication, division, etc.) and elementary functions (exp, log, sin, cos, etc.).

By applying the chain rule repeatedly to these operations, derivatives of arbitrary order can be computed automatically, accurately to working precision, and using at most a small constant factor more arithmetic operations than the original program.


When implementing custom training loops with Keras and TensorFlow, you to need to define, at a bare minimum, four components:

Component 1: The model architecture

Component 2: The loss function used when computing the model loss

Component 3: The optimizer used to update the model weights

Component 4: The step function that encapsulates the forward and backward pass of the network

We begin with our imports from TensorFlow 2.0 and NumPy.

If you inspect carefully, you won’t see GradientTape
; we can access it via tf.GradientTape
. We will be using the MNIST dataset (mnist
) for our example in this tutorial.

In [2]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
import tensorflow as tf
import numpy as np
import time
import sys

Let’s go ahead and build our model using TensorFlow/Keras’ Sequential API:

In [3]:
def build_model(width, height, depth, classes):
	# initialize the input shape and channels dimension to be
	# "channels last" ordering
	inputShape = (height, width, depth)
	chanDim = -1

	# build the model using Keras' Sequential API
	model = Sequential([
		# CONV => RELU => BN => POOL layer set
		Conv2D(16, (3, 3), padding="same", input_shape=inputShape),
		Activation("relu"),
		BatchNormalization(axis=chanDim),
		MaxPooling2D(pool_size=(2, 2)),

		# (CONV => RELU => BN) * 2 => POOL layer set
		Conv2D(32, (3, 3), padding="same"),
		Activation("relu"),
		BatchNormalization(axis=chanDim),
		Conv2D(32, (3, 3), padding="same"),
		Activation("relu"),
		BatchNormalization(axis=chanDim),
		MaxPooling2D(pool_size=(2, 2)),

		# (CONV => RELU => BN) * 3 => POOL layer set
		Conv2D(64, (3, 3), padding="same"),
		Activation("relu"),
		BatchNormalization(axis=chanDim),
		Conv2D(64, (3, 3), padding="same"),
		Activation("relu"),
		BatchNormalization(axis=chanDim),
		Conv2D(64, (3, 3), padding="same"),
		Activation("relu"),
		BatchNormalization(axis=chanDim),
		MaxPooling2D(pool_size=(2, 2)),

		# first (and only) set of FC => RELU layers
		Flatten(),
		Dense(256),
		Activation("relu"),
		BatchNormalization(),
		Dropout(0.5),

		# softmax classifier
		Dense(classes),
		Activation("softmax")
	])

	# return the built model to the calling function
	return model

Here we define our build_model
function used to construct the model architecture (Component #1 of creating a custom training loop). The function accepts the shape parameters for our data:

    width
    and height
    : The spatial dimensions of each input image
    depth
    : The number of channels for our images (1 for grayscale as in the case of MNIST or 3 for RGB color images)
    classes
    : The number of unique class labels in our dataset


Let’s work on Components 2, 3, and 4:

In [4]:
def step(X, y):
	# keep track of our gradients
	with tf.GradientTape() as tape:
		# make a prediction using the model and then calculate the
		# loss
		pred = model(X)
		loss = categorical_crossentropy(y, pred)

	# calculate the gradients using our tape and then update the
	# model weights
	grads = tape.gradient(loss, model.trainable_variables)
	opt.apply_gradients(zip(grads, model.trainable_variables))

Our step
function accepts training images X
and their corresponding class labels y
(in our example, MNIST images and labels).

Now let’s record our gradients (fancy word for derivative) by:

    Gathering predictions on our training data using our model
    pred = model(X)
    Computing the loss
    (Component #2 of creating a custom training loop) on loss = categorical_crossentropy(y, pred)

We then calculate our gradients using tape.gradients
and by passing our loss
and trainable variables grads = tape.gradient(loss, model.trainable_variables)

We use our optimizer to update the model weights using the gradients on opt.apply_gradients(zip(grads, model.trainable_variables)) (Component #3).

The step
function as a whole rounds out Component #4, encapsulating our forward and backward pass of data using our GradientTape
and then updating our model weights.

With both our build_model
and step
functions defined, now we’ll prepare data:

In [5]:
# initialize the number of epochs to train for, batch size, and
# initial learning rate
EPOCHS = 25
BS = 64
INIT_LR = 1e-3

# load the MNIST dataset
print("[INFO] loading MNIST dataset...")
((trainX, trainY), (testX, testY)) = mnist.load_data()

# add a channel dimension to every image in the dataset, then scale
# the pixel intensities to the range [0, 1]
trainX = np.expand_dims(trainX, axis=-1)
testX = np.expand_dims(testX, axis=-1)
trainX = trainX.astype("float32") / 255.0
testX = testX.astype("float32") / 255.0

# one-hot encode the labels
trainY = to_categorical(trainY, 10)
testY = to_categorical(testY, 10)

[INFO] loading MNIST dataset...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


With our data in hand and ready to go, we’ll build our model:

In [6]:
# build our model and initialize our optimizer
print("[INFO] creating model...")
model = build_model(28, 28, 1, 10)
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)

[INFO] creating model...


We’re now ready to train our model with our GradientTape:

In [7]:
# compute the number of batch updates per epoch
numUpdates = int(trainX.shape[0] / BS)

# loop over the number of epochs
for epoch in range(0, EPOCHS):
	# show the current epoch number
	print("[INFO] starting epoch {}/{}...".format(
		epoch + 1, EPOCHS), end="")
	sys.stdout.flush()
	epochStart = time.time()

	# loop over the data in batch size increments
	for i in range(0, numUpdates):
		# determine starting and ending slice indexes for the current
		# batch
		start = i * BS
		end = start + BS

		# take a step
		step(trainX[start:end], trainY[start:end])

	# show timing information for the epoch
	epochEnd = time.time()
	elapsed = (epochEnd - epochStart) / 60.0
	print("took {:.4} minutes".format(elapsed))

[INFO] starting epoch 1/25...took 0.9399 minutes
[INFO] starting epoch 2/25...took 0.7805 minutes
[INFO] starting epoch 3/25...took 0.7648 minutes
[INFO] starting epoch 4/25...took 0.8133 minutes
[INFO] starting epoch 5/25...took 0.7757 minutes
[INFO] starting epoch 6/25...took 0.7758 minutes
[INFO] starting epoch 7/25...took 0.775 minutes
[INFO] starting epoch 8/25...took 0.7712 minutes
[INFO] starting epoch 9/25...took 0.782 minutes
[INFO] starting epoch 10/25...took 0.7765 minutes
[INFO] starting epoch 11/25...took 0.7596 minutes
[INFO] starting epoch 12/25...took 0.7576 minutes
[INFO] starting epoch 13/25...took 0.7569 minutes
[INFO] starting epoch 14/25...took 0.7557 minutes
[INFO] starting epoch 15/25...took 0.7568 minutes
[INFO] starting epoch 16/25...took 0.78 minutes
[INFO] starting epoch 17/25...took 0.7532 minutes
[INFO] starting epoch 18/25...took 0.7729 minutes
[INFO] starting epoch 19/25...took 0.7836 minutes
[INFO] starting epoch 20/25...took 0.7806 minutes
[INFO] starti

Finally, we’ll calculate the loss and accuracy on the testing set:

In [8]:
# in order to calculate accuracy using Keras' functions we first need
# to compile the model
model.compile(optimizer=opt, loss=categorical_crossentropy,
	metrics=["acc"])

# now that the model is compiled we can compute the accuracy
(loss, acc) = model.evaluate(testX, testY)
print("[INFO] test accuracy: {:.4f}".format(acc))

[INFO] test accuracy: 0.9930


tf.GradientTape allows us to track TensorFlow computations and calculate gradients w.r.t. (with respect to) some given variables

In [9]:
x = tf.constant(3.0)
with tf.GradientTape() as tape:
    tape.watch(x)
    y = x**3

print(tape.gradient(y, x).numpy())

27.0


By default, GradientTape doesn’t track constants, so we must instruct it to with: tape.watch(variable)

Then we can perform some computation on the variables we are watching. The computation can be anything from cubing it, x**3, to passing it through a neural network

We calculate gradients of a calculation w.r.t. a variable with tape.gradient(target, sources). Note, tape.gradient returns an EagerTensor that you can convert to ndarray format with .numpy()

If at any point, we want to use multiple variables in our calculations, all we need to do is give tape.gradient a list or tuple of those variables. When we optimize Keras models, we pass model.trainable_variables as our variable list.

Automatically Watching Variables

If x were a trainable variable instead of a constant, there would be no need to tell the tape to watch it—GradientTape automatically watches all trainable variables.

In [10]:
x = tf.Variable(3.0, trainable=True)
with tf.GradientTape() as tape:
    y = x**3

print(tape.gradient(y, x).numpy())

27.0


watch_accessed_variables=False

If we don’t want GradientTape to watch all trainable variables automatically, we can set the tape’s watch_accessed_variables parameter to False:

Disabling watch_accessed_variables gives us fine control over what variables we want to watch.

If you have a lot of trainable variables and are not optimizing them all at once, You may want to disable watch_accessed_variables to protect yourself from mistakes.

In [11]:
x = tf.Variable(3.0, trainable=True)
with tf.GradientTape(watch_accessed_variables=False) as tape:
    y = x**3

print(tape.gradient(y, x))

None


Higher-Order Derivatives

If you want to compute higher-order derivatives, you can use nested GradientTapes:

Higher-order derivatives is generally the only time when you would want to compute gradients inside a GradientTape object. Otherwise, it will slow done computations as the GradientTape is watching every computation done in the gradient.

In [12]:
x = tf.Variable(3.0, trainable=True)
with tf.GradientTape() as tape1:
    with tf.GradientTape() as tape2:
        y = x ** 3
    order_1 = tape2.gradient(y, x)
order_2 = tape1.gradient(order_1, x)

print(order_2.numpy())

18.0


In [None]:
persistent=True

If we were to run the following:

a = tf.Variable(6.0, trainable=True)
b = tf.Variable(2.0, trainable=True)
with tf.GradientTape() as tape:
    y1 = a ** 2
    y2 = b ** 3

print(tape.gradient(y1, a).numpy())
print(tape.gradient(y2, b).numpy())

But in reality, calling tape.gradient a second time will raise an error.

This is because immediately after calling tape.gradient, the GradientTape releases all the information stored inside of it for computational purposes.

If we want to bypass this, we can set persistent=True

In [13]:
a = tf.Variable(6.0, trainable=True)
b = tf.Variable(2.0, trainable=True)
with tf.GradientTape(persistent=True) as tape:
    y1 = a ** 2
    y2 = b ** 3
print(tape.gradient(y1, a).numpy())
print(tape.gradient(y2, b).numpy())

12.0
12.0


stop_recording()

tape.stop_recording() temporarily pauses the tapes recording, leading to greater computation speed.

In my opinion, in long functions, it is more readable to use stop_recording blocks multiple times to calculate gradients in the middle of a function, than to calculate all the gradients at the end of a function.

For example, I prefer:

In [14]:
a = tf.Variable(6.0, trainable=True)
b = tf.Variable(2.0, trainable=True)
with tf.GradientTape(persistent=True) as tape:
    y1 = a ** 2
    with tape.stop_recording():
        print(tape.gradient(y1, a).numpy())
    
    y2 = b ** 3
    with tape.stop_recording():
        print(tape.gradient(y2, b).numpy())

12.0
12.0


Other Methods

Although I won’t go into detail here, GradientTape has a few other handy methods, including:

    .jacobian: “Computes the jacobian using operations recorded in context of this tape.”
    .batch_jacobian: “Computes and stacks per-example jacobians.”
    .reset: “Clears all information stored in this tape.”
    .watched_variables: “Returns variables watched by this tape in order of construction.”

All above information quoted from the GradientTape documentation.

Advanced Uses
2.0 — Linear Regression

To start off the more advanced uses of GradientTape, let’s look at a classic “Hello World!” to ML: linear regression.

First, we start by defining a few essential variables and functions.

In [15]:
import numpy as np
import random

# Loss function
def loss(real_y, pred_y):
    return tf.abs(real_y - pred_y)

# Training data
x_train = np.asarray([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
y_train = np.asarray([i*10+5 for i in x_train]) # y = 10x+5

# Trainable variables
a = tf.Variable(random.random(), trainable=True)
b = tf.Variable(random.random(), trainable=True)

Then, we can go ahead and define our step function. The step function will be run every epoch to update the trainable variables, a and b.

In [18]:
def step(real_x, real_y):
    with tf.GradientTape(persistent=True) as tape:
        # Make prediction
        pred_y = a * real_x + b
        # Calculate loss
        reg_loss = loss(real_y, pred_y)
    
    # Calculate gradients
    a_gradients, b_gradients = tape.gradient(reg_loss, (a, b))

    # Update variables
    a.assign_sub(a_gradients * 0.001)
    b.assign_sub(b_gradients * 0.001)

In [19]:
for _ in range(100000):
    step(x_train, y_train)

print(f'y ≈ {a.numpy()}x + {b.numpy()}')

y ≈ 9.993399620056152x + 4.990865230560303
