# CLIP is all you need (c) [crumb](https://twitter.com/aicrumb)
Greatly inspired by [this tweet](https://twitter.com/aicrumb/status/1448351059957764096/photo/1) and all the CLIP guided approaches.

What if we directly optimize the raw image tensor using CLIP instead of tuning a generator network or its inputs? 
Just like style transfer algos were doing 5 years ago :D

by [sxela](https://github.com/Sxela)

this notebook's repo: [github](https://github.com/Sxela/CLIPguidedRGB)

tip me: [paypal](http://paypal.me/sx3la)


In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson, Alexander Spirin

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

# Vanilla CLIP guided RBG
Slow but fancy

In [None]:
#installation. run once
!git clone https://github.com/openai/CLIP
!pip install -e ./CLIP -qq

In [None]:
#imports. run once or after restart
%cd CLIP
import torch
from torchvision.transforms import *
import clip
import PIL
import torch.nn.functional as F
from tqdm.notebook import trange
device='cuda'
model = clip.load('ViT-B/32',jit=False)[0].eval().requires_grad_(False).to(device)

In [None]:
#define functions. run once or after restart
def get_sizes(sz, min_sz=32):
  szs = [sz]
  while True:
    if sz<=min_sz: return sorted(szs)
    if sz%2==0:
      sz = sz//2
      szs.append(sz)
    else: return sorted(szs)
  return sorted(szs)

def make_crop(img, ratio, max_cut=224, min_cut=0.2):
  w, h = img.shape[2:]
  min_sz = min(w,h)
  if min_cut<1: min_cut = int(min_sz*min_cut)
  crop_size = int(min(max(ratio*min_sz, min_cut), max_cut, min_sz))

  w_offset = int(torch.rand(1)*(w-crop_size))
  h_offset = int(torch.rand(1)*(h-crop_size))

  cropped = img[:,:,w_offset:w_offset+crop_size,h_offset:h_offset+crop_size]
  return f(cropped)

def get_crops(img, ratios, max_cut, min_cut):
  return torch.cat([make_crop(img, ratio.item(), max_cut, min_cut) for ratio in ratios])

def show_img(t):
    img = PIL.Image.fromarray((t.permute(0,2,3,1)*127.5+128).clamp(0,255).to(torch.uint8)[0].cpu().numpy(),'RGB')
    display(img)

def range_loss(input):
    #taken from this colab https://colab.research.google.com/drive/1QBsaDAZv8np29FPbvjffbE1eytoJcsgA#scrollTo=YHOj78Yvx8jP
    return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])  

def tv_loss(input):
    #taken from this colab https://colab.research.google.com/drive/1QBsaDAZv8np29FPbvjffbE1eytoJcsgA#scrollTo=YHOj78Yvx8jP
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])

def fit(t, size, steps=1000, ncut=8, max_sz=224, min_sz=32, use_weighted_ratios=True):
  z2 = F.interpolate(t, (size,size), mode='bicubic')
  t = z2.detach().clone().requires_grad_(True)
  show_img(t)
  opt=torch.optim.Adam([t],lr=lr)
  for i in trange(steps):
    opt.zero_grad()

    ratios = [torch.ones(1).cuda()]
    for j in range(ncut):
      ratios.append(torch.rand(1).cuda())
    ratios = torch.cat(ratios)
    crops = get_crops(t, ratios, max_sz, min_sz)
    loss_avg = 0.
    loss = 0.
    weighted_ratios = ratios/ratios.sum() if use_weighted_ratios else torch.ones_like(ratios).to(device)

    embeds = model.encode_image(crops)

    for embed, ratio in zip(embeds, weighted_ratios):
      x = F.normalize(embed, dim=-1)
      loss+=torch.sqrt(criterion(x, y))*ratio

    for crop, ratio in zip(crops, weighted_ratios):
      loss+=range_loss(crop[None,...]).sum()*range_loss_w
      loss+=tv_loss(crop[None,...]).sum()*tv_loss_w

    loss.backward()
    opt.step()
    loss_avg = loss if loss_avg==0. else (loss_avg*loss_lerp+loss*(1-loss_lerp))
    if i % 100 == 0: 
      print(loss_avg.item())
    if i % 500 == 0: 
      show_img(t)
  show_img(t)
  return t

criterion = torch.nn.MSELoss()
f=Compose([Resize(224),
          Lambda(lambda x:torch.clamp((x+1)/2,0,1)),
          RandomGrayscale(p=.2),
          Lambda(lambda x: x+torch.randn_like(x)*0.01)])


In [None]:
#set parameters and train
prompt = 'a landscape containing knights riding on the horizon by Greg Rutkowski' #text prompt
seed = 0 
torch.manual_seed(seed)

tv_loss_w = 0 #increase to reduce image noise. 0.01 is a good start
range_loss_w = 0 #increase to reduce image burn. 150 is a good start 

