In [2]:
%load_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp
def drop_path(x, drop_prob: float = 0.0, deterministic: bool = False):
    """Drop paths (Stochastic Depth) per sample.
    
    This is an implementation of the DropPath function as described in the
    paper "Deep Networks with Stochastic Depth" (https://arxiv.org/abs/1603.09382).
    
    Args:
        x: input tensor
        drop_prob: probability of dropping a path
        deterministic: if True, the drop mask will be all ones (no dropout)
    
    Returns:
        Output tensor after applying drop path.
    """
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (x.ndim - 1)
    
    rng = jax.random.PRNGKey(0)  # You might want to pass this as an argument for better randomness
    random_tensor = jax.random.bernoulli(rng, p=keep_prob, shape=shape)
    random_tensor = jnp.asarray(random_tensor, dtype=x.dtype)
    random_tensor = random_tensor / jnp.where(keep_prob > 0, keep_prob, 1.)

    output = jnp.where(deterministic, x, x * random_tensor)
    return output

# Mine: layer_id as carry

In [2]:
import jax
import jax.numpy as jnp
import flax.linen as nn

batch_size = 32
feature_dim = 768
layers = 12
rng = jax.random.key(42)
x = jnp.ones((batch_size, feature_dim))

class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, layer_id, x):
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = x+h
    return layer_id+1, o
    # out = {}
    # out['layer_id'] = layer_id
    # out['o'] = o
    # return layer_id+1, o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x):

    remat_block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    )

    scan_block = nn.scan(
      remat_block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)(
        feature_dim=self.feature_dim,)

    init_layer_id = 0
    final_layer_id, y = scan_block(init_layer_id, x)
    return final_layer_id, y
  

model = ResidualMLP(n_layers=layers, feature_dim=feature_dim)
variables = model.init(rng, x)
layer_id, y = model.apply(variables, x)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"layer_id: {layer_id}")
print(f"y shape: {y.shape}")

variable shapes: {'params': {'ScanCheckpointResidualMLPBlock_0': {'Dense_0': {'bias': (12, 768), 'kernel': (12, 768, 768)}}}}
layer_id: 12
y shape: (12, 32, 768)


# Mine: x and layer id as carry

In [3]:
import jax
import jax.numpy as jnp
import flax.linen as nn

batch_size = 32
feature_dim = 768
layers = 12
rng = jax.random.key(42)
x = jnp.ones((batch_size, feature_dim))

class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, carry, inputs):
    layer_id = carry
    x, deterministic = inputs
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = x+h
    carry = layer_id+1
    return carry, o
    # out = {}
    # out['layer_id'] = layer_id
    # out['o'] = o
    # return layer_id+1, o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x, deterministic=True):

    remat_block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      static_argnums=(-1,),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    )

    scan_block = nn.scan(
      remat_block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)(
        feature_dim=self.feature_dim,)

    init_layer_id = 0
    carry = init_layer_id
    inputs = (x, deterministic)
    final_carry, y = scan_block(carry, inputs)
    final_layer_id = final_carry
    print(f"final_layer_id = {final_layer_id}")
    return y
  

model = ResidualMLP(n_layers=layers, feature_dim=feature_dim)
variables = model.init(rng, x)
y = model.apply(variables, x, deterministic=True)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"y shape: {y.shape}")

final_layer_id = 12
final_layer_id = 12
variable shapes: {'params': {'ScanCheckpointResidualMLPBlock_0': {'Dense_0': {'bias': (12, 768), 'kernel': (12, 768, 768)}}}}
y shape: (12, 32, 768)


# Mine: x, layer id as carry; y, out as output

In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn

batch_size = 32
feature_dim = 768
layers = 12
rng = jax.random.key(42)
x = jnp.ones((batch_size, feature_dim))

