Just After eXecution (JAX) is library made by DeepMind for numerical computations. Its quite similar to NumPy, but yet has a number of advantages over traditional NumPy, like:

- JIT Compilation
- Automatic Differentation
- GPU/TPU support
- Auto-vectorization
- Better Random Numbers
- Supported by Keras (3.0)

and so on.

JAX has a world of its own and can nowhere be presented/covered properly here. So its really an introduction only here. Lets begin by importing it:

In [4]:
import jax

## JAX's NumPy

JAX uses it's own version of NumPy. It can be imported as:

In [5]:
import jax.numpy as jnp

Luckily, most of the syntax is same, meaning we don't have to re-learn the NumPy. For example:

In [6]:
a = jnp.ones((1,10))
a

Array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)

Now this warning is weird, but it just means that JAX is unable to find the GPU/TPU it expects by default. This warning can be suppressed as:

`jax.config.update('jax_platform_name','cpu')`

In [7]:
jax.config.update('jax_platform_name','cpu')

Coming back to our NumPy and JAX's NumPy equivalence, we can see a few other examples too:

In [8]:
b = jnp.zeros((2,3))
b

Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [9]:
c = jnp.arange(1,20,2)
c

Array([ 1,  3,  5,  7,  9, 11, 13, 15, 17, 19], dtype=int32)

### Differences

So luckily the syntax is same. Now lets see where lies the difference between the two?

#### 1. Datatype

Obviously, JAX and NumPy arrays will have different data types as we can verify here:

In [11]:
type(a)

jaxlib.xla_extension.ArrayImpl

While NumPy arrays are simply `ndarray`.

In [12]:
import numpy
d = numpy.array([1,3,4])

type(d)

numpy.ndarray

#### 2. Mutability

If you recall, NumPy arrays are mutable. For example,

In [13]:
d[2] = 5

d

array([1, 3, 5])

Whereas, try to do the same for the JAX array.

In [14]:
c[4] = -1

c

TypeError: ignored

> **Note:** Don't let the error mislead you. While we can use `x = x.at[idx].set(y)`, it makes a new (modified) copy of the array. Mutation is not allowed in JAX arrays. Period!

#### 3. Initialization from Python Collections

NumPy and JAX differ in the way they are initialized from Python collections. As a revision, please check the NumPy's behaviour:

In [15]:
listA = [1, 2, 3]
setB = {2,3,5}
tupleC = (1,3,5)

In [18]:
a = numpy.array(listA)
b = numpy.array(setB)
c = numpy.array(tupleC)

print(a, type(a))
print(b, type(b))
print(c, type(c))

[1 2 3] <class 'numpy.ndarray'>
{2, 3, 5} <class 'numpy.ndarray'>
[1 3 5] <class 'numpy.ndarray'>


They all are permissible and all of them are normal `ndarray`(s). Now observe the JAX's behaviour:

In [21]:
a = jnp.array(listA)
b = jnp.array(setB)
c = jnp.array(tupleC)

print(a, type(a))
print(b, type(b))
print(c, type(c))

TypeError: ignored

**Conclusion:** JAX doesn't allow array creation from a set. And it sounds intuitive.

---
## Autograd

