# Growing Neural Cellular automata using Evolutionary Strategies

In this notebook we train the [growing NCA](https://distill.pub/2020/growing-ca/) using evolutionary strategies and the [evoJax](https://github.com/google/evojax) library.


An extensive explanation of the problem and the CA model can be found in [this article](https://distill.pub/2020/growing-ca/), the current notebook will be mainly used for explaining the training with evoJax.

In [None]:
!pip install evojax

In [None]:
# @title Importing libraries

import os
import io
import PIL
import PIL.Image, PIL.ImageDraw, PIL.ImageFont, PIL.ImageOps
import base64
import matplotlib.pylab as pl
import numpy as np
from IPython.display import display
from functools import partial
from IPython.display import Image, HTML, clear_output
import matplotlib.animation as animation
import time

import einops
import requests



#jax
import jax
import jax.numpy as jp
from jax import grad, jit, vmap, pmap
from jax.example_libraries import optimizers
import jax.random as jr
from jax.flatten_util import ravel_pytree
import flax.linen as fnn
from jax.sharding import PartitionSpec as P



#evojax
import evojax
from evojax.algo import PGPE



In [None]:
#@title Imports and Notebook Utilities
os.environ['FFMPEG_BINARY'] = 'ffmpeg'
clear_output()

def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return PIL.Image.fromarray(a)

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def im2url(a, fmt='jpeg'):
  encoded = imencode(a, fmt)
  base64_byte_string = base64.b64encode(encoded).decode('ascii')
  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string


def imshow(a, fmt='jpeg', display_id=None):
  display(Image(data=imencode(a, fmt)), display_id=display_id)

def tile2d(a, w=None):
  a = np.asarray(a)
  if w is None:
    w = int(np.ceil(np.sqrt(len(a))))
  th, tw = a.shape[1:3]
  pad = (w-len(a))%w
  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')
  h = len(a)//w
  a = a.reshape([h, w]+list(a.shape[1:]))
  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))
  return a

def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)

  return img

