Just After eXecution (JAX) is library made by DeepMind for numerical computations. Its quite similar to NumPy, but yet has a number of advantages over traditional NumPy, like:

- JIT Compilation
- Automatic Differentation
- GPU/TPU support
- Auto-vectorization
- Better Random Numbers
- Supported by Keras (3.0)

and so on.

JAX has a world of its own and can nowhere be presented/covered properly here. So its really an introduction only here. Lets begin by importing it:

In [7]:
import jax

## JAX's NumPy

JAX uses it's own version of NumPy. It can be imported as:

In [8]:
import jax.numpy as jnp

Luckily, most of the syntax is same, meaning we don't have to re-learn the NumPy. For example:

In [9]:
a = jnp.ones((1,10))
a

Array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)

Now this warning is weird, but it just means that JAX is unable to find the GPU/TPU it expects by default. This warning can be suppressed as:

`jax.config.update('jax_platform_name','cpu')`

In [10]:
jax.config.update('jax_platform_name','cpu')

Coming back to our NumPy and JAX's NumPy equivalence, we can see a few other examples too:

In [11]:
b = jnp.zeros((2,3))
b

Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [12]:
c = jnp.arange(1,20,2)
c

Array([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19], dtype=int32)

### Differences

So luckily the syntax is same. Now lets see where lies the difference between the two?

#### 1. Datatype

Obviously, JAX and NumPy arrays will have different data types as we can verify here:

In [13]:
type(a)

jaxlib.xla_extension.ArrayImpl

While NumPy arrays are simply `ndarray`.

In [14]:
import numpy
d = numpy.array([1,3,4])

type(d)

numpy.ndarray

#### 2. Mutability

If you recall, NumPy arrays are mutable. For example,

In [15]:
d[2] = 5

d

array([1, 3, 5])

Whereas, try to do the same for the JAX array.

In [16]:
c[4] = -1

c

TypeError: ignored

> **Note:** Don't let the error mislead you. While we can use `x = x.at[idx].set(y)`, it makes a new (modified) copy of the array. Mutation is not allowed in JAX arrays. Period!

#### 3. Initialization from Python Collections

NumPy and JAX differ in the way they are initialized from Python collections. As a revision, please check the NumPy's behaviour:

In [18]:
listA = [1, 2, 3]
setB = {2,3,5}
tupleC = (1,3,5)

In [19]:
a = numpy.array(listA)
b = numpy.array(setB)
c = numpy.array(tupleC)

print(a, type(a))
print(b, type(b))
print(c, type(c))

[1 2 3] <class 'numpy.ndarray'>
{2, 3, 5} <class 'numpy.ndarray'>
[1 3 5] <class 'numpy.ndarray'>


They all are permissible and all of them are normal `ndarray`(s). Now observe the JAX's behaviour:

In [21]:
a = jnp.array(listA)
b = jnp.array(setB)
c = jnp.array(tupleC)

print(a, type(a))
print(b, type(b))
print(c, type(c))

TypeError: ignored

**Conclusion:** JAX doesn't allow array creation from a set. And it sounds intuitive.

---
## Autograd

