# tf.function
Source: https://www.tensorflow.org/alpha/tutorials/eager/tf_function

In [5]:
!pip install -q tensorflow==2.0.0-alpha0
import tensorflow as tf

In [2]:
# A function is like an op

@tf.function
def add(a, b):
  return a + b

add(tf.ones([2, 2]), tf.ones([2, 2]))  #  [[2., 2.], [2., 2.]]

<tf.Tensor: id=16, shape=(2, 2), dtype=float32, numpy=
array([[2., 2.],
       [2., 2.]], dtype=float32)>

A tf.function you define is just like a core TensorFlow operation: you can execute it eagerly, you can use it in a graph, it has gradients, etc.

In [3]:
# Functions have gradients

@tf.function
def add(a, b):
  return a + b

v = tf.Variable(1.0)
with tf.GradientTape() as tape:
  result = add(v, 1.0)
tape.gradient(result, v)

<tf.Tensor: id=44, shape=(), dtype=float32, numpy=1.0>

In [4]:
# You can use functions inside functions

@tf.function
def dense_layer(x, w, b):
  return add(tf.matmul(x, w), b)

dense_layer(tf.ones([3, 2]), tf.ones([2, 2]), tf.ones([2]))

<tf.Tensor: id=74, shape=(3, 2), dtype=float32, numpy=
array([[3., 3.],
       [3., 3.],
       [3., 3.]], dtype=float32)>

## Polymorphism

tf.function tries to be as generic as a Python function.
You can call Python functions with all sorts of signatures, and Python will usually do something reasonable. tf.function does this type of polymorphism for you even though the underlying TensorFlow graphs it generates are specific to the particular types in its signature.

You can call a function with arguments of different types to see what is happening.

In [6]:
# Functions are polymorphic

@tf.function
def add(a):
  return a + a

print("add 1", add(1))
print("add 1.1", add(1.1))
print("add string tensor", add(tf.constant("a")))
c = add.get_concrete_function(tf.TensorSpec(shape=None, dtype=tf.string))
c(a=tf.constant("a"))  # aa

add 1 tf.Tensor(2, shape=(), dtype=int32)
add 1.1 tf.Tensor(2.2, shape=(), dtype=float32)
add string tensor tf.Tensor(b'aa', shape=(), dtype=string)


<tf.Tensor: id=104, shape=(), dtype=string, numpy=b'aa'>

In [7]:
# Functions can be faster than eager code, for graphs with many small ops

import timeit
conv_layer = tf.keras.layers.Conv2D(100, 3)

@tf.function
def conv_fn(image):
  return conv_layer(image)

image = tf.zeros([1, 200, 200, 100])
# warm up
conv_layer(image); conv_fn(image)
print("Eager conv:", timeit.timeit(lambda: conv_layer(image), number=10))
print("Function conv:", timeit.timeit(lambda: conv_fn(image), number=10))
print("Note how there's not much difference in performance for convolutions")

lstm_cell = tf.keras.layers.LSTMCell(10)

@tf.function
def lstm_fn(input, state):
  return lstm_cell(input, state)

input = tf.zeros([10, 10])
state = [tf.zeros([10, 10])] * 2
# warm up
lstm_cell(input, state); lstm_fn(input, state)
print("eager lstm:", timeit.timeit(lambda: lstm_cell(input, state), number=10))
print("function lstm:", timeit.timeit(lambda: lstm_fn(input, state), number=10))

Eager conv: 3.617294017908095
Function conv: 2.0063116591468
Note how there's not much difference in performance for convolutions
eager lstm: 0.14929981073746657
function lstm: 0.012622805108856383


## State in tf.function

In [8]:
# Automatic control dependencies

a = tf.Variable(1.0)
b = tf.Variable(2.0)

@tf.function
def f(x, y):
  a.assign(y * b)
  b.assign_add(x * a)
  return a + b

f(1.0, 2.0)  # 10.0

<tf.Tensor: id=1610, shape=(), dtype=float32, numpy=10.0>

## Variables

In [7]:
@tf.function
def f(x):
    v = tf.Variable(1.0)
    v.assign_add(x)
    return v

f(1.) # Note: BROKEN, will throw exception


ValueError: tf.function-decorated function tried to create variables on non-first call.

In [None]:
# Non-ambiguous code is ok though

v = tf.Variable(1.0)

@tf.function
def f(x):
  return v.assign_add(x)

f(1.0)  # 2.0
f(2.0)  # 4.0

## Control flow and autograph

In [8]:
# Simple loop

@tf.function
def f(x):
    while tf.reduce_sum(x) > 1:
        tf.print(x)
        x = tf.tanh(x)
    return x

