**S02P05_tutorial_pseudo_random_numbers_in_jax.ipynb**

Arz

2024 APR 11 (THU)

reference:
https://jax.readthedocs.io/en/latest/jax-101/05-random-numbers.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit, vmap
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# random numbers in NumPy

In [13]:
np.random.seed(0)

In [14]:
def print_truncated_random_state(n):
    full_random_state = np.random.get_state()
    print(str(full_random_state)[:n], "...")

In [15]:
print_truncated_random_state(77)

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044 ...


state is updated by each call to a random function. 

In [16]:
print_truncated_random_state(77)

_ = np.random.uniform()  # a random function

print_truncated_random_state(77)

('MT19937', array([         0,          1, 1812433255, 1900727105, 1208447044 ...
('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660 ...


In [19]:
np.random.seed(0)  # initialize the random state
print(np.random.uniform(size=3))

[0.5488135  0.71518937 0.60276338]


## sequential equivalent guarantee

one-by-one and all-at-once results in the same.

In [20]:
np.random.seed(0)
print("one by one :", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(0)
print("all at once:", np.random.uniform(size=3))

one by one : [0.5488135  0.71518937 0.60276338]
all at once: [0.5488135  0.71518937 0.60276338]


# random numbers in JAX

a JAX code must be:

- 1) reproducible
  
- 2) parallelizable
 
- 3) vectorizable  

## NumPy assumes executions in order

because the random state is global and changes at each call to a NumPy random function.

In [21]:
np.random.seed(0)

def f(): return np.random.uniform()
def g(): return np.random.uniform()

def h(): return f() + 2*g()

In [22]:
print(h())

1.9791922366721637


## JAX introduces *key*

- the random state is not global
- a JAX random function explicitly consumes a state associated to a key

- a key has a form: (n,) where n is a scalar.

In [29]:
key = random.key(7)

print(key)

Array((), dtype=key<fry>) overlaying:
[0 7]


the random state is not changed by a random function call,

as long as the key remains the same.

In [30]:
print(random.normal(key))

print(random.normal(key))

1.0114812
1.0114812


### rule of thumb: ⚠️ never reuse keys

note: feeding the same key to different random functions can result in correlated outputs, which is generally undesirable.

- use **split()** to generate different and independent samples.
    - good idea to follow JAX convention:
        - keep **key** to feed **split()** to generate more randomnesses.
        - feed **subkey** to a random function.

In [31]:
print("key before split:", key)
print("  number:", random.normal(key, shape=(1,)), "\n")

new_key, subkey = random.split(key)
del key  # discard used key, never use it again 
print("key after split:", new_key)
print("subkey:", subkey)
print("  number:", random.normal(subkey, shape=(1,)))

# note: you don't actually need to `del` keys. it's just for emphasis.

key before split: Array((), dtype=key<fry>) overlaying:
[0 7]
  number: [1.0114812] 

key after split: Array((), dtype=key<fry>) overlaying:
[966301609 989289821]
subkey: Array((), dtype=key<fry>) overlaying:
[1948237315 1058547403]
  number: [-1.4622003]


### how to produce >1 subkeys at a time

ex) 3 subkeys

In [33]:
key = random.key(7)

In [34]:
key, *subkeys = random.split(key, 4)
print("key after split:", key)
print("subkeys:", subkeys, "\n")

for subkey in subkeys:
    print(random.normal(subkey, shape=(1,)))

key after split: Array((), dtype=key<fry>) overlaying:
[  51277378 1628829261]
subkeys: [Array((), dtype=key<fry>) overlaying:
[1440439266  395909871], Array((), dtype=key<fry>) overlaying:
[3088387524 4291721531], Array((), dtype=key<fry>) overlaying:
[3731608162 3705585371]] 

[0.3386509]
[0.23955461]
[-0.07911882]


## sequential equivalent guarantee does not hold

because doing so would interfere with the vectorization on SIMD hardware.
(for requirement #3: vectorizable)

In [36]:
key = random.key(7)
subkeys = random.split(key, 3)
random_numbers = np.stack([random.normal(subkey) for subkey in subkeys])
print("one by one :", random_numbers)

key = random.key(7)
print("all at once:", random.normal(key, shape=(3,)))

one by one : [-0.4787308  -0.15271705  0.47495216]
all at once: [-0.75546646 -0.18615817 -0.11654735]
