In [1]:
import torch
from torchvision.io import read_image, ImageReadMode, write_video
import torchvision.transforms as T
from random import randint
from IPython.display import clear_output
import numpy as np
import pylab as pl

from src import *

N_CHANNELS = 15        # Number of CA state channels
TARGET_PADDING = 8     # Number of pixels used to pad the target image border
TARGET_SIZE = 40       # Size of the target emoji
IMAGE_SIZE = TARGET_PADDING+TARGET_SIZE
BATCH_SIZE = 16        # can be much bigger
POOL_SIZE = 512
CELL_FIRE_RATE = 0.5
N_EPOCHS = 100

torch.backends.cudnn.benchmark = True # Speeds up things

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Imports the target emoji
target = read_image("images/firework.png", ImageReadMode.RGB_ALPHA).float()
target = T.Resize((TARGET_SIZE, TARGET_SIZE))(target)
target = RGBAtoFloat(target)
imshow(target)
target = target.to(device)

In [3]:
#import the models
growing=NeuralCA(N_CHANNELS)
regenerating=NeuralCA(N_CHANNELS)

#TODO: INSERT PRETAINED MODELS
growing.load_state_dict(torch.load('Pretrained_models/firework_growing.pt', map_location=device))
regenerating.load_state_dict(torch.load('Pretrained_models/firework_regenerating.pt', map_location=device))

growing.to(device)
regenerating.to(device)

NeuralCA(
  (layers): Sequential(
    (0): Conv2d(48, 128, kernel_size=(1, 1), stride=(1, 1))
    (1): ReLU()
    (2): Conv2d(128, 16, kernel_size=(1, 1), stride=(1, 1))
  )
)

In [None]:
#create the sample pool

def distribution(avg,min,max):
  x=np.random.exponential(avg)
  if x>=max or x<=min: return min
  return int(x)

def apply_virus(image:torch.Tensor, probabiblity=0.1) -> torch.Tensor:
    filt=torch.rand_like(image[0])<probability
    image[-1]=image[-2]*filt
    image[-2]=image[-2]*(~filt)
    return image

def virus_generator(n_images:int , nsteps=50,probability=0.1) -> torch.Tensor:
  out=torch.zeros(n_images,N_CHANNELS,IMAGE_SIZE,IMAGE_SIZE)
  for i in range(n_images):
      grow=growing.evolve(make_seed(1,N_CHANNELS,IMAGE_SIZE,cuda),nsteps)[0].detach()
      grow=apply_virus(grow,probability)
      out[i]=grow
  return out.detach()

def generator(n_images):
  return virus_generator(n_images,distrubution(30,5,60),probability=0.1)

pool=SamplePool(POOL_SIZE,transform=None,device='cuda', generator=generator)

imshow(pool[10])

In [None]:
# Imports the target emoji
target = read_image("images/firework.png", ImageReadMode.RGB_ALPHA).float()
target = T.Resize((TARGET_SIZE, TARGET_SIZE))(target)
target = RGBAtoFloat(target)
target = target.to(cuda)

#torch.backends.cudnn.benchmark = True # Speeds up training
optimizer = torch.optim.Adam(regenerating.parameters(), lr=2e-3)
criterion = CustomLoss(pad(target, TARGET_PADDING), torch.nn.MSELoss, 16)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40,80], gamma=0.3)

In [None]:
#The actual training part
regenerating.train_CA(
    optimizer,
    criterion,
    pool,
    batch_size=BATCH_SIZE,
    n_epochs=N_EPOCHS,
    scheduler=scheduler,
    skip_update=1,
    kind="regenerating",
    n_max_losses=BATCH_SIZE // 4)

In [None]:
#plot the training graph
pl.plot(regenerating.losses)
pl.xlabel("Epochs")
pl.ylabel("Loss")
pl.show()

In [None]:
#Output the video
switch=3 #the amount of steps before switching from growing to regenerating
video, initial_state=make_video(growing,switch)
_=make_video(regenerating,200,initial_state,fname='switch'+str(switch)+'.mkv', initial_video=video)