**S01P07_random_numbers.ipynb**

Arz

2024 APR 06 (SAT)

reference:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# random numbers

## NumPy: PRNG

In [4]:
np.random.random()

0.35283432881870713

In [5]:
np.random.seed(0)
rng_state = np.random.get_state()

In [6]:
# print(rng_state)

## JAX: PRNG

In [7]:
key = random.key(0)

key

Array((), dtype=key<fry>) overlaying:
[0 0]

the same key produces the same result.

In [8]:
print(random.normal(key, shape=(1,)))
print(key)

print(random.normal(key, shape=(1,)))
print(key)

[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]
[-0.20584226]
Array((), dtype=key<fry>) overlaying:
[0 0]


so split a key into subkeys to generate different random numbers.

In [9]:
print("key before split:", key)
print("  number:", random.normal(key, shape=(1,)), "\n")

key, subkey = random.split(key)
print("key after split:", key)
print("subkey:", subkey)
print("  number:", random.normal(subkey, shape=(1,)))

key before split: Array((), dtype=key<fry>) overlaying:
[0 0]
  number: [-0.20584226] 

key after split: Array((), dtype=key<fry>) overlaying:
[4146024105  967050713]
subkey: Array((), dtype=key<fry>) overlaying:
[2718843009 1272950319]
  number: [-1.2515389]


propagate key and split it to produce subkeys that will be used to generate random numbers.

In [10]:
print("key before split:", key)
print("  number:", random.normal(key, shape=(1,)), "\n")

key, subkey = random.split(key)
print("key after split:", key)
print("subkey:", subkey)
print("  number:", random.normal(subkey, shape=(1,)))

key before split: Array((), dtype=key<fry>) overlaying:
[4146024105  967050713]
  number: [0.14389051] 

key after split: Array((), dtype=key<fry>) overlaying:
[2384771982 3928867769]
subkey: Array((), dtype=key<fry>) overlaying:
[1278412471 2182328957]
  number: [-0.58665055]


### how to produce >1 subkeys at a time

ex) 3 subkeys

In [11]:
key, *subkeys = random.split(key, 4)
print("key after split:", key)
print("subkeys:", subkeys, "\n")

for subkey in subkeys:
    print(random.normal(subkey, shape=(1,)))

key after split: Array((), dtype=key<fry>) overlaying:
[1594945422 1369375165]
subkeys: [Array((), dtype=key<fry>) overlaying:
[2931675882 1444655455], Array((), dtype=key<fry>) overlaying:
[2994431502 1854917485], Array((), dtype=key<fry>) overlaying:
[2303906914 4183882777]] 

[-0.37533438]
[0.98645043]
[0.14553197]
