<a href="https://colab.research.google.com/github/ADMoreau/High_dimensional_cellular_automata/blob/main/4D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q plotly==4.9
!pip install -q kaleido

import plotly.io as pio
pio.kaleido.scope.defaul_height = "300"
pio.kaleido.scope.defaul_width = "300"

[K     |████████████████████████████████| 12.9MB 12.3MB/s 
[K     |████████████████████████████████| 79.9MB 61kB/s 
[?25h

In [2]:
import numpy as np
import sys 
from skimage.transform import resize
import matplotlib.pyplot as plt
import matplotlib.pylab as pl
import json 
from IPython.display import Image, HTML, clear_output
import plotly.graph_objects as go
from math import sqrt    
import itertools

In [297]:
def distance_dimension(xyz0 = [], xyz1 = []):
    delta_OX = pow(xyz0[0] - xyz1[0], 2)
    delta_OY = pow(xyz0[1] - xyz1[1], 2)
    delta_OZ = pow(xyz0[2] - xyz1[2], 2)
    return sqrt(delta_OX+delta_OY+delta_OZ)    

def voxels_figure(figure = 'sphere', position = [0,0,0], radius = 1):
    xmin, xmax = position[0]-size,  position[0]+size+1
    ymin, ymax = position[1]-size,  position[1]+size+1
    zmin, zmax = position[2]-size,  position[2]+size+1

    voxels = []

    if figure == 'cube':
        for local_z, world_z in zip(range(zmax-zmin), range(zmin, zmax)):
            for local_y, world_y in zip(range(ymax-ymin), range(ymin, ymax)):
                for local_x, world_x in zip(range(xmax-xmin), range(xmin, xmax)):
                    voxels.append([world_x,world_y,world_z])

    elif figure == 'sphere':
        for local_z, world_z in zip(range(zmax-zmin), range(zmin, zmax)):
            for local_y, world_y in zip(range(ymax-ymin), range(ymin, ymax)):
                for local_x, world_x in zip(range(xmax-xmin), range(xmin, xmax)):
                    radius = distance_dimension(xyz0 = [world_x, world_y,world_z], xyz1 = position)
                    if  radius < size:
                        voxels.append([world_x,world_y,world_z])
    return voxels

#get all voxels in a sphere with center 3, 2, 2 and a radius of 2
size = 7
voxels = voxels_figure(figure = 'sphere', position = [3, 2, 2], radius = 2) #working 4d

In [298]:
#Draw these voxels in a cube of zeros, here we are selecting only every third voxel as this improves performance 
sphere = np.zeros(shape=(size, size, size))
for i in range(0, len(voxels), 3):
  temp = voxels[i]
  sphere[temp[0], temp[1], temp[2]] = np.random.random_sample()

In [22]:
#make the voxels grow linearly and limit values between 1 and 0 to create the temporal sequence 
growth = np.ones(shape=(len(voxels)))
seq = np.zeros(shape=(size, size, size, size))
seq[0, ...] = sphere.copy()
for time in range(1, 5):
  #temp_sphere = sphere.copy()
  for i in range(0, len(voxels), 3):
    temp_voxel = voxels[i]
    if sphere[temp_voxel[0], temp_voxel[1], temp_voxel[2]] >= 1.0: growth[i] = -1.0
    if sphere[temp_voxel[0], temp_voxel[1], temp_voxel[2]] <= 0.0: growth[i] = 1.0
    sphere[temp_voxel[0], temp_voxel[1], temp_voxel[2]] += .1 * growth[i]
  seq[time, ...] = sphere.copy()

In [None]:
#simple visualizion using plotly
X, Y, Z = np.mgrid[0:size:1, 0:size:1, 0:size:1]
fig = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=seq[0, ...].flatten(),
    isomin=0.1,
    isomax=1.0,
    opacity=1.0, # needs to be small to see through all surfaces
    surface_count=600, # needs to be a large number for good volume rendering
    ))
#fig.write_image("temporal_original_first.png")
fig.show()

In [None]:
fig = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=seq[-1, ...].flatten(),
    isomin=0.1,
    isomax=1.0,
    opacity=1.0, # needs to be small to see through all surfaces
    surface_count=600, # needs to be a large number for good volume rendering
    ))
#fig.write_image("temporal_original_last.png")
fig.show()

