# Growing Unsupervised Neural Cellular Automata

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX with:

In [2]:
# %pip install -U "jax[cuda12]"

Then, install CAX from PyPi:

In [3]:
# %pip install -U "cax[examples]"

## Import

In [1]:
import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.core.ca import CA
from cax.core.perceive.depthwise_conv_perceive import DepthwiseConvPerceive
from cax.core.perceive.kernels import grad_kernel, identity_kernel
from cax.core.state import state_to_alive
from cax.core.update.nca_update import NCAUpdate
from cax.nn.pool import Pool
from cax.nn.vae import Encoder
from datasets import load_dataset
from flax import nnx
from tqdm.auto import tqdm

## Configuration

In [2]:
seed = 0
n_classes = 10

spatial_dims = (28, 28)
features = (1, 32, 32)
latent_size = n_classes

channel_size = 32
num_kernels = 3
hidden_size = 256
cell_dropout_rate = 0.5

pool_size = 1_024
batch_size = 8
num_steps = 64
learning_rate = 1e-3

key = jax.random.key(seed)
rngs = nnx.Rngs(seed)

2024-09-25 16:11:38.808347: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


## Dataset

In [3]:
ds = load_dataset("ylecun/mnist")

image_train = jnp.expand_dims(
    jnp.array(ds["train"]["image"], dtype=jnp.float32) / 255, axis=-1
)
image_test = jnp.expand_dims(
    jnp.array(ds["test"]["image"], dtype=jnp.float32) / 255, axis=-1
)

mediapy.show_images(image_train[:8], width=128, height=128)

## Init state

