In [176]:
import torch
from copy import deepcopy
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, pmap
from jax.ops import index_update
import numpy as np

In [186]:
class CA:
    def __init__(self, initial_state, update_rule):
        self.size = len(initial_state.shape)
        self.previous_state = jnp.array(initial_state, dtype=int)
        self.current_state = jnp.array(initial_state, dtype=int)
        self.size = initial_state.shape
        self.update_rule = update_rule
        
    
    def init_compile(self):
        self.timeStep = jit(vmap(self.timeStep))
        self.evolve = jit(self.evolve)
        
    def timeStep(self):
        for i in range(self.size[0]):
            for j in range(self.size[1]):
                self.current_state = self.current_state.at[i,j].set(self.update_rule(self.previous_state[i-1:i+2, j-1:j+2]))
        
        self.previous_state = jnp.array(self.current_state, dtype=int)
        
        return self.current_state
        
    
    def evolve(self, t):
        for i in range(t):
            self.timeStep(current_state)
    

def kernelToNeighbourhood(kernel):
    center = [i//2 for i in kernel.shape]
    
    neighbourhood = []
    for i in range(kernel.shape[0]):
        for j in range(kernel.shape[1]):
            if kernel[i, j] == 1:
                neighbourhood.append([i-center[0], j-center[1]])
            
    return jnp.array(neighbourhood)

In [188]:
def update_rule(state):
    return jnp.sum(state) % 3

kernel = torch.ones((3,3), dtype=int)
kernel[1,1] = 0
neighbourhood = kernelToNeighbourhood(kernel)

ca = CA(jnp.ones((100, 10), dtype=int), jit(update_rule))

#ca.init_compile()

In [189]:
ca.timeStep()

DeviceArray([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
             [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
          

In [169]:
ca.current_state

Traced<ShapedArray(int32[3,3])>with<DynamicJaxprTrace(level=0/1)>