In [1]:
import tensorflow as tf

from https://www.tensorflow.org/tutorials/customization/performance
"Better performance with tf.function"

In [2]:
@tf.function
def while_tf_cond():
    x = tf.constant(5)
    while x > 0:
        x -= 1
    return x

this is a good function which return a tensor.
even though it creates a tensor inside multiple times

In [None]:
@tf.function
def while_tf_true_tf_break(x):
    while tf.constant(True): # tf true
    if x == 0:  # py break
        break
    x -= 1
    return x


for whiles, the while condition must depend on a tensor.
the inner breaks or returns can be py or tf

In [None]:
@tf.function
def tf_for_py_break():
    x = 0
    for i in tf.range(5): # tf for
    if i == 3:  # py break
        break
    x += i
    return x

The correct form for for loops

In [3]:
#In order to accumulate results from a dynamically unrolled loop, you'll want to use tf.TensorArray.

batch_size = 2
seq_len = 3
feature_size = 4

def rnn_step(inp, state):
  return inp + state

@tf.function
def dynamic_rnn(rnn_step, input_data, initial_state):
  # [batch, time, features] -> [time, batch, features]
  input_data = tf.transpose(input_data, [1, 0, 2])
  max_seq_len = input_data.shape[0]

  states = tf.TensorArray(tf.float32, size=max_seq_len)
  state = initial_state
  for i in tf.range(max_seq_len):
    state = rnn_step(input_data[i], state)
    states = states.write(i, state)
  return tf.transpose(states.stack(), [1, 0, 2])
  
dynamic_rnn(rnn_step,
            tf.random.uniform([batch_size, seq_len, feature_size]),
            tf.zeros([batch_size, feature_size]))

<tf.Tensor: id=86, shape=(2, 3, 4), dtype=float32, numpy=
array([[[0.7342559 , 0.669428  , 0.08792293, 0.78824925],
        [0.9890313 , 1.3537512 , 0.2650727 , 1.0636464 ],
        [1.8808205 , 1.8170239 , 0.6934947 , 1.0850322 ]],

       [[0.8474921 , 0.78943443, 0.6176113 , 0.60328925],
        [1.6546186 , 0.82788   , 0.6411681 , 0.947973  ],
        [1.8728838 , 1.6841171 , 1.191936  , 1.3726848 ]]], dtype=float32)>

CRITICAL:As with tf.cond, tf.while_loop also comes with a number of subtleties. - Since a loop can execute 0 times, all tensors used downstream of the while_loop **must be initialized** above the loop - The **shape**/**dtypes** of all loop variables must stay consistent with each iteration


In [7]:
@tf.function
def concat_with_padding():
  x = tf.zeros([5, 10])
  for i in tf.range(3):
    x = tf.concat([x[:i], tf.ones([1, 10]), tf.zeros([4-i, 10])], axis=0)
    x.set_shape([5, 10])
  return x

concat_with_padding()

<tf.Tensor: id=284, shape=(5, 10), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>