# DM Haiku

## Basics
https://dm-haiku.readthedocs.io/en/latest/notebooks/basics.html

#### intro

In [1]:
import haiku as hk
import jax
import jax.numpy as jnp
import numpy as np

In [2]:
class MyLinear1(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.ones)
    return jnp.dot(x, w) + b

In [3]:
def _forward_fn_linear1(x):
  module = MyLinear1(output_size=2)
  return module(x)

forward_linear1 = hk.transform(_forward_fn_linear1) # return transformed object: init, apply

In [5]:
dummy_x = jnp.array([[1., 2., 3.]])
rng_key = jax.random.PRNGKey(42)

params = forward_linear1.init(rng=rng_key, x=dummy_x)
print(params)

{'my_linear1': {'w': DeviceArray([[-0.30350366,  0.5123803 ],
             [ 0.08009139, -0.3163005 ],
             [ 0.60566676,  0.5820702 ]], dtype=float32), 'b': DeviceArray([1., 1.], dtype=float32)}}


In [6]:
sample_x = jnp.array([[1., 2., 3.]])
sample_x_2 = jnp.array([[4., 5., 6.], [7., 8., 9.]])

output_1 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)
# Outputs are identical for given inputs since the forward inference is non-stochastic.
# , `apply` don't need rng_key 
output_2 = forward_linear1.apply(params=params, x=sample_x, rng=rng_key)

output_3 = forward_linear1.apply(params=params, x=sample_x_2, rng=rng_key)

print(f'Output 1 : {output_1}')
print(f'Output 2 (same as output 1): {output_2}')
print(f'Output 3 : {output_3}')

Output 1 : [[2.6736794 2.62599  ]]
Output 2 (same as output 1): [[2.6736794 2.62599  ]]
Output 3 : [[3.819336  4.9589844]
 [4.965576  7.2924805]]


In [7]:
forward_without_rng = hk.without_apply_rng(hk.transform(_forward_fn_linear1))
params = forward_without_rng.init(rng=rng_key, x=sample_x)
output = forward_without_rng.apply(x=sample_x, params=params)
print(f'Output without random key in forward pass \n {output_1}')

Output without random key in forward pass 
 [[2.6736794 2.62599  ]]


In [8]:
mutated_params = jax.tree_map(lambda x: x+1., params)
print(f'Mutated params \n : {mutated_params}')
mutated_output = forward_without_rng.apply(x=sample_x, params=mutated_params)
print(f'Output with mutated params \n {mutated_output}')

Mutated params 
 : {'my_linear1': {'b': DeviceArray([2., 2.], dtype=float32), 'w': DeviceArray([[0.69649637, 1.5123804 ],
             [1.0800914 , 0.6836995 ],
             [1.6056668 , 1.5820701 ]], dtype=float32)}}
Output with mutated params 
 [[9.673679 9.62599 ]]


#### Stateful: `hk.get_state` & `hk.transform_with_state`
carrying some internal state across function calls

In [9]:
def stateful_f(x):
  counter = hk.get_state("counter", shape=[], dtype=jnp.int32, init=jnp.ones)
  multiplier = hk.get_parameter('multiplier', shape=[1,], dtype=x.dtype, init=jnp.ones)
  hk.set_state("counter", counter + 1)
  output = x + multiplier * counter
  return output

stateful_forward = hk.without_apply_rng(hk.transform_with_state(stateful_f))
sample_x = jnp.array([[5., ]])
params, state = stateful_forward.init(x=sample_x, rng=rng_key)
print(f'Initial params:\n{params}\nInitial state:\n{state}')
print('##########')
for i in range(3):
  output, state = stateful_forward.apply(params, state, x=sample_x)
  print(f'After {i+1} iterations:\nOutput: {output}\nState: {state}')
  print('##########')

Initial params:
{'~': {'multiplier': DeviceArray([1.], dtype=float32)}}
Initial state:
{'~': {'counter': DeviceArray(1, dtype=int32)}}
##########
After 1 iterations:
Output: [[6.]]
State: {'~': {'counter': DeviceArray(2, dtype=int32)}}
##########
After 2 iterations:
Output: [[7.]]
State: {'~': {'counter': DeviceArray(3, dtype=int32)}}
##########
After 3 iterations:
Output: [[8.]]
State: {'~': {'counter': DeviceArray(4, dtype=int32)}}
##########


#### Built-in nets, modules

In [10]:
# See: https://dm-haiku.readthedocs.io/en/latest/api.html#common-modules