f(tf.random.uniform([10]))

[0.0860520601 0.272780418 0.211995602 ... 0.188304186 0.558368802 0.296167374]
[0.0858402774 0.266210109 0.20887582 ... 0.186109632 0.5067662 0.287801325]
[0.0856300518 0.260094821 0.205890208 ... 0.183990225 0.46742174 0.280109912]
[0.085421361 0.25438422 0.203029424 ... 0.181941777 0.436113775 0.273006797]
[0.0852141902 0.249035433 0.200284928 ... 0.1799604 0.410417974 0.266420424]
[0.0850085169 0.244011715 0.197649121 ... 0.178042516 0.388827533 0.260290891]
[0.0848043337 0.239281386 0.195114955 ... 0.176184788 0.370349 0.254567593]
[0.0846016183 0.234816864 0.192676052 ... 0.174384132 0.354296923 0.249207422]
[0.0844003484 0.230594084 0.190326616 ... 0.172637701 0.340180755 0.244173452]
[0.0842005 0.226591989 0.188061267 ... 0.170942813 0.327638716 0.239433855]
[0.0840020701 0.222791955 0.185875118 ... 0.16929695 0.316397429 0.234960914]
[0.0838050395 0.219177485 0.183763623 ... 0.167697847 0.306245834 0.230730489]
[0.0836093873 0.21573393 0.181722671 ... 0.166143283 0.297017932 0.

[0.0692668408 0.107235432 0.10220506 ... 0.0991777405 0.114063218 0.108505964]
[0.0691562742 0.106826261 0.101850659 ... 0.0988538265 0.1135711 0.108082123]
[0.0690462291 0.106421739 0.10149993 ... 0.0985330716 0.113085307 0.107663207]
[0.0689367056 0.106021777 0.101152785 ... 0.0982154161 0.112605691 0.107249118]
[0.0688277 0.1056263 0.100809187 ... 0.0979008153 0.112132132 0.106839776]
[0.0687192231 0.105235212 0.100469068 ... 0.0975892246 0.111664496 0.106435098]
[0.0686112493 0.10484843 0.100132369 ... 0.0972805917 0.11120268 0.106034994]
[0.0685037822 0.104465894 0.0997990444 ... 0.0969748646 0.110746548 0.105639368]
[0.0683968216 0.104087517 0.0994690135 ... 0.096672006 0.110295981 0.105248146]
[0.0682903603 0.103713222 0.0991422459 ... 0.0963719711 0.109850883 0.104861222]
[0.0681843907 0.103342943 0.0988186747 ... 0.0960747153 0.109411128 0.10447856]
[0.0680789128 0.102976605 0.0984982699 ... 0.0957801938 0.108976625 0.104100041]
[0.0679739267 0.102614142 0.0981809497 ... 0.095

<tf.Tensor: id=127, shape=(10,), dtype=float32, numpy=
array([0.06766185, 0.1015493 , 0.09724709, 0.10587054, 0.10569899,
       0.10761058, 0.10842628, 0.09462864, 0.10728898, 0.10262614],
      dtype=float32)>

In [9]:
# If you're curious you can inspect the code autograph generates.
# It feels like reading assembly language, though.

def f(x):
  while tf.reduce_sum(x) > 1:
    tf.print(x)
    x = tf.tanh(x)
  return x

print(tf.autograph.to_code(f))

from __future__ import print_function

def tf__f(x):
  try:
    with ag__.function_scope('f'):
      do_return = False
      retval_ = None

      def loop_test(x_1):
        with ag__.function_scope('loop_test'):
          return ag__.gt(ag__.converted_call('reduce_sum', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x_1,), {}), 1)

      def loop_body(x_1):
        with ag__.function_scope('loop_body'):
          with ag__.utils.control_dependency_on_returns(ag__.converted_call('print', tf, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=ag__.Feature.ALL, internal_convert_user_code=True), (x_1,), {})):
            x, tf_1 = ag__.utils.alias_tensors(x_1, tf)
            x = ag__.converted_call(

In [10]:
@tf.function
def f(x):
  for i in range(10):  # Static python loop, we'll not convert it
    do_stuff()
  for i in tf.range(10):  # depends on a tensor, we'll convert it

SyntaxError: unexpected EOF while parsing (<ipython-input-10-1aa0802c291f>, line 5)

### Tensorflow data structure

In [11]:
@tf.function
def f(x):
  for i in tf.range(10):
    tf.print(i)
    tf.Assert(i < 10, ["a"])
    x += x
  return x

f(10)

0
1
2
3
4
5
6
7
8
9


<tf.Tensor: id=190, shape=(), dtype=int32, numpy=10240>