## TensorFlow Playground

### TensorFlow *tf.train.latest_checkpoint* test

@Date    : Nov-27-20 01:17

@Author  : Kelly Hwong (dianhuangkan@gmail.com)

In [1]:
import tensorflow as tf
tf.__version__

'2.1.0'

In [2]:
class Net(tf.keras.Model):
  """A simple linear model."""

  def __init__(self):
    super(Net, self).__init__()
    self.l1 = tf.keras.layers.Dense(5)

  def call(self, x):
    return self.l1(x)

In [3]:
net = Net()

In [4]:
def toy_dataset():
  inputs = tf.range(10.)[:, None]
  labels = inputs * 5. + tf.range(5.)[None, :]
  return tf.data.Dataset.from_tensor_slices(
    dict(x=inputs, y=labels)).repeat().batch(2)

In [5]:
def train_step(net, example, optimizer):
  """Trains `net` on `example` using `optimizer`."""
  with tf.GradientTape() as tape:
    output = net(example['x'])
    loss = tf.reduce_mean(tf.abs(output - example['y']))
  variables = net.trainable_variables
  gradients = tape.gradient(loss, variables)
  optimizer.apply_gradients(zip(gradients, variables))
  return loss

In [17]:
opt = tf.keras.optimizers.Adam(0.1)
dataset = toy_dataset()
iterator = iter(dataset)
ckpt = tf.train.Checkpoint(step=tf.Variable(1), optimizer=opt, net=net, iterator=iterator)
manager = tf.train.CheckpointManager(ckpt, os.path.join(".", "tf_ckpts"), max_to_keep=3)

In [18]:
def train_and_checkpoint(net, manager):
  ckpt.restore(manager.latest_checkpoint)
  if manager.latest_checkpoint:
    print("Restored from {}".format(manager.latest_checkpoint))
  else:
    print("Initializing from scratch.")

  for _ in range(50):
    example = next(iterator)
    loss = train_step(net, example, opt)
    ckpt.step.assign_add(1)
    if int(ckpt.step) % 10 == 0:
      save_path = manager.save()
      print("Saved checkpoint for step {}: {}".format(int(ckpt.step), save_path))
      print("loss {:1.2f}".format(loss.numpy()))

In [19]:
train_and_checkpoint(net, manager)

Restored from .\tf_ckpts\ckpt-5
Saved checkpoint for step 60: .\tf_ckpts\ckpt-6
loss 0.84
Saved checkpoint for step 70: .\tf_ckpts\ckpt-7
loss 0.93
Saved checkpoint for step 80: .\tf_ckpts\ckpt-8
loss 0.30
Saved checkpoint for step 90: .\tf_ckpts\ckpt-9
loss 0.41
Saved checkpoint for step 100: .\tf_ckpts\ckpt-10
loss 0.27


In [20]:
import tempfile
tmpdir = tempfile.mkdtemp()
ckpt_dir = os.path.join(tmpdir, "ckpts")
os.makedirs(ckpt_dir)
print(f"ckpt_dir: {ckpt_dir}")

ckpt_dir: C:\Users\KELLYH~1\AppData\Local\Temp\tmpjuu3ywaq\ckpts


In [21]:
ckpt_dir

'C:\\Users\\KELLYH~1\\AppData\\Local\\Temp\\tmpjuu3ywaq\\ckpts'

In [22]:
# net.save(os.path.join(ckpt_dir, "easy_checkpoint"))
net.save_weights(os.path.join(ckpt_dir, "easy_checkpoint"))
os.listdir(ckpt_dir)

['checkpoint',
 'easy_checkpoint.data-00000-of-00002',
 'easy_checkpoint.data-00001-of-00002',
 'easy_checkpoint.index']

In [13]:
tf.train.latest_checkpoint(checkpoint_dir=os.path.join(".", "tf_ckpts"))

'.\\tf_ckpts\\ckpt-5'

In [25]:
ret = tf.train.latest_checkpoint(checkpoint_dir="ckpt_dir")

In [28]:
type(ret)

NoneType

In [29]:
ret == None

True