# Diffusing Neural Cellular Automata

## Import

In [1]:
import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.core.ca import CA
from cax.core.perceive.dwconv_perceive import DWConvPerceive
from cax.core.perceive.kernels import grad_kernel, identity_kernel
from cax.core.state import state_from_rgba_to_rgb, state_to_rgba
from cax.core.update.nca_update import NCAUpdate
from cax.utils.image import get_emoji
from flax import nnx

## Configuration

In [None]:
seed = 0

channel_size = 16
num_kernels = 3
hidden_size = 128
cell_dropout_rate = 0.5

batch_size = 8
num_steps = 128
learning_rate = 1e-3

emoji = "🦎"
target_size = 40
target_padding = 16

key = jax.random.PRNGKey(seed)
rngs = nnx.Rngs(seed)

## Dataset

In [3]:
target = get_emoji(emoji, size=target_size, padding=target_padding)

mediapy.show_image(target)

## Init state

In [6]:
def add_noise(target, alpha, key):
	noise = jax.random.normal(key, target.shape)
	noisy_image = (1 - alpha) * target + alpha * noise
	return jnp.clip(noisy_image, 0.0, 1.0)


def init_state(key):
	state = jnp.zeros(target.shape[:2] + (channel_size,))

	alpha_key, noise_key = jax.random.split(key)
	alpha = jax.random.uniform(alpha_key)
	noise = jax.random.normal(noise_key, target.shape)
	noisy_target = (1 - alpha) * target + alpha * noise

	return state.at[..., -4:].set(noisy_target)

## Model

In [8]:
perceive = DWConvPerceive(channel_size, rngs)
update = NCAUpdate(channel_size, num_kernels * channel_size, (hidden_size,), rngs, cell_dropout_rate=cell_dropout_rate)

In [9]:
kernel = jnp.concatenate([identity_kernel(ndim=2), grad_kernel(ndim=2)], axis=-1)
kernel = jnp.expand_dims(jnp.concatenate([kernel] * channel_size, axis=-1), axis=-2)
perceive.dwconv.kernel = nnx.Param(kernel)

In [10]:
ca = CA(perceive, update)

In [11]:
params = nnx.state(ca, nnx.Param)
print("Number of params:", jax.tree_util.tree_reduce(lambda x, y: x + y.size, params, 0))

Number of params: 8768


## Train

### Optimizer

In [12]:
lr_sched = optax.linear_schedule(init_value=learning_rate, end_value=0.1 * learning_rate, transition_steps=2_000)

optimizer = optax.chain(
	optax.clip_by_global_norm(1.0),
	optax.adam(learning_rate=lr_sched),
)

update_params = nnx.All(nnx.Param, nnx.PathContains("update"))
optimizer = nnx.Optimizer(ca, optimizer, wrt=update_params)

### Loss

In [21]:
def mse(state):
	return jnp.mean(jnp.square(state_to_rgba(state) - target))

In [22]:
@nnx.jit
def loss_fn(ca, state):
	state = nnx.vmap(lambda state: ca(state, num_steps=num_steps))(state)
	loss = mse(state)
	return loss

### Train step

In [23]:
@nnx.jit
def train_step(ca, optimizer, key):
	keys = jax.random.split(key, batch_size)
	current_state = jax.vmap(init_state)(keys)

	loss, grad = nnx.value_and_grad(loss_fn, argnums=nnx.DiffState(0, update_params))(ca, current_state)
	optimizer.update(grad)

	return loss

### Main loop

In [None]:
for i in range(8_192):
	key, subkey = jax.random.split(key)
	loss = train_step(ca, optimizer, subkey)
	if i % 128 == 0:
		print(f"Step {i}: loss = {loss}")

## Visualize

In [24]:
key, subkey = jax.random.split(key)

keys = jax.random.split(subkey, 8)
state = jax.vmap(init_state)(keys)
state = nnx.vmap(lambda state: ca(state, num_steps=2 * num_steps, all_steps=True))(state)

mediapy.show_videos(state_from_rgba_to_rgb(state), width=128, height=128, codec="gif")

In [25]:
mediapy.show_images(state_to_rgba(state[:, -1]), width=128, height=128)