# Jax's RNG

## Lesson Goals:

By the end of this lesson you will understand why `jax`'s [random module](https://jax.readthedocs.io/en/latest/jax.random.html) is structured the way it is. For additional (external) information read [Jax's Pseudorandom numbers](https://jax.readthedocs.io/en/latest/random-numbers.html) documentation.

## Core Concepts:

- `rng`
- `samplers`

## Concepts In action:

TODO


**Note**

It's easy to conflate library-level determinism and GPU-level determinism. Jax and how you generate random samples is deterministic

> `jax.random.split()` is a deterministic function that converts one key into several independent (in the pseudorandomness sense) keys.

as per [random-numbers](https://jax.readthedocs.io/en/latest/random-numbers.html), but depending on how you've configured your system environment, your GPU might be causing non-determinism. See [one of the creators of Jax's response to a question about determinism](https://github.com/google/jax/discussions/10674)


# Foreword

## Numpy

In [lesson 2](./exe_02_jit.ipynb) we had briefly discussed state and functional programming, and we showed how using stateful programs can lead to bugs or unexpected results (remember the shopping cart?)

Well, it turns out that when you run code like this:

```python
import numpy as np
np.random.seed(42)
np.random.rand()
```

you're actually relying on global state of the program. This reliance on global state leads to a whole slew of problems e.g. when you're distributing your program across multiple machines, processes or threads. 
Randomness has been a pit trap for many a machine learner (see the [The Pit of Reproducibility](#the-pit-of-reproducibility) for examples). In numpy, every subsequent call to a sampler changes the state of the world, which means that we cannot guarantee
that the same number will be generated for a fixed program. 


## Jax

How does Jax handle the issue of state, then? Well, it relies on:

1) explicitly passing in a random seed/key every time you sample (sounds cumbersome, but it's really not!)
2) 

## The Pit of Reproducibility

- [TensorFlow results are not reproducible despite using tf.random.set_seed](https://stackoverflow.com/questions/75850086/tensorflow-results-are-not-reproducible-despite-using-tf-random-set-seed)
- [Stackoverflow answer about using an experimental tensorflow function](https://stackoverflow.com/a/71311207/24169564), [tf.config.experimental.enable_op_determinism](https://www.tensorflow.org/versions/r2.8/api_docs/python/tf/config/experimental/enable_op_determinism)
- [Nvidia guide on Pytorch and Tensorflow Reproducibility](https://github.com/NVIDIA/framework-reproducibility)
