In [None]:
from jax.config import config
config.update("jax_enable_x64", True)

import numpy as np
import jax.numpy as jnp

from jax import grad, jit, vmap
from jax import random


# Introduction to JAX

the purpose of this notebook is to understand, how JAX can help with analysing dynamical systems. 

It seems that JAX per default is [32 bit single precision floating point](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision), let us however first keep with 64bit as default.  This is obtained through the two lines of code 
```python
from jax.config import config
config.update("jax_enable_x64", True)
```

JAX provides all functions, which numpy typically provides. Using
```python
import jax.numpy as jnp
```
we can now use `jnp.`, whenever we previously would have used `np.`.

JAX also has a just-in-time compiler, which seems to work similar to numba by annotating functions.  (need to understand better difference to numba)


## Follow [tutorial](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)


In [None]:
key = random.PRNGKey(0)
x = random.normal(key, (10,))
print(x)


In [None]:
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready()  # runs on the GPU


In [None]:
import numpy as np

In [None]:
%timeit np.dot(x, x.T)

In [None]:
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))

In [None]:
derivative_fn(jnp.array([2,5,2.0,-1]))

In [None]:
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))

def apply_matrix(v):
    return jnp.dot(mat, v)

In [None]:
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()


In [None]:
@jit
def vmap_batched_apply_matrix(v_batched):
    return vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()


In [None]:
def make_a_func():
    a = 5
    return lambda:  a

bla = make_a_func()
bla()

In [None]:
from jax import lax

In [None]:
def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)
norm_compiled = jit(norm)



In [None]:
np.random.seed(1701)
X = jnp.array(np.random.rand(10000, 10))
np.allclose(norm(X), norm_compiled(X), atol=1E-13)

In [None]:
%timeit norm(X).block_until_ready()
%timeit norm_compiled(X).block_until_ready()


In [None]:
lax.abs_p.impl

In [None]:
a=7
bla()

In [None]:
x_small

In [None]:
jnp.finfo(jnp.float64).eps

In [None]:
import numpy as np

def init_mlp_params(layer_widths):
  params = []
  for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
    params.append(
        dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
             biases=np.ones(shape=(n_out,))
            )
    )
  return params

params = init_mlp_params([1, 128, 128, 1])

In [None]:
import jax

In [None]:
jax.tree_map(lambda x: x.shape, params)

In [None]:
2 * [3,4] *2

In [None]:
np.array(range(6))

In [None]:
np.concatenate(list(range(6)),axis=1)