# 1D-ARC Neural Cellular Automata

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX with:

In [None]:
%pip install -U "jax[cuda12]"

Then, install CAX from PyPi:

In [None]:
%pip install "cax[examples] @ git+https://github.com/879f4cf7/cax.git"

## Import

In [1]:
import json
import os

import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.core.ca import CA
from cax.core.perceive.conv_perceive import ConvPerceive
from cax.core.perceive.kernels import grad_kernel, identity_kernel
from cax.core.update.residual_update import ResidualUpdate
from flax import nnx
from tqdm.auto import tqdm

## Configuration

In [None]:
seed = 0

num_spatial_dims = 1
channel_size = 64
num_kernels = 2
hidden_layer_sizes = (256,)
cell_dropout_rate = 0.0

num_embeddings = 10  # 10 colors in total
features = 3  # embed in rgb

batch_size = 16
num_steps = 64
learning_rate = 1e-3

ds_size = 96

key = jax.random.key(seed)
rngs = nnx.Rngs(seed)

## Dataset

In [None]:
!git clone https://github.com/khalil-research/1D-ARC.git

In [4]:
ds_path = "./1D-ARC/dataset"


def process(input, output):
	input = jnp.squeeze(jnp.array(input, dtype=jnp.int32))
	output = jnp.squeeze(jnp.array(output, dtype=jnp.int32))

	assert input.shape == output.shape

	pad_size = ds_size - input.size
	pad_left, pad_right = pad_size // 2, pad_size - pad_size // 2

	input_padded = jnp.pad(input, (pad_left, pad_right))
	output_padded = jnp.pad(output, (pad_left, pad_right))

	return jnp.stack([input_padded, output_padded])


ds = []
tasks = []
for task_index, task_name in enumerate(os.listdir(ds_path)):
	task_path = os.path.join(ds_path, task_name)
	for task_file in os.listdir(task_path):
		with open(os.path.join(task_path, task_file)) as f:
			data = json.load(f)
			input_output = jnp.array(
				[
					process(data["train"][0]["input"], data["train"][0]["output"]),
					process(data["train"][1]["input"], data["train"][1]["output"]),
					process(data["train"][2]["input"], data["train"][2]["output"]),
					process(data["test"][0]["input"], data["test"][0]["output"]),
				],
				dtype=jnp.int32,
			)
			tasks.append(task_name)
			ds.append(input_output)
ds = jnp.stack(ds)

unique_tasks = list(set(tasks))
task_to_index = {task: index for index, task in enumerate(unique_tasks)}
tasks = jnp.array([task_to_index[task] for task in tasks], dtype=jnp.int32)

In [5]:
key, subkey = jax.random.split(key)

tasks = jax.random.permutation(subkey, tasks)
ds = jax.random.permutation(subkey, ds)

split = int(0.9 * ds.shape[0])

train_ds = ds[:split]
train_tasks = tasks[:split]

test_ds = ds[split:]
test_tasks = tasks[split:]

## Init state

In [6]:
def init_state_with_sample(ca, sample, key):
	# Sample input and target
	(
		(input_embed_1, output_embed_1),
		(input_embed_2, output_embed_2),
		(input_embed_3, output_embed_3),
		(input_embed, _),
	) = ca.embed_input(sample)

	# Create context
	context_1 = jnp.concatenate(
		[input_embed_1, output_embed_1, input_embed_2, output_embed_2, input_embed_3, output_embed_3], axis=-1
	)
	context_2 = jnp.concatenate(
		[input_embed_1, output_embed_1, input_embed_3, output_embed_3, input_embed_2, output_embed_2], axis=-1
	)
	context_3 = jnp.concatenate(
		[input_embed_2, output_embed_2, input_embed_1, output_embed_1, input_embed_3, output_embed_3], axis=-1
	)
	context_4 = jnp.concatenate(
		[input_embed_3, output_embed_3, input_embed_1, output_embed_1, input_embed_2, output_embed_2], axis=-1
	)
	context_5 = jnp.concatenate(
		[input_embed_2, output_embed_2, input_embed_3, output_embed_3, input_embed_1, output_embed_1], axis=-1
	)
	context_6 = jnp.concatenate(
		[input_embed_3, output_embed_3, input_embed_2, output_embed_2, input_embed_1, output_embed_1], axis=-1
	)
	context = jax.random.choice(key, jnp.array([context_1, context_2, context_3, context_4, context_5, context_6]))

	# Initialize state
	state = jnp.zeros((ds_size, channel_size))
	state = state.at[..., :3].set(input_embed)
	state = state.at[..., 3 : 18 + 3].set(context)
	return state, sample[-1, -1]


def init_state(ca, key):
	key_sample, key_flip, key_perm, key_init = jax.random.split(key, 4)

	# Sample dataset
	task_index = jax.random.choice(key_sample, train_tasks)
	sample = jax.random.choice(key_sample, train_ds)

	# Flip sample half of the time
	flip = jax.random.bernoulli(key_flip, p=0.5)
	sample = jnp.where(flip < 0.5, sample, jnp.flip(sample, axis=-1))

	# Permute colors
	color_perm = jnp.concatenate([jnp.array([0], dtype=jnp.int32), jax.random.permutation(key_perm, jnp.arange(9)) + 1])
	sample = color_perm[sample]

	return init_state_with_sample(ca, sample, key_init)


def init_state_test(ca, key):
	key_sample, key_init = jax.random.split(key)

	# Sample dataset
	task_index = jax.random.choice(key_sample, test_tasks)
	sample = jax.random.choice(key_sample, test_ds)

	return init_state_with_sample(ca, sample, key_init)

