# Randomness Jax

## Lesson Goals:

By the end of this lesson, you will be able to explain the difference between randomness in jax and numpy. You will also be able to explain why jax handles randomness the way it does. 

## Core Concepts:

- `rng`
- `vmap`

## Further Resources

[Jax Official Random Number Post](https://jax.readthedocs.io/en/latest/random-numbers.html)


In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax
from jax import random
import numpy as np

# RNG in Numpy

Notice how subsequent calls to `np.random.randn` causes a new number to be generated? This is because there is an internal state tracker, which advances with every subsequent call. Note how we can "reset" this state by calling `np.random.seed` again! 


In [None]:
import numpy as np

# Generate an array of random numbers
for i in range(3):
    print(f"Iteration: {i}")
    np.random.seed(42)
    
    for j in range(4):
        print(np.random.randn(1))
    print("")


# RNG In Jax

In keeping with Jax's philosophy of purity (and because it opens up some capabilities we'll touch on later), Jax's RNG system requires an explicit "key" to be passed in with every call. This purity is desirable because it forces us to be explicit in our RNG process, which means that reproducibility is easier across multiple machines. This reproducibility across multiple machines is extremely desirable, especially with the large models that need to be trained across multiple GPUs.

P.s. It's **probably** less of an issue nowadays, but look up "reproducible pytorch" and "reproducible tensorflow" to see a slew of stackoverflow posts about the hoops you have to jump through to get reproducible runs across machines.

## Generating random numbers in Jax

So how do we do this? Well, it looks very similar to the numpy code above! We pass in our original `key` into a `split` function, which generates a new `key` and `subkey`

In [None]:
 import jax
import jax.numpy as jnp


# Generate random numbers using the key
for i in range(3):
    key = jax.random.PRNGKey(0)
    print(f"Iteration: {i}")
    
    for j in range(4):
        key, subkey = jax.random.split(key)
        # Use the subkey to generate new random numbers
        new_random_numbers = jax.random.normal(subkey, shape=(1,))
        print(new_random_numbers)

    print("")


## Understanding the code:

```python
key = jax.random.PRNGKey(0)

...

key, subkey = jax.random.split(key)
```

We must first create the initial key via `jax.random.PRNGKey` (think of it like setting the `np.random.seed`) before we then create a "new key" from the original key via `jax.random.split`. 

## RNG Design

To maintain purity, JAX's random keys are immutable. And so, you need to split the key to generate new numbers. Splitting the key ties the state of our random generator to a number that can be more-easily repeated (think of it as setting the seed at every RNG call).



## Why bother?

It's clear that the two examples above behave similarly - we were able to reproduce the results by calling `np.random.seed` and `jax.random.PRNGKey` in between runs. However, in `numpy`, it's common to only set the seed at the start of the script. Very rarely do you see people resetting the seed in the middle of the program.

### Multiprocess/multithreaded problem

In fact, that raises another interesting problem: what happens in a multi-thread/ multi-process problem? 

In a single-process/thread program, resetting the numpy seed is trivial and relatively easy to reason about (but not necessarily).  However, in `jax`, this is simple, because we have to pass around the key and the randomness is made explicit.

### Control flow and misdirection

If you introduce control flow or loops in your functions, identifying where and when the seed is reset can be a headache.

# P.s you're (probably) using numpy's random module incorrectly

Look online, and you'll see a slew of code that looks like the following:

```python
import numpy as np
np.random.seed(42)

np.random.standard_normal(10)
```

but even the [official documentation](https://numpy.org/doc/stable/reference/random/legacy.html) calls this section the "legacy randomn

You're now supposed to use

```python
import numpy as np

rng = np.random.default_rng()
rng.standard_normal(10)
```

which creates a `Generator` object and ties the randomness to the state of a variable (as opposed to some global state). Tying the randomness to the state of a variable makes it easier to reason about in a multi-threaded scenario, but doesn't "fix" the core issue.