class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, carry, inputs):
    layer_id = carry
    x, deterministic = inputs
    
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = x+h
    
    out = {}
    out['o'] = o
    carry = layer_id+1
    return carry, (o,out)
    # return layer_id+1, o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x, deterministic=True):

    remat_block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      # static_argnums=(-1,),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),
    )

    scan_block = nn.scan(
      remat_block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)(
        feature_dim=self.feature_dim,)

    init_layer_id = 0
    carry = init_layer_id
    inputs = (x, deterministic)
    final_carry, output = scan_block(carry, inputs)
    final_layer_id = final_carry
    y, out = output
    return y, out

model = ResidualMLP(n_layers=layers, feature_dim=feature_dim)
print(f"Initializing model")
variables = model.init(rng, x)
print(f"Applying model")
y, out = model.apply(variables, x, deterministic=True)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"y shape: {y.shape}")
print(f"out: {jax.tree.map(lambda x: x.shape, out)}")

Initializing model
Applying model
variable shapes: {'params': {'ScanCheckpointResidualMLPBlock_0': {'Dense_0': {'bias': (12, 768), 'kernel': (12, 768, 768)}}}}
y shape: (12, 32, 768)
out: {'o': (12, 32, 768)}


# Mine: block

In [6]:
class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  drop_path_rate: float = 0.1
  @nn.compact
  def __call__(self, x, deterministic=True, layer_id=-1, total_layers=-1):
    out = {}
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    x = out['o'] = x+h
    assert layer_id < total_layers, f"layer_id={layer_id} total_layers={total_layers}"
    drop_path_rate = self.drop_path_rate * layer_id / total_layers
    x = out['drop_path'] = drop_path(x, drop_path_rate, deterministic)
    return x, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x, deterministic=True):
    block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      static_argnums=(2,4),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),)
    
    ScanMLP = nn.scan(
      block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)
    
    x, scan_out = ScanMLP(
      feature_dim=self.feature_dim
    )(x,deterministic,0,self.n_layers)
    return x, scan_out


model = ResidualMLP(n_layers=layers, feature_dim=feature_dim)
variables = model.init(rng, x)
y, scan_out = model.apply(variables, x, True)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"y shape: {y.shape}")
print(f"scan_out shapes: {jax.tree.map(lambda x: x.shape, scan_out)}")

ConcretizationTypeError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function new_fun at /mnt/vlm-pd/miniconda3/envs/vlm/lib/python3.10/site-packages/jax/_src/ad_checkpoint.py:393 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument dyn_args[3].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Consider using the `static_argnums` parameter for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html


# Big Vision

In [3]:
class ResidualMLPBlock(nn.Module):
  feature_dim: int = 3
  @nn.compact
  def __call__(self, x, deterministic=True):
    out = {}
    h = nn.Dense(features=self.feature_dim)(x)
    h = nn.relu(h)
    o = out['o'] = x+h
    return o, out

class ResidualMLP(nn.Module):
  n_layers: int = 4
  feature_dim: int = 3
  remat_policy: str = "nothing_saveable"

  @nn.compact
  def __call__(self, x, deterministic=True):
    block = nn.remat(
      ResidualMLPBlock, 
      prevent_cse=False,
      static_argnums=(2,),
      policy=getattr(jax.checkpoint_policies, self.remat_policy, None),)
    
    ScanMLP = nn.scan(
      block, 
      variable_axes={'params': 0},
      split_rngs={'params': True},
      in_axes=nn.broadcast,
      length=self.n_layers)
    
    x, scan_out = ScanMLP(
      feature_dim=self.feature_dim
    )(x,deterministic)
    return x, scan_out


model = ResidualMLP(n_layers=layers, feature_dim=feature_dim)
variables = model.init(rng, x)
y, scan_out = model.apply(variables, x, True)
print(f"variable shapes: {jax.tree.map(lambda x: x.shape, variables)}")
print(f"y shape: {y.shape}")
print(f"scan_out shapes: {jax.tree.map(lambda x: x.shape, scan_out)}")

variable shapes: {'params': {'ScanCheckpointResidualMLPBlock_0': {'Dense_0': {'bias': (12, 768), 'kernel': (12, 768, 768)}}}}
y shape: (32, 768)
scan_out shapes: {'o': (12, 32, 768)}
