In [4]:
import jax
import jax.numpy as jnp
import numpy as np
from jax.random import PRNGKey
import haiku as hk

In [6]:
rng_key = PRNGKey(42)

In [8]:
class HkRandom2(hk.Module):
    def __init__(self, rate=0.5):
        super().__init__()
        self.rate = rate

    def __call__(self, x):
        key1 = hk.next_rng_key()
        return jax.random.bernoulli(key1, 1.0 - self.rate, shape=x.shape)


class HkRandomNest(hk.Module):
    def __init__(self, rate=0.5):
        super().__init__()
        self.rate = rate
        self._another_random_module = HkRandom2()

    def __call__(self, x):
        key2 = hk.next_rng_key()
        p1 = self._another_random_module(x)
        p2 = jax.random.bernoulli(key2, 1.0 - self.rate, shape=x.shape)
        print(f'Bernoullis are  : {p1, p2}')

# Note that the modules that are stochastic cannot be wrapped with hk.without_apply_rng()
forward = hk.transform(lambda x: HkRandomNest()(x))

x = jnp.array(1.)
params = forward.init(rng_key, x=x)
for i in range(10):
    print(f'\n Iteration {i+1}')
    prediction = forward.apply(params, x=x, rng=rng_key)

Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 1
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 2
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 3
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 4
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 5
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 6
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 7
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 8
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 9
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 10
Bernoullis are  : (DeviceArray(True, dt