# PRNGKEY vs GLOBAL seeding

We will be following this very neat [explanation](https://github.com/google/jax/blob/master/design_notes/prng.md).

## JAX imports

## Beginner
### Prerequisites
- NumPy

### Imports

In [14]:
import numpy as np
from scipy.stats import norm as sp_norm
from jax.random import normal, uniform, key, split

JAX uses a functional approach to randomness. It uses a `key` to generate random numbers, which is then split to generate more keys. This is done to ensure reproducibility over multiple runs/parallel runs.

In [4]:
jax_key = key(0)
print(jax_key)

Array((), dtype=key<fry>) overlaying:
[0 0]


By default JAX keys are 2-tuple of 32-bit integers, according to the threefry algorithm.
This can be overriden to use other algorithms if needed (do not do this unless you know what you are doing).

In [7]:
jax_key = key(0, impl="rbg")
print(jax_key)

Array((), dtype=key<rbg>) overlaying:
[0 0 0 0]


This can then be used to generate random numbers via an interface similar to the new NumPy/SciPy API

In [22]:
jax_key = key(0)
print(normal(jax_key))

-0.20584226


In [23]:
np_gen = np.random.Generator(np.random.PCG64(0))
sp_norm.rvs(random_state=np_gen)

np.float64(0.1257302210933933)

However, contrary to NumPy, the JAX key will not automatically update and calling the function again will return the same value.

In [24]:
print(sp_norm.rvs(random_state=np_gen))
print(sp_norm.rvs(random_state=np_gen))  # the values change

print(normal(jax_key))
print(normal(jax_key))  # the values remain the same

-0.1321048632913019
0.6404226504432821
-0.20584226
-0.20584226


To generate new keys, we can use the `split` function

In [25]:
jax_key, subkey = split(jax_key)
print(normal(subkey))

-1.2515389


The `split` function also takes an optional argument `n` to generate `n` keys, which is useful for parallelisation or when running a loop.

In [26]:
many_keys = split(jax_key, 42)

### Questions:

#### Q1: 
Write a sampler for the exponential distribution: $p_ \lambda(x) = \lambda\exp(-\lambda x)$

#### Q2:
Implement the following Monte-Carlo integrator:

```python
def mc_integrator(fun, sampler, N, key):
    # fun: callable(x)
    # sampler: callable(key, N)
    # N int
    ...
```
    

## Intermediate
### Prerequisites
- Beginner randomness
- Beginner loops
- Beginner if-else

### Imports

In [2]:
from jax.lax import scan, while_loop
import jax.numpy as jnp
from jax.random import normal
from jax.scipy.stats.norm import pdf

### Questions:

#### Q1: 
Implement Von Neumann's acceptance-rejection method:

In [40]:
def acceptance_rejection(target_lik, proposal_lik, proposal_sampler, c, N):
    res = []
    while len(res) < N:
        y = proposal_sampler()
        u = np.random.rand()
        lik_ratio = target_lik(y) / (c * proposal_lik(y))
        if u > lik_ratio:
            continue
        else:
            res.append(y)
    return np.stack(res)

And test it on the uniform disk sampling:

In [41]:
def target_lik(y):
    if y[0] ** 2 + y[1] ** 2 < 1:
        return 1 / np.pi**2
    return 0

In [42]:
def proposal_lik(y):
    return np.exp(-(y[0] ** 2 + y[1] ** 2) / 2) / np.sqrt(2 * np.pi)


c = np.sqrt(2 * np.pi) * np.exp(0.5) * np.pi**2

#### Q2:
Implements your own parallel version of `associative_scan` using jax primitives.

## Advanced
### Prerequisites
- Intermediate randomness
- Beginner vectorisation

### Questions:

#### Q1: 
Noting that in the Von Neumann's acceptance-rejection method, $\mathbb{P}\left(u<\frac{f(y)}{cg(y)}\right) = \frac{1}{c}$, make it more efficient in average.

#### Q2: 
Implement Metropolis-Hastings algorithm using JAX primitives (specifically try to use `scan`)

In [None]:
def mh(f, n=1000):
    arr = np.empty((n, 2))
    x, y = np.random.uniform(-1, 1, 2)
    p = f(x, y)
    for i in range(n):
        eps_x, eps_y = np.random.uniform(-1, 1, 2)
        xi, yi = x + eps_x, y + eps_y
        pi = f(xi, yi)
        if np.random.rand() <= pi / p:
            x, y = xi, yi
            p = pi
        arr[i] = x, y
    return arr