In [1]:
# 20200414 : Better performance with tf.function

import tensorflow as tf

import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))


In [2]:
@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]

<tf.Tensor: id=14, shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

In [3]:
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)

<tf.Tensor: id=37, shape=(), dtype=float32, numpy=1.0>

In [4]:
@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

<tf.Tensor: id=63, shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

In [5]:
# Functions are polymorphic

@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()

Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)



In [6]:
print("Obtaining concrete trace")
double_strings = double.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
print("Executing traced function")
print(double_strings(tf.constant("a")))
print(double_strings(a=tf.constant("b")))
print("Using a concrete trace with incompatible types will throw an error")
with assert_raises(tf.errors.InvalidArgumentError):
  double_strings(tf.constant(1))

Obtaining concrete trace
Tracing with Tensor("a:0", dtype=string)
Executing traced function
tf.Tensor(b'aa', shape=(), dtype=string)
tf.Tensor(b'bb', shape=(), dtype=string)
Using a concrete trace with incompatible types will throw an error
Caught expected exception 
  <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>:


Traceback (most recent call last):
  File "<ipython-input-1-af2cf3b407e8>", line 12, in assert_raises
    yield
  File "<ipython-input-6-5351d0a2eda2>", line 8, in <module>
    double_strings(tf.constant(1))
tensorflow.python.framework.errors_impl.InvalidArgumentError: cannot compute __inference_double_90 as input #0(zero-based) was expected to be a string tensor but is a int32 tensor [Op:__inference_double_90]


In [7]:
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.int32),))
def next_collatz(x):
  print("Tracing with", x)
  return tf.where(x % 2 == 0, x // 2, 3 * x + 1)

print(next_collatz(tf.constant([1, 2])))
# We specified a 1-D tensor in the input signature, so this should fail.
with assert_raises(ValueError):
  next_collatz(tf.constant([[1, 2], [3, 4]]))


Tracing with Tensor("x:0", shape=(None,), dtype=int32)
tf.Tensor([4 1], shape=(2,), dtype=int32)
Caught expected exception 
  <class 'ValueError'>:


Traceback (most recent call last):
  File "<ipython-input-1-af2cf3b407e8>", line 12, in assert_raises
    yield
  File "<ipython-input-7-9939c82c1507>", line 9, in <module>
    next_collatz(tf.constant([[1, 2], [3, 4]]))
ValueError: Python inputs incompatible with input_signature:
  inputs: (
    tf.Tensor(
[[1 2]
 [3 4]], shape=(2, 2), dtype=int32))
  input_signature: (
    TensorSpec(shape=(None,), dtype=tf.int32, name=None))


In [8]:
def train_one_step():
  pass

@tf.function
def train(num_steps):
  print("Tracing with num_steps = {}".format(num_steps))
  for _ in tf.range(num_steps):
    train_one_step()

train(num_steps=10)
train(num_steps=20)

Tracing with num_steps = 10
Tracing with num_steps = 20


In [9]:
train(num_steps=tf.constant(10))
train(num_steps=tf.constant(20))

Tracing with num_steps = Tensor("num_steps:0", shape=(), dtype=int32)


In [10]:
@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)

Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2


In [11]:
external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
assert len(external_list) == 3
# .numpy() call required because py_function casts 1 to tf.constant(1)
assert external_list[0].numpy() == 1

Python side effect
Python side effect
Python side effect


In [12]:
external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)


Value of external_var: 0
Value of external_var: 0
Value of external_var: 0


In [13]:
def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 2
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))

