In [None]:
!nvidia-smi -L

In [None]:
! pip install git+https://github.com/openai/CLIP.git

In [None]:
!wget https://cdn2.thecatapi.com/images/c1_w1J682.jpg


In [4]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms.functional as F

from PIL import Image
import numpy as np
from tqdm import tqdm_notebook, tnrange
import imageio

from base64 import b64encode
import matplotlib.pyplot as plt
import requests
import io
import os
from IPython.display import  HTML, clear_output
import matplotlib.pylab as pl

os.environ['FFMPEG_BINARY'] = 'ffmpeg'
import moviepy.editor as mvp
from moviepy.video.io.ffmpeg_writer import FFMPEG_VideoWriter

In [None]:
import clip
clip.available_models()

model, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

In [None]:
def clip_encode_text(text_str):
    tokenized_text = clip.tokenize(text_str).cuda()
    text_features = model.encode_text(tokenized_text)
    return(text_features)
with torch.no_grad():
    text_features=clip_encode_text("frogs hi")

def clip_encode_images(image_input):
    resized_i = F.resize(image_input, size=(224, 224))
    image_features = model.encode_image(resized_i)
    # print(resized_i.shape)
    return(image_features)

image_input=torch.tensor(plt.imread("/content/c1_w1J682.jpg")).cuda().permute(2,1,0)
with torch.no_grad():
    image_features=clip_encode_images(image_input.unsqueeze(dim=0))

image_input.unsqueeze(dim=0).shape

In [None]:
def get_clip_loss(text_str, images_tensor, loss_func):
    # with torch.no_grad():
    image_features=clip_encode_images(images_tensor.cuda())
    text_features=clip_encode_text(text_str).cuda()

    #normalize
    # image_features /= image_features.norm(dim=-1, keepdim=True)
    # text_features /= text_features.norm(dim=-1, keepdim=True)
    text_features=text_features.repeat(image_features.shape[0], 1)
    similarity = text_features @ image_features.T
    # print("db", similarity)
    return(-torch.mean(similarity))

s=torch.tensor(plt.imread("/content/c1_w1J682.jpg")).permute(2,1,0)
print(s.unsqueeze(dim=0).shape)
get_clip_loss(text_str="image of a brown horse", images_tensor=s.unsqueeze(dim=0), loss_func = nn.MSELoss())


In [None]:
sobel_filter = torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]])
identity_filter = torch.tensor([[0, 0, 0], [0, 1, 0], [0, 0, 0],],dtype=torch.float32)
laplacian_filter = torch.tensor([[1.0,2.0,1.0], [2.0,-12,2.0], [1.0,2.0,1.0]])

def perception(x):
  filters = torch.stack([identity_filter, sobel_filter, sobel_filter.T, laplacian_filter])
  return perchannel_conv(x, filters)

class CA(torch.nn.Module):
  def __init__(self, chn=12, hidden_n=96):
    super().__init__()
    self.chn = chn
    self.w1 = torch.nn.Conv2d(chn*4, hidden_n, 1)
    self.w2 = torch.nn.Conv2d(hidden_n, chn, 1, bias=False)
    self.w2.weight.data.zero_()

  def forward(self, x, update_rate=0.5):
    y = perception(x)
    y = self.w2(torch.relu(self.w1(y)))
    b, c, h, w = y.shape
    udpate_mask = (torch.rand(b, 1, h, w)+update_rate).floor()
    return x+y*udpate_mask

  def seed(self, n, sz=128):
    return torch.randn(n, self.chn, sz, sz)
    
ca = CA()

def to_rgb(x):
  return x[...,:3,:,:]+0.5

param_n = sum(p.numel() for p in CA().parameters())
print('CA param count:', param_n)

def perchannel_conv(x, filters):
  '''filters: [filter_n, h, w]'''
  b, ch, h, w = x.shape
  y = x.reshape(b*ch, 1, h, w)
  y = torch.nn.functional.pad(y, [1, 1, 1, 1], 'circular')
  y = torch.nn.functional.conv2d(y, filters[:,None])
  return y.reshape(b, -1, h, w)

In [48]:
def imread(url, max_size=None, mode=None):
  if url.startswith(('http:', 'https:')):
    # wikimedia requires a user agent
    headers = {
      "User-Agent": "Requests in Colab/0.0 (https://colab.research.google.com/; no-reply@google.com) requests/0.0"
    }
    r = requests.get(url, headers=headers)
    f = io.BytesIO(r.content)
  else:
    f = url
  img = Image.open(f)
  if max_size is not None:
    img.thumbnail((max_size, max_size), Image.ANTIALIAS)
  if mode is not None:
    img = img.convert(mode)
  img = np.float32(img)/255.0
  return img


def imshow(a, fmt='jpeg'):
  display(imencode(a, fmt))

