In [1]:
import functools
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import jax.tree_util as tree
import jraph
import flax
import haiku as hk
import optax
import pickle
import numpy as onp
import networkx as nx
from typing import Any, Callable, Dict, List, Optional, Tuple

In [2]:
class MLP(hk.Module):
  def __init__(self, features: jnp.ndarray):
    super().__init__()
    self.features = features

  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    layers = []
    for feat in self.features[:-1]:
      layers.append(hk.Linear(feat))
      layers.append(jax.nn.relu)
    layers.append(hk.Linear(self.features[-1]))

    mlp = hk.Sequential(layers)
    return mlp(x)

def train(net, data):
  params = net.init(jax.random.PRNGKey(42), data)
  # many more params update
  return params

In [3]:
logging_flag = False
logging_list = []

class HackyLogging(hk.Module):
  def __init__(self, features: jnp.ndarray):
    super().__init__()
    self.features = features

  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
    hidden_layer = MLP(self.features)(x)
    if logging_flag:
      logging_list.append(hidden_layer)
    return MLP(self.features)(hidden_layer)

In [None]:
net = hk.without_apply_rng(hk.transform(lambda x: HackyLogging([10,4,3])(x)))
logging_flag = False
params = train(net, onp.random.rand(100,3))
print(f'with logging disabled nothing gets logged: {logging_list}')
logging_flag = True
out = net.apply(params, onp.random.rand(1,3))
out = net.apply(params, onp.random.rand(1,3))
print(f'with logging enabled we log (but do not jit): {logging_list}')
# to clear the list we must use .clear(), simply assigning logging_list = []
# would overwrite the captured variable
logging_list.clear()
print(f'now we have an empty list: {logging_list}')
out = net.apply(params, onp.random.rand(1,3))
print(f'and we can fill it again: {logging_list}')

In [None]:
# a less hackish way to achieve this is through the multi_transform function
# for more details see https://dm-haiku.readthedocs.io/en/latest/api.html#haiku.multi_transform
def whitebox_mlp(layers):
  def _whitebox_mlp():
    mlp1 = MLP(layers)
    mlp2 = MLP(layers)
  
    def full_mlp(x):
      hidden = mlp1(x)
      return mlp2(hidden)
    
    def middle_mlp(x):
      return mlp1(x)
  
    def init(x):
      return full_mlp(x)
    
    return init, (full_mlp, middle_mlp)
  return _whitebox_mlp

# multi_transform does not support without_apply_rng
whitebox_net = hk.multi_transform(whitebox_mlp([10,4,3]))
params = whitebox_net.init(jax.random.PRNGKey(42), onp.random.rand(100,3))
full_mlp_apply, middle_mlp_apply = whitebox_net.apply
example_in = onp.random.rand(100,3)
# since we could not use without_apply_rng, we must explicitely pass
# an extra None argument to indicate we do not use rng
hidden = middle_mlp_apply(params, None, example_in)
out = full_mlp_apply(params, None, example_in)

# note that train should be modified to take as input not the whitebox_net
# function, but directly the full_mlp_apply and params