In [21]:
import jax
import jax.numpy as jnp
import flax.linen as nn


class ResidualMLPBlock(nn.Module):
  @nn.compact
  def __call__(self, x, _):

    h = nn.Dense(features=x.shape[-1])(x)
    h = nn.relu(h)
    return x + h, None

class ResidualMLP(nn.Module):
  n_layers: int = 4
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x):
    block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      static_argnums=(2,),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    )

    ScanMLP = nn.scan(
      block, 
      variable_axes={'params': 0},
      variable_broadcast=False, 
      split_rngs={'params': True},
      length=self.n_layers)
    x, _ = ScanMLP()(x, None)
    return x

In [22]:
batch_size = 1
feature_dim = 4
rng = jax.random.key(42)

model = ResidualMLP(n_layers=4)
x = jnp.ones((batch_size, feature_dim))
variables = model.init(rng, x)
y = model.apply(variables, x)