szs = get_sizes(1024, 64); print(szs) #getting sizes
steps = [2000]*len(szs) #getting number of steps per size
cuts = [8,8,8,16,24] #number of image cuts for CLIP loss per image size

#max_sz 64 and min_sz 0.2 produce highly detailed crisp abstract patterns with lots of objects
#max_sz 224 and min_sz 48 produce blurry image with fewer objects (you can experiment with those)
max_szs=[64]*len(szs) #max cut size (pixels)
min_szs=[0.2]*len(szs) #min cut size (pixel or image size ratio)

lr=1e-2
loss_lerp = 0.6 #used for display only

encoded_prompt = model.encode_text(clip.tokenize(prompt).to(device))
y = F.normalize(encoded_prompt, dim=-1)

#init image, can be replaced with a photo
z=torch.rand((1,3,szs[0],szs[0]),device=device,requires_grad=True)

for size, step, cut, max_sz, min_sz in zip(szs, steps, cuts, max_szs, min_szs):
  print(size, step, cut)
  z = fit(z, size, steps=step, ncut=cut, max_sz=max_sz, min_sz=min_sz)

# CLIP guided Point Cloud (experimental)

In [None]:
#installation. run once
#taken from https://github.com/facebookresearch/pytorch3d/blob/main/docs/tutorials/render_colored_points.ipynb
!git clone https://github.com/openai/CLIP
!pip install -e ./CLIP -qq

import os
import sys
import torch
need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True
if need_pytorch3d:
    if torch.__version__.startswith("1.9") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{torch.__version__[0:5:2]}"
        ])
        !pip install pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
        !tar xzf 1.10.0.tar.gz
        os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

In [None]:
#imports. run once or after restart
%cd CLIP
import torch
from torchvision.transforms import *
import clip
import PIL
import torch.nn.functional as F
from tqdm.notebook import trange
device='cuda'
model = clip.load('ViT-B/32',jit=False)[0].eval().requires_grad_(False).to(device)

#taken from https://github.com/facebookresearch/pytorch3d/blob/main/docs/tutorials/render_colored_points.ipynb

import os
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

# Util function for loading point clouds|
import numpy as np

# Data structures and functions for rendering
from pytorch3d.structures import Pointclouds
from pytorch3d.vis.plotly_vis import AxisArgs, plot_batch_individually, plot_scene
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVOrthographicCameras, 
    PointsRasterizationSettings,
    PointsRenderer,
    PulsarPointsRenderer,
    PointsRasterizer,
    AlphaCompositor,
    NormWeightedCompositor
)

from google.colab import files

In [None]:
#define functions. run once or after restart
def get_sizes(sz, min_sz=32):
  szs = [sz]
  while True:
    if sz<=min_sz: return sorted(szs)
    if sz%2==0:
      sz = sz//2
      szs.append(sz)
    else: return sorted(szs)
  return sorted(szs)

def make_crop(img, ratio, max_cut=224, min_cut=0.2):
  w, h = img.shape[2:]
  min_sz = min(w,h)
  if min_cut<1: min_cut = int(min_sz*min_cut)
  crop_size = int(min(max(ratio*min_sz, min_cut), max_cut, min_sz))

  w_offset = int(torch.rand(1)*(w-crop_size))
  h_offset = int(torch.rand(1)*(h-crop_size))

  cropped = img[:,:,w_offset:w_offset+crop_size,h_offset:h_offset+crop_size]
  return f(cropped)

def get_crops(img, ratios, max_cut, min_cut):
  return torch.cat([make_crop(img, ratio.item(), max_cut, min_cut) for ratio in ratios])

def show_img(vertex, rgb, renderer):
    img = render(vertex, rgb, renderer)
    display(PIL.Image.fromarray((img[0]*255).clamp(0,255).detach().cpu().numpy().astype('uint8'),'RGB'))

def save_render(render, path):
    PIL.Image.fromarray((render[0]*255).clamp(0,255).detach().cpu().numpy().astype('uint8'),'RGB').save(path)

def range_loss(input):
    #taken from this colab https://colab.research.google.com/drive/1QBsaDAZv8np29FPbvjffbE1eytoJcsgA#scrollTo=YHOj78Yvx8jP
    return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])  

def tv_loss(input):
    #taken from this colab https://colab.research.google.com/drive/1QBsaDAZv8np29FPbvjffbE1eytoJcsgA#scrollTo=YHOj78Yvx8jP
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])

