# CLIP Guided Diffusion (Text-to-Image)

By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. 

Creates images of 1024x1024px with a 4 x Superresolution step added by Thomash. Could be a little slower if turned on.

In [None]:
text_input = 'pollination by Ernst Haeckel' #@param {type: "string"}
super_resolution = True   #@param {type: "boolean"}
output_path = "/content/output"


In [None]:
#@title Upscale images/video frames

import os.path as osp
import glob
import cv2
import numpy as np
import torch

import requests
import imageio
import requests
import warnings
import gdown


loaded_upscale_model = False
upscale_device = None
upscale_model = None
def upscale(path):
  global loaded_upscale_model, upscale_device, upscale_model
  
  if not loaded_upscale_model:

    print("Loading superresolution model")

    !git clone https://github.com/xinntao/ESRGAN
    %cd ESRGAN
    import RRDBNet_arch as arch
    print("Downloading Super-Resolution model")
    output1 = 'RRDB_ESRGAN_x4.pth'
    print ('Downloading RRDB_ESRGAN_x4.pth')
    gdown.download('https://drive.google.com/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene', output1,quiet=False) 


    warnings.filterwarnings("ignore")

    Choose_device = "cuda" 
    model_path = 'RRDB_ESRGAN_x4.pth'

    upscale_device = torch.device(Choose_device) 


    upscale_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
    upscale_model.load_state_dict(torch.load(model_path), strict=True)
    upscale_model.eval()
    upscale_model = upscale_model.to(upscale_device)

    print('Model path {:s}. \nTesting...'.format(model_path))
    
    %cd -
    loaded_upscale_model = True
  img = cv2.imread(path, cv2.IMREAD_COLOR)
  img = img * 1.0 / 255
  img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
  img_LR = img.unsqueeze(0)
  img_LR = img_LR.to(upscale_device)

  print("4x upscaling", path)
  with torch.no_grad():
      output = upscale_model(img_LR).data.squeeze().float().cpu().clamp_(0, 1).numpy()
  
  output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
  output = (output * 255.0).round()
  cv2.imwrite(path, output, [int(cv2.IMWRITE_JPEG_QUALITY), 70])
  print("Done upscaling")


In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
# Check the GPU status
%cd /content
!mkdir -p $output_path
!nvidia-smi

In [None]:
# Install dependencies

!git clone https://github.com/openai/CLIP
!git clone https://github.com/openai/guided-diffusion
!pip install -e ./CLIP
!pip install -e ./guided-diffusion

In [None]:
# Download the diffusion model

!wget -N 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'

In [None]:
# Imports

import math
import sys

from IPython import display
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')

import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults

In [None]:
# Define necessary functions

class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutout = F.interpolate(cutout, (self.cut_size, self.cut_size),
                                   mode='bilinear', align_corners=False)
            cutouts.append(cutout)
        return torch.cat(cutouts)


def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])


In [None]:
# Model settings

model_config = model_and_diffusion_defaults()
model_config.update({
    'attention_resolutions': '32, 16, 8',
    'class_cond': False,
    'diffusion_steps': 1000,
    'rescale_timesteps': True,
    'timestep_respacing': '1000',
    'image_size': 256,
    'learn_sigma': True,
    'noise_schedule': 'linear',
    'num_channels': 256,
    'num_head_channels': 64,
    'num_res_blocks': 2,
    'resblock_updown': True,
    'use_fp16': True,
    'use_scale_shift_norm': True,
})

In [None]:
# Load models

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(device)
for name, param in model.named_parameters():
    if 'qkv' in name or 'norm' in name or 'proj' in name:
        param.requires_grad_()
if model_config['use_fp16']:
    model.convert_to_fp16()

clip_model = clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                 std=[0.26862954, 0.26130258, 0.27577711])


## Settings for this run:

In [None]:
prompt = text_input
batch_size = 1
clip_guidance_scale = 1000
tv_scale = 100
cutn = 16
seed = 0

### Actually do the run...

In [None]:
if seed is not None:
    torch.manual_seed(seed)

text_embed = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()

make_cutouts = MakeCutouts(clip_size, cutn)

cur_t = diffusion.num_timesteps - 1

def cond_fn(x, t, y=None):
    with torch.enable_grad():
        x = x.detach().requires_grad_()
        n = x.shape[0]
        my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
        out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
        fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
        x_in = out['pred_xstart'] * fac + x * (1 - fac)
        clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
        image_embeds = clip_model.encode_image(clip_in).float().view([cutn, n, -1])
        dists = spherical_dist_loss(image_embeds, text_embed.unsqueeze(0))
        losses = dists.mean(0)
        tv_losses = tv_loss(x_in)
        loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale
        return -torch.autograd.grad(loss, x)[0]

if model_config['timestep_respacing'].startswith('ddim'):
    sample_fn = diffusion.ddim_sample_loop_progressive
else:
    sample_fn = diffusion.p_sample_loop_progressive

samples = sample_fn(
    model,
    (batch_size, 3, model_config['image_size'], model_config['image_size']),
    clip_denoised=False,
    model_kwargs={},
    cond_fn=cond_fn,
    progress=False,
)

for i, sample in enumerate(samples):
    cur_t -= 1
    if i % 4 == 0 or cur_t == -1:
        print()
        for j, image in enumerate(sample['pred_xstart']):
            filename = f'{output_path}/progress_{j:01}_{i:05}.jpg'
            TF.to_pil_image(image.add(1).div(2).clamp(0, 1)).save(filename)
            tqdm.write(f'Step {i}, output {j}:')
            print(f"step_{i}", file=sys.stderr)
            if super_resolution:
              upscale(filename)
            #display.display(display.Image(filename))


In [None]:
#@title Render Video

out_file=output_path+"/video.mp4"
#!rm /content/*.mp4
last_frame=!ls -t $output_path/*.jpg | head -1
last_frame = last_frame[0]
!cp -v $last_frame $output_path/0000.jpg

!ffmpeg  -r 10 -i $output_path/%*.jpg -y -c:v libx264 -profile:v high -pix_fmt yuv420p -level:v 4.0 /tmp/vid_no_audio.mp4
!ffmpeg -i /tmp/vid_no_audio.mp4 -f lavfi -i anullsrc -c:v copy -c:a aac -shortest -y "$out_file"
#ffmpeg -i input.mp4 -f lavfi -c:v copy -c:a aac -shortest output.mp4
#!cp -v /tmp/video.mp4 "$out_file"
#!rm /content/taming-transformers/*.png
print("Written", out_file)
!sleep 2