class MyModuleCustom(hk.Module):
  def __init__(self, output_size=2, name='custom_linear'):
    super().__init__(name=name)
    self._internal_linear_1 = hk.nets.MLP(output_sizes=[2, 3], name='hk_internal_linear')
    self._internal_linear_2 = MyLinear1(output_size=output_size, name='old_linear')

  def __call__(self, x):
    return self._internal_linear_2(self._internal_linear_1(x))

def _custom_forward_fn(x):
  module = MyModuleCustom()
  return module(x)

custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn))
params = custom_forward_without_rng.init(rng=rng_key, x=sample_x)
params

{'custom_linear/~/hk_internal_linear/~/linear_0': {'w': DeviceArray([[ 1.5159501 , -0.23353338]], dtype=float32),
  'b': DeviceArray([0., 0.], dtype=float32)},
 'custom_linear/~/hk_internal_linear/~/linear_1': {'w': DeviceArray([[-0.22075887, -0.2737596 ,  0.5931483 ],
               [ 0.78180665,  0.72626317, -0.6860752 ]], dtype=float32),
  'b': DeviceArray([0., 0., 0.], dtype=float32)},
 'custom_linear/~/old_linear': {'w': DeviceArray([[ 0.28584382,  0.31626165],
               [ 0.23357746, -0.4827032 ],
               [-0.14647584, -0.71857005]], dtype=float32),
  'b': DeviceArray([1., 1.], dtype=float32)}}

#### RNGs

In [14]:
class HkRandom2(hk.Module):
  def __init__(self, rate=0.5):
    super().__init__()
    self.rate = rate

  def __call__(self, x):
    key1 = hk.next_rng_key()
    print("key1", key1)
    return jax.random.bernoulli(key1, 1.0 - self.rate, shape=x.shape)


class HkRandomNest(hk.Module):
  def __init__(self, rate=0.5):
    super().__init__()
    self.rate = rate
    self._another_random_module = HkRandom2()

  def __call__(self, x):
    key2 = hk.next_rng_key()
    print("key2",key2)
    print("next_key2", hk.next_rng_key())
    p1 = self._another_random_module(x)
    p2 = jax.random.bernoulli(key2, 1.0 - self.rate, shape=x.shape)
    print(f'Bernoullis are  : {p1, p2}')

# Note that the modules that are stochastic cannot be wrapped with hk.without_apply_rng()
forward = hk.transform(lambda x: HkRandomNest()(x))

x = jnp.array(1.)
params = forward.init(rng_key, x=x)
for i in range(5):
  print(f'\n Iteration {i+1}')
  prediction = forward.apply(params, x=x, rng=rng_key)

key2 [255383827 267815257]
next_key2 [3923418436 1366451097]
key1 [1371681402 3011037117]
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 1
key2 [255383827 267815257]
next_key2 [3923418436 1366451097]
key1 [1371681402 3011037117]
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 2
key2 [255383827 267815257]
next_key2 [3923418436 1366451097]
key1 [1371681402 3011037117]
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 3
key2 [255383827 267815257]
next_key2 [3923418436 1366451097]
key1 [1371681402 3011037117]
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 4
key2 [255383827 267815257]
next_key2 [3923418436 1366451097]
key1 [1371681402 3011037117]
Bernoullis are  : (DeviceArray(True, dtype=bool), DeviceArray(False, dtype=bool))

 Iteration 5