train([(1, 1), (1, 1)]) contains 8 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 5 nodes in its graph
train(<DatasetV1Adapter shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 5 nodes in its graph


In [14]:
# Automatic control dependencies

a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(1.0, 2.0)  # 10.0

<tf.Tensor: id=417, shape=(), dtype=float32, numpy=10.0>

In [15]:
@tf.function
def f(x):
  v = tf.Variable(1.0)
  v.assign_add(x)
  return v

with assert_raises(ValueError):
  f(1.0)

Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Caught expected exception 
  <class 'ValueError'>:


Traceback (most recent call last):
  File "<ipython-input-1-af2cf3b407e8>", line 12, in assert_raises
    yield
  File "<ipython-input-15-73e410646579>", line 8, in <module>
    f(1.0)
ValueError: in converted code:

    <ipython-input-15-73e410646579>:3 f  *
        v = tf.Variable(1.0)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/variables.py:260 __call__
        return cls._variable_v2_call(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/variables.py:254 _variable_v2_call
        shape=shape)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/variables.py:65 getter
        return captured_getter(captured_previous, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py:413 invalid_creator_scope
        "tf.function-decorated function tried to create "

    ValueError: tf.function-decorated function tried to create variables on non-first call.



In [16]:
v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

print(f(1.0))  # 2.0
print(f(2.0))  # 4.0

tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)


In [17]:
class C:
  pass

obj = C()
obj.v = None

@tf.function
def g(x):
  if obj.v is None:
    obj.v = tf.Variable(1.0)
  return obj.v.assign_add(x)

print(g(1.0))  # 2.0
print(g(2.0))  # 4.0

tf.Tensor(2.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)


In [18]:
state = []
@tf.function
def fn(x):
  if not state:
    state.append(tf.Variable(2.0 * x))
    state.append(tf.Variable(state[0] * 3.0))
  return state[0] * x * state[1]

print(fn(tf.constant(1.0)))
print(fn(tf.constant(3.0)))


tf.Tensor(12.0, shape=(), dtype=float32)
tf.Tensor(36.0, shape=(), dtype=float32)


In [19]:
# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))

[0.979072094 0.218417168 0.910327792 0.157579541 0.31989789]
[0.75266397 0.215008914 0.721289635 0.156288058 0.309414566]
[0.636735559 0.211755827 0.617707491 0.155027851 0.299904376]
[0.562672734 0.208646506 0.54952985 0.153797701 0.291225106]
[0.50995785 0.20567061 0.500167727 0.152596429 0.283262]
[0.46991235 0.202818871 0.462249041 0.151422918 0.275921404]
[0.438128471 0.200082839 0.431915462 0.15027611 0.269126058]
[0.412091911 0.197454914 0.40692082 0.149154991 0.262811422]
[0.390247464 0.194928154 0.385855049 0.148058623 0.256923258]
[0.371573567 0.192496195 0.367781401 0.146986127 0.251415491]
[0.355367303 0.190153271 0.352049589 0.145936623 0.246248782]
[0.341126919 0.187894076 0.338192 0.144909322 0.24138923]
[0.328483045 0.185713693 0.325862318 0.143903449 0.236807495]
[0.31715703 0.183607668 0.314797968 0.142918259 0.232478008]
[0.306934029 0.181571856 0.304795653 0.141953066 0.228378445]
[0.297645271 0.179602429 0.295695156 0.141007185 0.224489078]
[0.289156199 0.177695855

<tf.Tensor: id=674, shape=(5,), dtype=float32, numpy=
array([0.24499726, 0.16586162, 0.24390934, 0.13406505, 0.19920303],
      dtype=float32)>

In [20]:
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(f))

def tf__f(x):
  do_return = False
  retval_ = ag__.UndefinedReturnValue()
  with ag__.FunctionScope('f', 'f_scope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as f_scope:

    def get_state():
      return ()

    def set_state(_):
      pass

    def loop_body(x):
      ag__.converted_call(tf.print, f_scope.callopts, (x,), None, f_scope)
      x = ag__.converted_call(tf.tanh, f_scope.callopts, (x,), None, f_scope)
      return x,

    def loop_test(x):
      return ag__.converted_call(tf.reduce_sum, f_scope.callopts, (x,), None, f_scope) > 1
    x, = ag__.while_stmt(loop_test, loop_body, get_state, set_state, (x,), ('x',), ())
    do_return = True
    retval_ = f_scope.mark_return_value(x)
  do_return,
  return ag__.retval(retval_)



In [21]:
def test_tf_cond(f, *args):
  g = f.get_concrete_function(*args).graph
  if any(node.name == 'cond' for node in g.as_graph_def().node):
    print("{}({}) uses tf.cond.".format(
        f.__name__, ', '.join(map(str, args))))
  else:
    print("{}({}) executes normally.".format(
        f.__name__, ', '.join(map(str, args))))

  print("  result: ",f(*args).numpy())

In [22]:
@tf.function
def dropout(x, training=True):
  if training:
    x = tf.nn.dropout(x, rate=0.5)
  return x

In [23]:
test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), True)

dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), True) executes normally.
  result:  [2. 0. 0. 0. 0. 2. 0. 2. 2. 2.]


In [24]:
test_tf_cond(dropout, tf.ones([10], dtype=tf.float32), tf.constant(True))

dropout(tf.Tensor([1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(10,), dtype=float32), tf.Tensor(True, shape=(), dtype=bool)) uses tf.cond.
  result:  [2. 0. 2. 0. 2. 2. 2. 0. 2. 2.]


In [25]:
@tf.function
def f(x):
  if x > 0:
    x = x + 1.
    print("Tracing `then` branch")
  else:
    x = x - 1.
    print("Tracing `else` branch")
  return x

f(-1.0).numpy()

Tracing `else` branch


-2.0

In [26]:
f(1.0).numpy()

Tracing `then` branch


2.0

In [27]:
f(tf.constant(1.0)).numpy()

Tracing `then` branch
Tracing `else` branch


2.0

In [28]:
@tf.function
def f():
  if tf.constant(True):
    x = tf.ones([3, 3])
  return x

# Throws an error because both branches need to define `x`.
with assert_raises(ValueError):
  f()

Caught expected exception 
  <class 'ValueError'>:


Traceback (most recent call last):
  File "<ipython-input-1-af2cf3b407e8>", line 12, in assert_raises
    yield
  File "<ipython-input-28-3ae6a92ff50d>", line 9, in <module>
    f()
ValueError: in converted code:

    <ipython-input-28-3ae6a92ff50d>:3 f  *
        if tf.constant(True):
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/autograph/operators/control_flow.py:893 if_stmt
        basic_symbol_names, composite_symbol_names)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/autograph/operators/control_flow.py:931 tf_if_stmt
        error_checking_orelse)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/deprecation.py:507 new_func
        return func(*args, **kwargs)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/control_flow_ops.py:1174 cond
        return cond_v2.cond_v2(pred, true_fn, false_fn, name)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/cond_v2.py:91 cond_v2
        o

In [29]:
@tf.function
def f(x, y):
  if bool(x):
    y = y + 1.
    print("Tracing `then` branch")
  else:
    y = y - 1.
    print("Tracing `else` branch")
  return y

In [30]:
f(True, 0).numpy()

Tracing `then` branch


1.0

In [31]:
f(False, 0).numpy()

Tracing `else` branch


-1.0

In [32]:
with assert_raises(TypeError):
  f(tf.constant(True), 0.0)

Caught expected exception 
  <class 'TypeError'>:


Traceback (most recent call last):
  File "<ipython-input-1-af2cf3b407e8>", line 12, in assert_raises
    yield
  File "<ipython-input-32-427fe80db841>", line 2, in <module>
    f(tf.constant(True), 0.0)
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: in converted code:

    <ipython-input-29-772984717e3d>:3 f  *
        if bool(x):
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/autograph/impl/api.py:396 converted_call
        return py_builtins.overload_of(f)(*args)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:765 __bool__
        self._disallow_bool_casting()
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:531 _disallow_bool_casting
        "using a `tf.Tensor` as a Python `bool`")
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:518 _disallow_when_autograph_enabled
        " decorating it directly with @tf.function.".format(task))

   