In [1]:
import jax
import numpy as np
import jax.numpy as jnp

import haiku as hk

In [2]:
key = jax.random.PRNGKey(10)

In [3]:
class MyLinear1(hk.Module):
    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size

    def __call__(self, x):
        j, k = x.shape[-1], self.output_size
        w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
        w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
        b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
        return jnp.dot(x, w) + b
    
def _forward_fn_linear1(x):
    module = MyLinear1(output_size=2)
    return module(x)

forward_linear1 = hk.transform(_forward_fn_linear1)

In [4]:
dummy_x = jnp.array([[1., 2., 3.]])
rng_key = jax.random.PRNGKey(42)

params = forward_linear1.init(rng=rng_key, x=dummy_x)
print(params)

FlatMap({
  'my_linear1': FlatMap({
                  'w': DeviceArray([[-0.30350366,  0.5123803 ],
                                    [ 0.08009139, -0.3163005 ],
                                    [ 0.60566676,  0.5820702 ]], dtype=float32),
                  'b': DeviceArray([1., 1.], dtype=float32),
                }),
})
