In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms.functional as F
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from PIL import Image
import numpy as np
from tqdm import tqdm_notebook, tnrange
import imageio

from base64 import b64encode
import matplotlib.pyplot as plt
import requests
import io
import os
from IPython.display import  HTML, clear_output
import matplotlib.pylab as pl

os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter

In [None]:
vgg16 = models.vgg16(pretrained=True).features
vgg16_model = models.vgg16(pretrained=True).cuda()

def calc_styles(imgs):
  style_layers = [1, 6, 11, 18, 25]  
  mean = torch.tensor([0.485, 0.456, 0.406])[:,None,None]
  std = torch.tensor([0.229, 0.224, 0.225])[:,None,None]
  x = (imgs-mean) / std
  grams = []
  for i, layer in enumerate(vgg16[:max(style_layers)+1]):
    x = layer(x)
    if i in style_layers:
      h, w = x.shape[-2:]
      y = x.clone()  # workaround for pytorch in-place modification bug(?)
      gram = torch.einsum('bchw, bdhw -> bcd', y, y) / (h*w)
      grams.append(gram)
  return grams

def style_loss(grams_x, grams_y):
  loss = 0.0
  for x, y in zip(grams_x, grams_y):
    loss = loss + (x-y).square().mean()
  return loss

def to_nchw(img):
  img = torch.as_tensor(img)
  if len(img.shape) == 3:
    img = img[None,...]
  return img.permute(0, 3, 1, 2)

def class_loss(imgs, class_idx):
    resized_i = F.resize(imgs, size=(224, 224))
    o = vgg16_model(resized_i)
    r = o[:, class_idx]
    return(-r)

In [None]:
class CAModel(nn.Module):
    def __init__(self, n_channels=16, hidden_channels=128, fire_rate=0.5, device=None):
        super().__init__()

        self.fire_rate = 0.5
        self.n_channels = n_channels
        self.device = device or torch.device("cpu")

        # Perceive step
        sobel_filter_ = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]])
        scalar = 8.0

        sobel_filter_x = sobel_filter_ / scalar
        sobel_filter_y = sobel_filter_.t() / scalar
        identity_filter = torch.tensor(
                [
                    [0, 0, 0],
                    [0, 1, 0],
                    [0, 0, 0],
                ],
                dtype=torch.float32,
        )
        filters = torch.stack(
                [identity_filter, sobel_filter_x, sobel_filter_y]
        )  # (3, 3, 3)
        filters = filters.repeat((n_channels, 1, 1))  # (3 * n_channels, 3, 3)
        self.filters = filters[:, None, ...].to(
                self.device
        )  # (3 * n_channels, 1, 3, 3)

        # Update step
        self.update_module = nn.Sequential(
                nn.Conv2d(
                    3 * n_channels,
                    hidden_channels,
                    kernel_size=1,  # (1, 1)
                ),
                nn.ReLU(),
                nn.Conv2d(
                    hidden_channels,
                    n_channels,
                    kernel_size=1,
                    bias=False,
                ),
        )

        with torch.no_grad():
            self.update_module[2].weight.zero_()

        self.to(self.device)

    def perceive(self, x):
        return nn.functional.conv2d(x, self.filters, padding=1, groups=self.n_channels)

    def update(self, x):
        return self.update_module(x)

    @staticmethod
    def stochastic_update(x, fire_rate):
        device = x.device

        mask = (torch.rand(x[:, :1, :, :].shape) <= fire_rate).to(device, torch.float32)
        return x * mask  # broadcasted over all channels

    @staticmethod
    def get_living_mask(x):
        return (
            nn.functional.max_pool2d(
                x[:, 3:4, :, :], kernel_size=3, stride=1, padding=1
            )
            > 0.1
        )
    def seed(self, n, sz=128):
      x = torch.zeros((n, self.n_channels, sz, sz), dtype=torch.float32)
      x[:, 3:, sz // 2, sz // 2] = 1
      return x

    def forward(self, x):
        pre_life_mask = self.get_living_mask(x)

        y = self.perceive(x)
        dx = self.update(y)
        dx = self.stochastic_update(dx, fire_rate=self.fire_rate)

        x = x + dx

        post_life_mask = self.get_living_mask(x)
        life_mask = (pre_life_mask & post_life_mask).to(torch.float32)

        return x * life_mask

In [None]:
def to_rgb(img_rgba):
    rgb, a = img_rgba[:, :3, ...], torch.clamp(img_rgba[:, 3:4, ...], 0, 1)
    return torch.clamp(1.0 - a + rgb, 0, 1)

In [None]:
ca_model = CAModel(device=device)
ca_model = ca_model.to(device)
optimizer = torch.optim.Adam(ca_model.parameters(), lr=2e-3)
lr_sched = torch.optim.lr_scheduler.MultiStepLR(optimizer, [200,700,800,900], 0.4)
loss_log = []

# Pool initialization
with torch.no_grad():
  pool = ca_model.seed(n=128, sz=128).to(device)

batch_size=8

In [None]:
plt.rcParams["figure.figsize"] = (15,5)

for i in range(1000):
  with torch.no_grad():
    batch_idx = np.random.choice(len(pool), batch_size, replace=False)
    x = pool[batch_idx]
    if i%8 == 0:
      x[:1] = ca_model.seed(1)
  step_n = np.random.randint(32, 96)
  x = torch.utils.checkpoint.checkpoint_sequential([ca_model]*step_n, 16, x)
  imgs = to_rgb(x)
  overflow_loss = (x-x.clamp(-1.0, 1.0)).abs().sum()
  loss = torch.mean(class_loss(imgs=imgs, class_idx=1))+overflow_loss
  with torch.no_grad():
    loss.backward()
    for p in ca_model.parameters():
      p.grad /= (p.grad.norm()+1e-8)   # normalize gradients 
    optimizer.step()
    optimizer.zero_grad()
    lr_sched.step()
    pool[batch_idx] = x                # update pool
    
    loss_log.append(loss.item())
    if i%5==0:
      clear_output(True)
      imgs = to_rgb(x[:, :4]).permute([0, 3, 2, 1]).cpu()
      f, axarr = plt.subplots(1,4)
      axarr[0].plot(loss_log[-50:], alpha=0.8)
    #   plt.yscale('log')
    #   pl.ylim(np.min(loss_log), loss_log[0])
      axarr[1].imshow(imgs[0].cpu().detach().numpy())
      axarr[2].imshow(imgs[1].cpu().detach().numpy())
      axarr[3].imshow(imgs[2].cpu().detach().numpy())                      
    #  plt.imshow(np.hstack(imgs)[-3:])
      plt.show()
    if i%5 == 0:
      print('\rstep_n:', len(loss_log),
        ' loss:', loss.item(), 
        ' overflow loss: ', overflow_loss.item(),
        ' lr:', lr_sched.get_lr()[0], end='')