In [None]:
import numpy as np

SEED = 1234

# Randomness in `numpy`

**Note**: You should always get the same numbers in the same order for every consequent call to `generate_np_weights`

In [None]:
np.random.seed(SEED)

def generate_np_weights():
    return np.random.normal(0, 1, (3, 2))


generate_np_weights()  # 1st call

In [None]:
generate_np_weights()  # 2nd call

# Statefulness issue illustrated:

Refer to the blogpost

In [None]:
def new_generate_np_weights():
    # Reset the seed in the context of this function
    np.random.seed(SEED)
    return np.random.normal(0, 1, (3, 2))


new_generate_np_weights()  # 1st reset call

## Continuing our original random sequence

wait, what just happened? 

In [None]:
generate_np_weights()  #"3rd" call

# Jax simple examples

This is just here for illustration. It's rather trivial, really.

In [None]:
import jax
import jax.numpy as jnp
import jax.random as jrandom

In [None]:
main_key = jrandom.PRNGKey(SEED)
def generate_jax_weights(_key):
    return jrandom.normal(_key, shape=(3,2))

In [None]:
main_key

In [None]:
main_key, subkey = jrandom.split(main_key)

generate_jax_weights(subkey)

In [None]:
generate_jax_weights(subkey)

In [None]:
main_key, subkey = jrandom.split(main_key)

generate_jax_weights(subkey)