# Generating random numbers with JaX
*Revision: 0.1 (23/4/25)*

### Introduction and purpose

In the Bayesian machine learning course, we often rely on random numbers for various computations. For example, if $p(\mathbf{x})$ is some distribution of interest, and if can get access to a number of samples from this $p$, i.e. $\mathbf{x}^{(i)} \sim p(\mathbf{x})$ for $i = 1, 2, \dots, S$, then we can estimate most properties of the distribution $p(\mathbf{x})$ using the samples. For example, we can use a *Monte Carlo*-estimator to estimate the expected value of the random variable with distribution $p(\mathbf{x})$, i.e.

$$\mathbb{E}\left[\mathbf{x}\right] \approx \hat{\mathbf{x}} = \frac{1}{S}\sum_{i=1}^S \mathbf{x}^{(i)},$$

assuming the expected value exists. In almost every week, we have used random sampling for summarizing posterior and posterior predictive distributions, but random sampling also play a key role in several of the inference algorithms discussed in the course (MCMC, variational inference etc.). 

`JaX` has built-in support for sampling from many common distributions, e.g. Gaussians, Beta, Gamma etc., but random number generation in `JaX` is a bit different to other frameworks/packages because `JaX` requires us to specify an explicit `state` (the `key`) for the random number generator *every single time* we want `JaX` to generate a random number. It is beyond the scope of the course and this note to dive into the motivation why and to cover all technical and practical aspects of random number generator in `JaX`, but you can read more about it here https://docs.jax.dev/en/latest/random-numbers.html and here https://docs.jax.dev/en/latest/jep/263-prng.html#prng-design-jep if you are interested (but this is mean no means necessary to achieve all the learning objectives of the Bayesian Machine Learning course). Instead of aiming to cover every aspect of random number generation in `JaX`, we will in the following see an example of how to use random numbers in `JaX` and a couple of common mistakes. 

### Random numbers from a random number generator is not actual random

First, we have to remind our selves that computers generally cannot generate *truly* random numbers. When we use `jax.random.normal`, `np.random.normal` or similar, we do not get a truly random samples from a Gaussian distribution. Instead we get the output from a *deterministic* algorithm designed to generate *pseudo random numbers* that *appear* and *behave* as if they were indeed truly random numbers from a Gaussian distribution.

The algorithms behind `jax.random.normal` and `np.random.normal` depends on an initial `state`, i.e. the `seed` or `key`, and they are *100% deterministic* in the sense that they provide the same output when given the same initial `state`. Recall, in `JaX` the state of the random number generator is often called the `key` and hence, we will use the two terms (key & state) interchangably in this note.
In `numpy`, we typically specify an initial `state` for the random number generator, and then `numpy` automatically updates the `state` internally to make sure we get different numbers every time we invoke the random number generator. 

For example, the cell below generates samples from a $\mathcal{N}(0, 1)$-distribution, and if you keep the seed fixed, then running the cell below multiple times should generate the exact same output everytime:

In [25]:
import numpy as np

# specify seed
seed = 0

# generate a batch of random numbers
np.random.seed(seed)
print('First batch of random numbers:')
for i in range(3):
    x = np.random.normal()
    print(f'{x:4.3f}')

# generate another batch of random numbers with same seed
np.random.seed(seed)
print('\nSecond batch of random numbers:')
for i in range(3):
    x = np.random.normal()
    print(f'{x:4.3f}')


First batch of random numbers:
1.764
0.400
0.979

Second batch of random numbers:
1.764
0.400
0.979


Being able to specify the initial random `state` is very convenient, because it is ensures reproducibility and makes debugging simpler.

Let's now see a similar example in `JaX`:

In [53]:
import jax.numpy as jnp
import jax

# specify seed
seed = 0

# generate a batch of random numbers
key = jax.random.PRNGKey(seed)
print('First batch of random numbers:')
for i in range(3):
    x = jax.random.normal(key)
    print(f'{x:4.3f}')

# generate another batch of random numbers with same seed
key = jax.random.PRNGKey(seed)
print('\nSecond batch of random numbers:')
for i in range(3):
    x = jax.random.normal(key)
    print(f'{x:4.3f}')


First batch of random numbers:
1.623
1.623
1.623

Second batch of random numbers:
1.623
1.623
1.623


Clearly, this is not what we want. The problem is that we did not update/change the `state/key` of the random number generator, and when we provide `JaX` with the same `key`, we get the same number.

Instead, we should have done the following:


In [54]:
import jax.numpy as jnp
import jax

# specify seed
seed = 0

print('First batch of random numbers:')
key = jax.random.PRNGKey(seed)
for i in range(3):
    key, subkey = jax.random.split(key)
    x = jax.random.normal(subkey)
    print(f'{x:4.3f}')

print('\nSecond batch of random numbers:')
key = jax.random.PRNGKey(seed)
for i in range(3):
    key, subkey = jax.random.split(key)
    x = jax.random.normal(subkey)
    print(f'{x:4.3f}')


First batch of random numbers:
-2.442
-1.257
-1.388

Second batch of random numbers:
-2.442
-1.257
-1.388


Now we get the desired behavior, where we get different numbers everytime we call `jax.random.normal`, but we can still reproduce a sequence of random numbers when desired. 

