### Install CLIP

In [None]:
%%capture
!pip install git+https://github.com/openai/CLIP.git

### Check for GPU

In [None]:
!nvidia-smi -L

## Download Code

In [None]:
%%capture
!rm -rf /content/project-omega/ *.py
!git clone https://github.com/Mainakdeb/project-omega.git
!cp /content/project-omega/*.py /content/

## Imports

In [None]:
from language_model import clip_encode_images, clip_encode_text, get_clip_loss
from nca import ca_model, to_rgb
from video_utils import create_inference_video, show_video, create_inference_gif

import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms.functional as F
from torchvision import transforms
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_default_tensor_type('torch.cuda.FloatTensor')

from PIL import Image
import numpy as np
from tqdm import tqdm_notebook, tnrange
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (10, 10)
import requests
import io
import os
from IPython.display import clear_output

os.environ['FFMPEG_BINARY'] = 'ffmpeg'
from google.colab import files

## Set params

In [None]:
# for CA model
total_channels = 12
hidden_filters = 128

# training
text_prompt = "Leopard Skin"
learning_rate = 2e-3
batch_size = 4
lr_decay_checkpoints = [300, 500, 850]
decay_factor = 0.4
train_iterations = 1000

# training pool
pool_size = 256
image_size = 128

# NCA model
ca = ca_model(chn=total_channels, hidden_n=hidden_filters).to(device)

# optimizer
opt = torch.optim.Adam(ca.parameters(), learning_rate)

# learning rate scheduler
lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, lr_decay_checkpoints, decay_factor)

# array to log loss
loss_log = []

#initiate pool
pool = ca.seed(n=pool_size, sz=image_size).to(device)

# experimental pool transforms
transform = transforms.Compose([# transforms.RandomHorizontalFlip(p=0.5),
                                #transforms.RandomVerticalFlip(p=0.5),
                                # tansforms.RandomRotation(degrees=30),
                                # A.GridDistortion(.9)
                                ])

## Train

In [None]:
for i in range(train_iterations):
  with torch.no_grad():

    # get ixs to sample 
    batch_idx = np.random.choice(len(pool), batch_size, replace=False)

    #get sample from pool
    x = pool[batch_idx]

  # periodically replace element in pool with noise tensor 
  if i%8 == 0:
    x[:1] = ca.seed(1, image_size).to(device)

  # step_n: number of times the batch is passed through the model -> random integer
  step_n = np.random.randint(64, 96)

  # 1 forward pass
  x = torch.utils.checkpoint.checkpoint_sequential([ca]*step_n, 16, x)

  # convert images to RGB for loss computation (consider only first 3 chaannels  of 12 channels)
  imgs = to_rgb(x)

  # compute overflow loss  
  overflow_loss = (x-x.clamp(-1.0, 1.0)).abs().sum() 

  # 2 compute loss
  loss = get_clip_loss(text_prompt, imgs) + overflow_loss

  # 3 clean gradients
  opt.zero_grad()

  # 4 accumulate partial derivatives
  loss.backward()

  # normalize gradients 
  with torch.no_grad():
    for p in ca.parameters():
      p.grad /= (p.grad.norm()+1e-8)   

  # 5 step in opposite direction of gradient
  opt.step()
  
  # step learning rate scheduler
  lr_sched.step()

  # update pool (replace sampled batch)
  pool[batch_idx] = x                
    
  #log loss for plot  
  loss_log.append(loss.item())

  # periodically visualize loss and current batch
  if i%10==0:
    clear_output(True)
    imgs = to_rgb(x).permute([0, 3, 2, 1]).cpu()
    f, axarr = plt.subplots(2,1)
    img_grid = torchvision.utils.make_grid(imgs.permute(0, 3, 2, 1,), nrow=4).cpu().detach().permute(1,2,0).numpy()
    axarr[0].imshow(img_grid)
    axarr[0].set_title('batch')
    axarr[0].axis('off')
    axarr[1].plot(loss_log[-20:], alpha=0.8)
    axarr[1].set_xlabel('step:'+str(len(loss_log))+' loss:'+str(loss.item())+' overflow loss: '+str(overflow_loss.item()))
    plt.show()

In [None]:
vid_path = create_inference_video(ca_model=ca,
                                  size=128,
                                  num_frames=300, 
                                  steps_per_frame=5, 
                                  filename="test_vid.mov")
show_video(vid_path)

In [None]:
gif_path = create_inference_gif(ca_model=ca,
                                  size=128,
                                  num_frames=120, 
                                  steps_per_frame=10, 
                                  fps=20,
                                  filename="test_gif.gif")

files.download(gif_path) 