In [103]:
#create the alpha channel
alpha = seq > 0.0

In [300]:
seq = np.stack((seq, alpha), axis = -1)

In [287]:
##@title Cellular Automata Parameters
CHANNEL_N = 16       # Number of CA state channels
BATCH_SIZE = 64
CELL_FIRE_RATE = 0.5
POOL_SIZE = 1024

In [271]:
import tensorflow as tf
from tensorflow import keras

@tf.function
def maxpool4d(data):
    # input, kernel, and output sizes
    (b, wi, zi, yi, xi, c) = data.shape.as_list()
    (wk, zk, yk, xk, ik, ok) = (1, 3, 3, 3, 3, 1)

    # output size and tensor
    wo = wi - wk + 1
    results = [ None ]*wo

    # convolve each kernel frame i with each input frame j
    for i in range(wk):
        for j in range(wi):
        
          # add results to this output frame
          out_frame = j - (i - wk//2) - (wi - wo)//2
          if out_frame < 0 or out_frame >= wo:
              continue

          # convolve input frame j with kernel frame i
          max = tf.nn.max_pool3d(tf.reshape(data[:,:,j,:,:], (b, zi, yi, xi, c)),
                                 3, [1, 1, 1, 1, 1], 'SAME')

          if results[out_frame] is None:
              results[out_frame] = max
          else:
              results[out_frame] += max

    return tf.stack(results, axis=2)

#from https://dev.to/meseta/advent-of-code-day-17-using-more-tensorflow-and-4d-convolution-in-python-3ifd

@tf.function
def conv4d(data, conv_filt):
    # input, kernel, and output sizes
    (b, wi, zi, yi, xi, c) = data.shape.as_list()
    (wk, zk, yk, xk, ik, ok) = conv_filt.shape.as_list()

    # output size and tensor
    wo = wi - wk + 1
    results = [ None ]*wo

    # convolve each kernel frame i with each input frame j
    for i in range(wk):
        for j in range(wi):
        
          # add results to this output frame
          out_frame = j - (i - wk//2) - (wi - wo)//2
          if out_frame < 0 or out_frame >= wo:
              continue

          # convolve input frame j with kernel frame i
          frame_conv3d = tf.nn.convolution(tf.reshape(data[:,:,j,:,:], (b, zi, yi, xi, c)), conv_filt[:,:,:,i])

          if results[out_frame] is None:
              results[out_frame] = frame_conv3d
          else:
              results[out_frame] += frame_conv3d

    return tf.stack(results, axis=2)

class Conv4D(keras.layers.Layer):
    """
    Adaptive convolutional layer for any necessary dimension
    From : https://stackoverflow.com/questions/60782034/how-to-create-a-keras-layer-to-do-a-4d-convolutions-conv4d
    """
    def __init__(self, filters, kernel_size, padding='VALID', kernel_initializer='glorot_uniform', activation=None, bias=True, **kwargs):
      self.filters = filters
      self.kernel_size = kernel_size #must be a tuple!!!!
      self.padding=padding
      self.is_bias = bias
      self.kernel_initializer = keras.initializers.get(kernel_initializer)
      self.activation = keras.activations.get(activation)

      super(Conv4D, self).__init__(**kwargs)

    #using channels last!!!
    def build(self, input_shape):
      spatialDims = len(self.kernel_size)
      allDims = len(input_shape)
      assert allDims == spatialDims + 2 #spatial dimensions + batch size + channels

      kernelShape = self.kernel_size + (input_shape[-1], self.filters)
          #(spatial1, spatial2,...., spatialN, input_channels, output_channels)

      biasShape = tuple(1 for _ in range(allDims-1)) + (self.filters,)


      self.kernel = self.add_weight(name='kernel', 
                                    shape=kernelShape,
                                    initializer='uniform',
                                    trainable=True)
      if self.is_bias == True:
        self.bias = self.add_weight(name='bias', 
                                    shape = biasShape, 
                                    initializer='zeros',
                                    trainable=True)
      self.built = True

    def call(self, inputs):
      if self.kernel_size == (3, 3, 3, 3):
        inputs = tf.pad(inputs, [(0, 0), (1, 1), (1, 1), (1, 1), (1, 1), (0, 0)])
      if self.kernel_size == (5, 5, 5, 5):
        inputs = tf.pad(inputs, [(0, 0), (2, 2), (2, 2), (2, 2), (2, 2), (0, 0)])
      results = conv4d(inputs, self.kernel)
      if self.is_bias == True:
        results += self.bias
      if self.activation is not None:
        return self.activation(results)
      return results

In [None]:
from tensorflow.keras.layers import Conv3D

def get_living_mask(x):
  alpha = x[:, :, :, :, 1:2]
  return tf.nn.max_pool3d(alpha, 3, [1, 1, 1, 1, 1], 'SAME') > 0.1

def make_seed(size, n=1):
  x = np.zeros([n, size, size, CHANNEL_N], np.float32)
  x[:, size//2, size//2, 3:] = 1.0
  return x

class CAModel(tf.keras.Model):

  def __init__(self, channel_n=CHANNEL_N, fire_rate=CELL_FIRE_RATE, dim=4, layers=128):
    super().__init__()
    self.channel_n = channel_n
    self.fire_rate = fire_rate
    self.dim = dim

    if self.dim == 3:
      self.dmodel = tf.keras.Sequential([
            Conv3D(layers, 1, activation=tf.nn.relu, padding='SAME', kernel_initializer='normal'),
            Conv3D(self.channel_n, 1, activation=None,
                kernel_initializer=tf.zeros_initializer),
      ])

      self(tf.zeros([1, 3, 3, 3, channel_n]))  # dummy call to build the model

    if self.dim == 4:

      #create the 4D convolutional layers for the sobel operators and identify mat
      self.identity_conv = Conv4D(1, (3,3,3,3), bias=False)
      self.identity_conv.trainable=False
      self.dx_conv = Conv4D(1, (3,3,3,3), bias=False)
      self.dx_conv.trainable = False
      self.dy_conv = Conv4D(1, (3,3,3,3), bias=False)
      self.dx_conv.trainable = False
      self.dz_conv = Conv4D(1, (3,3,3,3), bias=False)
      self.dx_conv.trainable = False
      self.da_conv = Conv4D(1, (3,3,3,3), bias=False)
      self.dx_conv.trainable = False

      self.dmodel = tf.keras.Sequential([
            Conv4D(layers, (1, 1, 1, 1), activation=tf.nn.relu, padding='SAME'),
            Conv4D(self.channel_n, (1, 1, 1, 1), activation=None,
                kernel_initializer=tf.zeros_initializer),
      ])

      self(tf.zeros([8, 7, 7, 7, 7, channel_n]))  # dummy call to build the model


  @tf.function
  def perceive(self, x, angle=0.0):
    identify = np.float32([[[0,0,0],[0,0,0],[0,0,0]],
                          [[0,0,0],[0,1,0],[0,0,0]],
                          [[0,0,0],[0,0,0],[0,0,0]]])
    
    if self.dim == 4:
      identify = np.float32([[[[0,0,0],[0,0,0],[0,0,0]], [[0,0,0],[0,0,0],[0,0,0]], [[0,0,0],[0,0,0],[0,0,0]]],
                            [[[0,0,0],[0,0,0],[0,0,0]], [[0,0,0],[0,1,0],[0,0,0]], [[0,0,0],[0,0,0],[0,0,0]]],
                            [[[0,0,0],[0,0,0],[0,0,0]], [[0,0,0],[0,0,0],[0,0,0]], [[0,0,0],[0,0,0],[0,0,0]]]])
      self.identity_conv.kernel = tf.reshape(identify, [3, 3, 3, 3, 1, 1])
      
    #3 dimensional sobel filter
    dx = np.float32([[[-1,-2,-1],[0,0,0],[1,2,1]],
                    [[-2,-4,-2],[0,0,0],[2,4,2]],
                    [[-1,-2,-1],[0,0,0],[1,2,1]]]) 
    #4 dimensional sobel filter
    if self.dim == 4:
      temp = np.zeros(shape=(3, 3, 3, 3))
      for d in range(3):
        for i, j, k in itertools.product(range(3), range(3), range(3)):
          temp[:, i, j, k] = dx[i, j, k] * np.array([1, 2, 1])
      dx = temp.astype('float32')
      self.dx_conv.kernel = tf.reshape(dx, [3, 3, 3, 3, 1, 1]) / 8.0
    
    dx /= 8.0  # Sobel filter

    dy = np.float32([[[-1,-2,-1],[-2,-4,-2],[-1,-2,-1]],
                    [[0,0,0],[0,0,0],[0,0,0]],
                    [[1,2,1],[2,4,2],[1,2,1]]]) 

    if self.dim == 4:
      temp = np.zeros(shape=(3, 3, 3, 3))
      for d in range(3):
        for i, j, k in itertools.product(range(3), range(3), range(3)):
          temp[:, i, j, k] = dy[i, j, k] * np.array([1, 2, 1])
      dy = temp.astype('float32')
      self.dy_conv.kernel = tf.reshape(dy, [3, 3, 3, 3, 1, 1]) / 8.0
    
    dy /= 8.0

    dz = np.float32([[[-1,0,1],[-2,0,2],[-1,0,1]],
                    [[-2,0,2],[-4,0,4],[-2,0,2]],
                    [[-1,0,1],[-1,0,1],[-1,0,1]]]) 
    
    if self.dim == 4:
      temp = np.zeros(shape=(3, 3, 3, 3))
      for d in range(3):
        for i, j, k in itertools.product(range(3), range(3), range(3)):
          temp[:, i, j, k] = dz[i, j, k] * np.array([1, 2, 1])
      dz = temp.astype('float32')
      self.dz_conv.kernel = tf.reshape(dz, [3, 3, 3, 3, 1, 1]) / 8.0

    dz /= 8.0

    if self.dim == 4:
      da = np.transpose(dx, axes=[2, 3, 0, 1]).astype('float32')
      self.da_conv.kernel = tf.reshape(da, [3, 3, 3, 3, 1, 1])

    #c, s, z = tf.cos(angle), tf.sin(angle), 0 #assume no rotation now
    if self.dim == 3:
      kernel = tf.stack([identify, dx, dy, dz], -1)[:, :, :, None, :]
    #if self.dim == 4:
    #  kernel = tf.stack([identify, dx, dy, dz, da], -1)[:, :, :, :, None, :]
    
    if self.dim == 3:
      kernel = tf.repeat(kernel, self.channel_n, 2)
      inputs = tf.unstack(x, axis = -1)
      filters = tf.unstack(kernel, axis = -1)

      y = tf.concat([tf.nn.conv3d(tf.expand_dims(i, axis = -1),
                                  tf.expand_dims(f, axis = -1),
                                  strides=[1,1,1,1,1], padding='SAME')
                      for i, f in zip(inputs, filters)], axis = -1)
    if self.dim == 4:
      batch, a, b, c, d, chan = x.shape
      x = tf.reshape(x, [batch * chan, a, b, c, d, 1])
      id_out = tf.reshape(self.identity_conv(x), [batch, a, b, c, d, chan])
      dx_out = tf.reshape(self.dx_conv(x), [batch, a, b, c, d, chan])
      dy_out = tf.reshape(self.dy_conv(x), [batch, a, b, c, d, chan])
      dz_out = tf.reshape(self.dz_conv(x), [batch, a, b, c, d, chan])
      da_out = tf.reshape(self.da_conv(x), [batch, a, b, c, d, chan])

      y = tf.concat([id_out, dx_out, dy_out, dz_out, da_out], axis = -1)
    return y

  @tf.function
  def call(self, x, fire_rate=None, angle=0.0, step_size=1.0):
    if self.dim == 3:
      pre_life_mask = get_living_mask(x)
    if self.dim == 4:
      alpha = x[:, :, :, :, :, 1:2]
      pre_life_mask = maxpool4d(alpha) > 0.1

    y = self.perceive(x, angle)
    dx = self.dmodel(y)*step_size
    if fire_rate is None:
      fire_rate = self.fire_rate
    if self.dim == 3:
      update_mask = tf.random.uniform(tf.shape(x[:, :, :, :, :1])) <= fire_rate
    if self.dim == 4:
      update_mask = tf.random.uniform(tf.shape(x[:, :, :, :, :, :1])) <= fire_rate
  
    x += dx * tf.cast(update_mask, tf.float32)

    if self.dim == 3:
      post_life_mask = get_living_mask(x)
    if self.dim == 4:
      alpha = x[:, :, :, :, :, 1:2]
      post_life_mask = maxpool4d(alpha) > 0.1

    life_mask = pre_life_mask & post_life_mask
    return x * tf.cast(life_mask, tf.float32)


CAModel().dmodel.summary()

In [177]:
from google.protobuf.json_format import MessageToDict
from tensorflow.python.framework import convert_to_constants


class SamplePool:
  def __init__(self, *, _parent=None, _parent_idx=None, **slots):
    self._parent = _parent
    self._parent_idx = _parent_idx
    self._slot_names = slots.keys()
    self._size = None
    for k, v in slots.items():
      if self._size is None:
        self._size = len(v)
      assert self._size == len(v)
      setattr(self, k, np.asarray(v))

  def sample(self, n):
    idx = np.random.choice(self._size, n, False)
    batch = {k: getattr(self, k)[idx] for k in self._slot_names}
    batch = SamplePool(**batch, _parent=self, _parent_idx=idx)
    return batch

  def commit(self):
    for k in self._slot_names:
      getattr(self._parent, k)[self._parent_idx] = getattr(self, k)

def export_model(ca, base_fn, loss):
  ca.save_weights(base_fn)
  np.save('temporal_loss.npy', loss)

def plot_loss(loss_log):
  pl.figure(figsize=(10, 4))
  pl.title('Loss history')
  pl.plot(np.log10(loss_log), '.', alpha=1.0)
  pl.show()

In [131]:
#save sequence
np.save('temporal_seq.npy', seq)

In [306]:
pad_target = tf.convert_to_tensor(seq, dtype=tf.float32)
a, h, w, d = pad_target.shape[:4]
seed = np.zeros([a, h, w, d, CHANNEL_N], np.float32)
seed[a//2, h//2, w//2, d//2, 1:] = 1.0

def loss_f(x):
  return tf.math.reduce_mean(tf.square(x[..., 0:2]-pad_target), [-2, -3, -4, -5, -1])

ca = CAModel(dim=4, layers=128)

loss_log = []

lr = 2e-4
lr_sched = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
    [2000, 5000], [lr, lr*0.1, lr*0.01])
trainer = tf.keras.optimizers.Adam(lr_sched)

loss0 = loss_f(seed).numpy()
pool = SamplePool(x=np.repeat(seed[None, ...], POOL_SIZE, 0))

In [None]:
@tf.function
def train_step(x):
  iter_n = tf.random.uniform([], 32, 64, tf.int32)
  with tf.GradientTape() as g:
    for i in tf.range(iter_n):
      x = ca(x)
    loss = tf.reduce_mean(loss_f(x))
  grads = g.gradient(loss, ca.weights)
  grads = [g/(tf.norm(g)+1e-8) for g in grads]
  trainer.apply_gradients(zip(grads, ca.weights))
  return x, loss

while len(loss_log) < 10001:
  if USE_PATTERN_POOL:
    batch = pool.sample(BATCH_SIZE)
    x0 = batch.x
    loss_rank = loss_f(x0).numpy().argsort()[::-1]
    x0 = x0[loss_rank]
    x0[:1] = seed
  else:
    x0 = np.repeat(seed[None, ...], BATCH_SIZE, 0)

  x, loss = train_step(x0)

  if USE_PATTERN_POOL:
    batch.x[:] = x
    batch.commit()

  step_i = len(loss_log)
  loss_log.append(loss.numpy())

  if step_i%100 == 0:
    clear_output()
    plot_loss(loss_log)
    export_model(ca, '%04d'%step_i, loss_log)
  
  print('\r step: %d, log loss: %f'%(len(loss_log), np.log10(loss)), end='')

In [None]:
#find the mse for the outputs
for i in range(4):
  generated = x[i, ...].numpy()
  generated = generated[..., 0]
  mse = (np.square(generated - seq[..., 0])).mean()
  print(mse)

In [277]:
generated = x[0, ...].numpy()
generated = generated[..., 0]

In [None]:
#visualize outputs

X, Y, Z = np.mgrid[0:5:1, 0:5:1, 0:5:1]

fig = go.Figure(data=go.Volume(
    x=X.flatten(),
    y=Y.flatten(),
    z=Z.flatten(),
    value=generated[0, ...].flatten(),
    isomin=0.1,
    isomax=1.0,
    opacity=1.0, 
    surface_count=600, 
    ))
#fig.write_image("teporal_generated_first.png")
fig.show()