def imencode(a, fmt='jpeg'):
  a = np.asarray(a)
  if len(a.shape) == 3 and a.shape[-1] == 4:
    fmt = 'png'
  f = io.BytesIO()
  imwrite(f, a, fmt)
  return f.getvalue()

def np2pil(a):
  if a.dtype in [np.float32, np.float64]:
    a = np.uint8(np.clip(a, 0, 1)*255)
  return (Image.fromarray(a))

def imwrite(f, a, fmt=None):
  a = np.asarray(a)
  if isinstance(f, str):
    fmt = f.rsplit('.', 1)[-1].lower()
    if fmt == 'jpg':
      fmt = 'jpeg'
    f = open(f, 'wb')
  np2pil(a).save(f, fmt, quality=95)


In [49]:
# style_url = 'https://upload.wikimedia.org/wikipedia/commons/thumb/0/04/Tempera%2C_charcoal_and_gouache_mountain_painting_by_Nicholas_Roerich.jpg/301px-Tempera%2C_charcoal_and_gouache_mountain_painting_by_Nicholas_Roerich.jpg'
# style_img = imread(style_url, max_size=128)
# with torch.no_grad():
#   target_style = calc_styles(to_nchw(style_img))
# # imshow(style_img)

# plt.imshow(style_img)

In [None]:
opt = torch.optim.Adam(ca.parameters(), 4e-3)
lr_sched = torch.optim.lr_scheduler.MultiStepLR(opt, [200,700,800,900], 0.4)
loss_log = []
with torch.no_grad():
  pool = ca.seed(256)

torch.set_default_tensor_type('torch.cuda.FloatTensor')
batch_size=2
torch.autograd.set_detect_anomaly(True)

In [None]:
for i in range(1000):
  with torch.no_grad():
    batch_idx = np.random.choice(len(pool), batch_size, replace=False)

    x = pool[batch_idx]
    if i%8 == 0:
      x[:1] = ca.seed(1)
  step_n = np.random.randint(64, 96)
  x = torch.utils.checkpoint.checkpoint_sequential([ca]*step_n, 16, x)
  imgs = to_rgb(x)

  overflow_loss = (x-x.clamp(-1.0, 1.0)).abs().sum()
  loss = get_clip_loss("Fire and Water", imgs, nn.MSELoss()) + overflow_loss
#   print("LOSS",loss)
  with torch.no_grad():
    loss.backward()
    for p in ca.parameters():
      p.grad /= (p.grad.norm()+1e-8)   # normalize gradients 
    opt.step()
    opt.zero_grad()
    lr_sched.step()
    pool[batch_idx] = x                # update pool
    
    loss_log.append(loss.item())
    if i%10==0:
      clear_output(True)
      pl.plot(loss_log, alpha=0.8)
      pl.yscale('log')
    #   pl.ylim(np.min(loss_log), loss_log[0])
      pl.show()
      imgs = to_rgb(x).permute([0, 3, 2, 1]).cpu()
      plt.imshow(imgs[0].cpu().detach().numpy())
    #   plt.imshow(np.hstack(imgs)[-3:])
      plt.show()
    if i%10 == 0:
      print('\rstep_n:', len(loss_log),
        ' loss:', loss.item(), 
        ' overflow loss: ', overflow_loss.item(),
        ' lr:', lr_sched.get_lr()[0], end='')

In [None]:
imgs=torch.randn(2,3,512,512).cuda()
class_loss(imgs, 5) 

In [59]:
class VideoWriter:
  def __init__(self, filename='./_autoplay.mp4', fps=60.0, **kw):
    self.writer = None
    self.params = dict(filename=filename, fps=fps, **kw)

  def add(self, img):
    img = np.asarray(img)
    if self.writer is None:
      h, w = img.shape[:2]
      self.writer = FFMPEG_VideoWriter(size=(w, h), **self.params)
    if img.dtype in [np.float32, np.float64]:
      img = np.uint8(img.clip(0, 1)*255)
    if len(img.shape) == 2:
      img = np.repeat(img[..., None], 3, -1)
    self.writer.write_frame(img)

  def close(self):
    if self.writer:
      self.writer.close()

  def __enter__(self):
    return self

  def __exit__(self, *kw):
    self.close()
    if self.params['filename'] == '_autoplay.mp4':
      self.show()

  def show(self, **kw):
      self.close()
      fn = self.params['filename']
      display(mvp.ipython_display(fn, **kw))

def zoom(img, scale=4):
  img = np.repeat(img, scale, 0)
  img = np.repeat(img, scale, 1)
  return img

In [None]:
with VideoWriter() as vid, torch.no_grad():
  x = ca.seed(1, 256)
  for k in tnrange(600, leave=False):
    for i in range(1):
      x[:] = ca(x)
    img = to_rgb(x[0]).permute(1, 2, 0).cpu()
    #vid.add(zoom(img, 16))
    vid.add(img)

In [None]:
from google.colab import files
files.download('_autoplay.mp4')