In [None]:
import numpy as np
from jax import random

## Random numbers in Numpy

3 rules for JAX random numbers:


*   1- reproducible
*   2- parallelizable
*   3- vectorisable



In [None]:
np.random.seed(0)

In [None]:
def print_truncated_random_state():
  """To avoid spamming the outputs, print only part of the state."""
  full_random_state = np.random.get_state()
  print(str(full_random_state)[:460], '...')

In [None]:
print_truncated_random_state()

In [None]:
print_truncated_random_state()

In [None]:
_ = np.random.uniform()
print_truncated_random_state()

In [None]:
print(np.random.uniform(size=3))

In [None]:
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))

In [None]:
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))

## Random numbers in JAX

In [None]:
np.random.seed(0)

In [None]:
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()

In [None]:
print(foo())

In [None]:
key = random.PRNGKey(40)
print(key)

In [None]:
print(random.normal(key))
print(random.normal(key))

The rule of thumb is: never reuse keys (unless you want identical outputs).

In [None]:
print("Old key: ", key)

new_key, subkey = random.split(key)
del key  # The old key is discarded -- we must never use it again.
normal_sample = random.normal(subkey)

print(r"    \---SPLIT --> new key   ", new_key)
print(r"             \--> new subkey", subkey, "--> normal", normal_sample)

del subkey  # The subkey is also discarded after use.

# Note: you don't actually need to `del` keys -- that's just for emphasis.
# Not reusing the same values is enough.

key = new_key  # If we wanted to do this again, we would use new_key as the key.

split() is a deterministic function that converts one key into several independent (in the pseudorandomness sense) keys.

In [None]:
key, subkey = random.split(key) #which discards the old key automatically.

In [None]:
key, *forty_two_subkeys = random.split(key, num=43)

In [None]:
key = random.PRNGKey(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.PRNGKey(42)
print("all at once: ", random.normal(key, shape=(3,)))