## Writing checkpoints manually

The persistent state of a TensorFlow model is stored in `tf.Variable` objects. These can be constructed directly, but are often created through high-level APIs like `tf.keras.layers`.

The easiest way to manage variables is by attaching them to Python objects, then referencing those objects. Subclasses of `tf.train.Checkpoint`, `tf.keras.layers.Layer`, and `tf.keras.Model` automatically track variables assigned to their attributes. The following example constructs a simple linear model, then writes checkpoints which contain values for all of the model's variables.

In [1]:
import tensorflow as tf

In [2]:
class Net(tf.keras.Model):
    """A simple linear model."""
    
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = tf.keras.layers.Dense(5)
        
    def call(self, x):
        return self.l1(x)

Although it's not the focus of this guide, to be executable the example needs data and an optimization step. The model will train on slices of an in-memory dataset.

In [3]:
def toy_dataset():
    inputs = tf.range(10.)[:, None]
    labels = inputs * 5. + tf.range(5.)[None, :]
    return tf.data.Dataset.from_tensor_slices(
        dict(x=inputs, y=labels)).repeat(10).batch(2)

In [4]:
def train_step(net, example, optimizer):
    """Trains `net` on `example` using `optimizer`."""
    with tf.GradientTape() as tape:
        output = net(example['x'])
        loss = tf.reduce_mean(tf.abs(output - example['y']))
    variables = net.trainable_variables
    gradients = tape.gradient(loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))
    return loss

The following training loop creates an instance of the model and of an optimizer, then gathers them into a `tf.train.Checkpoint` object. It calls the training step in a loop on each batch of data, and periodically writes checkpoints to disk.

In [5]:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")
    
for example in toy_dataset():
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
        save_path = manager.save()
        print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
        print("loss {:1.2f}".format(loss.numpy()))

Initializing from scratch.
Saved checkpoint for step 10: ./tf_ckpts\ckpt-1
loss 26.75
Saved checkpoint for step 20: ./tf_ckpts\ckpt-2
loss 20.17
Saved checkpoint for step 30: ./tf_ckpts\ckpt-3
loss 13.61
Saved checkpoint for step 40: ./tf_ckpts\ckpt-4
loss 7.15
Saved checkpoint for step 50: ./tf_ckpts\ckpt-5
loss 2.11


The preceding snippet will randomly initialize the model variables when it first runs. After the first run it will resume training from where it left off:

In [6]:
opt = tf.keras.optimizers.Adam(0.1)
net = Net()
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net)
manager = tf.train.CheckpointManager(ckpt, './tf_ckpts', max_to_keep=3)
ckpt.restore(manager.latest_checkpoint)
if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
else:
    print("Initializing from scratch.")

for example in toy_dataset():
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
        save_path = manager.save()
        print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
        print("loss {:1.2f}".format(loss.numpy()))

Restored from ./tf_ckpts\ckpt-5
Saved checkpoint for step 60: ./tf_ckpts\ckpt-6
loss 1.16
Saved checkpoint for step 70: ./tf_ckpts\ckpt-7
loss 1.23
Saved checkpoint for step 80: ./tf_ckpts\ckpt-8
loss 0.69
Saved checkpoint for step 90: ./tf_ckpts\ckpt-9
loss 0.59
Saved checkpoint for step 100: ./tf_ckpts\ckpt-10
loss 0.32


The `tf.train.CheckpointManager` object deletes old checkpoints. Above it's configured to keep only the three most recent checkpoints.

In [7]:
print(manager.checkpoints) # list the three remaining checkpoints

['./tf_ckpts\\ckpt-8', './tf_ckpts\\ckpt-9', './tf_ckpts\\ckpt-10']


These paths, e.g. `'./tf_ckpts/ckpt-10'`, are not files on disk. Instead they are prefixes for an `index` file and one or more data files which contain the variable values. These prefixes are grouped together in a single `checkpoint` file (`'./tf_ckpts/checkpoint'`) where the `CheckpointManager` saves its state.

In [8]:
import os
os.listdir("./tf_ckpts")

['checkpoint',
 'ckpt-10.data-00000-of-00002',
 'ckpt-10.data-00001-of-00002',
 'ckpt-10.index',
 'ckpt-8.data-00000-of-00002',
 'ckpt-8.data-00001-of-00002',
 'ckpt-8.index',
 'ckpt-9.data-00000-of-00002',
 'ckpt-9.data-00001-of-00002',
 'ckpt-9.index']

<a id="loading_mechanics"/>
## Loading mechanics

TensorFlow matches variables to checkpointed values by traversing a directed graph with named edges, starting from the object being loaded. Edge names typically come from attribute names in objects, for example the `"l1"` in `self.l1 = tf.keras.layers.Dense(5)`. `tf.train.Checkpoint` uses its keyword argument names, as in the `"step"` in `tf.train.Checkpoint(step=...)`.

The dependency graph from the example above looks like this:

