##### Copyright 2020 The TensorFlow Authors.


In [1]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Better performance with tf.function

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/guide/function"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/function.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/docs/blob/master/site/en/guide/function.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/function.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

In TensorFlow 2, eager execution is turned on by default. The user interface is intuitive and flexible (running one-off operations is much easier
and faster), but this can come at the expense of performance and deployability.

You can use `tf.function` to make graphs out of your programs. It is a transformation tool that creates Python-independent dataflow graphs out of your Python code. This will help you create performant and portable models, and it is required to use `SavedModel`.

This guide will help you conceptualize how `tf.function` works under the hood so you can use it effectively.

The main takeaways and recommendations are:

- Debug in eager mode, then decorate with `@tf.function`.
- Don't rely on Python side effects like object mutation or list appends.
- `tf.function` works best with TensorFlow ops; NumPy and Python calls are converted to constants.


## Setup

In [1]:
import tensorflow as tf

Define a helper function to demonstrate the kinds of errors you might encounter:

In [2]:
import traceback
import contextlib

# Some helper code to demonstrate the kinds of errors you might encounter.
@contextlib.contextmanager
def assert_raises(error_class):
  try:
    yield
  except error_class as e:
    print('Caught expected exception \n  {}:'.format(error_class))
    traceback.print_exc(limit=2)
  except Exception as e:
    raise e
  else:
    raise Exception('Expected {} to be raised but no error was raised!'.format(
        error_class))

## Basics

### Usage

A `Function` you define is just like a core TensorFlow operation: You can execute it eagerly; you can compute gradients; and so on.

In [3]:
@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

In [4]:
v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

You can use `Function`s inside other `Function`s.

In [5]:
@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

<tf.Tensor: shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

`Function`s can be faster than eager code, especially for graphs with many small ops. But for graphs with a few expensive ops (like convolutions), you may not see much speedup.


In [6]:
import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")


Eager conv: 0.004598194733262062
Function conv: 0.0024671442806720734
Note how there's not much difference in performance for convolutions


### Tracing

Python's dynamic typing means that you can call functions with a variety of argument types, and Python can do something different in each scenario.

Yet, to create a TensorFlow Graph, static `dtypes` and shape dimensions are required. `tf.function` bridges this gap by wrapping a Python function to create a `Function` object. Based on the given inputs, the `Function` selects the appropriate graph for the given inputs, retracing the Python function as necessary. Once you understand why and when tracing happens, it's much easier to use `tf.function` effectively!

You can call a `Function` with arguments of different types to see this polymorphic behavior in action.

In [7]:
@tf.function
def double(a):
  print("Tracing with", a)
  return a + a

print(double(tf.constant(1)))
print()
print(double(tf.constant(1.1)))
print()
print(double(tf.constant("a")))
print()


Tracing with Tensor("a:0", shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)

Tracing with Tensor("a:0", shape=(), dtype=float32)
tf.Tensor(2.2, shape=(), dtype=float32)

Tracing with Tensor("a:0", shape=(), dtype=string)
tf.Tensor(b'aa', shape=(), dtype=string)



Note that if you repeatedly call a `Function` with the same argument type, TensorFlow will reuse a previously traced graph, as the generated graph would be identical.

In [8]:
# This doesn't print 'Tracing with ...'
print(double(tf.constant("b")))

tf.Tensor(b'bb', shape=(), dtype=string)


So far, you've seen that `tf.function` creates a cached, dynamic dispatch layer over TensorFlow's graph tracing logic. To be more specific about the terminology:

- A `tf.Graph` is the raw, language-agnostic, portable representation of your computation.
- A `ConcreteFunction` is an eagerly-executing wrapper around a `tf.Graph`.
- A `Function` manages a cache of `ConcreteFunction`s and picks the right one for your inputs.
- `tf.function` wraps a Python function, returning a `Function` object.


### Debugging

In general, debugging code is easier in eager mode than inside `tf.function`. You should ensure that your code executes error-free in eager mode before decorating with `tf.function`. To assist in the debugging process, you can call `tf.config.run_functions_eagerly(True)` to globally disable and reenable `tf.function`.

