<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="300" height="300" align="center"/><br>

Welcome to another JAX tutorial. Till now, we have discussed **DeviceArray** and **Pure Functions** in JAX. Today, we will dive
into another important concept **`Pseudo Random Number Generation`** in JAX. We all have been using `random numbers` in libraries
like `numpy`, `scikit-learn`, `TensorFlow`, `PyTorch`, etc. We will see how PRNGs, as done in `numpy`, are not good enough and how JAXtries to overcome those limitations.

As usual, if you haven't gone through the previous tutorials, I highly suggest going through them. Here are the links:

1. [TF_JAX_Tutorials - Part 1](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part1)
2. [TF_JAX_Tutorials - Part 2](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part2)
3. [TF_JAX_Tutorials - Part 3](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3)
4. [TF_JAX_Tutorials - Part 4 (JAX and DeviceArray)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-4-jax-and-devicearray)
5. [TF_JAX_Tutorials - Part 5 (Pure Functions in JAX)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-5-pure-functions-in-jax/)


Lets' start with the basic concepts of PRNGs first, and then we will look into the implementation differences between `numpy` and JAX for the same

# PRNG - An Introduction

Before we look into random numbers generation, here are a few questions that you should ask:<br>
1. What is PRNG?
2. What is PRNG used for?
3. Why should you care about PRNG?


Let's tackle those questions one by one and try to understand the `why` and `how` for the things that we are going to learn in this tutorial

## What is PRNG?

If we go by the definition, then **`Pseudo Random Number Generation`** is a process of generating a sequence of random numbers **algorithmically** such that the properties of the generated random numbers approximate the properties of a sequence of random numbers sampled from an appropriate distribution. And when we say `random`, it means that the probability of predicting this sequence is no better than a random guess.

Although we are concerned about the randomness here, pseudo random number generation isn't *truly* a random process. Why? Because the sequence is determined by the initial value or initial state provided to the algorithm. The **algorithm** used to generate these sequences of random numbers is known as **`Pseudo Random Number Generator`**


## What is PRNG used for?

PRNG has a lot of use cases but among the most interesting ones are in `Cryptography`, `Simulations`, `Games`, `Data Science and Machine Learning`(of course), etc. You might have noticed that *most* people set the **`seed`** in their Data Science and Machine Learning workflow. The **seed** is known as the initial value!


## Why should you care about PRNG?

Although there are tons of use cases of PRNG, I would keep this very specific to Data Science and Machine Learning workflow. When we set a seed, what we try to solve is the `reproducibility` issue. Although `reproducibility` depends on a lot of things, I will use the term very loosely in this context. 

We deal with random states in Machine Learning work more often than we think about. For example, splitting a dataset into training and validation sets, sampling the weights of a hidden layer from a given distribution in a neural network, sampling a noise vector from a Gaussian distribution, etc. So, when we say `reproducible` in this context, what we mean is that no matter how many times I run the same process, I should get the same sequence of random numbers. That's why setting a seed becomes important. 

**Note:** Saying it again, setting `seed` doesn't solve the reproducibility crisis of a workflow, it's just a first step to ensure it.

Let's take an example to clear the point about reproducibility!

In [1]:
import numpy as np
from joblib import Parallel, delayed

import jax
from jax import jit
import jax.numpy as jnp

%config IPCompleter.use_jedi = False

# Random Numbers in Numpy

In [2]:
# If I set the seed, would I get the same sequence of random numbers every time?

for i in range(10):
    # Set initial value by providing a seed value
    seed = 0 
    np.random.seed(seed)
    
    # Generate a random integer from a range of [0, 5)
    random_number = np.random.randint(0, 5)
    print(f"Seed: {seed} -> Random number generated: {random_number}")

Seed: 0 -> Random number generated: 4
Seed: 0 -> Random number generated: 4
Seed: 0 -> Random number generated: 4
Seed: 0 -> Random number generated: 4
Seed: 0 -> Random number generated: 4
Seed: 0 -> Random number generated: 4
Seed: 0 -> Random number generated: 4
Seed: 0 -> Random number generated: 4
Seed: 0 -> Random number generated: 4
Seed: 0 -> Random number generated: 4