The growing use/application of gradient-based optimizers (especially SGD family) means we need a proper framework for calculating derivatives. While NumPy uses [finite-difference methods only](https://numpy.org/doc/stable/reference/generated/numpy.gradient.html), JAX has a proper automatic differentiation method.

Let's start by importing it.

In [22]:
from jax import grad

### `grad()`

First thing first. In order to calculate the derivative of any function, we use `grad()` method.

Its syntax is:

**`<gradFunc> = grad(<function>)`**

#### First/Normal Derivatives

Let's test it for a function – say, $f(x) = e^x$.

Since $e^x$ is our stubborn guy (derivative is also same), we can quickly test it.

In [23]:
gradientExp = grad(jnp.exp)

Having made the gradient (itself a function), we can now test it. Lets suppose for $x=2$.

In [24]:
gradientExp(2.0)

Array(7.389056, dtype=float32, weak_type=True)

We can check it side by side to further confirm it.

In [25]:
for x in jnp.linspace(-2.0,2.0,10):
  print("Exponential of ", x, " is: ", jnp.exp(x), " and its derivative's value (for x) is: ", gradientExp(x))

Exponential of  -2.0  is:  0.13533528  and its derivative's value (for x) is:  0.13533528
Exponential of  -1.5555556  is:  0.21107209  and its derivative's value (for x) is:  0.21107209
Exponential of  -1.1111112  is:  0.32919297  and its derivative's value (for x) is:  0.32919297
Exponential of  -0.6666667  is:  0.5134171  and its derivative's value (for x) is:  0.5134171
Exponential of  -0.22222227  is:  0.8007374  and its derivative's value (for x) is:  0.8007374
Exponential of  0.22222227  is:  1.2488489  and its derivative's value (for x) is:  1.2488489
Exponential of  0.66666675  is:  1.9477342  and its derivative's value (for x) is:  1.9477342
Exponential of  1.1111112  is:  3.037732  and its derivative's value (for x) is:  3.037732
Exponential of  1.5555556  is:  4.737718  and its derivative's value (for x) is:  4.737718
Exponential of  2.0  is:  7.389056  and its derivative's value (for x) is:  7.389056


Lets try another example:

$$f(x) = 3x^2+9$$

$$f'(x) = 6x$$

In [26]:
func = lambda x: 3*x*x+9

gradFunc = grad(func)

**Note:** If we use integer values with `grad()`, it will throw an error as it expects float inputs. For example:

In [27]:
gradFunc(2)

TypeError: ignored

The correct way to do it would be to pass it as a float.

In [29]:
gradFunc(2.0)

Array(12., dtype=float32, weak_type=True)

#### Higher Derivatives

Similarly, we can calculate higher derivatives as well. For example, 2nd derivative of above function would be $6$ always. We can verify it too:

In [30]:
doubleDerivativeFunc = grad(gradFunc)

doubleDerivativeFunc(20.0)

Array(6., dtype=float32, weak_type=True)

So we can simply call `grad()` recursively to calculate higher derivatives.

#### Multivariate Functions

Similarly, we can extend it to calculate the derivatives for multivariate functions as well. Lets revise it quickly.

Suppose we have a function,

$$f(a,b) = sin(a) + cos(b)$$

To take its derivative, we have to specify the differentiation variable as well. Hence, we call them partial derivatives:

$$\frac{\partial f}{\partial a} = cos(a)$$

Since $b$ is independent to the change in $a$, hence its derivative will be $0$.

$$\frac{\partial f}{\partial b} = -sin(b)$$

---
Now lets try to calculate it in JAX:

In [31]:
func_ab = lambda a,b: jnp.sin(a)+jnp.cos(b)  #For the ease, we can assume both a and b are already in Radians.

df_da = grad(func_ab)

Since it takes two values, we have to specify both $a$ and $b$ in the arguments as well.

Lets test it for a value say $(\pi,2\pi)$

In [32]:
df_da(jnp.pi, 2*jnp.pi)

Array(-1., dtype=float32, weak_type=True)

As we can confirm that $cos(\pi) = -1$. But, how come did JAX realize that?

Whenever we take derivative of a multivariate function (i.e, pass it through the `grad()`), it implicitly assumes the differentiation variable as the first one. In case we want to take the partial derivative with respect to other variable (like $b$), we can do it by specifying `argnums=<x>`.

> **Note:** Here `x` specifies the variable's order (since it uses 0-based ordering, so 2nd variable will have `1` order, nth will have `n-1` and so on.)

Lets calculate the $\frac{\partial f}{\partial b}$.

In [33]:
df_db = grad(func_ab,argnums=1)
df_db(jnp.pi, 2*jnp.pi)

Array(-1.7484555e-07, dtype=float32, weak_type=True)

The answer may sound weird but its just an outcome of the weird world of floating-point numbers. If we round it to even 5 or 6 floating points, the answer will be $0$.

In [34]:
round(df_db(jnp.pi, 2*jnp.pi),6)

Array(-0., dtype=float32, weak_type=True)

---

### Multivariate Differentiation

So far, we have considered scalar functions and differentiated them wrt scalar variables as well. But their application is limited. If the job of autodiff was only to calculate the derivatives of scalar functions, it could easily have been done using a small set of functions. Even you can write it yourselves on some bored weekend, I am sure.

The real application of autograd lies in the multivariate and vector output functions. For that, we need to revise a little bit (a very tiny bit). Please don't run away with the mention of Linear Algebra. I will make sure we go through it before moving to coding.

#### Gradient

In the previous example, we had to make two separate partial derivatives. That's not a smart approach. For example, consider the case of multivariate regression here.

$$y' = b +a_1x_1+a_2x_2+\dots$$



Now, its partial derivatives will be:

$$\frac{\partial y'}{\partial a_1} = \dots$$
$$\frac{\partial y'}{\partial a_2} = \dots$$
$$ \vdots $$
$$\frac{\partial y'}{\partial a_n} = \dots$$

**Note:** We are just skipping the $\frac{\partial y'}{\partial b}$ here for simplicity, but its also calculated.

...

Since we have no idea about the number of parameters, so we can also write the above problem in the simple, vectorized form as:

$$y' = a.x+b$$

**Shouldn't there be a way of doing the same for the partial derivatives as well?**

Well luckily, we have one. We can combine all the partial derivatives in a vector, known as **Gradient**. It is represented by nabla ($\nabla$) sign.

$$ \nabla y' = \begin{bmatrix}
 \frac{\partial y'}{\partial a_1} \\
 \frac{\partial y'}{\partial a_2} \\
 \vdots \\
 \frac{\partial f}{\partial a_n}
\end{bmatrix}
$$

> **Gradient** Descent is called so for the same reasons as we are using the Gradient vector here.

---

Coming back to JAX, we don't need to do anything special as it automatically assumes functions to take vector inputs. Hence the name **`grad()`**.

In order to evaluate the gradient, we can specify the vector the same way we evaluate it for the scalar values. For example:


We have a function (its not multivariate regression):

$$y' = x_0-2cos(x_1)+4x_2^2$$

In [74]:
def FuncA(x):
  return x[0]-2*jnp.cos(x[1])+4*x[2]*x[2]

In [75]:
gradientVec = grad(FuncA)

<function __main__.FuncA(x)>

We can evaluate it any point in the 3D space. Say, on $(1,1,-1)$:

In [83]:
x = jnp.array((1.0,0.5*jnp.pi,-1.0))
gradientVec(x)

Array([ 1.,  2., -8.], dtype=float32)

Similarly, we can test it on other values too. A gradient can be computed on any point in the $\mathcal R^n$ space.

Interestingly, **it works on Python primitive collections too** (sets are not allowed. So let's try them to round it off.

In [85]:
tupleD = (1.0,2.0,0.0)
gradientVec(tupleD)

(Array(1., dtype=float32, weak_type=True),
 Array(1.8185948, dtype=float32, weak_type=True),
 Array(0., dtype=float32, weak_type=True))

In [87]:
listE = [0.0,jnp.pi,0.0]
gradientVec(listE)

[Array(1., dtype=float32, weak_type=True),
 Array(-1.7484555e-07, dtype=float32, weak_type=True),
 Array(0., dtype=float32, weak_type=True)]

**Caution for Dictionaries:** It can run on dictionaries too, but we will have to update the main function accordingly.

In [88]:
dictionaryF = {"x0":0.5,"x1":-jnp.pi,"x2":0.5}
gradientVec(dictionaryF)

KeyError: ignored

#### Jacobian

Gradient maps $\mathcal R^n \to \mathcal R$.

There can also be scenarios where a function mapping $\mathcal R^n \to \mathcal R^m$. For such a functions, we generalize gradients to the whole matrices.

--To be continued--