A random `state/key` in `JaX` is essentially just a tuple of two integers:

In [55]:
print('key', key)

key [ 683029726 1624662641]


which controls the `state` of the random number generator. In this example above, we used the function `jax.random.split` to iteratively update the `state` of the random number generator. The function `split` essentially takes one `state/key` and splits it into two `states/keys`:



In [56]:
new_key1, new_key2 = jax.random.split(key)

print('key', key)
print('new_key1', new_key1)
print('new_key2', new_key2)

key [ 683029726 1624662641]
new_key1 [1113701576 1346130448]
new_key2 [1539457558  118255239]


Finally, using the following pattern

In [36]:
key = jax.random.PRNGKey(seed)
for i in range(3):
    key, subkey = jax.random.split(key)
    x = jax.random.normal(subkey)

ensures that we never generate random numbers based on the same `state/key`, and hence, we get the desired behavior.

Instead of splitting the `key` in every iteration, we could also have generated 3 `keys` from the beginning, i.e. creating a `list` containing 3 `keys`:

In [57]:
import jax.numpy as jnp
import jax

# specify seed
seed = 0

print('First batch of random numbers:')
key = jax.random.PRNGKey(seed)
keys = jax.random.split(key, num=3)
for i in range(3):
    x = jax.random.normal(keys[i])
    print(f'{x:4.3f}')

print('\nSecond batch of random numbers:')
key = jax.random.PRNGKey(seed)
keys = jax.random.split(key, num=3)
for i in range(3):
    x = jax.random.normal(keys[i])
    print(f'{x:4.3f}')


First batch of random numbers:
1.004
-2.442
1.296

Second batch of random numbers:
1.004
-2.442
1.296


Notice that the numbers here are different from the example above. The reason is that the sequence of generated keys are different in the two cases for technical reasons, but importantly, both implementations yield reproducible code. 

Updating/splitting the `key` in every iteration like


In [77]:
key = jax.random.PRNGKey(seed)
for i in range(3):
    key, subkey = jax.random.split(key)

has the advantage the new keys can be safely split into multiple keys in every iteration without risk of getting duplicate keys. For example, suppose now we are implementing an iterative algorithm, where we need to sample several different random variables in every iteration, then we can safely use the following pattern:


In [78]:
key = jax.random.PRNGKey(seed)
for i in range(3):
    key, key_x, key_y, key_z = jax.random.split(key, num=4)
    x = jax.random.normal(key_x)
    y = jax.random.normal(key_y)
    z = jax.random.normal(key_z)
    print(f'Iteration {i}')
    print(f'(x,y,z) = ({x:4.3f}, {y:4.3f}, {z:4.3f})\n')

Iteration 0
(x,y,z) = (-2.442, 1.296, -0.622)

Iteration 1
(x,y,z) = (-1.257, -0.744, 0.340)

Iteration 2
(x,y,z) = (-1.388, -1.251, -0.343)



Finally, it is worth mentioning that we can also generate a sequence of $N$ random numbers using a single key as follows

In [79]:
N = 3

key = random.PRNGKey(123)
x = random.normal(key, shape=(N, ))
print(f'{N} random numbers using a single key:', x)


3 random numbers using a single key: [1.6359469  0.8408094  0.02212393]


### Example: Three different implementations of a Monte Carlo estimator

We conclude with a small example showing three ways to implemenent the random number generation for estimating the mean of $X^2 + Y^2$, where $X \sim \mathcal{N}(1, 2^2)$ and $Y \sim \mathcal{N}(0, 1)$ using $N = 1000$ samples:

In [95]:
import jax.numpy as jnp
from jax import random

N = 1000

############################################################
# implementation 1: update the key iteratively
############################################################
key = random.PRNGKey(1)
values = []
for i in range(N):
    key, subkey_x, subkey_y = random.split(key, num=3)
    x = 1 + 2*random.normal(subkey_x)
    y = random.normal(subkey_y)
    values.append(x**2 + y**2)
print(f'Monte Carlo estimator for the mean of X^2 + Y^2: {jnp.mean(jnp.array(values)):4.3f}')

############################################################
# implementation 2: prepare list of keys beforehand
############################################################
key = random.PRNGKey(1)
key_x, key_y = random.split(key, num=2)
keys_x = random.split(key_x, num=N)
keys_y = random.split(key_y, num=N)
values = []
for i in range(N):
    x = 1 + 2*random.normal(keys_x[i])
    y = random.normal(keys_y[i])
    values.append(x**2 + y**2)
print(f'Monte Carlo estimator for the mean of X^2 + Y^2: {jnp.mean(jnp.array(values)):4.3f}')


############################################################
# implementation 3: vectorized implementation
############################################################
key = random.PRNGKey(1)
key_x, key_y = random.split(key, num=2)
x = 1 + 2*random.normal(key_x, shape=(N, ))
y = random.normal(key_y, shape=(N, ))
print(f'Monte Carlo estimator for the mean of X^2 + Y^2: {jnp.mean(x**2 + y**2):4.3f}')






Monte Carlo estimator for the mean of X^2 + Y^2: 6.138
Monte Carlo estimator for the mean of X^2 + Y^2: 6.225
Monte Carlo estimator for the mean of X^2 + Y^2: 5.878