In [6]:
def init_state(key):
    state_shape = image_train.shape[1:3] + (channel_size,)

    state = jnp.zeros(state_shape)
    mid = tuple(size // 2 for size in state_shape[:-1])
    state = state.at[mid[0], mid[1], -1].set(1.0)

    target_index = jax.random.choice(key, image_train.shape[0])
    # Set the "alive" cell
    state = state.at[mid + (channel_size - 1,)].set(1.0)

    return state, target_index

## Model

In [10]:
perceive = DepthwiseConvPerceive(channel_size, rngs)
update = NCAUpdate(
	channel_size, latent_size + num_kernels * channel_size, (hidden_size,), rngs, cell_dropout_rate=cell_dropout_rate
)
encoder = Encoder(spatial_dims, features, latent_size, rngs)


class UnsupervisedCA(CA):
    encoder: Encoder

    def __init__(self, perceive, update, encoder):
        super().__init__(perceive, update)

        self.encoder = encoder
    
    def encode(self, target):
        mean, logvar = self.encoder(target)
        return self.encoder.reparameterize(mean, logvar)

In [11]:
kernel = jnp.concatenate([identity_kernel(ndim=2), grad_kernel(ndim=2)], axis=-1)
kernel = jnp.expand_dims(jnp.concatenate([kernel] * channel_size, axis=-1), axis=-2)
perceive.depthwise_conv.kernel = nnx.Param(kernel)

In [12]:
ca = UnsupervisedCA(perceive, update, encoder)

In [8]:
params = nnx.state(ca, nnx.Param)
print("Number of params:", jax.tree.reduce(lambda x, y: x + y.size, params, 0))

Number of params: 2530832


## Train

### Pool

In [9]:
def get_pool(key):

    key, subkey = jax.random.split(key)

pool = Pool.create({"state": state, "target_index": target_index})

### Optimizer

In [10]:
def get_optimizer(ca):
    lr_sched = optax.linear_schedule(
        init_value=learning_rate, end_value=0.1 * learning_rate, transition_steps=50_000
    )

    optimizer = optax.chain(
        optax.clip_by_global_norm(1.0),
        optax.adam(learning_rate=lr_sched),
    )

    grad_params = nnx.All(
        nnx.Param, nnx.Any(nnx.PathContains("update"), nnx.PathContains("encoder"))
    )
    optimizer = nnx.Optimizer(ca, optimizer, wrt=grad_params)

    return optimizer, grad_params

### Loss

In [11]:
def mse(state, target):
    return jnp.mean(jnp.square(state_to_alive(state) - target))

In [3]:
@nnx.jit
def loss_fn(ca, state, target, key):
	target_enc = ca.encode(target)

	state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
	state = nnx.split_rngs(splits=batch_size)(
		nnx.vmap(
			lambda ca, state, target_enc: ca(state, target_enc, num_steps=num_steps, all_steps=True),
			in_axes=(state_axes, 0, 0),
		)
	)(ca, state, target_enc)

	index = jax.random.randint(key, (state.shape[0],), num_steps // 2, num_steps)
	state = state[jnp.arange(state.shape[0]), index]

	loss = mse(state, target)
	return loss, state

NameError: name 'nnx' is not defined

### Train step

In [7]:
@nnx.jit
def train_step(ca, optimizer, pool, key):
	sample_key, init_state_key, loss_key = jax.random.split(key, 3)

	# Sample from pool
	pool_index, batch = pool.sample(sample_key, batch_size=batch_size)
	current_state = batch["state"]
	current_target_index = batch["target_index"]
	current_target = image_train[current_target_index]

	# Sort by descending loss
	sort_index = jnp.argsort(
		jax.vmap(mse)(current_state, current_target), descending=True
	)
	pool_index = pool_index[sort_index]
	current_state = current_state[sort_index]
	current_target_index = current_target_index[sort_index]

	# Sample a new target to replace the worst
	new_state, new_target_index = init_state(init_state_key)
	current_state = current_state.at[0].set(new_state)
	current_target_index = current_target_index.at[0].set(new_target_index)
	current_target = image_train[current_target_index]
	current_label = label_train[current_target_index]

	(loss, current_state), grad = nnx.value_and_grad(
		loss_fn, has_aux=True, argnums=nnx.DiffState(0, grad_params)
	)(ca, current_state, current_target, current_label, loss_key)
	optimizer.update(grad)

	pool = pool.update(pool_index, {"state": current_state, "target_index": current_target_index})
	return loss, pool

NameError: name 'nnx' is not defined

### Main loop

In [43]:
num_train_steps = 8_192
print_interval = 128

pbar = tqdm(range(num_train_steps), desc="Training", unit="train_step")
losses = []

for i in pbar:
	key, subkey = jax.random.split(key)
	loss, pool = train_step(ca, optimizer, pool, subkey)
	losses.append(loss)

	if i % print_interval == 0 or i == num_train_steps - 1:
		avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
		pbar.set_postfix({"Average Loss": f"{avg_loss:.6f}"})

## Visualize

In [47]:
key, subkey = jax.random.split(key)

keys = jax.random.split(subkey, 8)
state, image_index = jax.vmap(init_state)(keys)

key, subkey = jax.random.split(key)
target = image_train[image_index]
target_enc = ca.encode(target)

state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
state = nnx.split_rngs(splits=batch_size)(
	nnx.vmap(
		lambda ca, state, target_enc: ca(state, target_enc, num_steps=2 * num_steps, all_steps=True),
		in_axes=(state_axes, 0, 0),
	)
)(ca, state, target_enc)

mediapy.show_images(target, width=128, height=128)
mediapy.show_videos(
    jnp.squeeze(state_to_alive(state)), width=128, height=128, codec="gif"
)

### Interpolation

In [37]:
key, subkey = jax.random.split(key)
image_index = jax.random.choice(subkey, image_train.shape[0], shape=(2,))
image = image_train[image_index]

key, subkey = jax.random.split(key)
target_enc = ca.encode(image)

In [None]:
alphas = jnp.linspace(0.0, 1.0, 8)
target_encs = jnp.array(
    [alpha * target_enc[0] + (1 - alpha) * target_enc[1] for alpha in alphas]
)

key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, 8)
state, _ = jax.vmap(init_state)(keys)

state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
state = nnx.split_rngs(splits=batch_size)(
	nnx.vmap(
		lambda ca, state, target_enc: ca(state, target_enc, num_steps=num_steps),
		in_axes=(state_axes, 0, 0),
	)
)(ca, state, target_encs)

mediapy.show_images(state_to_alive(state), width=128, height=128)