# Mine: layer_id as carry

In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn

batch_size = 32
feature_dim = 768
layers = 12
rng = jax.random.key(42)
x = jnp.ones((batch_size, feature_dim))

class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, layer_id, x):
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = x+h
    return layer_id+1, o
    # out = {}
    # out['layer_id'] = layer_id
    # out['o'] = o
    # return layer_id+1, o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x):

    remat_block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    )

    scan_block = nn.scan(
      remat_block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)(
        feature_dim=self.feature_dim,)

    init_layer_id = 0
    final_layer_id, y = scan_block(init_layer_id, x)
    return final_layer_id, y
  

model = ResidualMLP(n_layers=layers, feature_dim=feature_dim)
variables = model.init(rng, x)
layer_id, y = model.apply(variables, x)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"layer_id: {layer_id}")
print(f"y shape: {y.shape}")

RuntimeError: Unable to initialize backend 'tpu': UNKNOWN: TPU initialization failed: open(/dev/accel2): Operation not permitted: Operation not permitted; Couldn't open device: /dev/accel2; [/dev/accel2]  (set JAX_PLATFORMS='' to automatically choose an available backend)

# Mine: x and layer id as carry

In [5]:
import jax
import jax.numpy as jnp
import flax.linen as nn

batch_size = 32
feature_dim = 768
layers = 12
rng = jax.random.key(42)
x = jnp.ones((batch_size, feature_dim))

class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, carry, inputs):
    layer_id = carry
    x, deterministic = inputs
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = x+h
    carry = layer_id+1
    return carry, o
    # out = {}
    # out['layer_id'] = layer_id
    # out['o'] = o
    # return layer_id+1, o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x, deterministic=True):

    remat_block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      static_argnums=(-1,),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    )

    scan_block = nn.scan(
      remat_block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)(
        feature_dim=self.feature_dim,)

    init_layer_id = 0
    carry = init_layer_id
    inputs = (x, deterministic)
    final_carry, y = scan_block(carry, inputs)
    final_layer_id = final_carry
    print(f"final_layer_id = {final_layer_id}")
    return y
  

model = ResidualMLP(n_layers=layers, feature_dim=feature_dim)
variables = model.init(rng, x)
y = model.apply(variables, x, deterministic=True)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"y shape: {y.shape}")

variable shapes: {'params': {'ScanCheckpointResidualMLPBlock_0': {'Dense_0': {'bias': (12, 768), 'kernel': (12, 768, 768)}}}}
layer_id: 12
y shape: (32, 768)


# Mine: x, layer id as carry; y, out as output

In [12]:
import jax
import jax.numpy as jnp
import flax.linen as nn

batch_size = 32
feature_dim = 768
layers = 12
rng = jax.random.key(42)
x = jnp.ones((batch_size, feature_dim))

class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, carry, inputs):
    layer_id = carry
    x, deterministic = inputs
    
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = x+h
    
    out = {}
    out['o'] = o
    carry = layer_id+1
    return carry, (o,out)
    # return layer_id+1, o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x, deterministic=True):

    remat_block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      # static_argnums=(-1,),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    )

    scan_block = nn.scan(
      remat_block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)(
        feature_dim=self.feature_dim,)

    init_layer_id = 0
    carry = init_layer_id
    inputs = (x, deterministic)
    final_carry, output = scan_block(carry, inputs)
    final_layer_id = final_carry
    y, out = output
    print(f"final_layer_id = {final_layer_id}")
    return y, out
  

model = ResidualMLP(n_layers=layers, feature_dim=feature_dim)
variables = model.init(rng, x)
y = model.apply(variables, x, deterministic=True)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"y shape: {y.shape}")

> [0;32m/mnt/vlm-pd/miniconda3/envs/vlm/lib/python3.10/site-packages/flax/linen/transforms.py[0m(420)[0;36mwrapped_fn[0;34m()[0m
[0;32m    418 [0;31m  [0;32mdef[0m [0mwrapped_fn[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    419 [0;31m    [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 420 [0;31m    [0mstate[0m [0;34m=[0m [0mself[0m[0;34m.[0m[0m_state[0m[0;34m.[0m[0mexport[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    421 [0;31m[0;34m[0m[0m
[0m[0;32m    422 [0;31m    [0;31m# make a scope-function to transform[0m[0;34m[0m[0;34m[0m[0m
[0m


AttributeError: 'int' object has no attribute '_state'

# Big Vision

In [3]:
class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, x, deterministic=True):
    out = {}
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = out['o'] = x+h
    return o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x, deterministic=True):
    block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      static_argnums=(2,),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),)
    
    ScanMLP = nn.scan(
      block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)
    
    x, scan_out = ScanMLP(
      feature_dim=self.feature_dim
    )(x,deterministic)
    return x, scan_out


model = ResidualMLP(n_layers=layers, feature_dim=feature_dim)
variables = model.init(rng, x)
y, scan_out = model.apply(variables, x, True)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"y shape: {y.shape}")
print(f"scan_out shapes: {jax.tree.map(lambda x: x.shape, scan_out)}")

variable shapes: {'params': {'ScanCheckpointResidualMLPBlock_0': {'Dense_0': {'bias': (12, 768), 'kernel': (12, 768, 768)}}}}
y shape: (32, 768)
scan_out shapes: {'o': (12, 32, 768)}
