# Elementary Cellular Automata

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX with:

In [None]:
%pip install -U "jax[cuda12]"

Then, install CAX from PyPi:

In [None]:
%pip install "cax[examples] @ git+https://github.com/879f4cf7/cax.git"

## Import

In [1]:
import jax.numpy as jnp
import mediapy
from cax.core.ca import CA
from cax.core.perceive.conv_perceive import ConvPerceive
from cax.core.update.elementary_update import ElementaryUpdate
from flax import nnx

## Configuration

In [None]:
seed = 42

spatial_dims = (1_024,)
channel_size = 1
wolfram_code = "01101110"  # Rule 110

num_steps = 512

rngs = nnx.Rngs(seed)

## Init state

In [3]:
def init_state():
	state = jnp.zeros((*spatial_dims, channel_size))
	return state.at[spatial_dims[0] // 2].set(1.0)

## Model

In [4]:
perceive = ConvPerceive(
	channel_size=channel_size,
	perception_size=3 * channel_size,
	rngs=rngs,
	kernel_size=(3,),
	feature_group_count=channel_size,
)
update = ElementaryUpdate(wolfram_code=wolfram_code)

In [6]:
left_kernel = jnp.array([[1.0], [0.0], [0.0]])
identity_kernel = jnp.array([[0.0], [1.0], [0.0]])
right_kernel = jnp.array([[0.0], [0.0], [1.0]])

kernel = jnp.concatenate([left_kernel, identity_kernel, right_kernel], axis=-1)
kernel = jnp.expand_dims(kernel, axis=-2)
perceive.conv.kernel = nnx.Param(kernel)

In [7]:
ca = CA(perceive, update)

## Visualize

In [8]:
state = init_state()
state = ca(state, num_steps=num_steps, all_steps=True)

In [9]:
mediapy.show_image(state)