key2 [255383827 267815257]
next_key2 [3923418436 1366451097]
key1 [137

## Limitation of Nesting Haiku and Jax
  
https://dm-haiku.readthedocs.io/en/latest/notebooks/transforms.html


TL;DR: A JAX transform inside of a hk.transform is likely to transform a side effecting function, which will result in an UnexpectedTracerError. This page describes two ways to get around this.  

Once a Haiku network has been transformed to a pair of pure functions using hk.transform, it’s possible to freely combine these with any JAX transformations like jax.jit, jax.grad, jax.scan and so on.

<table class="docutils align-default">
  <colgroup>
  <col style="width: 35%">
  <col style="width: 24%">
  <col style="width: 41%">
  </colgroup>
  <thead>
  <tr class="row-odd"><th class="head"><p>What?</p></th>
  <th class="head"><p>Works?</p></th>
  <th class="head"><p>Example</p></th>
  </tr>
  </thead>
  <tbody>
  <tr class="row-even"><td><p>vmapping outside a hk.transform</p></td>
  <td><p style="color: green;">✔ yes!</p></td>
  <td><p>jax.vmap(hk.transform(hk.nets.ResNet50))</p></td>
  </tr>
  <tr class="row-odd"><td><p>vmapping inside a hk.transform</p></td>
  <td><p style="color: red;">✖ unexpected tracer error</p></td>
  <td><p>hk.transform(jax.vmap(hk.nets.ResNet50))</p></td>
  </tr>
  <tr class="row-even"><td><p>vmapping a nested hk.transform (without lift)</p></td>
  <td><p style="color: red;">✖ inner state is not registered</p></td>
  <td><p>hk.transform(jax.vmap(hk.transform(hk.nets.ResNet50)))</p></td>
  </tr>
  <tr class="row-odd"><td><p>vmapping a nested hk.transform (with lift)</p></td>
  <td><p style="color: green;">✔ yes!</p></td>
  <td><p>hk.transform(jax.vmap(hk.lift(hk.transform(hk.nets.ResNet50))))</p></td>
  </tr>
  <tr class="row-even"><td><p>using hk.vmap</p></td>
  <td><p style="color: green;">✔ yes!</p></td>
  <td><p>hk.transform(hk.vmap(hk.nets.ResNet50))</p></td>
  </tr>
  </tbody>
</table>

In [15]:
def net(x): # inside of a hk.transform, this is still side-effecting
  w = hk.get_parameter("w", (2, 2), init=jnp.ones) # same goes with next_rng_key
  return w @ x

def eval_shape_net(x):
  output_shape = jax.eval_shape(net, x) # eval_shape on side-effecting function
  return net(x)                         # UnexpectedTracerError!

init, _ = hk.transform(eval_shape_net)
try:
  init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
  print("UnexpectedTracerError: applied JAX transform to side effecting function")

UnexpectedTracerError: applied JAX transform to side effecting function


These examples use `jax.eval_shape`, but could have used any higher-order JAX function (eg. `jax.vmap`, `jax.scan`, `jax.while_loop`, …).

In [16]:
def net(w, x): # no side effects! (PURE function)
  return w @ x

def eval_shape_net(x):
  w = hk.get_parameter("w", (3, 2), init=jnp.ones)
  output_shape = jax.eval_shape(net, w, x) # net is now side-effect free
  return output_shape, net(w, x)

key = jax.random.PRNGKey(777)
x = jnp.ones((2, 3))
init, apply = hk.transform(eval_shape_net)
params = init(key, x)
apply(params, key, x)

(ShapeDtypeStruct(shape=(3, 3), dtype=float32),
 DeviceArray([[2., 2., 2.],
              [2., 2., 2.],
              [2., 2., 2.]], dtype=float32))

In [18]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100]) # impossible to detach get_parameter inside
  output_shape = jax.eval_shape(net, x)
  return output_shape, net(x)

init, _ = hk.transform(eval_shape_net)
try:
  init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
  print("UnexpectedTracerError: applied JAX transform to side effecting function")

UnexpectedTracerError: applied JAX transform to side effecting function


In [20]:
# Solution 1: `hk.lift`
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])    # still side-effecting
  init, apply = hk.transform(net)  # nested transform
  params = hk.lift(init, name="inner")(hk.next_rng_key(), x) # register parameters in outer module scope with name "inner"
  output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x) # apply is a functionaly pure function and can be transformed!
  out = net(x)
  return out, output_shape


init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))

(DeviceArray([[-0.21376152,  0.19059503, -0.29319692, ...,  0.7637194 ,
                0.47994688, -0.7069051 ],
              [-0.21376152,  0.19059503, -0.29319692, ...,  0.7637194 ,
                0.47994688, -0.7069051 ],
              [-0.21376152,  0.19059503, -0.29319692, ...,  0.7637194 ,
                0.47994688, -0.7069051 ],
              ...,
              [-0.21376152,  0.19059503, -0.29319692, ...,  0.7637194 ,
                0.47994688, -0.7069051 ],
              [-0.21376152,  0.19059503, -0.29319692, ...,  0.7637194 ,
                0.47994688, -0.7069051 ],
              [-0.21376152,  0.19059503, -0.29319692, ...,  0.7637194 ,
                0.47994688, -0.7069051 ]], dtype=float32),
 ShapeDtypeStruct(shape=(100, 100), dtype=float32))

In [21]:
# Solution 2: Haiku's jax transforms
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])         # still side-effecting
  output_shape = hk.eval_shape(net, x)  # hk.eval_shape threads through the Haiku state for you
  out = net(x)
  return out, output_shape


init, apply = hk.transform(eval_shape_net)
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
out = apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))