# Autograd: automatic differentiation

In `jax` the `grad` function takes a function and returns a diff function

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap, jacobian, value_and_grad
from jax import random

key = random.PRNGKey(0)

Create a tensor:

In [2]:
# Create a 2x2 tensor with gradient-accumulation capabilities
x = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
print(x)

[[1. 2.]
 [3. 4.]]


Do an operation on the tensor:

In [3]:
# We can do an operation like this, but unlike a torch tensor
# the arrays are just arrays, and don't keep state respective to operations
y = x - 2
print(y)

[[-1.  0.]
 [ 1.  2.]]


So, let's think about the grad for this, `y` is also a matrix. And when differentiating w.r.t X, which is a matrix, it will differentiate w.r.t `x_ij` where i is the row and j is the column. And therefore, what we get is a jacobian matrix. This can be calculated using the jacobian function

In [4]:
def subtract_2(X):
    return jnp.subtract(X, 2)

jac_x = jacobian(subtract_2, argnums=0)(x)

In [5]:
print(jac_x)
print(jac_x.shape)

[[[[1. 0.]
   [0. 0.]]

  [[0. 1.]
   [0. 0.]]]


 [[[0. 0.]
   [1. 0.]]

  [[0. 0.]
   [0. 1.]]]]
(2, 2, 2, 2)


And this is amazing! Work it out and check that this is correct! 

In [6]:
jac_x[0][0]

DeviceArray([[1., 0.],
             [0., 0.]], dtype=float32)

In [7]:
jac_x[0][1]

DeviceArray([[0., 1.],
             [0., 0.]], dtype=float32)

In [8]:
## Lets do more operations

def self_mul_times_3(X):
    return X * X * 3

def loss(X):
    return jnp.mean(self_mul_times_3(subtract_2(X)))

a = loss(x)
print(a)

4.5


And so now we can use the `grad` function to get gradients w.r.t x

In [9]:
grad_x = grad(loss, argnums=0)(x)

print(grad_x)

[[-1.5  0. ]
 [ 1.5  3. ]]


We can also just simply get both the value and the gradients by using `jax.value_and_grad` function!

In [10]:
value_a, grad_x = value_and_grad(loss, argnums=0)(x)
print(value_a)
print(grad_x)

4.5
[[-1.5  0. ]
 [ 1.5  3. ]]