def get_img_loss(img, ncut=8, max_sz=224, min_sz=32, use_weighted_ratios=True):
    ratios = [torch.ones(1).cuda()]
    img = img.permute(0,3,1,2)
    for j in range(ncut):
      ratios.append(torch.rand(1).cuda())
    ratios = torch.cat(ratios)
    crops = get_crops(img, ratios, max_sz, min_sz)

    loss = 0.
    weighted_ratios = ratios/ratios.sum() if use_weighted_ratios else torch.ones_like(ratios).to(device)
    embeds = model.encode_image(crops)

    for embed, ratio in zip(embeds, weighted_ratios):
      x = F.normalize(embed, dim=-1)
      loss+=torch.sqrt(criterion(x, y))*ratio

    for crop, ratio in zip(crops, weighted_ratios):
      loss+=range_loss(crop[None,...]).sum()*range_loss_w
      loss+=tv_loss(crop[None,...]).sum()*tv_loss_w
        
    return loss

def render(var_h, rgb, renderer):
  p = Pointclouds(points=[var_h], features=[rgb])
  return renderer(p)[...,:3]

def fit(var_h, rgb, steps=1000, ncut=8, max_sz=224, min_sz=32, use_weighted_ratios=True, renderer=None):

  show_img(var_h, rgb, renderer)
  opt=torch.optim.Adam([{'params': rgb,'lr': lr}, {'params': var_h,'lr': lr_h}])
  loss_avg = 0.
  for i in trange(steps):
    opt.zero_grad()

    if (rotation_step!=0) & (i % rotation_step == 0):
      #rotate the camera every rotation_step steps
      R, T = look_at_view_transform(20, 10, (i//rotation_step)%359)
      cameras = FoVOrthographicCameras(device=device, R=R, T=T, znear=0.01)
      renderer = PointsRenderer(
        rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings),
        compositor=NormWeightedCompositor(background_color=(0,0,0)))

    img = render(var_h, rgb, renderer)
    loss = 0.
    loss+=get_img_loss(img, ncut=ncut, max_sz=max_sz, min_sz=min_sz, use_weighted_ratios=use_weighted_ratios)

    loss.backward()
    opt.step()

    loss_avg = loss if loss_avg==0. else (loss_avg*loss_lerp+loss*(1-loss_lerp))
    
    if i % 100 == 0: 
      print('loss_avg', loss_avg.item())
    if i % 300 == 0: 
      show_img(var_h, rgb, renderer)
    if i % rotation_step == 0:
      save_render(img, f'/content/out/{i//rotation_step:05d}.jpg')
      
  show_img(var_h, rgb, renderer)
  return var_h, rgb

criterion = torch.nn.MSELoss()
f=Compose([Resize(224),
          Lambda(lambda x:torch.clamp(x,0,1)),
          RandomGrayscale(p=.2),
          Lambda(lambda x: x+torch.randn_like(x)*0.01)])

loss_lerp = 0.6 #used for display only

In [None]:
!mkdir /content/out
!rm -rf /content/out/*

#set parameters and train
prompt = '8bit pokemon #pixelart' #text prompt
rotation_step = 100 #rotate the camera for 1 degree every N steps. Set to 0 for no rotation.
size = 256 #image size
step = 2000 #number of steps

seed = 0 
torch.manual_seed(seed)

tv_loss_w = 0.002 #increase to reduce image noise. 0.01 is a good start
range_loss_w = 150 #increase to reduce image burn. 150 is a good start 

lr = 1e-2 
lr_h = 1e-2

#max_sz 64 and min_sz 0.2 produce highly detailed crisp abstract patterns with lots of objects
#max_sz 224 and min_sz 48 produce blurry image with fewer objects (you can experiment with those)
max_sz=224 #max cut size (pixels)
min_sz=64 #min cut size (pixel or image size ratio)
cut = 8 #number of image cuts for CLIP loss

sz = 64 #point cloud density
var_h = torch.randn((sz*sz*sz,3)).cuda().div(3.).requires_grad_(True)
rgb = torch.randn((var_h.shape[0],4)).cuda().requires_grad_(True)

encoded_prompt = model.encode_text(clip.tokenize(prompt).to(device))
y = F.normalize(encoded_prompt, dim=-1)

#set up carema angle
R, T = look_at_view_transform(20, 10, 0)
cameras = FoVOrthographicCameras(device=device, R=R, T=T, znear=0.01)
#set up camera image size
raster_settings = PointsRasterizationSettings(
      image_size=size, 
      radius = 0.003,
      points_per_pixel = 1
  )
  #set up camera
renderer = PointsRenderer(
      rasterizer=PointsRasterizer(cameras=cameras, raster_settings=raster_settings),
      compositor=NormWeightedCompositor(background_color=(0,0,0)))

var_h, rgb = fit(var_h, rgb, steps=step, ncut=cut, max_sz=max_sz, min_sz=min_sz, renderer=renderer)

video_name = f"/content/video-{prompt.replace(' ','_')}_size{size}_maxsz{max_sz}.mp4"
!ffmpeg -pattern_type glob -i '/content/out/*.jpg' {video_name}
files.download(video_name)