To do machine learning in TensorFlow, you are likely to need to define, save, and restore a model.

A model is, abstractly:

A function that computes something on tensors (a forward pass)
Some variables that can be updated in response to training

**Defining models and layers**

Most models are made of layers. Layers are functions with a known mathematical structure that can be reused and have trainable variables. In TensorFlow, most high-level implementations of layers and models, such as Keras or Sonnet, are built on the same foundational class: tf.Module.

Here's an example of a very simple tf.Module that operates on a scalar tensor:

In [2]:
import tensorflow as tf
from datetime import datetime

%load_ext tensorboard

In [3]:
class SimpleModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)
    self.a_variable = tf.Variable(5.0, name="train_me")
    self.non_trainable_variable = tf.Variable(5.0, trainable=False, name="do_not_train_me")
  def __call__(self, x):
    return self.a_variable * x + self.non_trainable_variable

simple_module = SimpleModule(name="simple")

simple_module(tf.constant(5.0))

<tf.Tensor: shape=(), dtype=float32, numpy=30.0>

Modules and, by extension, layers are deep-learning terminology for "objects": They have internal state, and methods that use that state.

There is nothing special about __call__ except to act like a Python callable; you can invoke your models with whatever functions you wish.

You can set the trainability of variables on and off for any reason, including freezing layers and variables during fine-tuning.

By subclassing tf.Module, any tf.Variable or tf.Module instances assigned to this object's properties are automatically collected. This allows you to save and load variables, and also create collections oftf.Modules.

In [4]:
# All trainable variables
print("trainable variables:", simple_module.trainable_variables)
# Every variable
print("all variables:", simple_module.variables)

trainable variables: (<tf.Variable 'train_me:0' shape=() dtype=float32, numpy=5.0>,)
all variables: (<tf.Variable 'train_me:0' shape=() dtype=float32, numpy=5.0>, <tf.Variable 'do_not_train_me:0' shape=() dtype=float32, numpy=5.0>)


In [5]:
# Define dense(linear) layer
class Dense(tf.Module):
  def __init__(self, in_features, out_features, name=None):
    super().__init__(name=name)
    self.w = tf.Variable(
      tf.random.normal([in_features, out_features]), name='w')
    self.b = tf.Variable(tf.zeros([out_features]), name='b')
  def __call__(self, x):
    y = tf.matmul(x, self.w) + self.b
    return tf.nn.relu(y)

class SequentialModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)

    self.dense_1 = Dense(in_features=3, out_features=3)
    self.dense_2 = Dense(in_features=3, out_features=2)

  def __call__(self, x):
    x = self.dense_1(x)
    return self.dense_2(x)

# You have made a model!
my_model = SequentialModule(name="the_model")

# Call it, with random results
print("Model results:", my_model(tf.constant([[2.0, 2.0, 2.0]])))

Model results: tf.Tensor([[0.         0.66619265]], shape=(1, 2), dtype=float32)


tf.Module instances will automatically collect, recusively, any tf.Variable or tf.Module instances assigned to it. This allows you to manage collections oftf.Modules with a single model instance, and save and load whole models.

In [6]:
print("Submodules:", my_model.submodules)

for var in my_model.variables:
  print(var, "\n")

Submodules: (<__main__.Dense object at 0x7fd7bf17e828>, <__main__.Dense object at 0x7fd7c08fc198>)
<tf.Variable 'b:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)> 

<tf.Variable 'w:0' shape=(3, 3) dtype=float32, numpy=
array([[-1.565403  ,  0.5565351 ,  0.96435696],
       [-0.7753218 , -0.07704509, -0.37318504],
       [ 0.58896524,  0.9412106 ,  2.4840398 ]], dtype=float32)> 

<tf.Variable 'b:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)> 

<tf.Variable 'w:0' shape=(3, 2) dtype=float32, numpy=
array([[ 2.0976584 ,  1.7208703 ],
       [-1.401734  ,  0.5614856 ],
       [ 0.46375453, -0.15108119]], dtype=float32)> 



**Create variables**

By deferring variable creation to the first time the module is called with a specific input shape, you do not need specify the input size up front.

