# Pseudorandom Numbers in JAX

https://jax.readthedocs.io/en/latest/random-numbers.html

In `jax.numpy`, pseudorandom number generation requires passing a `PRNGKey`, otherwise an error will occur.

In [1]:
from jax import random

key = random.PRNGKey(12)

for i in range(3):
  new_key, subkey = random.split(key)
  del key  # The old key is consumed by split() -- we must never use it again.

  val = random.normal(subkey)
  del subkey  # The subkey is consumed by normal().

  print(f"draw {i}: {val}")
  key = new_key  # new_key is safe to use in the next iteration

draw 0: 0.735927939414978
draw 1: 1.810187578201294
draw 2: 0.11553369462490082