When tracking down issues that only appear within `tf.function`, here are some tips:
- Plain old Python `print` calls only execute during tracing, helping you track down when your function gets (re)traced.
- `tf.print` calls will execute every time, and can help you track down intermediate values during execution.
- `tf.debugging.enable_check_numerics` is an easy way to track down where NaNs and Inf are created.
- `pdb` can help you understand what's going on during tracing. (Caveat: PDB will drop you into AutoGraph-transformed source code.)

### Python side effects

Python side effects like printing, appending to lists, and mutating globals only happen the first time you call a `Function` with a set of inputs. Afterwards, the traced `tf.Graph` is reexecuted, without executing the Python code.

The general rule of thumb is to only use Python side effects to debug your traces. Otherwise, TensorFlow ops like `tf.Variable.assign`, `tf.print`, and `tf.summary` are the best way to ensure your code will be traced and executed by the TensorFlow runtime with each call.

In [9]:
@tf.function
def f(x):
  print("Traced with", x)
  tf.print("Executed with", x)

f(1)
f(1)
f(2)


Traced with 1
Executed with 1
Executed with 1
Traced with 2
Executed with 2


Many Python features, such as generators and iterators, rely on the Python runtime to keep track of state. In general, while these constructs work as expected in eager mode, many unexpected things can happen inside a `Function`. 

To give one example, advancing iterator state is a Python side effect and therefore only happens during tracing.

In [10]:
external_var = tf.Variable(0)
@tf.function
def buggy_consume_next(iterator):
  external_var.assign_add(next(iterator))
  tf.print("Value of external_var:", external_var)

iterator = iter([0, 1, 2, 3])
buggy_consume_next(iterator)
# This reuses the first value from the iterator, rather than consuming the next value.
buggy_consume_next(iterator)
buggy_consume_next(iterator)


Value of external_var: 0
Value of external_var: 0
Value of external_var: 0


