# PRNGKEY vs GLOBAL seeding

We will be following this very neat [explanation](https://github.com/google/jax/blob/master/design_notes/prng.md).

## JAX imports

## Beginner
### Prerequisites
- NumPy

### Imports

In [17]:
import numpy as np
from jax.random import normal, uniform, PRNGKey, split

### Questions:

#### Q1: 
Write a sampler for the exponential distribution: $p_ \lambda(x) = \lambda\exp(-\lambda x)$

#### Q2:
Implement the following Monte-Carlo integrator:

```python
def mc_integrator(fun, sampler, N, key):
    # fun: callable(x)
    # sampler: callable(key, N)
    # N int
    ...
```
    

## Intermediate
### Prerequisites
- Beginner randomness
- Beginner loops
- Beginner if-else

### Imports

In [18]:
from jax.lax import scan, while_loop
import jax.numpy as jnp
from jax.random import normal
from jax.scipy.stats.norm import pdf

### Questions:

#### Q1: 
Implement Von Neumann's acceptance-rejection method:

In [40]:
def acceptance_rejection(target_lik, proposal_lik, proposal_sampler, c, N):
    res = []
    while len(res) < N:
        y = proposal_sampler()
        u = np.random.rand()
        lik_ratio = target_lik(y) / (c * proposal_lik(y))
        if u > lik_ratio:
            continue
        else:
            res.append(y)
    return np.stack(res)

And test it on the uniform disk sampling:

In [41]:
def target_lik(y):
    if y[0] ** 2 + y[1] ** 2 < 1:
        return 1 / np.pi ** 2
    return 0

In [42]:
def proposal_lik(y):
    return np.exp(-(y[0] ** 2 + y[1] ** 2) / 2) / np.sqrt(2 * np.pi)

c = np.sqrt(2 * np.pi) * np.exp(0.5) * np.pi ** 2

#### Q2:
Implements your own parallel version of `associative_scan` using jax primitives.

## Advanced
### Prerequisites
- Intermediate randomness
- Beginner vectorisation

### Questions:

#### Q1: 
Noting that in the Von Neumann's acceptance-rejection method, $\mathbb{P}\left(u<\frac{f(y)}{cg(y)}\right) = \frac{1}{c}$, make it more efficient in average.

#### Q2: 
Implement Metropolis-Hastings algorithm using JAX primitives (specifically try to use `scan`)

In [None]:
def mh(f, n=1000):
    arr = np.empty((n, 2))
    x, y = np.random.uniform(-1, 1, 2)
    p = f(x, y)
    for i in range(n):
        eps_x, eps_y = np.random.uniform(-1, 1, 2)
        xi, yi = x + eps_x, y + eps_y
        pi = f(xi, yi)
        if np.random.rand() <= pi / p:
            x, y = xi, yi
            p = pi
        arr[i] = x, y
    return arr