## Imports

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
from PIL import Image
import IPython
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
from IPython import display
from pathlib import Path

## Get Image

In [None]:
#Choose image in image folder
img_num = 200

In [None]:
#Get list of images from image folder
img_folder = "emoji-data-master\\emoji-data-master\\img-apple-64"
enteries = os.listdir(img_folder)

In [None]:
#Get image
img = mpimg.imread(img_folder+"\\"+enteries[img_num])
plt.imshow(img)

## Convert image from rgba to rgb

In [None]:
def rgba_to_rgb(img):
    temp_img = Image.fromarray(np.uint8(img*255))
    background = Image.new("RGB",(64,64), (255,255,255))
    background.paste(temp_img,mask=temp_img.split()[3])
    background = np.array(background)
    background = (background - background.min())/(background.max()-background.min())
    return background

In [None]:
img = rgba_to_rgb(img)
plt.imshow(img)

## state to image

In [None]:
def state_to_image(state,num_channels):
    state = state.permute(3,2,1,0)
    state = torch.reshape(state,(64,64,1,num_channels))
    state = torch.reshape(state,(64,64,num_channels))
    state = state.narrow(2,0,3)
    return state

## Build Neural Network

In [None]:
class big_NN(nn.Module):
    def __init__(self, num_channels):
        super(big_NN, self).__init__()
        temp = num_channels*3
        self.fc1 = nn.Linear(temp, temp)
        self.fc2 = nn.Linear(temp, num_channels)
        self.fc3 = nn.LeakyReLU()
        
    def forward(self,x):
        x1 = self.fc3(self.fc1(x))
        x2 = self.fc2(x1)*.01
        return x2

## Initial states

In [None]:
#initial state of all zeros except for one 'seed'
def initial_state(num_channels):
    state = torch.zeros((num_channels,1,64,64)).cuda()
    for channel in range(num_channels):
        for i in range(64):
            for j in range(64):
                if(i == 31 and j == 31):
                    state[0][0][31][31] = 1.
                    state[1][0][31][31] = 1.
                    state[2][0][31][31] = 1.
                    state[3][0][31][31] = 1.
    return state

## Perception

In [None]:
#Return the state concatenated with the convolution of sobel_x and sobel_y on state
def perception(state):
    sobel_x = torch.FloatTensor([[-1,0,1],[-2,0,2],[-1,0,1]]).cuda()
    sobel_y = torch.transpose(sobel_x,0,1).cuda()
    
    sobel_x = sobel_x.view(1,1,3,3)
    sobel_y = sobel_y.view(1,1,3,3)
    
    big_state = F.pad(state,(1,1,1,1),"circular",0).cuda()
    
    grad_x = F.conv2d(big_state, sobel_x,stride=1,padding=0).cuda()
    grad_y = F.conv2d(big_state, sobel_y,stride=1,padding=0).cuda()
    
    temp_state = torch.cat((state,grad_x,grad_y),0)
    
    return temp_state

## Stochastic Update

In [None]:
#stochastically update the state
def stochastic_update(state,state_update):
    state_update = state_update.permute(1,0)
    state_update = torch.reshape(state_update,(num_channels,1,64,64))            #reshape from 9,2 --> 2,1,3,3
    update_mask = torch.FloatTensor(num_channels,1,64,64).uniform_().cuda()>.1   #stochastic update mask
    state_update = torch.mul(state_update,update_mask)
    return state_update

## Alive Masking

In [None]:
#update state where dimension tensor is greater than .1
def alive_masking(state,state_update):
    alive = F.max_pool2d(state[3],kernel_size=3,stride=1,padding=1).cuda()
    state = torch.where((alive>.1),state+state_update,state).cuda()
    return state

## Parameters

In [None]:
num_channels = 16
epochs = 1000
steps = 150

model = big_NN(num_channels)
model = model.cuda()

mse = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

images = []
loss_list = []
img = torch.FloatTensor(img).cuda()

## Train model

In [None]:
for epoch in range(epochs):
    state = initial_state(num_channels).cuda()
    
    #Run neural cellular automata
    for step in range(steps):
        conv_state = perception(state)
        conv_state = torch.reshape(conv_state.permute(3,2,1,0),(1*64*64,num_channels*3))
        state_update = model(conv_state)
        state_update = stochastic_update(state,state_update)

        state = alive_masking(state,state_update)

    temp_state = state_to_image(state,num_channels)
    
    #update model
    loss = mse(temp_state,img)
    loss_list.append(float(loss))
    optimizer.zero_grad()
    model.zero_grad()
    
    loss.backward(retain_graph=True)
    optimizer.step()
    print("Epoch: {}    loss: {}".format(epoch,loss.detach()))

In [None]:
plt.imshow(temp_state.cpu().detach().numpy())

In [None]:
plt.imshow(img.cpu())

In [None]:
plt.plot([i for i in range(len(loss_list))],loss_list)

## Run Cellular automata without training

In [None]:
images = []
state = initial_state(num_channels)
for step in range(steps):
    conv_state = perception(state)
    conv_state = torch.reshape(conv_state.permute(3,2,1,0),(1*64*64,num_channels*3))
    state_update = model(conv_state)
    state_update = stochastic_update(state,state_update)

    state = alive_masking(state,state_update)
    temp_state = state_to_image(state,num_channels)
    images.append(temp_state.cpu().detach().numpy())
    
loss = mse(temp_state,img)
print("Epoch: {}    loss: {}".format(epoch,loss.detach()))
state = initial_state(num_channels) 

## Convert list of images to GIF

In [None]:
image_list = [Image.fromarray(np.uint8(image*100.)) for image in images]
image_list = [image.resize((640,640)) for image in image_list]
image_list[0].save('growing_phase.gif',save_all=True,append_images=image_list[1:],optimize=False,duration=40,loops=10)

## Play GIF

In [None]:
gifPath = Path("growing_phase.gif")
with open(gifPath,'rb') as f:
    display.Image(data=f.read(),format='png')