In [1]:
import os
import tensorflow as tf

## Checkpoint files
When creating a tensorflow checkpoint file, it actually creates several accompanying files along with it. It is, therefore, a good idea to place the checkpoints in a dedicated subdirectory, to keep all the related files nicely organized.

So let's start by creating a subdirectory called "checkpoints", and specifying the path of the checkpoint file to be "checkpoints/checkpoint.chk".

In [7]:
# Specify the name of the checkpoints directory
checkpoint_dir = "checkpoints"

# Create the directory if it does not already exist
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

# Specify the path to the checkpoint file
checkpoint_file = os.path.join(checkpoint_dir, "checkpoint.chk")

## Saving and Restoring Operations
In order to actually save and restore checkpoints, we need to create a saver operation in the tensorflow graph using tf.train.Saver()

In [10]:
# CREATE THE GRAPH
graph = tf.Graph()
with graph.as_default():
    tf_w1 = tf.Variable(tf.constant(1, shape=[2, 3]), name="weights_1")
    tf_w2 = tf.Variable(tf.constant(1, shape=[2, 3]), name="weights_2")
    update_vars = tf_w1.assign(tf_w1 + tf_w2) # update the value of w1

    # Create a Saver Object
    saver = tf.train.Saver(name="saver")

In [18]:
# RUN THE SESSION
with tf.Session(graph=graph) as session:
    # Initialize Variables
    if tf.train.checkpoint_exists(checkpoint_file):
        print("Restoring from file: ", checkpoint_file)
        saver.restore(session, checkpoint_file)
    else:
        print("Initializing from scratch")
        session.run(tf.global_variables_initializer())

    # RUN THE GRAPH - updating the variables
    session.run(update_vars)
    w1_val = session.run(tf_w1)
    print("Value of w1 a after running: \n", w1_val)

    # Save a snapshot of the variables
    saver.save(session, checkpoint_file)

Restoring from file:  checkpoints/checkpoint.chk
INFO:tensorflow:Restoring parameters from checkpoints/checkpoint.chk
Value of w1 a after running: 
 [[8 8 8]
 [8 8 8]]
