<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="300" height="300" align="center"/><br>

Welcome to another JAX tutorial. I hope you all have been enjoying the JAX Tutorials so far. If you haven't gone through the previous tutorials, I highly suggest going through them. Here are the links:

1. [TF_JAX_Tutorials - Part 1](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part1)
2. [TF_JAX_Tutorials - Part 2](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part2)
3. [TF_JAX_Tutorials - Part 3](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3)
4. [TF_JAX_Tutorials - Part 4 (JAX and DeviceArray)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-4-jax-and-devicearray)
5. [TF_JAX_Tutorials - Part 5 (Pure Functions in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-5-pure-functions-in-jax/)
6. [TF_JAX_Tutorials - Part 6 (PRNG in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-6-prng-in-jax/)
7. [TF_JAX_Tutorials - Part 7 (JIT in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-7-jit-in-jax)
8. [TF_JAX_Tutorials - Part 8 (Vmap and Pmap)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-8-vmap-pmap)


Today, we are going to look into another important concept `Automatic Differentiation`. We already have seen [automatic differentiation in TensorFlow](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3). The idea of automatic differentiation is pretty similar in all the frameworks, but IMO JAX does it better than all of them.
One of the questions that people ask me generally is: If the framework already takes care of all autodiff parts, why should I bother learning about it? 

It is important to learn these concepts because in order to implement something that is derived from these concepts or something that isn't provided out of the box,  then you need to have a clear understanding of the underlying mechanism first.

**Note:** I have already covered the fundamentals of automatic differentiation in [this](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3) notebook, hence I will skip those details here and we will directly work on examples

In [1]:
import time
import numpy as np

import jax
import jax.numpy as jnp
from jax import random
from jax import make_jaxpr
from jax import vmap, pmap, jit
from jax import grad, value_and_grad
from jax.test_util import check_grads


%config IPCompleter.use_jedi = False

# Gradients

I will try to include all the examples we saw in the [TensorFlow Autodiff Notebook](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3) so that you can compare them both side by side and learn the differences between them

The `grad` function in JAX is used for computing the gradients. As we know that the basic idea behind JAX is to work with function compositions,`grad` also takes a callable as input and returns a callable. So, whenever we want to do the computation of the gradients, we need to pass a callable to `grad` first. Let's take an example to make it more clear

In [2]:
def product(x, y):
    z = x * y
    return z


x = 3.0
y = 4.0

z = product(x, y)

print(f"Input Variable x: {x}")
print(f"Input Variable y: {y}")
print(f"Product z: {z}\n")

# dz / dx
dx = grad(product, argnums=0)(x, y)
print(f"Gradient of z wrt x: {dx}")

# dz / dy
dy = grad(product, argnums=1)(x, y)
print(f"Gradient of z wrt y: {dy}")

Input Variable x: 3.0
Input Variable y: 4.0
Product z: 12.0

Gradient of z wrt x: 4.0
Gradient of z wrt y: 3.0


Let us break down the above example and try to understand the gradients calculation step by step.

1. We have a function named `product(...)` that takes two-position arguments as input and returns the product of these arguments
2. We pass the `product(...)` function to `grad`  to compute the gradients. The `argnums` argument in `grad` tells `grad` to differentiate the function wrt `ith` positional argument. Hence we pass `0` and `1` for calculating the gradients wrt `x` and `y` correspondingly.

You can also calculate the value of the function and the gradients in one go. For this we will use the `value_and_grad(...)` function

In [3]:
z, dx = value_and_grad(product, argnums=0)(x, y)
print("Product z:", z)
print(f"Gradient of z wrt x: {dx}")

Product z: 12.0
Gradient of z wrt x: 4.0


# Jaxprs and `grad`

As we can combine function transforms in JAX, we can make `jaxprs` from the grad function to understand what is going on behind the scene. Let's take an example.

In [4]:
# Differentiating wrt first positional argument `x`
print("Differentiating wrt x")
print(make_jaxpr(grad(product, argnums=0))(x, y))


# Differentiating wrt second positional argument `y`
print("\nDifferentiating wrt y")
print(make_jaxpr(grad(product, argnums=1))(x, y))

Differentiating wrt x
{ lambda  ; a b.
  let _ = mul a b
      c = mul 1.0 b
  in (c,) }

Differentiating wrt y
{ lambda  ; a b.
  let _ = mul a b
      c = mul a 1.0
  in (c,) }


Notice that the argument other than the one wrt which we are differentiating is a constant with a value of `1`.

# Stopping Gradients computation

Sometimes we do not want the gradients to flow through some of the variables involved in a specific computation. In that case, we need to tell JAX explicitly that we don't want the gradients to flow through the specified set of variables. We will look into complex examples of this later on, but for now, I will modify our `product(...)` function where we do not want the gradients to flow through `y`

In [5]:
# Modified product function. Explicity stopping the
# flow of the gradients through `y`
def product_stop_grad(x, y):
    z = x * jax.lax.stop_gradient(y)
    return z

In [6]:
# Differentiating wrt y. This should return 0
grad(product_stop_grad, argnums=1)(x, y)

DeviceArray(0., dtype=float32)

# Gradients per sample

In reverse mode, the gradients are defined only for a function that outputs a scalar e.g. backpropagating on your loss value to update the parameters of your machine learning model. The loss is always a scalar value. What if your function returns a batch and you want to calculate the gradients per sample for that batch? 

These things are pretty straightforward in JAX (thanks to `vmap` and `pmap`). In the next example, we will perform these steps in order:

1. Write a function that takes an input and applies `tanh` on the input. This function is written in a way that works on a single example (Remember the philosophy behind `vmap` from the last tutorial?)
2. We will check if we can compute the gradients on a single example
3. We will then pass a batch of inputs and compute the gradients for the whole batch

In [7]:
def activate(x):
    """Applies tanh activation."""
    return jnp.tanh(x)


# Check if we can compute the gradients for a single example
grads_single_example = grad(activate)(0.5)
print("Gradient for a single input x=0.5: ", grads_single_example)


# Now we will generate a batch of random inputs, and will pass
# those inputs to our activate function. And we will also try to
# calculate the grads on the same batch in the same way as above

# Always use the PRNG
key = random.PRNGKey(1234)
x = random.normal(key=key, shape=(5,))
activations = activate(x)

print("\nTrying to compute gradients on a batch")
print("Input shape: ", x.shape)
print("Output shape: ", activations.shape)

try:
    grads_batch = grad(activate)(x)
    print("Gradients for the batch: ", grads_batch)
except Exception as ex:
    print(type(ex).__name__, ex)

Gradient for a single input x=0.5:  0.7864477

Trying to compute gradients on a batch
Input shape:  (5,)
Output shape:  (5,)
TypeError Gradient only defined for scalar-output functions. Output had shape: (5,).


So what's the solution then? Well `vmap` and `pmap` is the solution to almost everything, Let's see it in action

In [8]:
grads_batch = vmap(grad(activate))(x)
print("Gradients for the batch: ", grads_batch)

Gradients for the batch:  [0.48228705 0.45585024 0.99329686 0.0953269  0.8153717 ]


Let's break down all the modifications we did above to achieve the desired results.

1. `grad(activate)(...)` works for a single example
2. Adding `vmap` composition adds a batch dimension (defaults to 0) to our inputs and outputs

Its' that simple to go from a single example to a batch and vice-versa. All you need is to focus on using `vmap`. Let's see how the `jaxpr` for this transformation looks like

In [9]:
make_jaxpr(vmap(grad(activate)))(x)

{ lambda  ; a.
  let b = tanh a
      c = sub 1.0 b
      d = mul 1.0 c
      e = mul d b
      f = add_any d e
  in (f,) }

# Composition of other transformations

We can combine any other transformation with `grad`. We already saw `vmap` applied with `grad`. Let's apply `jit` to the above transformation to make it more efficient.

In [10]:
jitted_grads_batch = jit(vmap(grad(activate)))

for _ in range(3):
    start_time = time.time()
    print("Gradients for the batch: ", jitted_grads_batch(x))
    print(f"Time taken: {time.time() - start_time:.2f} seconds")
    print("="*50)
    print()

Gradients for the batch:  [0.48228705 0.45585027 0.99329686 0.09532695 0.8153717 ]
Time taken: 0.02 seconds

Gradients for the batch:  [0.48228705 0.45585027 0.99329686 0.09532695 0.8153717 ]
Time taken: 0.00 seconds

Gradients for the batch:  [0.48228705 0.45585027 0.99329686 0.09532695 0.8153717 ]
Time taken: 0.00 seconds



# Validate finite differences
Many times we want to verify the computation of the gradients with finite differences to double-check if everything we did is right. Because this is a pretty-common sanity check while working with derivatives, JAX provides a convenient function `check_grads` that checks the finite differences to any order of gradients. Let's take a look

In [11]:
try:
    check_grads(jitted_grads_batch, (x,),  order=1)
    print("Gradient match with gradient calculated using finite differences")
except Exception as ex:
    print(type(ex).__name__, ex)

Gradient match with gradient calculated using finite differences


# Higher Order Gradients

`grad` function takes a callable as an input and returns another function. We can compose the function returned by the transformation with `grad` again and again to compute higher-order derivates of any order. Let's take an example to see it in action. We will use our `activate(...)` function to demonstrate this

In [12]:
x = 0.5

print("First order derivative: ", grad(activate)(x))
print("Second order derivative: ", grad(grad(activate))(x))
print("Third order derivative: ", grad(grad(grad(activate)))(x))

First order derivative:  0.7864477
Second order derivative:  -0.726862
Third order derivative:  -0.5652091


# Gradients and numerical stability

`Underflow` and `Overflow` are common problems that we run into many times especially while computing the gradients. We will take an example (this one is straight from the JAX docs, and it's a pretty good example) to illustrate how we can run into numerical instability and how JAX tries to aid you to overcome it.

In [13]:
# An example of a mathematical operation in your workflow
def log1pexp(x):
    """Implements log(1 + exp(x))"""
    return jnp.log(1. + jnp.exp(x))

What happens when you compute the gradients for some value?

In [14]:
# This works fine
print("Gradients for a small value of x: ", grad(log1pexp)(5.0))

Gradients for a small value of x:  0.9933072


In [15]:
# But what about for very large values of x for which the
# exponent operation will explode
print("Gradients for a large value of x: ", grad(log1pexp)(500.0))

Gradients for a large value of x:  nan


Woah! What just happened? Let's break it down to understand the expected output and what is gpoing on behind the scene in JAX that returned `nan`. We know that derivative of the above function can be written like this:

<div align="center">$\frac{\mathrm{d} }{\mathrm{d} x}(log(1 + e^{x}) = \frac {e^{x}} {1 + e^{x}}$</div><br>

For very large values, you would expect the value of the derivative to be 1 but when we combined `grad` with our function implementation, it returned `nan`. To get more insights, we can break down the gradients computation by looking at the `jaxpr` of the transformation 

In [16]:
make_jaxpr(grad(log1pexp))(500.0)

{ lambda  ; a.
  let b = exp a
      c = add b 1.0
      _ = log c
      d = div 1.0 c
      e = mul d b
  in (e,) }

If you take a closer look, you will notice that the computation is equivalent to this:

<div align="center">$\frac{1}{1 + e^{x}} * e^{x}$</div><br>

For large values, the term on the right-hand side will be rounded off to `inf`, and the `grad` computation will return `nan` as we saw above. A human knows how to compute the gradient correctly in this case but JAX doesn't. It is working on the standard autodiff rules. So, how do we tell JAX that our function should be differentiated in the way we want? We can achieve this using `custom_vjp` or `custom_vjp` functions in JAX. Let's see it in action

In [17]:
from jax import custom_jvp

@custom_jvp
def log1pexp(x):
    """Implements log(1 + exp(x))"""
    return jnp.log(1. + jnp.exp(x))

@log1pexp.defjvp
def log1pexp_jvp(primals, tangents):
    """Tells JAX to differentiate the function in the way we want."""
    x, = primals
    x_dot, = tangents
    ans = log1pexp(x)
    # This is where we define the correct way to compute gradients
    ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot
    return ans, ans_dot

In [18]:
# Let's now compute the gradients for large values
print("Gradients for a small value of x: ", grad(log1pexp)(500.0))

Gradients for a small value of x:  1.0


In [19]:
# What about the Jaxpr?
make_jaxpr(grad(log1pexp))(500.0)

{ lambda  ; a.
  let _ = custom_jvp_call_jaxpr[ fun_jaxpr={ lambda  ; a.
                                             let b = exp a
                                                 c = add b 1.0
                                                 d = log c
                                             in (d,) }
                                 jvp_jaxpr_thunk=<function _memoize.<locals>.memoized at 0x7f99201a3200>
                                 num_consts=0 ] a
      b = exp a
      c = add b 1.0
      d = div 1.0 c
      e = sub 1.0 d
      f = mul e 1.0
  in (f,) }

Let's break down the steps we did to achieve the expected results.

1. We decorated our `log1pexp(...)` with `custom_vjp` that computes the Jacobian-vector product (forward-mode)
2. we then defined `log1pexp_jvp(...)` that defines how the gradients should be computed. Focus on this line of code in that function: `ans_dot = (1 - 1/(1 + jnp.exp(x))) * x_dot`. Simply written, all we are doing is to rearrange the derivative in this way:

<div align="center">$\frac{\mathrm{d} }{\mathrm{d} x}(log(1 + e^{x}) = 1 - \frac {1} {1 + e^{x}}$</div><br>


3. We decorate the `logp1exp_jvp(...)` function with `log1pexp.defjvp` to tell JAX that for calculating JVP, please consume the function we have defined and return the expected output


That's it for this tutorial folks! Many things were left out of this tutorial on purpose. For example, we didn't cover the forward-mode, and reverse-mode in detail because the scope of those concepts is outside of this notebook. If you want to understand those concepts, I highly suggest going through this [Advanced Autodiff Documentation](https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html). There is one last fundamental concept remaining for the JAX series that we will be covering in the upcoming tutorial.

# References

1. https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
2. https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html