# Growing Neural Cellular Automata

This is a reimplementation in PyTorch of the notebook that contains code to reproduce experiments and figures for the ["Growing Neural Cellular Automata"](http://distill.pub/2020/growing-ca) article.

Basically is the same code, just changing the TF implementations to Torch. Is full of bug because of that and my coding skills LOL.

Copyright 2020 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

In [None]:
#@title Imports and Notebook Utilities
# Cell 1 Pytorch

import os
import io
import PIL.Image, PIL.ImageDraw # Making images
import base64 # Convert Images to base64 string
import zipfile
import json 
import requests
import numpy as np
import matplotlib.pylab as pl
import glob # Regex like library

import torch

from IPython.display import Image, HTML, clear_output 
import tqdm # Progress bar

import os
os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
clear_output()

def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return PIL.Image.fromarray(a)

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def im2url(a, fmt='jpeg'):
  encoded = imencode(a, fmt)
  base64_byte_string = base64.b64encode(encoded).decode('ascii')
  return 'data:image/' + fmt.upper() + ';base64,' + base64_byte_string

def imshow(a, fmt='jpeg'):
  display(Image(data=imencode(a, fmt)))

def tile2d(a, w=None):
  a = np.asarray(a)
  if w is None:
    w = int(np.ceil(np.sqrt(len(a))))
  th, tw = a.shape[1:3]
  pad = (w-len(a))%w
  a = np.pad(a, [(0, pad)]+[(0, 0)]*(a.ndim-1), 'constant')
  h = len(a)//w
  a = a.reshape([h, w]+list(a.shape[1:]))
  a = np.rollaxis(a, 2, 1).reshape([th*h, tw*w]+list(a.shape[4:]))
  return a

def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)
  return img