![Visualization of the dependency graph for the example training loop](http://tensorflow.org/images/guide/whole_checkpoint.svg)

With the optimizer in red, regular variables in blue, and optimizer slot variables in orange. The other nodes, for example representing the `tf.train.Checkpoint`, are black.

Slot variables are part of the optimizer's state, but are created for a specific variable. For example the `'m'` edges above correspond to momentum, which the Adam optimizer tracks for each variable. Slot variables are only saved in a checkpoint if the variable and the optimizer would both be saved, thus the dashed edges.

Calling `restore()` on a `tf.train.Checkpoint` object queues the requested restorations, restoring variable values as soon as there's a matching path from the `Checkpoint` object. For example we can load just the kernel from the model we defined above by reconstructing one path to it through the network and the layer.

In [9]:
to_restore = tf.Variable(tf.zeros([5]))
print(to_restore.numpy()) # all zeros
fake_layer = tf.train.Checkpoint(bias=to_restore)
fake_net = tf.train.Checkpoint(l1=fake_layer)
new_root = tf.train.Checkpoint(net=fake_net)
status = new_root.restore(tf.train.latest_checkpoint('./tf_ckpts/'))
print(to_restore.numpy())

[0. 0. 0. 0. 0.]
[3.0906975 2.115607  2.7918575 2.8857708 4.059075 ]


The dependency graph for these new objects is a much smaller subgraph of the larger checkpoint we wrote above. It includes only the bias and a save counter that `tf.train.Checkpoint` uses to number checkpoints.

![Visualization of a subgraph for the bias variable](http://tensorflow.org/images/guide/partial_checkpoint.svg)

`restore()` returns a status object, which has optional assertions. All of the objects we've created in our new `Checkpoint` have been restored, so `status.assert_existing_objects_matched()` passes.

In [10]:
status.assert_existing_objects_matched()

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x1beab4439b0>

There are many objects in the checkpoint which haven't matched, including the layer's kernel and the optimizer's variables. `status.assert_consumed()` only passes if the checkpoint and the program match exactly, and would throw an exception here.

### Delayed restorations

`Layer` objects in TensorFlow may delay the creation of variables to their first call, when input shapes are available. For example the shape of a `Dense` layer's kernel depends on both the layer's input and output shapes, and so the output shape required as a constructor argument is not enough information to create the variable on its own. Since calling a `Layer` also reads the variable's value, a restore must happen between the variable's creation and its first use.

To support this idiom, `tf.train.Checkpoint` queues restores which don't yet have a matching variable.

In [11]:
delayed_restore = tf.Variable(tf.zeros([1, 5]))
print(delayed_restore.numpy())
fake_layer.kernel = delayed_restore
print(delayed_restore.numpy())

[[0. 0. 0. 0. 0.]]
[[4.5674243 4.8244634 4.8828235 5.0211086 4.982023 ]]


### Manually inspecting checkpoints

`tf.train.list_variables` lists the checkpoint keys and shapes of variables in a checkpoint. Checkpoint keys are paths in the graph displayed above.

In [12]:
tf.train.list_variables(tf.train.latest_checkpoint('./tf_ckpts/'))

[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('net/l1/bias/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/bias/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE', [5]),
 ('net/l1/kernel/.ATTRIBUTES/VARIABLE_VALUE', [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/m/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('net/l1/kernel/.OPTIMIZER_SLOT/optimizer/v/.ATTRIBUTES/VARIABLE_VALUE',
  [1, 5]),
 ('optimizer/beta_1/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/beta_2/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/decay/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/iter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('optimizer/learning_rate/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('save_counter/.ATTRIBUTES/VARIABLE_VALUE', []),
 ('step/.ATTRIBUTES/VARIABLE_VALUE', [])]

### List and dictionary tracking

As with direct attribute assignments like `self.l1 = tf.keras.layers.Dense(5)`, assigning lists and dictionaries to attributes will track their contents.

In [13]:
save = tf.train.Checkpoint()
save.listed = [tf.Variable(1.)]
save.listed.append(tf.Variable(2.))
save.mapped = {'one': save.listed[0]}
save.mapped['two'] = save.listed[1]
save_path = save.save('./tf_list_example')

restore = tf.train.Checkpoint()
v2 = tf.Variable(0.)
assert 0. == v2.numpy() # not restored yet
restore.mapped = {'two': v2}
restore.restore(save_path)
assert 2. == v2.numpy()

You may notice wrapper objects for lists and dictionaries. These wrappers are checkpointable versions of the underlying data-structures. Just like the attribute based loading, these wrappers restore a variable's value as soon as it's added to the container.

In [14]:
restore.listed = []
print(restore.listed) # ListWrapper([])
v1 = tf.Variable(0.)
restore.listed.append(v1)
assert 1. == v1.numpy()

ListWrapper([])


The same tracking is automatically applied to subclasses of `tf.keras.Model`, and may be used for example to track lists of layers.