In [None]:
import haiku as hk
import jax
import jax.numpy as jnp

# Limitations of using JAX transforms inside of networks




**TL;DR:** Using a JAX transform inside of a `hk.transform` is likely to transform a side effecting function, which will result in an `UnexpectedTracerError`. To get around this, either use [the Haiku versions of these JAX functions](https://dm-haiku.readthedocs.io/en/latest/api.html#jax-fundamentals), or use [lift](https://dm-haiku.readthedocs.io/en/latest/api.html#lift) to be able to nest multiple `hk.transform`s.

Once a Haiku network has been transformed to a pure function using `hk.transform`, it's possible to freely combine it with any JAX transformations like `jax.jit`, `jax.grad`, `jax.scan` and so on.

If you want to use JAX transformations **inside** of a `hk.transform` however, you need to be more possible. It's possible, but most functions inside of the `hk.transform` boundary are still side effecting, and cannot safely be transformed by JAX. 
This is a common cause of `UnexpectedTracerError`s in code using Haiku. These errors are an indication of using a JAX transform on a side effecting function  (for more information on this error, see https://jax.readthedocs.io/en/latest/_modules/jax/_src/errors.html#UnexpectedTracerError). 

An example with `jax.eval_shape`:


In [None]:
def net(x): # inside of a hk.transform this is still side-effecting
  w = hk.get_parameter("w", (2, 2), init=jnp.ones)
  return w @ x

def eval_shape_net(x):
  output_shape = jax.eval_shape(net, x) # eval_shape on side-effecting function
  return net(x)                         # UnexpectedTracerError!

init, _ = hk.transform(eval_shape_net)
init(jax.random.PRNGKey(666), jnp.ones((2, 2)))

The error points to `get_parameter`. This is the operation which makes `net` a side effecting function. The side effect in this case is the creation of a parameter, which is stored in the Haiku state. Similarly you would get an error using `next_rng_key()`.

These examples use `jax.eval_shape`, but this would be true of any higher-order JAX function (eg. `jax.vmap`, `jax.scan`, 
`jax.while_loop`, ...)

You could re-write the code above to create the parameter outside of the `eval_shape` transformation, making `net` a pure function by threading through the parameter explictly as an argument:

In [None]:
def net(w, x): # no side effects!
  return w @ x

def eval_shape_net(x):
  w = hk.get_parameter("w", (3, 2), init=jnp.ones)
  output_shape = jax.eval_shape(net, w, x) # net is now side-effect free
  return output_shape, net(w, x)

key = jax.random.PRNGKey(777)
x = jnp.ones((2, 3))
init, apply = hk.transform(eval_shape_net)
params = init(key, x)
apply(params, key, x)

However, that's not always possible! Consider the following code which calls a Haiku module (`hk.nets.MLP`) which is outside of our control. This module will internally call `get_parameter` and `next_rng_key`.



In [None]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])
  output_shape = jax.eval_shape(net, x)
  return net(x)

init, _ = hk.transform(eval_shape_net)
try:
  init(jax.random.PRNGKey(666), jnp.ones((2, 2)))
except jax.errors.UnexpectedTracerError:
  print("UnexpectedTracerError: applied JAX transform to side effecting function")

To work around this, Haiku provides wrapped versions of JAX transforms under the `haiku` namespace. You can access these as `hk.jit`, `hk.grad`, and so on. See https://dm-haiku.readthedocs.io/en/latest/api.html#jax-fundamentals for a full list of available functions. 

These wrappers apply the JAX function to a functionally pure version of the Haiku function, by doing the explicit state threading above for you.

In [None]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])         # still side-effecting
  output_shape = hk.eval_shape(net, x)  # hk.eval_shape threads through the Haiku state for you
  out = net(x)                                                                                                                                    
  return out, output_shape                                                                                                  
                                                                                                                                                                        
                                                                                                                                                                      
init, apply = hk.transform(eval_shape_net)                                                                                                                  
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))

Another option is to use `hk.experimental.lift`, which allows you to nest a `hk.transform` inside of an outer `hk.transform`. Lift registers any parameters initialized in the wrapped `init` into the outer Haiku module's scope. The nested transform means you have access to the implicit Haiku state, which allows you to be explicit about how that state is handled in the JAX transform.

In [None]:
def eval_shape_net(x):
  net = hk.nets.MLP([300, 100])    # still side-effecting
  init, apply = hk.transform(net)  # nested transform
  params = hk.experimental.lift(init)(hk.next_rng_key(), x) # register parameters in outer module
  output_shape = jax.eval_shape(apply, params, hk.next_rng_key(), x) # apply is a functionaly pure function!
  out = net(x)                                                                                                                                    
  return out, output_shape                                                                                                  
                                                                                                                                                                        
                                                                                                                                                                      
init, apply = hk.transform(eval_shape_net)                                                                                                                  
params = init(jax.random.PRNGKey(777), jnp.ones((100, 100)))
apply(params, jax.random.PRNGKey(777), jnp.ones((100, 100)))