## Model

In [7]:
perceive = ConvPerceive(
	channel_size=channel_size,
	perception_size=num_kernels * channel_size,
	rngs=rngs,
	kernel_size=(3,),
	feature_group_count=channel_size,
)
update = ResidualUpdate(
	num_spatial_dims=num_spatial_dims,
	channel_size=channel_size,
	perception_size=num_kernels * channel_size,
	hidden_layer_sizes=hidden_layer_sizes,
	rngs=rngs,
	cell_dropout_rate=cell_dropout_rate,
)
embed_input = nnx.Embed(num_embeddings=num_embeddings, features=features, rngs=rngs)


class ARCNCA(CA):
	embed_input: nnx.Embed

	def __init__(self, perceive, update, embed_input):
		super().__init__(perceive, update)
		self.embed_input = embed_input

In [8]:
kernel = jnp.concatenate([identity_kernel(ndim=num_spatial_dims), grad_kernel(ndim=num_spatial_dims)], axis=-1)
kernel = jnp.expand_dims(jnp.concatenate([kernel] * channel_size, axis=-1), axis=-2)
perceive.conv.kernel = nnx.Param(kernel)

In [9]:
ca = ARCNCA(perceive, update, embed_input)

In [10]:
params = nnx.state(ca, nnx.Param)
print("Number of params:", jax.tree.reduce(lambda x, y: x + y.size, params, 0))

Number of params: 49886


## Train

In [11]:
num_train_steps = 100000
lr_sched = optax.linear_schedule(
	init_value=learning_rate, end_value=0.1 * learning_rate, transition_steps=num_train_steps // 10
)

optimizer = optax.chain(
	optax.clip_by_global_norm(1.0),
	optax.adam(learning_rate=lr_sched),
)

params = nnx.All(
	nnx.Param,
	# nnx.Not(nnx.PathContains("perceive"))
)
optimizer = nnx.Optimizer(ca, optimizer, wrt=params)

### Loss

In [12]:
def ce(state, output):
	return jnp.mean(optax.softmax_cross_entropy_with_integer_labels(state[..., -10:], output))

In [13]:
@nnx.jit
def loss_fn(ca, key):
	keys = jax.random.split(key, batch_size)
	state, output = jax.vmap(init_state, in_axes=(None, 0))(ca, keys)

	state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
	state = nnx.split_rngs(splits=batch_size)(
		nnx.vmap(
			lambda ca, state: ca(state, num_steps=num_steps),
			in_axes=(state_axes, 0),
		)
	)(ca, state)

	loss = ce(state, output)
	return loss

### Train step

In [14]:
@nnx.jit
def train_step(ca, optimizer, key):
	loss, grad = nnx.value_and_grad(loss_fn, argnums=nnx.DiffState(0, params))(ca, key)
	optimizer.update(grad)
	return loss

### Main loop

In [15]:
def accuracy(ca, eval_ds):
	eval_size = eval_ds.shape[0]
	state, output = jax.vmap(init_state_with_sample, in_axes=(None, 0, None))(ca, eval_ds, key)

	state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
	state = nnx.split_rngs(splits=eval_size)(
		nnx.vmap(
			lambda ca, state: ca(state, num_steps=num_steps, all_steps=True),
			in_axes=(state_axes, 0),
		)
	)(ca, state)

	# Convert logits to symbols
	final_state_logits = state[:, -1, :, -10:]
	final_state = jnp.argmax(final_state_logits, axis=-1)

	# Successful if all symbols match in the prediction
	return jnp.sum(jnp.all(final_state == output, axis=-1)) / eval_size

In [None]:
print_interval = 128

pbar = tqdm(range(num_train_steps), desc="Training", unit="train_step")
losses = []
for i in pbar:
	key, subkey = jax.random.split(key)
	loss = train_step(ca, optimizer, subkey)
	losses.append(loss)

	if i % print_interval == 0 or i == num_train_steps - 1:
		avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
		test_acc = accuracy(ca, test_ds)
		train_acc = accuracy(ca, train_ds)
		pbar.set_postfix(
			{
				"Average Loss": f"{avg_loss:.3e}",
				"Test Accuracy": f"{test_acc:.2%}",
				"Train Accuracy": f"{train_acc:.2%}",
			}
		)

## Visualize

In [43]:
n_examples = 8

key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, n_examples)
state, output = jax.vmap(init_state_test, in_axes=(None, 0))(ca, keys)

state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
state = nnx.split_rngs(splits=n_examples)(
	nnx.vmap(
		lambda ca, state: ca(state, num_steps=num_steps, all_steps=True),
		in_axes=(state_axes, 0),
	)
)(ca, state)

In [44]:
logits = state[..., -10:]
pred = jnp.argmax(logits, axis=-1)
colors = ["black", "red", "green", "blue", "yellow", "purple", "orange", "pink", "brown", "gray"]
color_map = {
	"black": [0, 0, 0],
	"red": [255, 0, 0],
	"green": [0, 255, 0],
	"blue": [0, 0, 255],
	"yellow": [255, 255, 0],
	"purple": [128, 0, 128],
	"orange": [255, 165, 0],
	"pink": [255, 192, 203],
	"brown": [165, 42, 42],
	"gray": [128, 128, 128],
}

# Create a lookup table for faster conversion
color_lookup = jnp.array([color_map[color] for color in colors])

# Convert pred to RGB
state_rgb = color_lookup[pred]

# Normalize RGB values to [0, 1] range
state_rgb = state_rgb.astype(jnp.float32) / 255.0

In [46]:
mediapy.show_images(state_rgb, width=128, height=128)