In [15]:
import jax
import jax.numpy as jnp
import flax

In [3]:
jax.distributed.initialize()

In [4]:
jax.devices()

[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0),
 TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0),
 TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0),
 TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0)]

In [20]:
@jax.jit
def test(x, rng):
    rngs, subkey = jax.random.split(rng)
    y = jax.random.normal(subkey, x.shape)
    return y, rngs

x = jnp.ones((4, 1))
rngs = jax.random.PRNGKey(0)
print(rngs)

test(x, rngs)

[0 0]


(Array([[ 1.1378773 ],
        [-0.14331432],
        [-0.5915394 ],
        [ 0.7946691 ]], dtype=float32),
 Array([4146024105,  967050713], dtype=uint32))

In [21]:

@jax.pmap
def test(x, rng):
    rngs, subkey = jax.random.split(rng)
    y = jax.random.normal(subkey, x.shape)
    return y, rngs

x = jnp.ones((4, 1))
rngs = jax.random.PRNGKey(0)
rngs = flax.jax_utils.replicate(rngs)
print(rngs)
test(x, rngs)

[[0 0]
 [0 0]
 [0 0]
 [0 0]]
Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)> Traced<ShapedArray(uint32[2])>with<DynamicJaxprTrace(level=0/1)>


(Array([[-1.2515285],
        [-1.2515285],
        [-1.2515285],
        [-1.2515285]], dtype=float32),
 Array([[4146024105,  967050713],
        [4146024105,  967050713],
        [4146024105,  967050713],
        [4146024105,  967050713]], dtype=uint32))

In [40]:

@jax.pmap
def test(x, rng, indexes):
    rngs, subkey = jax.random.split(rng)
    subkey = jax.random.fold_in(subkey, indexes)
    y = jax.random.normal(subkey, x.shape)
    return y, rngs

x = jnp.ones((4, 1))
rngs = jax.random.PRNGKey(0)
rngs = flax.jax_utils.replicate(rngs)
indexes = jnp.arange(4)
print(rngs)
test(x, rngs, indexes)

[[0 0]
 [0 0]
 [0 0]
 [0 0]]


(Array([[ 0.32595065],
        [-0.6241612 ],
        [ 0.7319309 ],
        [-0.865844  ]], dtype=float32),
 Array([[4146024105,  967050713],
        [4146024105,  967050713],
        [4146024105,  967050713],
        [4146024105,  967050713]], dtype=uint32))

In [35]:
jax.random.fold_in(jax.random.PRNGKey(0), 1)

Array([ 928981903, 3453687069], dtype=uint32)

In [27]:

class MarkovState(flax.struct.PyTreeNode):
    pass

class RandomMarkovState(MarkovState):
    rng: jax.random.PRNGKey

    def get_random_key(self):
        rng, subkey = jax.random.split(self.rng)
        return RandomMarkovState(rng), subkey

In [32]:
a = RandomMarkovState(jax.random.PRNGKey(0))
b = RandomMarkovState(jax.random.PRNGKey(1))

a.get_random_key(), b.get_random_key()

((RandomMarkovState(rng=Array([4146024105,  967050713], dtype=uint32)),
  Array([2718843009, 1272950319], dtype=uint32)),
 (RandomMarkovState(rng=Array([2441914641, 1384938218], dtype=uint32)),
  Array([3819641963, 2025898573], dtype=uint32)))