Some iteration constructs are supported through AutoGraph. See the section on [AutoGraph Transformations](#autograph_transformations) for an overview.

If you would like to execute Python code during each invocation of a `Function`, `tf.py_function` is an exit hatch. The drawback of `tf.py_function` is that it's not portable or particularly performant, nor does it work well in distributed (multi-GPU, TPU) setups. Also, since `tf.py_function` has to be wired into the graph, it casts all inputs/outputs to tensors.

APIs like `tf.gather`, `tf.stack`, and `tf.TensorArray` can help you implement common looping patterns in native TensorFlow.

In [11]:
external_list = []

def side_effect(x):
  print('Python side effect')
  external_list.append(x)

@tf.function
def f(x):
  tf.py_function(side_effect, inp=[x], Tout=[])

f(1)
f(1)
f(1)
# The list append happens all three times!
assert len(external_list) == 3
# The list contains tf.constant(1), not 1, because py_function casts everything to tensors.
assert external_list[0].numpy() == 1


Python side effect
Python side effect
Python side effect


Another error you may encounter is a garbage-collected variable. Unlike normal Python functions, concrete functions only retain [WeakRefs](https://docs.python.org/3/library/weakref.html) to the variables they close over, so you must retain a reference to any variables.

## AutoGraph Transformations

AutoGraph is a library that is on by default in `tf.function`, and transforms a subset of Python eager code into graph-compatible TensorFlow ops. This includes control flow like `if`, `for`, `while`.

TensorFlow ops like `tf.cond` and `tf.while_loop` continue to work, but control flow is often easier to write and understand when written in Python.

In [14]:
# Simple loop

@tf.function
def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

f(tf.random.uniform([5]))

[0.0218499899 0.252755761 0.476065755 0.882139921 0.701021075]
[0.0218465142 0.247507364 0.44308728 0.707489729 0.605015516]
[0.0218430385 0.242574126 0.416200221 0.609100282 0.540608883]
[0.0218395647 0.237925634 0.393724501 0.54349339 0.493448734]
[0.0218360927 0.233535469 0.374566644 0.495627761 0.456949353]
[0.0218326226 0.229380503 0.357979625 0.458671659 0.427594602]
[0.0218291543 0.225440428 0.343433201 0.429000974 0.403309137]
[0.021825688 0.221697286 0.330538958 0.40448609 0.382776827]
[0.0218222234 0.218135178 0.319004953 0.383780897 0.365116566]
[0.0218187608 0.214739949 0.308606893 0.365986437 0.349712849]
[0.0218153 0.211498931 0.299169183 0.350476116 0.336120844]
[0.0218118392 0.208400786 0.290552109 0.336797714 0.324009836]
[0.0218083803 0.205435291 0.282642871 0.324615538 0.313128114]
[0.0218049232 0.202593222 0.275349349 0.313674331 0.303280175]
[0.021801468 0.199866235 0.268595338 0.303776056 0.294311523]
[0.0217980146 0.197246775 0.2623173 0.2947644 0.286098272]
[0.0

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([0.02178767, 0.18996745, 0.24584256, 0.27190816, 0.2650693 ],
      dtype=float32)>

If you're curious you can inspect the code autograph generates.

In [15]:
print(tf.autograph.to_code(f.python_function))

def tf__f(x):
    do_return = False
    retval_ = ag__.UndefinedReturnValue()
    with ag__.FunctionScope('f', 'fscope', ag__.ConversionOptions(recursive=True, user_requested=True, optional_features=(), internal_convert_user_code=True)) as fscope:

        def get_state():
            return (x,)

        def set_state(loop_vars):
            nonlocal x
            (x,) = loop_vars

        def loop_body():
            nonlocal x
            ag__.converted_call(tf.print, (x,), None, fscope)
            x = ag__.converted_call(tf.tanh, (x,), None, fscope)

        def loop_test():
            return (ag__.converted_call(tf.reduce_sum, (x,), None, fscope) > 1)
        ag__.while_stmt(loop_test, loop_body, get_state, set_state, ('x',), {})
        try:
            do_return = True
            retval_ = fscope.mark_return_value(x)
        except:
            do_return = False
            raise
    (do_return,)
    return ag__.retval(retval_)



### Conditionals

AutoGraph will convert some `if <condition>` statements into the equivalent `tf.cond` calls. This substitution is made if `<condition>` is a Tensor. Otherwise, the `if` statement is executed as a Python conditional.

A Python conditional executes during tracing, so exactly one branch of the conditional will be added to the graph. Without AutoGraph, this traced graph would be unable to take the alternate branch if there is data-dependent control flow.

`tf.cond` traces and adds both branches of the conditional to the graph, dynamically selecting a branch at execution time. Tracing can have unintended side effects; see [AutoGraph tracing effects](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#effects-of-the-tracing-process) for more.

In [16]:
@tf.function
def fizzbuzz(n):
  for i in tf.range(1, n + 1):
    print('Tracing for loop')
    if i % 15 == 0:
      print('Tracing fizzbuzz branch')
      tf.print('fizzbuzz')
    elif i % 3 == 0:
      print('Tracing fizz branch')
      tf.print('fizz')
    elif i % 5 == 0:
      print('Tracing buzz branch')
      tf.print('buzz')
    else:
      print('Tracing default branch')
      tf.print(i)

fizzbuzz(tf.constant(5))
fizzbuzz(tf.constant(20))

Tracing for loop
Tracing fizzbuzz branch
Tracing fizz branch
Tracing buzz branch
Tracing default branch
1
2
fizz
4
buzz
1
2
fizz
4
buzz
fizz
7
8
fizz
buzz
11
fizz
13
14
fizzbuzz
16
17
fizz
19
buzz


See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#if-statements) for additional restrictions on AutoGraph-converted if statements.

### Loops

AutoGraph will convert some `for` and `while` statements into the equivalent TensorFlow looping ops, like `tf.while_loop`. If not converted, the `for` or `while` loop is executed as a Python loop.

This substitution is made in the following situations:

- `for x in y`: if `y` is a Tensor, convert to `tf.while_loop`. In the special case where `y` is a `tf.data.Dataset`, a combination of `tf.data.Dataset` ops are generated.
- `while <condition>`: if `<condition>` is a Tensor, convert to `tf.while_loop`.

A Python loop executes during tracing, adding additional ops to the `tf.Graph` for every iteration of the loop.

A TensorFlow loop traces the body of the loop, and dynamically selects how many iterations to run at execution time.  The loop body only appears once in the generated `tf.Graph`.

See the [reference documentation](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/autograph/g3doc/reference/control_flow.md#while-statements) for additional restrictions on AutoGraph-converted `for` and `while` statements.

#### Looping over Python data

A common pitfall is to loop over Python/Numpy data within a `tf.function`. This loop will execute during the tracing process, adding a copy of your model to the `tf.Graph` for each iteration of the loop.

If you want to wrap the entire training loop in `tf.function`, the safest way to do this is to wrap your data as a `tf.data.Dataset` so that AutoGraph will dynamically unroll the training loop.

In [17]:
def measure_graph_size(f, *args):
  g = f.get_concrete_function(*args).graph
  print("{}({}) contains {} nodes in its graph".format(
      f.__name__, ', '.join(map(str, args)), len(g.as_graph_def().node)))

@tf.function
def train(dataset):
  loss = tf.constant(0)
  for x, y in dataset:
    loss += tf.abs(y - x) # Some dummy computation.
  return loss

small_data = [(1, 1)] * 3
big_data = [(1, 1)] * 10
measure_graph_size(train, small_data)
measure_graph_size(train, big_data)

measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: small_data, (tf.int32, tf.int32)))
measure_graph_size(train, tf.data.Dataset.from_generator(
    lambda: big_data, (tf.int32, tf.int32)))

train([(1, 1), (1, 1), (1, 1)]) contains 11 nodes in its graph
train([(1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1), (1, 1)]) contains 32 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph
train(<FlatMapDataset shapes: (<unknown>, <unknown>), types: (tf.int32, tf.int32)>) contains 8 nodes in its graph


When wrapping Python/Numpy data in a Dataset, be mindful of `tf.data.Dataset.from_generator` versus ` tf.data.Dataset.from_tensors`. The former will keep the data in Python and fetch it via `tf.py_function` which can have performance implications, whereas the latter will bundle a copy of the data as one large `tf.constant()` node in the graph, which can have memory implications.

Reading data from files via TFRecordDataset/CsvDataset/etc. is the most effective way to consume data, as then TensorFlow itself can manage the asynchronous loading and prefetching of data, without having to involve Python. To learn more, see the [tf.data guide](../../guide/data).

#### Accumulating values in a loop

A common pattern is to accumulate intermediate values from a loop. Normally, this is accomplished by appending to a Python list or adding entries to a Python dictionary. However, as these are Python side effects, they will not work as expected in a dynamically unrolled loop. Use `tf.TensorArray` to accumulate results from a dynamically unrolled loop.

In [18]:
batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))

<tf.Tensor: shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.30274212, 0.54034984, 0.6057179 , 0.96152544],
        [1.1768875 , 0.95808554, 0.61519766, 1.0193876 ],
        [2.107791  , 0.9707072 , 1.2463995 , 1.6324626 ]],

       [[0.77548623, 0.43522775, 0.35147822, 0.8738022 ],
        [0.78578734, 1.138827  , 0.75858307, 1.2443646 ],
        [1.3380113 , 1.1977221 , 0.9884018 , 1.6696664 ]]], dtype=float32)>

## Further reading

To learn about how to export and load a `Function`, see the [SavedModel guide](../../guide/saved_model). To learn more about graph optimizations that are performed after tracing, see the [Grappler guide](../../guide/graph_optimization). To learn how to optimize your data pipeline and profile your model, see the [Profiler guide](../../guide/profiler.md).