# Mine

In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn


class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, layer_id, x):
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = x+h
    return layer_id+1, o
    # out = {}
    # out['layer_id'] = layer_id
    # out['o'] = o
    # return layer_id+1, o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x):
    block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    )

    ScanMLP = nn.scan(
      block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)
    
    init_layer_id = 0
    final_layer_id, y = ScanMLP(
      feature_dim=self.feature_dim,
    )(init_layer_id, x)
    return final_layer_id, y
  
batch_size = 1
feature_dim = 3
layers = 2
rng = jax.random.key(42)

model = ResidualMLP(n_layers=layers)
x = jnp.ones((batch_size, feature_dim))
variables = model.init(rng, x)
layer_id, y = model.apply(variables, x)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"layer_id: {layer_id}")
print(f"y shape: {y.shape}")

variable shapes: {'params': {'ScanCheckpointResidualMLPBlock_0': {'Dense_0': {'bias': (2, 3), 'kernel': (2, 3, 3)}}}}
layer_id: 2
y shape: (2, 1, 3)


# Big Vision

In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn


class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, x, deterministic=True):
    out = {}
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = out['o'] = x+h
    return o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x, deterministic=True):
    block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      static_argnums=(2,),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),)
    
    ScanMLP = nn.scan(
      block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)
    
    x, scan_out = ScanMLP(
      feature_dim=self.feature_dim
    )(x,deterministic)
    return x, scan_out
  
batch_size = 1
feature_dim = 3
layers = 2
rng = jax.random.key(42)

model = ResidualMLP(n_layers=layers)
x = jnp.ones((batch_size, feature_dim))
variables = model.init(rng, x)
y, scan_out = model.apply(variables, x, True)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"y shape: {y.shape}")
print(f"scan_out shapes: {jax.tree.map(lambda x: x.shape, scan_out)}")

variable shapes: {'params': {'ScanCheckpointResidualMLPBlock_0': {'Dense_0': {'bias': (2, 3), 'kernel': (2, 3, 3)}}}}
y shape: (1, 3)
scan_out shapes: {'o': (2, 1, 3)}
