In [2]:
import numpy as np
import tensorflow as tf

tf.enable_eager_execution()
tfe = tf.contrib.eager

This snippet references: https://www.tensorflow.org/tutorials/eager/automatic_differentiation

- Every differentiable operation in Tensorflow has an associated gradient function (ex. the gradient function of `tf.square(x)` would be a function that returns `2.0 * x`). 


- The way gradient computation for any user-defined function (composed of tensorflow primitive operations) is done is by first "recording" all the operations applied to comput the function output. 


- It then uses that tape and the gradients functions associated with each primitive operation to compute the gradients of the user-defined function using reverse mode differentiation.


- When it is inconvenient to wrap all computation in a function, all intermediate calculation operations can be recorded within an explicit `tf.GradientTape` context wrapper.

### 1. Gradient Tape Context Wrapper

In [13]:
x = tf.ones((2, 2))

# gradient tape context
with tf.GradientTape(persistent=True) as t:
    t.watch(x) # watch all operations done on x
    y = tf.reduce_sum(x)
    z = tf.multiply(y, y)

# Use the gradient tape to compute the derivative of z with respect to y
dz_dy = t.gradient(z, y)
# Use the gradient tape to compute the derivative of z with respect to x
dz_dx = t.gradient(z, x)

print('dz/dy:',dz_dy.numpy(),'\n')
print('dz/dx:\n',dz_dx.numpy())

dz/dy: 8.0 

dz/dx:
 [[8. 8.]
 [8. 8.]]


Higher order gradients can be computed by wrapper one gradient tape context within another as shown below:

In [17]:
x = tf.constant(1.0)  # Convert the Python 1.0 to a Tensor object

with tf.GradientTape() as t:
    with tf.GradientTape() as t2:
        t2.watch(x)
        y = x * x * x
    # Compute the gradient inside the 't' context manager, which means 
    # this gradient computation is differentiable under t
    dy_dx = t2.gradient(y, x)
    
d2y_dx2 = t.gradient(dy_dx, x)

print('dy/dx:', dy_dx.numpy())
print('d2y/dx2:', d2y_dx2.numpy())

dy/dx: 3.0
d2y/dx2: 6.0