class VideoWriter:
  def __init__(self, filename, fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()

In [None]:
#@title Cellular Automata Parameters
CHANNEL_N = 16        # Number of CA state channels
TARGET_PADDING = 4   # Number of pixels used to pad the target image border
TARGET_SIZE = 40
BATCH_SIZE = 1
CELL_FIRE_RATE = 0.5

TARGET_EMOJI = "🦎"

In [None]:
#@title CA Model and Utilities

def load_image(url, max_size=TARGET_SIZE):
  r = requests.get(url)
  img = PIL.Image.open(io.BytesIO(r.content))
  img.thumbnail((max_size, max_size), PIL.Image.LANCZOS)
  img = np.float32(img)/255.0
  # premultiply RGB by Alpha
  img[..., :3] *= img[..., 3:]
  return img

def load_emoji(emoji):
  code = hex(ord(emoji))[2:].lower()
  url = 'https://github.com/googlefonts/noto-emoji/blob/main/png/128/emoji_u%s.png?raw=true'%code
  return load_image(url)


def visualize(frames, namefile, size, negate=0):
  # Create and save animation
  fig, ax = pl.subplots(figsize=(size, size))
  ax.set_xticks([])
  ax.set_yticks([])

  def animate(frame):
    ax.clear()
    ax.set_xticks([])
    ax.set_yticks([])

    if negate:
      frame = 1 - frame

    frame = np.clip(frame, 0, 1)

    return [ax.imshow(frame)]

  anim = animation.FuncAnimation(fig, animate, frames=frames, interval=200, blit=True)

  # Save the gif
  writer = animation.PillowWriter(fps=10)
  anim.save(namefile, writer=writer)
  pl.close()


def to_rgba(x):
  return x[..., :4]

def to_alpha(x):
  return x[..., 3:4].clip(0.0, 1.0)

def to_rgb(x):
  # assume rgb premultiplied by alpha
  rgb, a = x[..., :3], to_alpha(x)
  return 1.0-a+rgb

def get_living_mask(x):
  # probably not needed anymore.
  alpha = x[:, :, :, 3:4]
  return fnn.max_pool(alpha, window_shape=(3,3), strides=(1, 1),
                      padding='SAME') > 0.1

def make_seed(size, n=1):
  x = np.zeros([n, size, size, CHANNEL_N], np.float32)
  x[:, size//2, size//2, 3:] = 1.0
  return x

def depthwise_conv2d(x, kernel, strides, padding):
  c = x.shape[-1]
  x = einops.rearrange(x, 'b h w c -> (b c) () h w')
  y = jax.lax.conv(x, kernel, strides, padding)
  y = einops.rearrange(y, '(b c) f h w -> b h w (f c)', c=c)
  return y

def perceive(x, angle=0.0):
  identify = np.float32([0, 1, 0])
  identify = np.outer(identify, identify)
  dx = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0  # Sobel filter
  dy = dx.T
  c, s = jp.cos(angle), jp.sin(angle)
  kernel = jp.stack([identify, c*dx-s*dy, s*dx+c*dy], 0)[:,None, :, :]
  y = depthwise_conv2d(x, kernel, (1, 1), 'SAME')
  return y


class CellsUpdate(fnn.Module):
  channel_n: int = CHANNEL_N
  fire_rate: float = CELL_FIRE_RATE

  def setup(self):
    """Initializes all parameters.

    """
    self.dmodel = fnn.Sequential([
        fnn.Conv(features=128, kernel_size=(1,1)),
        fnn.relu,
        fnn.Conv(features=self.channel_n, kernel_size=(1,1),
                 kernel_init=jax.nn.initializers.zeros)
    ])

  def __call__(self, x, key, fire_rate=None, angle=0.0, step_size=1.0):
    pre_life_mask = get_living_mask(x)

    y = perceive(x, angle)
    dx = self.dmodel(y)*step_size
    if fire_rate is None:
      fire_rate = self.fire_rate
    update_mask_f32 = (jr.uniform(key, x[:, :, :, :1].shape) <= fire_rate).astype(jp.float32)
    x += dx * update_mask_f32

    post_life_mask = get_living_mask(x)
    life_mask_f32 = (pre_life_mask & post_life_mask).astype(jp.float32)
    return x * life_mask_f32

In [None]:

def visualize_batch(x0, x, step_i):
  zoomed_x0 = zoom(to_rgb(x0), 2)
  zoomed_x = zoom(to_rgb(x), 2)
  vis0 = np.hstack(zoomed_x0[::2])
  vis1 = np.hstack(zoomed_x[::2])
  vis = np.vstack([vis0, vis1])
  imshow(vis)

def plot_loss(loss_log):
  pl.figure(figsize=(10, 4))
  pl.title('Loss history (log10)')
  pl.plot(np.log10(-np.asarray(loss_log)), '.', alpha=0.1)
  pl.show()


In [None]:
#@title Load emoji

target_img = load_emoji(TARGET_EMOJI)
imshow(zoom(to_rgb(target_img), 2), fmt='png')

## Initialization of the training

Evolutionary Strategies (ES) typically don't require schedulers used in gradient methods.
However, ES requires proper parameter formatting according to the specific library's (e.g., evoJax) requirements.

Key Concepts for ES Optimizers:
1. **Population**: The population is a set of parameters tested during each ES iteration. For a model with 10 parameters and a population size of 100, the population has a dimension of (100, 10).
Each element in the population represents a set of parameters to be evaluated.
2. **Fitness Function**: Unlike gradient descent, ES aims to maximize a fitness function. Each population element is evaluated using this function.
3. **Solver Algorithm**: The population elements with the best fitness are used to create the next population. The specific algorithm determines how this happens. This Colab will explore [PGPE ](https://people.idsia.ch/~juergen/icann2008sehnke.pdf), which mimics gradient descent. Many other options exist.


---


The ES-optimizer defined by the evoJAX exhibits 3 main API methods:

1. **Initializer** in which the type of algorithm is specified with all the hyperparameters (like the size of the population and the dimension of each population element)
1. **solver.ask()**: It provides a population of parameter sets. The output is a vector with the shape (population size, parameter size). This necessitates flattening the parameters before applying evolutionary strategies.
2. **solver.tell(fitness)**: Once the population elements are obtained and their fitnesses are evaluated, we need to provide these fitness values to the solver which update its own parameters. This is done using the solver.tell(fitness) function. **Important**: The fitness vector should have the shape (population size, 1), where each element corresponds to the fitness of the respective population element.


To initialize training, we defined the fitness function *fitness_f*. To use evolutionary strategies, we must flatten the parameters (for the reason explained above). The *ravel_pytree* function is used for this, and *unravel_pytree* is used within the fitness function to reconstruct the original parameter shapes.

In [None]:
#@title Initialize Training { vertical-output: true}

p = TARGET_PADDING
pad_target = jp.pad(target_img, [(p, p), (p, p), (0, 0)])
h, w = pad_target.shape[:2]
seed = np.zeros([h, w, CHANNEL_N], np.float32)
seed[h//2, w//2, 3:] = 1.0
x0 = np.repeat(seed[None, ...], BATCH_SIZE, 0)

def target_loss_f(x):
  return (jp.square(to_rgba(x)-pad_target)).mean([-2, -3, -1])


#Initialization of the cell model
ca_update = CellsUpdate()
key = jr.PRNGKey(1)
k1, k2, key = jr.split(key, 3)


params = ca_update.init(k1, seed[None,:], k2)

#flatten all the parameters
params_flat, unravel_pytree = ravel_pytree(params)


@jit
def fitness_f(params_flat, x, key):

  iter_n = 90
  compute_loss_every=30
  params = unravel_pytree(params_flat)

  loss = 0.0

  def scan_f(carry, i):
    loss, x, key = carry
    key, keyused = jr.split(key)
    x = ca_update.apply(params, x, key)

    loss = jax.lax.cond(((i+1)%compute_loss_every==0), (lambda loss, x: loss + target_loss_f(x).mean()), (lambda loss, x: loss), loss, x)

    return (loss, x, key), 0

    #loss function computed with all the timesteps from 0 to iter_n

  (loss, x, key), _ = jax.lax.scan(
      scan_f, (loss, x, key), jp.arange(0, iter_n, dtype=jp.int32))

  return -loss, x

k1, key = jr.split(key, 2)
imshow(zoom(to_rgb(pad_target), 2), fmt='png')

In [None]:
#parallelized using pmap

#only used if Training Device is TPU
def p_fitness_fn(ca_params, x0, key):
  #each device then execute the vectorized version of the fitness function (vmap)
  fitness, x = pmap(v_fitness_fn)(ca_params, x0, key)
  return fitness, x


@jit
def v_fitness_fn(ca_params, x0, key):
  fitness, x = vmap(fitness_f)(ca_params, x0, key)
  return fitness, x


If TPU is the device used for training, then we split the population in equal chunks, each chunk will be allocated to a specific XLA device to increase performance. This approach demonstrated strong linear scaling (for the same problem size, doubling the computational resources reduces the computational time by half) and constant weak scaling (doubling both the problem size and computational resources —XLA devices— maintains approximately constant computational time).

In [None]:
#@title Hyperparameters of the optimizer

# For popsize=128
lr=0.08

# For popsize=256
#lr=0.1

# param_size is the total number of parameters in the flattened parameter vector.
param_size=len(params_flat)
popsize= 128

optimizer='adam'
center_learning_rate=0.05*lr
stdev_learning_rate=0.01*lr
init_stdev=0.1*lr

training_iterations = 50000
print_every=100


if 'tpu' in [device.platform for device in jax.devices()]:
  #select TPU for training
  TPU = 8

  # Initialize the TPU mesh for memory sharding.
  # This distributes the population across XLA devices to improve performance.
  mesh = jax.make_mesh((TPU,), ('x'))
  sharding = jax.sharding.NamedSharding(mesh, P('x'))


  @jax.jit
  def repeat_for_tpu(tensor):
      """
      Repeat tensor from (batch, x, y, channel) to (tpu, pop_size/tpu, x, y, channel)
      """
      expanded = jax.numpy.expand_dims(tensor, axis=0)  # (1, batch, x, y, channel)
      tiled = jax.numpy.repeat(expanded, popsize, axis=0)  # (tpu, batch, x, y, channel)
      reshaped = tiled.reshape(TPU, popsize//TPU, tensor.shape[0], tensor.shape[1], tensor.shape[2], tensor.shape[3])
      return reshaped


@jax.jit
def repeat_for_gpu(tensor):
    """
    Repeat tensor from (batch, x, y, channel) to (tpu, pop_size/tpu, x, y, channel)
    """
    expanded = jax.numpy.expand_dims(tensor, axis=0)  # (1, batch, x, y, channel)
    tiled = jax.numpy.repeat(expanded, popsize, axis=0)  # (popsize, batch, x, y, channel)

    return tiled


#Initializer of the solver
solver = PGPE(
    pop_size=popsize,
    param_size=param_size,
    optimizer=optimizer,
    center_learning_rate=center_learning_rate,
    stdev_learning_rate=stdev_learning_rate,
    seed=1,
    init_stdev=init_stdev,
)


In [None]:
#@title Training Loop

#init variables
time_ = 0.0
time_step = 0.0
i = 0

start_time = time.time()
best_fitness = []

x0 = np.repeat(seed[None, ...], BATCH_SIZE, 0)

if 'tpu' in [device.platform for device in jax.devices()]:
  x0_tiled = repeat_for_tpu(x0)
else:
  x0_tiled = repeat_for_gpu(x0)


while i<training_iterations+1:

  k1, key = jr.split(key)

  #get the parameters from the solver (new populatiopn)
  v_params = solver.ask()

  #reshape the parameters for allowing parallel execution on different devices using pmap
  key, skey = jr.split(key)
  skey_tiled = jr.split(skey, popsize)

  #TPU
  if 'tpu' in [device.platform for device in jax.devices()]:
    #reshape the parameters for allowing parallel execution on different devices using pmap
    pop_per_tpu = popsize // TPU
    v_params = jax.numpy.reshape(v_params, (TPU, pop_per_tpu,-1))
    v_params = jax.device_put(v_params, sharding)
    skey_tiled = jax.numpy.reshape(skey_tiled, (TPU, pop_per_tpu,-1))
    v_params = jax.numpy.reshape(v_params, (TPU, pop_per_tpu, -1))

    #compute the fitness of the parameters
    fitness, x = p_fitness_fn(v_params, x0_tiled, skey_tiled)
    fitness = einops.rearrange(fitness, 't p -> (t p)')


  #GPU/CPU
  else:
    fitness, x = v_fitness_fn(v_params, x0_tiled, skey_tiled)

  best_fitness.append(float(fitness.max().item()))

  #update the solver
  solver.tell(fitness)

  if(i%print_every == 0):
      if 'tpu' in [device.platform for device in jax.devices()]:
        x = einops.rearrange(x, 't p b w h c -> (t p) b w h c')
      best_fitness_arg = jp.argmax(fitness)
      clear_output()
      best_fitness_arg = jp.argmax(fitness)
      visualize_batch(x0, x[best_fitness_arg], i)
      plot_loss(best_fitness)

  time_ = time.time() - start_time
  time_step = time_/(i+1)

  print('\r popsize: %d, step: %d, fitness: %.4f , loss: %.4f , total_time: %.2f , time_step: %.3f'%(popsize, i, best_fitness[-1], -best_fitness[-1],  time_, time_step), end='')

  i+=1



In [None]:
params = solver.ask()[0]
params = unravel_pytree(params)
x0 = np.repeat(seed[None, ...], BATCH_SIZE, 0)
x = x0
frames = []
for _ in range(100):
  k1, key = jr.split(k1, 2)
  x = ca_update.apply(params, x, key)
  frames.append(zoom(to_rgb(x)[0],4))

visualize(frames, 'lizard.gif', 10)
Image('lizard.gif')