In [None]:
import numpy as np
from numpy.random import default_rng
import scipy.stats as st
import matplotlib.pyplot as plt

In [None]:
# Let's start with a home made RNG
# It will be of the linear congruential generator variety 

In [None]:
def rng(size = 10, seed = 1, m = 2**32, a = 1664525, c = 1013904223):
    rng_current = seed
    uni_rns = np.zeros((size))
    for i in range(size):
        rng_current = (a * rng_current + c) % m
        uni_rns[i] = rng_current/m
    return uni_rns

In [None]:
# lets get us a 10000 samples from Uniform(0,1). The seed is 42
u = rng(10000, 42)

In [None]:
fig, ax = plt.subplots(1, 1)
ax.hist(u, bins=100)
ax.set_ylabel('Bin counts')
ax.set_title("Histogram");

In [None]:
# Lets use the Python built in generator 

In [None]:
bi_rng = default_rng(42)
w = bi_rng.random(10000)

In [None]:
fig1, ax1 = plt.subplots(1, 1)
ax1.hist(w, bins=100)
ax1.set_ylabel('Bin counts')
ax1.set_title("Histogram");

However, JAX cannot use a global PRNG state!

In [None]:
np.random.seed(0)

def foo(): return np.random.uniform()
def bar(): return np.random.uniform()

def fooba(): return 3*foo() + 2 * bar()

print(fooba())

To make the result here reproducible we have to enforce a specific order of the execution (as numpy does). But JAX needs to paralelize foo() and bar().

To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a key (random key is the same thing as random seed).

In [None]:
# Install jax here if not already installed
# !pip install --upgrade -q "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

In [None]:
from jax import random

key = random.PRNGKey(42)

print(key)

In [None]:
print(random.normal(key))
print(random.normal(key))

In JAX, in order to generate different and independent samples, you must split() the key yourself whenever you want to call a random function. split() is a deterministic function that converts one key into several independent (in the pseudorandomness sense) keys.

In [None]:
key1, key2 = random.split(key)

In [None]:
print(f"Key 1: {key1}; Key 2: {key2}")

Final Note: NumPy provides a sequential equivalent guarantee, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences. JAX does not provide such a sequential equivalence guarantee.