In [None]:
%%capture
!pip install git+https://github.com/openai/CLIP.git

In [None]:
!nvidia-smi -L

In [None]:
%%capture
!rm -rf /content/project-omega/ *.py
!git clone https://github.com/Mainakdeb/project-omega.git
!cp /content/project-omega/*.py /content/

In [None]:
from language_model import clip_encode_images, clip_encode_text, get_clip_loss
from nca import ca_model, to_rgb
from video_utils import create_inference_video, show_video, create_inference_gif

In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms.functional as F
from torchvision import transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

from PIL import Image
import numpy as np
from tqdm import tqdm_notebook, tnrange

import matplotlib.pyplot as plt
import requests
import io
import os
from IPython.display import clear_output

os.environ['FFMPEG_BINARY'] = 'ffmpeg'
from google.colab import files

In [None]:
ca = ca_model(chn=12, hidden_n=128).to(device)

opt = torch.optim.Adam(ca.parameters(), 2e-3)
lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, [200,700,800,900], 0.4)
loss_log = []
with torch.no_grad():
  pool = ca.seed(n=256, sz=128).to(device)

torch.set_default_tensor_type('torch.cuda.FloatTensor')
batch_size=4

transform = transforms.Compose([# transforms.RandomHorizontalFlip(p=0.5),
                                #transforms.RandomVerticalFlip(p=0.5),
                                # tansforms.RandomRotation(degrees=30),
                                # A.GridDistortion(.9)
                                ])

In [None]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (10, 10)

for i in range(500):
  with torch.no_grad():
    batch_idx = np.random.choice(len(pool), batch_size, replace=False)

    x = pool[batch_idx]
    if i%8 == 0:
      x[:1] = ca.seed(1).to(device)
  step_n = np.random.randint(64, 96)
  x = torch.utils.checkpoint.checkpoint_sequential([ca]*step_n, 16, x)
  imgs = to_rgb(x)

  overflow_loss = (x-x.clamp(-1.0, 1.0)).abs().sum()
  loss = get_clip_loss("zebra skin", imgs) + overflow_loss

  with torch.no_grad():
    loss.backward()
    for p in ca.parameters():
      p.grad /= (p.grad.norm()+1e-8)   # normalize gradients 
    opt.step()
    opt.zero_grad()
    lr_sched.step()
    pool[batch_idx] = x                # update pool
    
    loss_log.append(loss.item())
    if i%10==0:
      clear_output(True)
      imgs = to_rgb(x).permute([0, 3, 2, 1]).cpu()
      f, axarr = plt.subplots(2,1)
      axarr[0].imshow(torchvision.utils.make_grid(imgs.permute(0, 3, 1, 2,), nrow=4).cpu().detach().permute(1,2,0).numpy())
      axarr[1].plot(loss_log[-50:], alpha=0.8)
      plt.show()
    if i%10 == 0:
      print('\rstep_n:', len(loss_log),
        ' loss:', loss.item(), 
        ' overflow loss: ', overflow_loss.item(),
        ' lr:', lr_sched.get_lr()[0], end='')
    
    # if i%5 == 0:
    #   pool=transform(pool)

In [None]:
vid_path = create_inference_video(ca_model=ca,
                                  size=64,
                                  num_frames=300, 
                                  steps_per_frame=20, 
                                  filename="test_vid.mov")
show_video(vid_path)

In [None]:
gif_path = create_inference_gif(ca_model=ca,
                                  size=64,
                                  num_frames=120, 
                                  steps_per_frame=10, 
                                  fps=20,
                                  filename="test_gif.gif")

files.download(gif_path) 