In [7]:
class FlexibleDenseModule(tf.Module):
  # Note: No need for `in+features`
  def __init__(self, out_features, name=None):
    super().__init__(name=name)
    self.is_built = False
    self.out_features = out_features

  def __call__(self, x):
    # Create variables on first call.
    if not self.is_built:
      self.w = tf.Variable(
        tf.random.normal([x.shape[-1], self.out_features]), name='w')
      self.b = tf.Variable(tf.zeros([self.out_features]), name='b')
      self.is_built = True

    y = tf.matmul(x, self.w) + self.b
    return tf.nn.relu(y)

# Used in a module
class MySequentialModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)

    self.dense_1 = FlexibleDenseModule(out_features=3)
    self.dense_2 = FlexibleDenseModule(out_features=2)

  def __call__(self, x):
    x = self.dense_1(x)
    return self.dense_2(x)

my_model = MySequentialModule(name="the_model")
print("Model results:", my_model(tf.constant([[2.0, 2.0, 2.0]])))

Model results: tf.Tensor([[0.        3.4368594]], shape=(1, 2), dtype=float32)


This flexibility is why TensorFlow layers often only need to specify the shape of their outputs, such as in tf.keras.layers.Dense, rather than both the input and output size.

**Saving weights**

You can save a [tf.Module](https://www.tensorflow.org/api_docs/python/tf/Module) as both a [checkpoint](https://www.tensorflow.org/guide/checkpoint) and a [SavedModel](https://www.tensorflow.org/guide/saved_model).

Checkpoints are just the weights (that is, the values of the set of variables inside the module and its submodules).

In [None]:
chkp_path = "my_checkpoint"
checkpoint = tf.train.Checkpoint(model=my_model)
checkpoint.write(chkp_path)
checkpoint.write(chkp_path)

In [None]:
tf.train.list_variables(chkp_path)

In [None]:
new_model = MySequentialModule()
new_checkpoint = tf.train.Checkpoint(model=new_model)
new_checkpoint.restore("my_checkpoint")

# Should be the same result as above
new_model(tf.constant([[2.0, 2.0, 2.0]]))

**Saving functions**

TensorFlow can run models without the original Python objects, as seen in [TensorFlow Serving](https://www.tensorflow.org/tfx) and [TensorFlow Lite](https://www.tensorflow.org/lite) and even when you download a trained model from [TensorFlow Hub](https://www.tensorflow.org/hub).

TensorFlow needs to know how to do the computations described in Python, but without the original code. To do this, you can make a graph, which is described in the previous guide.

This graph contains operations, or ops, that implement the function.

You can define a graph in the model above by adding the @tf.function decorator to indicate that this code should run as a graph.

In [9]:
class MySequentialModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)

    self.dense_1 = Dense(in_features=3, out_features=3)
    self.dense_2 = Dense(in_features=3, out_features=2)

  @tf.function
  def __call__(self, x):
    x = self.dense_1(x)
    return self.dense_2(x)

# You have made a model with a graph!
my_model = MySequentialModule(name="the_model")

**Visualize the graph**
https://www.tensorflow.org/guide/intro_to_modules#saving_functions

**Creating a SavedModel**

The recommended way of sharing completely trained models is to use SavedModel. SavedModel contains both a collection of functions and a collection of weights.

You can save the model just made.

In [10]:
tf.saved_model.save(my_model, "the_saved_model")



INFO:tensorflow:Assets written to: the_saved_model/assets


INFO:tensorflow:Assets written to: the_saved_model/assets


In [11]:
# Inspect the SavedModel in the directory
!ls -l the_saved_model

total 20
drwxr-xr-x 2 root root 4096 Jan  2 19:58 assets
-rw-r--r-- 1 root root 9689 Jan  2 19:58 saved_model.pb
drwxr-xr-x 2 root root 4096 Jan  2 19:58 variables


In [12]:
# The variables/ directory contains a checkpoint of the variables 
!ls -l the_saved_model/variables

total 8
-rw-r--r-- 1 root root 408 Jan  2 19:58 variables.data-00000-of-00001
-rw-r--r-- 1 root root 356 Jan  2 19:58 variables.index
