In [48]:
import jax
import jax.numpy as jnp
import flax.linen as nn


class ResidualMLPBlock(nn.Module):
  @nn.compact
  def __call__(self, carry, x):
    h = nn.Dense(features=x.shape[-1])(x)
    h = nn.relu(h)
    o = x+h
    return carry+1, o

class ResidualMLP(nn.Module):
  n_layers: int = 4
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x):
    block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      static_argnums=(2,),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    )

    ScanMLP = nn.scan(
      block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)
    
    carry = 0
    carry, x = ScanMLP()(carry, x)
    return carry, x

In [49]:
batch_size = 1
feature_dim = 3
layers = 2
rng = jax.random.key(42)

model = ResidualMLP(n_layers=layers)
x = jnp.ones((batch_size, feature_dim))
variables = model.init(rng, x)
y = model.apply(variables, x)
jax.tree.map(lambda x: x.shape, variables)

{'params': {'ScanCheckpointResidualMLPBlock_0': {'Dense_0': {'bias': (2, 3),
    'kernel': (2, 3, 3)}}}}