In [None]:
%%capture
!pip install git+https://github.com/openai/CLIP.git
!wget https://cdn2.thecatapi.com/images/c1_w1J682.jpg

In [None]:
!nvidia-smi -L

In [None]:
%%capture
!rm -rf /content/project-omega/ *.py
!git clone https://github.com/Mainakdeb/project-omega.git
!cp /content/project-omega/*.py /content/

In [None]:
from language_model import clip_encode_images, clip_encode_text, get_clip_loss
from nca import ca_model, to_rgb

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms.functional as F
from torchvision import transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from PIL import Image
import numpy as np
from tqdm import tqdm_notebook, tnrange
import imageio

from base64 import b64encode
import matplotlib.pyplot as plt
import requests
import io
import os
from IPython.display import  HTML, clear_output

os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter
import imageio
from google.colab import files

In [None]:
ca = ca_model(chn=12, hidden_n=128).to(device)

opt = torch.optim.Adam(ca.parameters(), 2e-3)
lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, [200,700,800,900], 0.4)
loss_log = []
with torch.no_grad():
  pool = ca.seed(n=256, sz=128).to(device)

torch.set_default_tensor_type('torch.cuda.FloatTensor')
batch_size=4

transform = transforms.Compose([# transforms.RandomHorizontalFlip(p=0.5),
                                #transforms.RandomVerticalFlip(p=0.5),
                                # tansforms.RandomRotation(degrees=30),
                                # A.GridDistortion(.9)
                                ])

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (10, 10)

for i in range(1000):
  with torch.no_grad():
    batch_idx = np.random.choice(len(pool), batch_size, replace=False)

    x = pool[batch_idx]
    if i%8 == 0:
      x[:1] = ca.seed(1).to(device)
  step_n = np.random.randint(32, 96)
  x = torch.utils.checkpoint.checkpoint_sequential([ca]*step_n, 16, x)
  imgs = to_rgb(x)

  overflow_loss = (x-x.clamp(-1.0, 1.0)).abs().sum()
  loss = get_clip_loss("underwater bioluminescence", imgs) + overflow_loss

  with torch.no_grad():
    loss.backward()
    for p in ca.parameters():
      p.grad /= (p.grad.norm()+1e-8)   # normalize gradients 
    opt.step()
    opt.zero_grad()
    lr_sched.step()
    pool[batch_idx] = x                # update pool
    
    loss_log.append(loss.item())
    if i%10==0:
      clear_output(True)
      imgs = to_rgb(x).permute([0, 3, 2, 1]).cpu()
      f, axarr = plt.subplots(2,1)
      axarr[0].imshow(torchvision.utils.make_grid(imgs.permute(0, 3, 1, 2,), nrow=4).cpu().detach().permute(1,2,0).numpy())
      axarr[1].plot(loss_log[-50:], alpha=0.8)
      plt.show()
    if i%10 == 0:
      print('\rstep_n:', len(loss_log),
        ' loss:', loss.item(), 
        ' overflow loss: ', overflow_loss.item(),
        ' lr:', lr_sched.get_lr()[0], end='')
    
    # if i%5 == 0:
    #   pool=transform(pool)

In [None]:
class VideoWriter:
  def __init__(self, filename='./_autoplay.mov', fps=60.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()
    if self.params['filename'] == '_autoplay.mov':
      self.show()

  def show(self, **kw):
      self.close()
      fn = self.params['filename']
      display(mvp.ipython_display(fn, **kw))

def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)
  return img

In [None]:
vid_infer_size=256
with VideoWriter() as vid, torch.no_grad():
  x = ca.seed(1, vid_infer_size)
  for k in tnrange(600, leave=False):
    for i in range(5):
      x[:] = ca(x)
    img = to_rgb(x[0]).permute(1, 2, 0).cpu()
    # vid.add(zoom(img, 4))
    img = np.uint8(img.clip(0, 1)*255)
    vid.add(img)

from google.colab import files
files.download('_autoplay.mov')

In [None]:
infer_sz=256
gif_arr=np.zeros((600, infer_sz,infer_sz, 3))
with torch.no_grad():
  x = ca.seed(1, infer_sz)
  for k in tnrange(600, leave=False):
    for i in range(10):
      x[:] = ca(x)
    img = to_rgb(x[0]).permute(1, 2, 0).cpu()
    gif_arr[k]=img#.permute(2, 0, 1)
    # plt.ims

In [None]:
with imageio.get_writer('lossless.gif', mode='I', fps=60) as writer:
    for out in gif_arr:
        writer.append_data(out)
writer.close()
files.download('lossless.gif')