class VideoWriter:
  def __init__(self, filename, fps=30.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()

In [None]:
#@title Cellular Automata Parameters
CHANNEL_N = 16        # Number of CA state channels
CHANNEL_H = 128       # Number of hidden layers
TARGET_PADDING = 16   # Number of pixels used to pad the target image border
TARGET_SIZE = 40
BATCH_SIZE = 8
POOL_SIZE = 1024
CELL_FIRE_RATE = 0.5

TARGET_EMOJI = "😀" #@param {type:"string"}

EXPERIMENT_TYPE = "Regenerating" #@param ["Growing", "Persistent", "Regenerating"]
EXPERIMENT_MAP = {"Growing":0, "Persistent":1, "Regenerating":2}
EXPERIMENT_N = EXPERIMENT_MAP[EXPERIMENT_TYPE]

USE_PATTERN_POOL = [0, 1, 1][EXPERIMENT_N]
DAMAGE_N = [0, 0, 3][EXPERIMENT_N]  # Number of patterns to damage in a batch

In [None]:
#@title CA Model and Utilities { form-width: "401px" }

from torch.nn.modules.module import register_module_forward_hook
#from tensorflow.keras.layers import Conv2D
from torch import conv2d

def load_image(url, max_size=TARGET_SIZE):
    r = requests.get(url)
    img = PIL.Image.open(io.BytesIO(r.content))
    img.thumbnail((max_size, max_size), PIL.Image.ANTIALIAS)
    img = np.float32(img)/255.0
    # premultiply RGB by Alpha
    img[..., :3] *= img[..., 3:]
    return img

def load_emoji(emoji):
    code = hex(ord(emoji))[2:].lower()
    #url='https://emoji.slack-edge.com/T03ACRYJKPB/sgi/fe651d12ef124541.png' SGI 22 emoji
    url = 'https://github.com/googlefonts/noto-emoji/blob/main/png/128/emoji_u%s.png?raw=true'
    return load_image(url)


def to_rgba(x):
    return x[..., :4]

def to_alpha(x):
    #return tf.clip_by_value(x[..., 3:4], 0.0, 1.0)
    return torch.clip(x[..., 3:4], 0.0, 1.0)

def to_rgb(x):
    # assume rgb premultiplied by alpha
    rgb, a = x[..., :3], to_alpha(x)
    return 1.0-a+rgb

def get_living_mask(x):
    alpha = x[:, :, :, 3:4]
    #return tf.nn.max_pool2d(alpha, 3, [1, 1, 1, 1], 'SAME') > 0.1
    alpha_tensor = torch.nn.MaxPool2d(alpha, stride = 1, padding = 1) 
    return torch.where((alpha_tensor > 0.1), alpha_tensor) # Error with multipĺe 'kernel_size' defined in __init__
    # '>' not supported between instances of 'MaxPool2d' and 'float'
    # This function return this both errors, depending of how is written

def make_seed(size, n=1):
    x = np.zeros([n, size, size, CHANNEL_N], np.float32)
    x[:, size//2, size//2, 3:] = 1.0
    return x


class CAModel(torch.nn.Module):
    
    def __init__(self, channel_n=CHANNEL_N, fire_rate=CELL_FIRE_RATE, channel_h=CHANNEL_H):
        super().__init__()
        self.channel_n = channel_n
        self.channel_h = channel_n
        self.fire_rate = fire_rate
        
        #self.dmodel = tf.keras.Sequential([
        #  Conv2D(128, 1, activation=tf.nn.relu),
        #  Conv2D(self.channel_n, 1, activation=None,
        #      kernel_initializer=tf.zeros_initializer),
        #])
        self.dmodel = torch.nn.Sequential(
            torch.nn.Conv2d(
                in_channels = 3*self.channel_n,
                out_channels = self.channel_h,
                kernel_size = 3,
                ),
                torch.nn.ReLU(),
            torch.nn.Conv2d(
                self.channel_h,
                self.channel_n,
                kernel_size = 3,
                bias = False
                )
        )
        self(torch.zeros([1, 3, 3, channel_n]))  # dummy call to build the model
    
    
    #@tf.function
    def perceive(self, x, angle=0.0):
        identify = np.float32([0, 1, 0])
        identify = np.outer(identify, identify)
        dx = np.outer([1, 2, 1], [-1, 0, 1]) / 8.0  # Sobel filter
        dy = dx.T
        c, s = np.cos(angle), np.sin(angle)
        kernel = torch.stack([identify, c*dx-s*dy, s*dx+c*dy], -1)[:, :, None, :]
        kernel = torch.Tensor.repeat(kernel, self.channel_n, 2)
        y = torch.nn.Conv2d(x, kernel, stride=1, padding=1, groups=1)
        return y
        
    #@tf.function
    def forward(self, x, fire_rate=None, angle=0.0, step_size=1.0):
        pre_life_mask = get_living_mask(x)
        
        y = self.perceive(x, angle)
        dx = self.dmodel(y)*step_size
        if fire_rate is None:
            fire_rate = self.fire_rate
            
        update_mask = torch.random(x[:, :, :, :1].shape) <= fire_rate
        x = x + dx * update_mask.type(torch.FloatTensor)
        
        post_life_mask = get_living_mask(x)
        life_mask = pre_life_mask & post_life_mask
        
        return x * life_mask.type(torch.FloatTensor)


CAModel().dmodel.summary()

TypeError: ignored

# Training

In [None]:
#@title Train Utilities (SamplePool, Model Export, Damage)
#This chunck seems to be fine so far

from google.protobuf.json_format import MessageToDict
#from tensorflow.python.framework import convert_to_constants

class SamplePool:
  def __init__(self, *, _parent=None, _parent_idx=None, **slots):
    self._parent = _parent
    self._parent_idx = _parent_idx
    self._slot_names = slots.keys()
    self._size = None
    for k, v in slots.items():
      if self._size is None:
        self._size = len(v)
      assert self._size == len(v)
      setattr(self, k, np.asarray(v))

  def sample(self, n):
    idx = np.random.choice(self._size, n, False)
    batch = {k: getattr(self, k)[idx] for k in self._slot_names}
    batch = SamplePool(**batch, _parent=self, _parent_idx=idx)
    return batch

  def commit(self):
    for k in self._slot_names:
      getattr(self._parent, k)[self._parent_idx] = getattr(self, k)

#@tf.function
def make_circle_masks(n, h, w):
  x = torch.linspace(-1.0, 1.0, w)[None, None, :]
  y = torch.linspace(-1.0, 1.0, h)[None, :, None]
  center = torch.distributions.uniform.Uniform([2, n, 1, 1], -0.5, 0.5)
  r = torch.distributions.uniform.Uniform([n, 1, 1], 0.1, 0.4)
  x, y = (x-center[0])/r, (y-center[1])/r
  mask = (x*x+y*y < 1.0).type(FloatTensor)
  return mask

def export_model(ca, base_fn):
  ca.save_weights(base_fn)

  cf = ca.call.get_concrete_function(
      #x=tf.TensorSpec([None, None, None, CHANNEL_N]),
      x=Tensor([None, None, None, CHANNEL_N]).type(),
      #fire_rate=tf.constant(0.5),
      fire_rate=torch.Tensor(0.5),
      angle=torch.Tensor(0.0),
      step_size=torch.Tensor(1.0))
  cf = convert_to_constants.convert_variables_to_constants_v2(cf)
  graph_def = cf.graph.as_graph_def()
  graph_json = MessageToDict(graph_def)
  graph_json['versions'] = dict(producer='1.14', minConsumer='1.14')
  model_json = {
      'format': 'graph-model',
      'modelTopology': graph_json,
      'weightsManifest': [],
  }
  with open(base_fn+'.json', 'w') as f:
    json.dump(model_json, f)

def generate_pool_figures(pool, step_i):
  tiled_pool = tile2d(to_rgb(pool.x[:49]))
  fade = np.linspace(1.0, 0.0, 72)
  ones = np.ones(72) 
  tiled_pool[:, :72] += (-tiled_pool[:, :72] + ones[None, :, None]) * fade[None, :, None] 
  tiled_pool[:, -72:] += (-tiled_pool[:, -72:] + ones[None, :, None]) * fade[None, ::-1, None]
  tiled_pool[:72, :] += (-tiled_pool[:72, :] + ones[:, None, None]) * fade[:, None, None]
  tiled_pool[-72:, :] += (-tiled_pool[-72:, :] + ones[:, None, None]) * fade[::-1, None, None]
  imwrite('train_log/%04d_pool.jpg'%step_i, tiled_pool)

def visualize_batch(x0, x, step_i):
  vis0 = np.hstack(to_rgb(x0).numpy())
  vis1 = np.hstack(to_rgb(x).numpy())
  vis = np.vstack([vis0, vis1])
  imwrite('train_log/batches_%04d.jpg'%step_i, vis)
  print('batch (before/after):')
  imshow(vis)

def plot_loss(loss_log):
  pl.figure(figsize=(10, 4))
  pl.title('Loss history (log10)')
  pl.plot(np.log10(loss_log), '.', alpha=0.1)
  pl.show()


In [None]:
#@title Choose Target Image { vertical-output: true}
#url = 'https://github.com/google-research/self-organising-systems/blob/master/assets/growing_ca/planaria2_48.png?raw=true'
#target_img = load_image(url, 48)

target_img = load_emoji(TARGET_EMOJI)
imshow(zoom(to_rgb(target_img), 2), fmt='png')


UnidentifiedImageError: ignored

In [None]:
#@title Initialize Training { vertical-output: true}

p = TARGET_PADDING
pad_target = torch.nn.functional.pad(target_img, [(p, p), (p, p), (0, 0)])
h, w = pad_target.shape[:2]
seed = np.zeros([h, w, CHANNEL_N], np.float32)
seed[h//2, w//2, 3:] = 1.0

def loss_f(x):
  return torch.mean(torch.square(to_rgba(x)-pad_target), [-2, -3, -1])

ca = CAModel()

loss_log = []

lr = 2e-3
lr_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    [2000], [lr, lr*0.1]) # I dont know what would be the equivalent in torch
trainer = torch.optim.Adam(lr_sched, zero_grad(set_to_none=False)
)

loss0 = loss_f(seed).numpy()
pool = SamplePool(x=np.repeat(seed[None, ...], POOL_SIZE, 0))

!mkdir -p train_log && rm -f train_log/*

In [None]:
#@title Training Loop {vertical-output: true}

#@tf.function
def train_step(x):
  iter_n = torch.distributions.uniform.Uniform([], 64, 96).to(int32)
  with tf.GradientTape() as g:
      # Dont know the equvalent in torch. Maybe forward?
    for i in torch.range(iter_n):
      x = ca(x)
    loss = torch.mean(loss_f(x), 1)
  grads = g.gradient(loss, ca.weights)
  grads = [g/(torch.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, ca.weights))
  return x, loss

for i in range(8000+1):
  if USE_PATTERN_POOL:
    batch = pool.sample(BATCH_SIZE)
    x0 = batch.x
    loss_rank = loss_f(x0).numpy().argsort()[::-1]
    x0 = x0[loss_rank]
    x0[:1] = seed
    if DAMAGE_N:
      damage = 1.0-make_circle_masks(DAMAGE_N, h, w).numpy()[..., None]
      x0[-DAMAGE_N:] *= damage
  else:
    x0 = np.repeat(seed[None, ...], BATCH_SIZE, 0)

  x, loss = train_step(x0)

  if USE_PATTERN_POOL:
    batch.x[:] = x
    batch.commit()

  step_i = len(loss_log)
  loss_log.append(loss.numpy())
  
  if step_i%10 == 0:
    generate_pool_figures(pool, step_i)
  if step_i%100 == 0:
    clear_output()
    visualize_batch(x0, x, step_i)
    plot_loss(loss_log)
    export_model(ca, 'train_log/%04d'%step_i)

  print('\r step: %d, log10(loss): %.3f'%(len(loss_log), np.log10(loss)), end='')

# Figures

In [None]:
#@title Training Progress (Checkpoints)

models = []
for i in [4000]:
  ca = CAModel()
  ca.load_weights('train_log/%04d'%i)
  models.append(ca)

out_fn = 'train_steps_damage_%d.mp4'%DAMAGE_N
x = np.zeros([len(models), 72, 72, CHANNEL_N], np.float32)
x[..., 36, 36, 3:] = 1.0
with VideoWriter(out_fn) as vid:
  for i in tqdm.trange(1000):
    vis = np.hstack(to_rgb(x))
    vid.add(zoom(vis, 2))
    for ca, xk in zip(models, x):
      xk[:] = ca(xk[None,...])[0]
mvp.ipython_display(out_fn)

In [None]:
#@title Training Progress (Batches)
frames = sorted(glob.glob('train_log/batches_*.jpg'))
mvp.ImageSequenceClip(frames, fps=10.0).write_videofile('batches.mp4')
mvp.ipython_display('batches.mp4')

In [None]:
#@title Pool Contents
frames = sorted(glob.glob('train_log/*_pool.jpg'))[:80]
mvp.ImageSequenceClip(frames, fps=20.0).write_videofile('pool.mp4')
mvp.ipython_display('pool.mp4')

## Pretrained Models and Figures

### Up to this point I have tested the code.
### From here to the end is untested. -Ale

Please run the cell below to download pretrained models that are used to generate the subsequent figures. The figures generated after this are generated using the pretrained CAs.

In [None]:
!wget -O models.zip 'https://github.com/google-research/self-organising-systems/blob/master/assets/growing_ca/models.zip?raw=true'
!unzip -oq models.zip

EMOJI = '🦎😀💥👁🐠🦋🐞🕸🥨🎄'

def get_model(emoji='🦋', fire_rate=0.5, use_pool=1, damage_n=3, run=0,
              prefix='models/', output='model'):
  path = prefix
  assert fire_rate in [0.5, 1.0]
  if fire_rate==0.5:
    path += 'use_sample_pool_%d damage_n_%d '%(use_pool, damage_n)
  elif fire_rate==1.0:
    path += 'fire_rate_1.0 '
  code = hex(ord(emoji))[2:].upper()
  path += 'target_emoji_%s run_index_%d/08000'%(code, run)
  assert output in ['model', 'json']
  if output == 'model':
    ca = CAModel(channel_n=16, fire_rate=fire_rate)
    ca.load_weights(path)
    return ca
  elif output == 'json':
    return open(path+'.json', 'r').read()

In [None]:
#@title Teaser
models = [get_model(emoji, run=1) for emoji in EMOJI]
with VideoWriter('teaser.mp4') as vid:
  x = np.zeros([len(EMOJI), 64, 64, CHANNEL_N], np.float32)
  # grow
  for i in tqdm.trange(200):
    k = i//20
    if i%20==0 and k<len(EMOJI):
      x[k, 32, 32, 3:] = 1.0
    vid.add(zoom(tile2d(to_rgb(x), 5), 2))
    for ca, xk in zip(models, x):
      xk[:] = ca(xk[None,...])[0]
  # damage
  mask = PIL.Image.new('L', (64*5, 64*2))
  draw = PIL.ImageDraw.Draw(mask)
  for i in tqdm.trange(400):
    cx, r = i*3-20, 6
    y1, y2 = 32+np.sin(i/5+np.pi)*8, 32+64+np.sin(i/5)*8
    draw.rectangle((0, 0, 64*5, 64*2), fill=0)
    draw.ellipse((cx-r, y1-r, cx+r, y1+r), fill=255)
    draw.ellipse((cx-r, y2-r, cx+r, y2+r), fill=255)
    x *= 1.0-(np.float32(mask).reshape(2, 64, 5, 64)
        .transpose([0, 2, 1, 3]).reshape(10, 64, 64, 1))/255.0
    if i<200 or i%2 == 0:
      vid.add(zoom(tile2d(to_rgb(x), 5), 2))
    for ca, xk in zip(models, x):
      xk[:] = ca(xk[None,...])[0]
  # fade out
  last = zoom(tile2d(to_rgb(x), 5), 2)
  for t in np.linspace(0, 1, 30):
    vid.add(last*(1.0-t)+t)

mvp.ipython_display('teaser.mp4', loop=True)

In [None]:
#@title Regeneration (trained without damage)
models = [get_model(ch, damage_n=0) for ch in '😀🦋🦎']
with VideoWriter('regen1.mp4') as vid:
  x = np.zeros([len(models), 5, 56, 56, CHANNEL_N], np.float32)
  cx, cy = 28, 28
  x[:, :, cy, cx, 3:] = 1.0
  for i in tqdm.trange(2000):
    if i == 200:
      x[:, 0, cy:] = x[:, 1, :cy] = 0
      x[:, 2, :, cx:] = x[:, 3, :, :cx] = 0
      x[:, 4, cy-8:cy+8, cx-8:cx+8] = 0
    vis = to_rgb(x)
    vis = np.vstack([np.hstack(row) for row in vis])
    vis = zoom(vis, 2)
    if (i < 400 and i%2==0) or i%8 == 0:
      vid.add(vis)
    if i == 200:
      for _ in range(29):
        vid.add(vis)
    for ca, row in zip(models, x):
      row[:] = ca(row)

mvp.ipython_display('regen1.mp4')