In [47]:
import tensorflow as tf
class Dense(tf.Module):
  def __init__(self, in_features, out_features, name=None):
    super().__init__(name=name)
    self.w = tf.Variable(
      tf.random.normal([in_features, out_features]), name='w')
  @tf.function
  def __call__(self, x):
    y = tf.matmul(x,self.w)
    return tf.nn.sigmoid(y)

In [48]:
class SequentialModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)

    self.dense_1 = Dense(in_features=3, out_features=3)
    self.dense_2 = Dense(in_features=3, out_features=1)
  @tf.function
  def __call__(self, x):
    x = self.dense_1(x)
    return self.dense_2(x)

# You have made a model!
my_model = SequentialModule(name="the_model")

# Call it, with random results
print("Model results:", my_model(tf.constant([[0.0288, -0.3256, 0.5925]])))

Model results: tf.Tensor([[0.17839792]], shape=(1, 1), dtype=float32)


In [49]:

chkp_path = "my_checkpoint"
checkpoint = tf.train.Checkpoint(model=my_model)
checkpoint.write(chkp_path)
tf.train.list_variables(chkp_path)

[('_CHECKPOINTABLE_OBJECT_GRAPH', []),
 ('model/dense_1/w/.ATTRIBUTES/VARIABLE_VALUE', [3, 3]),
 ('model/dense_2/w/.ATTRIBUTES/VARIABLE_VALUE', [3, 1])]

In [50]:
# Set up logging.
%load_ext tensorboard
from datetime import datetime
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
logdir = "logs/func/%s" % stamp
writer = tf.summary.create_file_writer(logdir)

# Create a new model to get a fresh trace
# Otherwise the summary will not see the graph.
new_model = SequentialModule()

# Bracket the function call with
# tf.summary.trace_on() and tf.summary.trace_export().
tf.summary.trace_on(graph=True, profiler=True)
# Call only one tf.function when tracing.
z = print(new_model(tf.constant([[2.0, 2.0, 2.0]])))
with writer.as_default():
  tf.summary.trace_export(
      name="my_func_trace",
      step=0,
      profiler_outdir=logdir)

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard
tf.Tensor([[0.48116893]], shape=(1, 1), dtype=float32)


In [53]:
%tensorboard --logdir logs

2.3.1


In [55]:
tf.saved_model.save(my_model, "the_saved_model")
# Inspect the in the directory

new_model = tf.saved_model.load("the_saved_model")
isinstance(new_model, SequentialModule)
print(my_model([[2.0, 2.0, 2.0]]))
print(my_model([[[2.0, 2.0, 2.0], [2.0, 2.0, 2.0]]]))

INFO:tensorflow:Assets written to: the_saved_model/assets
tf.Tensor([[0.4617024]], shape=(1, 1), dtype=float32)
tf.Tensor(
[[[0.4617024]
  [0.4617024]]], shape=(1, 2, 1), dtype=float32)