Let's take a bit complex example. We will take an array and split the array into two arrays.

In [3]:
# Array of 10 values
array = np.arange(10)

for i in range(5):
    # Set initial value by providing a seed value
    seed = 1234
    np.random.seed(seed)
    
    # Choose array1 and array2 indices
    train_indices = np.random.choice(array, size=8)
    valid_indices = np.random.choice(array, size=2)
    
    # Split the array into two sets
    train_array = array[train_indices]
    valid_array = array[valid_indices]
    
    print(f"Iteration: {i+1}  Seed value: {seed}\n")
    print(f"First array: {train_array}  Second array: {valid_array}")
    print("="*50)
    print("")

Iteration: 1  Seed value: 1234

First array: [3 6 5 4 8 9 1 7]  Second array: [9 6]

Iteration: 2  Seed value: 1234

First array: [3 6 5 4 8 9 1 7]  Second array: [9 6]

Iteration: 3  Seed value: 1234

First array: [3 6 5 4 8 9 1 7]  Second array: [9 6]

Iteration: 4  Seed value: 1234

First array: [3 6 5 4 8 9 1 7]  Second array: [9 6]

Iteration: 5  Seed value: 1234

First array: [3 6 5 4 8 9 1 7]  Second array: [9 6]



<div class="alert alert-warning"> <b>Note: </b>The one we saw above is the <b>legacy</b> way to generate a sequence of random numbers in numpy. It uses a legacy generator provided by numpy <i><a href="https://numpy.org/doc/stable/reference/random/legacy.html">RandomState(...)</a></i>. But this is also the one that is most widely used. There is another functions (preferred way as per the docs) <i>np.random.default_rng()</i> that uses the default BitGenerator for generating random sequences.
</div><br><br>

Let's repeat the above example with `default_rng(...)` as well. Because this is a different RNG, we should expect a different sequence here.

In [4]:
# Same example but with a different kind of random number generator
for i in range(5):
    # Set initial value by providing a seed value
    seed = 1234
    rng = np.random.default_rng(seed)
    
    # Choose array1 and array2 indices
    train_indices = np.random.choice(array, size=8)
    valid_indices = np.random.choice(array, size=2)
    
    # Split the array into two sets
    train_array = array[train_indices]
    valid_array = array[valid_indices]
    
    print(f"Iteration: {i+1}  Seed value: {seed}\n")
    print(f"First array: {train_array}  Second array: {valid_array}")
    print("="*50)
    print("")

Iteration: 1  Seed value: 1234

First array: [8 0 5 0 9 6 2 0]  Second array: [5 2]

Iteration: 2  Seed value: 1234

First array: [6 3 7 0 9 0 3 2]  Second array: [3 1]

Iteration: 3  Seed value: 1234

First array: [3 1 3 7 1 7 4 0]  Second array: [5 1]

Iteration: 4  Seed value: 1234

First array: [5 9 9 4 0 9 8 8]  Second array: [6 8]

Iteration: 5  Seed value: 1234

First array: [6 3 1 2 5 2 5 6]  Second array: [7 4]



# Numpy PRNG: Pros and Cons

We saw a few examples of how you can generate pseudo random numbers in numpy. But I am pretty sure that most of us overlook the `pros` and `cons` of these approaches. Today isn't that day. We will dive into the pros and cons right away.


## Pros

