In [None]:
import sys
sys.path.append('..')
from NCA.trainer.NCA_trainer import NCA_Trainer
from Common.utils import load_emoji_sequence
from Common.eddie_indexer import index_to_data_nca_type
from NCA.trainer.data_augmenter_nca import DataAugmenter
from NCA.model.NCA_model import NCA
from einops import rearrange
import time
import jax
import jax.numpy as np
import optax
import matplotlib.pyplot as plt


## Define all the variables

In [None]:
CHANNELS = 16           # How many channels to use in the model
TRAINING_STEPS = 1000   # How many steps to train for
DOWNSAMPLE = 4          # How much to downsample the image by
NCA_STEPS = 32          # How many NCA steps between each image in the data sequence

## Create a Neural Cellular Automata model
 - There are a few important parameters:
    - `KERNEL_STR` defines what sort of spatial kernels we use
    - `ACTIVATION` accepts any scalar function, acting as the neural network nonlinearity
    - `PADDING` must be "CIRCULAR", "REFLECT" , "REPLICATE" or "ZEROS" - this controls how to handle the borders of the image
    - `FIRE_RATE` must be between 0 and 1 - it is the probability of updating each pixel at each step

In [None]:
model = NCA(N_CHANNELS=CHANNELS,
            KERNEL_STR=["ID","GRAD","LAP"],
            ACTIVATION=jax.nn.relu,
            PADDING="CIRCULAR",
            FIRE_RATE=0.5)


## Load Data
 - Here we load an individual image from `demo_data/` , and create an initial condition of one seed pixel
 - `load_emoji_sequence` takes a list of strings like `["file_1","file_2",...]` and returns:
   
    an array of shape `[Batch, Timestep, Channels, Width, Height]`, where:
      - `Batch` is currently 1 - this matters more later if we want to train to multiple images at the same time
      - `Timestep` is the length of the input list
      - `Channels` is typically 3 or 4 for colour channels
      - `Width` and `Height` are for the image size

 -  We also use the `DataAugmenter` class, defined in `NCA.trainer.data_augmenter_nca.py`
    - This has a few useful functions for modifying the data during training to produce better results
    - This also adds extra hidden channels to an image
    - By creating subclasses of `DataAugmenter` we can define what behaviour to apply to data during training
         - In this example we just pad the data with extra zeros around the boundary

In [None]:
data = load_emoji_sequence(["crab.png"],impath_emojis="demo_data/",downsample=DOWNSAMPLE)

initial_condition = np.zeros(data[:,:1].shape)
W = initial_condition.shape[-2]
H = initial_condition.shape[-1]
initial_condition = initial_condition.at[0,0,:,W//2,H//2].set(1)

data = np.concatenate([initial_condition,data],axis=1) # Join initial condition and data along the time axis
print("(Batch, Time, Channels, Width, Height): "+str(data.shape))
plt.imshow(rearrange(data,"() T C W H -> W (T H) C" )[...,:3])
plt.show()


class data_augmenter_subclass(DataAugmenter):
    #Redefine how data is pre-processed before training
    def data_init(self,SHARDING=None):
        data = self.return_saved_data()
        data = self.pad(data, 10) 		
        self.save_data(data)
        return None

## Define the trainer object
- The `NCA_Trainer` takes as input the `model`, the `data` and a reference to the `DataAugmenter` class (or a custom subclass)
    - It also takes a `model_filename` for saving the output
- `NCA_Trainer` also logs a lot of training statistics using tensorboard, instructions to read that are below

In [None]:
trainer = NCA_Trainer(NCA_model=model,
                      data = data,
                      DATA_AUGMENTER=data_augmenter_subclass,
                      model_filename="test_grow_crab")

## Training
- Run the following code cell first, then follow these instructions to view how the training is progressing

### Evaluating training:
- In the terminal, run the following command:

`tensorboard --samples_per_plugin images=200 --logdir logs/test_grow_crab/train/`

- Where `test_grow_crab` is the model filename we supplied when defining the `NCA_Trainer`

- Then, open your browser and go to: `http://localhost:6006/`

In [None]:
trainer.train(t=NCA_STEPS,iters=TRAINING_STEPS)