In [None]:
import torch
from torchvision.io import read_image, ImageReadMode, write_video
import torchvision.transforms as T
from random import randint
from IPython.display import clear_output
import numpy as np
import pylab as pl

from src import *

N_CHANNELS = 16        # Number of CA state channels
TARGET_PADDING = 16    # Number of pixels used to pad the target image border
TARGET_SIZE = 40       # Size of the target emoji
IMAGE_SIZE = TARGET_PADDING+TARGET_SIZE
cuda=torch.device('cuda')

In [None]:
growing=NeuralCA(N_CHANNELS)
regenerating=NeuralCA(N_CHANNELS)
regenerating.device

growing.load_state_dict(torch.load('Pretrained_models/firework_growing.pt', map_location=cuda))
regenerating.load_state_dict(torch.load('Pretrained_models/firework_regenerating.pt', map_location=cuda))

growing.to(cuda)
regenerating.to(cuda)

# Without training the regenerating part
Without any training on the regenerating part the image fails to converge and persist to the final form

In [None]:
switch=5 #the amount of steps before switching from growing to regenerating
video, initial_state=make_video(growing,switch)
_=make_video(regenerating,200,initial_state,fname='switch'+str(switch)+'.mkv', initial_video=video)

# Training the regenerating part
Even training on the regenerating part the image fails to converge and persist to the final form

In [None]:
def distribution(avg,maximum,minimum=0):
  x=np.random.exponential(avg)
  if x>=maximum or x<=minimum: return minimum
  return int(x)

def growing_generator(n_images):
  #seed=make_seed(1,N_CHANNELS,IMAGE_SIZE,cuda).detach()
  out=[growing.evolve(make_seed(1,N_CHANNELS,IMAGE_SIZE,cuda).detach(),distribution(25,60,5))[0].detach() for _ in range(n_images)]
  return torch.stack(out).detach()

#qua succedono cose strane
pool=SamplePool(512,transform=None,device='cuda', generator=growing_generator)

In [None]:
imshow(pool[10])

In [None]:
# Imports the target emoji
target = read_image("images/firework.png", ImageReadMode.RGB_ALPHA).float()
target = T.Resize((TARGET_SIZE, TARGET_SIZE))(target)
target = RGBAtoFloat(target)
target = target.to(cuda)

torch.backends.cudnn.benchmark = True # Speeds up training
optimizer = torch.optim.Adam(regenerating.parameters(), lr=2e-3)
criterion = CustomLoss(pad(target, TARGET_PADDING), torch.nn.MSELoss, 16)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[40,80], gamma=0.3)

In [None]:
#The actual training part
regenerating.train_CA(optimizer, criterion, pool, batch_size=8, n_epochs=50, scheduler=scheduler,skip_update=1, kind="regenerating")

In [None]:
#plot the training graph
pl.plot(regenerating.losses)
pl.xlabel("Epochs")
pl.ylabel("Loss")
pl.show()

In [None]:
#Output the video
switch=3 #the amount of steps before switching from growing to regenerating
video, initial_state=make_video(growing,switch)
_=make_video(regenerating,200,initial_state,fname='switch'+str(switch)+'.mkv', initial_video=video)