1. Setting a global seed is easy from most of the end users' perspectives. You set it once and be done with it
2. With the new generator and **[SeedSequencing](https://numpy.org/doc/stable/reference/random/bit_generators/generated/numpy.random.SeedSequence.html#numpy.random.SeedSequence)**, it is possible to produce repeatable pseudo-random numbers across multiple processes (local or distributed)
3. **Sequential Equivalent Guarantee**: One of the good things about random number generation in `numpy` is that it ensures sequential equivalent guarantee. What does that mean? It means that whether you sample a vector of `n` elements at once, or sample `n` elements but one at a time, the final sequence will always be the same. Let's see this one in action

In [5]:
# Set the seed
seed = 1234
np.random.seed(seed)

# Sample a vector of size 10 
array1 = np.random.randint(0, 10, size=10)

# Sample 10 elements one at a time
np.random.seed(seed)
array2 = np.stack([np.random.randint(0, 10) for _ in range(10)])

print(f"Sampled all at once    => {array1}")
print(f"Sampled one at a time  => {array2}")

Sampled all at once    => [3 6 5 4 8 9 1 7 9 6]
Sampled one at a time  => [3 6 5 4 8 9 1 7 9 6]


## Cons

1. Global state is bad for reproducibility: Global state is problematic especially if you are going to implement some sort of concurrency in your code. That's why the original way to set global seed in numpy isn't encouraged anymore
2. With a shared global state, it’s hard to reason about how it’s being used and updated across different threads, processes, and devices, and it’s very easy to screw up when the details of entropy production and consumption are hidden from the end-user.
3. The **Mersenne Twister PRNG** used in most of the python and numpy code has several [initialization issues](https://dl.acm.org/doi/10.1145/1276927.1276928)
4. `SeedSequencing` makes it easy to get a reproducible sequence of random numbers when concurrency is involved but it still can't be used for JAX (we will see later in why exactly!) 

Let's take an example of `SeedSequencing` as well before we move to JAX PRNG design

In [6]:
def get_sequence(seed, size=5):
    rng = np.random.default_rng(seed)
    array = np.arange(10)
    return rng.choice(array, size=size)

In [7]:
# Instantiate SeedSequence
seed = 1234
ss = np.random.SeedSequence(seed)

# Spawn 2 child seed sequence
child_seeds = ss.spawn(2)

# Run the function a few times in parallel to check if we get
# same RNG sequence
for i in range(5):
    res = []
    for child_seed in child_seeds:
        res.append(delayed(get_sequence)(child_seed))
    res = Parallel(n_jobs=2)(res)
    print(f"Iteration: {i+1} Sequences: {res}")
    print("="*70)

Iteration: 1 Sequences: [array([4, 5, 4, 2, 5]), array([7, 7, 7, 5, 1])]
Iteration: 2 Sequences: [array([4, 5, 4, 2, 5]), array([7, 7, 7, 5, 1])]
Iteration: 3 Sequences: [array([4, 5, 4, 2, 5]), array([7, 7, 7, 5, 1])]
Iteration: 4 Sequences: [array([4, 5, 4, 2, 5]), array([7, 7, 7, 5, 1])]
Iteration: 5 Sequences: [array([4, 5, 4, 2, 5]), array([7, 7, 7, 5, 1])]


# Random Numbers in JAX

RNG in JAX is very different from RNG in numpy. A question that naturally comes to mind is this: Why would the JAX team implement a whole new PRNG in JAX when they could have just reused the same codebase from numpy? ¯\_(ツ)_/¯

Let's take a few examples to answer that question

Execution of functions that use numpy code is enforced by Python. Let's say `A`, and `B` are two functions. The return values from `A` and `B` are assigned to `C`. So, the code looks like this: `C = A() + B()`

In [8]:
# Global seed
np.random.seed(1234)

def A():
    return np.random.choice(["a", "A"])

def B():
    return np.random.choice(["b", "B"])

for i in range(2):
    C = A() + B()
    print(f"Iteration: {i+1}  C: {C}")

Iteration: 1  C: AB
Iteration: 2  C: aB


Here the execution has a defined order. `A()` is always called before `B()`. But if you do the same thing in JAX (although JAX doesn't allow string type, this is just for the sake of an example) and `jit` it, then you don't know whether `A()` will be called first or `B()` will be called first. Why?

1. XLA will execute them in the order that is most efficient not necessarily in the same order. Remember `tf.control_dependencies(...)` that we used to use in the old days? Nothing wrong with TensorFlow, it's just a way to instruct the compiler
2. If you force the order of execution, then it contradicts the philosophy of JAX that if two transformations are independent of each other, then their execution can be parallelized.


This looks like a crisis. How? If you use a global state (as in numpy), you won't be able to infer which function was called first, hence the sequence of generated random numbers is irreproducible. What's the solution then?

## RNG Design in JAX

To make sure that we can parallelize the transformations, and still get reproducible results, JAX applies two rules:
1. Don't depend on the global seed for generating random sequences
2. Random functions should explicitly consume a state(seed), this will ensure that these functions would reproduce the same result when the same seed. This can have some weird effects as well which we will see in a moment

Let's take a few examples of how `state` is passed to random functions in JAX.

**Note:** When people say `state`, `seed`, or `key` in the context of PRNG, they mean the same thing (unless it is something different). JAX uses the word `key` and `subkey` more often than the word `seed`. To keep it consistent with the docs, we will use the same terminology here

In [9]:
from jax import random

In [10]:
# Define a state
seed = 1234
key = random.PRNGKey(1234)
key

DeviceArray([   0, 1234], dtype=uint32)

So, a `key` is nothing but a `DeviceArray` of shape `(2, )`. This key is then passed to random functions. **Random functions consume the state but don't change it**, meaning if you keep passing the same key to the same function, it will always return the same output.

Because functions don't change the state, ever ytime we call a new random function, we need to pass a new key. How is the new key generated? By splitting the original key. Take a look at the example below

In [11]:
# Passing the original key to a random function
random_integers = random.randint(key=key, minval=0, maxval=10, shape=[5])
print(random_integers)

[2 4 9 9 4]


In [12]:
# What if we want to call another function?
# Don't use the same key. Split the original key, and then pass it
print("Original key: ", key)

# Split the key. By default the number of splits is set to 2
# You can specify explicitly how many splits you want to do
key, subkey = random.split(key, num=2)

print("New key: ",  key)
print("Subkey: ", subkey)

Original key:  [   0 1234]
New key:  [2113592192 1902136347]
Subkey:  [603280156 445306386]


In [13]:
# Call another random function with the new key
random_floats = random.normal(key=key, shape=(5,), dtype=jnp.float32)
print(random_floats)

[ 5.2179128e-01  1.4659788e-03 -5.9906763e-01 -3.9343226e-01
 -1.9224551e+00]


**Note:** Although we are calling them `key` and `subkey`, both are states and you can pass either of them to any random function or even the `split` function

# JAX PRNG: Pros and Cons

Now that we have seen the design of PRNG in JAX and how it is implemented and consumed, it is time to discuss the `pros` and `cons` of this approach. **After all, anything and everything has pros and cons**


## Pros

1. The JAX PRNG is a counter-based PRNG design and it uses the **Threefry hash function**. This design allows JAX to escape sequential execution order constraint, allowing everything to be vectorizable and to be parallelizable without giving up on reproducibility
2. Every random function consumes the state but doesn't change it. Neither the `key` has to be returned from the function
3. The `split` method is **deterministic**. So, if you start with a random key, and split it into `n` keys in your code, you can be assured that every time you run the code, you will get the same splits. We will see an example of this right away
4. You can generate `n` number of keys from a key in a single go and keep passing them around

In [14]:
# Splitting is deterministic!

for i in range(5):
    key = random.PRNGKey(1234)
    print(f"Iteration: {i+1}\n")
    print(f"Original key: {key}")
    key, subkey = random.split(key)
    print(f"First subkey: {key}")
    print(f"Second subkey: {subkey}")
    print("="*50)
    print("")

Iteration: 1

Original key: [   0 1234]
First subkey: [2113592192 1902136347]
Second subkey: [603280156 445306386]

Iteration: 2

Original key: [   0 1234]
First subkey: [2113592192 1902136347]
Second subkey: [603280156 445306386]

Iteration: 3

Original key: [   0 1234]
First subkey: [2113592192 1902136347]
Second subkey: [603280156 445306386]

Iteration: 4

Original key: [   0 1234]
First subkey: [2113592192 1902136347]
Second subkey: [603280156 445306386]

Iteration: 5

Original key: [   0 1234]
First subkey: [2113592192 1902136347]
Second subkey: [603280156 445306386]



In [15]:
# You can generate multiple keys at one go with one split
key = random.PRNGKey(111)
print(f"Original key: {key}\n")

subkeys = random.split(key, num=5)

for i, subkey in enumerate(subkeys):
    print(f"Subkey no: {i+1}  Subkey: {subkey}")

Original key: [  0 111]

Subkey no: 1  Subkey: [2149343144 3788759061]
Subkey no: 2  Subkey: [1263116805 2203640444]
Subkey no: 3  Subkey: [ 260051842 2161001049]
Subkey no: 4  Subkey: [ 450316230 2080109636]
Subkey no: 5  Subkey: [2532194002 3516360950]


## Cons

1. The new PRNG design in JAX is only possible if we give up **Sequential Equivalent Guarantee**. Why? Because that property is incompatible with vectorization, the latter one is, in fact, a priority for JAX
2. This is not a con as such but this is something an end-user can easily forget. Two things to consider here:
    * If you call a function again and again with the same key, you will **always** get the same output. Consider that you want to sample 5 random numbers from a uniform distribution. If you pass the same key to your sampling function, you will end up with 5 duplicate numbers.
    * If you pass the same key to different functions, in some cases you will get highly correlated results. The end-user should always split the key before passing to anything that uses a random function in whatsoever sense
    
Let's take an example of each to clarify these points

In [16]:
# No more Sequential Equivalent Guarantee unlike numpy

key = random.PRNGKey(1234)
random_integers_1 = random.randint(key=key, minval=0, maxval=10, shape=(5,))

key = random.PRNGKey(1234)
key, *subkeys = random.split(key, 5)
random_integers_2 = []

for subkey in subkeys:
    num = random.randint(key=subkey, minval=0, maxval=10, shape=(1,))
    random_integers_2.append(num)

random_integers_2 = np.stack(random_integers_2, axis=-1)[0]

print("Generated all at once: ", random_integers_1)
print("Generated sequentially: ", random_integers_2)

Generated all at once:  [2 4 9 9 4]
Generated sequentially:  [1 5 8 7]


In [17]:
# Possible highly correlated outputs. 
# Not a very good example but serves the demonstration purpose

def sampler1(key):
    return random.uniform(key=key, minval=0, maxval=1, shape=(2,))

def sampler2(key):
    return 2 * random.uniform(key=key, minval=0, maxval=1, shape=(2,))

key = random.PRNGKey(0)
sample_1 = sampler1(key=key)
sample_2 = sampler2(key=key)

print("First sample: ", sample_1)
print("Second sample: ", sample_2)

First sample:  [0.21629536 0.8041241 ]
Second sample:  [0.43259072 1.6082482 ]


Whattttttt!!! Let's try that in numpy now!

In [18]:
def sampler1():
    return np.random.uniform(low=0, high=1, size=(2,))

def sampler2():
    return 2 * np.random.uniform(low=0, high=1, size=(2,))

np.random.seed(0)
sample_1 = sampler1()
sample_2 = sampler2()

print("First sample: ", sample_1)
print("Second sample: ", sample_2)

First sample:  [0.5488135  0.71518937]
Second sample:  [1.20552675 1.08976637]


You see that in JAX code, the outputs of two samplers were highly correlated while in numpy code we didn't get that perfect correlation. **Lesson?** Unless you want the same outputs, never reuse a `key` by passing it to different random functions in JAX. **Always split the key!**

That's it for Part 6! We will dive into other important concepts in the next tutorial. Stay tuned!

# References
1. https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#rngs-and-state
2. https://github.com/google/jax/blob/main/design_notes/prng.md
3. https://numpy.org/neps/nep-0019-rng-policy.html
4. https://albertcthomas.github.io/good-practices-random-number-generators/
5. https://courses.physics.illinois.edu/phys466/fa2016/lnotes/PRNG.pdf