<a href="https://colab.research.google.com/github/afiaka87/clip-guided-diffusion/blob/main/colab_clip_guided_diff_hq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Generates images from text prompts with CLIP guided diffusion.

By Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses OpenAI's 256x256 unconditional ImageNet diffusion model (https://github.com/openai/guided-diffusion) together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images. 

In [None]:
# @title Licensed under the MIT License

# Copyright (c) 2021 Katherine Crowson

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

In [None]:
#@title Check the GPU status

!nvidia-smi

In [None]:
#@title Install dependencies

!git clone https://github.com/openai/CLIP
!git clone https://github.com/openai/guided-diffusion
!pip install -e ./CLIP
!pip install -e ./guided-diffusion
%pip install kornia

In [None]:
#@title Download the pretrained 256px diffusion checkpoint. Rerun if download fails.

!wget --continue 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'

In [None]:
#@title Imports

import math
import sys

from IPython import display
from PIL import Image
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision.transforms import functional as TF
from tqdm.notebook import tqdm

sys.path.append('./CLIP')
sys.path.append('./guided-diffusion')

import clip
from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
import kornia.augmentation as K

In [None]:
#@title Transforms
color_jitter = True #@param{type: 'boolean'}
color_jitter_two = False #@param{type: 'boolean'}
elastic_transform = False #@param{type: 'boolean'}
sharpness = False #@param{type: 'boolean'}
gaussian_noise = False #@param{type: 'boolean'}
perspective = False #@param{type: 'boolean'}
rotation = False #@param{type: 'boolean'}
affine = False #@param{type: 'boolean'}
thin_plate_spline = False #@param{type: 'boolean'}
crop = False #@param{type: 'boolean'}
erasing = True #@param{type: 'boolean'}
resizedcrop = False #@param{type: 'boolean'}
solarize = False #@param{type: 'boolean'}




In [None]:
#@title `MakeCutouts` implementation

class MakeCutouts(nn.Module):
    def __init__(self, cut_size, cutn, cut_pow=1.):
        super().__init__()
        self.cut_size = cut_size
        self.cutn = cutn
        self.cut_pow = cut_pow

        augment_list = []
        # Parametrization of the augmentations and new augmentations taken from <https://github.com/nerdyrodent/VQGAN-CLIP>, thanks to @nerdyrodent.
        if color_jitter:
            augment_list.append(K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.05, hue=0.05, p=0.5))
        if color_jitter_two:
            augment_list.append(K.ColorJitter(hue=0.1, saturation=0.1, p=0.7))
        if elastic_transform:
            augment_list.append(K.RandomElasticTransform(p=0.7))
        if sharpness:
            augment_list.append(K.RandomSharpness(sharpness=0.4, p=0.7))
        if gaussian_noise:
            augment_list.append(K.RandomGaussianNoise(mean=0.0, std=1., p=0.7))
        if perspective:
            augment_list.append(K.RandomPerspective(distortion_scale=0.7, p=0.7))
        if rotation:
            augment_list.append(K.RandomRotation(degrees=15, p=0.7))
        if affine:
            augment_list.append(K.RandomAffine(degrees=15, translate=0.1, p=0.7, padding_mode='border'))
        if thin_plate_spline:
            augment_list.append(K.RandomThinPlateSpline(scale=0.3, same_on_batch=False, p=0.7))
        if crop:
            augment_list.append(K.RandomCrop(size=(self.cut_size,self.cut_size), p=0.5))
        if erasing:
            augment_list.append(K.RandomErasing((.1, .4), (.3, 1/.3), same_on_batch=True, p=0.7))
        if resizedcrop:
            augment_list.append(K.RandomResizedCrop(size=(self.cut_size,self.cut_size), scale=(0.1,1),  ratio=(0.75,1.333), cropping_mode='resample', p=1.0))
        if solarize:
            augment_list.append(K.RandomSolarize(0.01, 0.01, p=0.7))
                
        self.augs = nn.Sequential(*augment_list)

    def forward(self, input):
        sideY, sideX = input.shape[2:4]
        max_size = min(sideX, sideY)
        min_size = min(sideX, sideY, self.cut_size)
        cutouts = []
        for _ in range(self.cutn):
            size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
            offsetx = torch.randint(0, sideX - size + 1, ())
            offsety = torch.randint(0, sideY - size + 1, ())
            cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
            cutout = F.interpolate(cutout, (self.cut_size, self.cut_size),
                                   mode='bilinear', align_corners=False)
            cutouts.append(cutout)
        return self.augs(torch.cat(cutouts))

In [None]:
#@title Utility Functions
def spherical_dist_loss(x, y):
    x = F.normalize(x, dim=-1)
    y = F.normalize(y, dim=-1)
    return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)


def tv_loss(input):
    """L2 total variation loss, as in Mahendran et al."""
    input = F.pad(input, (0, 1, 0, 1), 'replicate')
    x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
    y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
    return (x_diff**2 + y_diff**2).mean([1, 2, 3])

# from github.com/mehdidc/feed_forward_vqgan_clip
def tokenize_lines(line_separated, out="tokenized.pkl"):
    """Save each line of `line_separated` as a CLIP embed. Save CLIP embeds to single file for usage later on. """
    texts = line_separated.splitlines()
    T = clip.tokenize(texts, truncate=True)
    torch.save(T, out)

def tokenize_imagenet():
    !wget --continue --quiet "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt" 
    imgnet_idx_lbl = open('imagenet1000_clsidx_to_labels.txt').read().splitlines()
    clean_captions = []
    for idx_label in imgnet_idx_lbl:
        imgnet_lbl = re.sub(r'\d+: ', '', line).replace(" '","").replace(" '","").replace("  ", " ").replace("'", '').split(',') # get rid of 'digit: ' then get rid of weird spaces then ' then ,
        if imgnt_lbl[1] is '':
            clean_captions.append(imgnt_lbl[0])
        else:
            clean_captions.append(imgnt_lbl[1])

 
def imagenet_class_line_to_caption():
    !wget --continue --quiet "https://gist.githubusercontent.com/yrevar/942d3a0ac09ec9e5eb3a/raw/238f720ff059c1f82f368259d1ca4ffa5dd8f9f5/imagenet1000_clsidx_to_labels.txt" 
    imagenet_class_labels = open('imagenet1000_clsidx_to_labels.txt').read().splitlines()
    line = random.choice(imagenet_class_labels)
    labels_in_line = re.sub(r'\d+: ', '', line).replace(" '","").replace(" '","").replace("  ", " ").replace("'", '').split(',') # get rid of 'digit: ' then get rid of weird spaces then ' then ,
    print("prompt will be ignored and a random class will be chosen instead. Set `random_imagenet_class` to False and re-run this cell if that is not desired.")
    if labels_in_line[1] is '':
        return labels_in_line[0]
    return " ".join(labels_in_line)

In [None]:
#@title Model settings
#@markdown `diffusion_steps` - Total number of steps for diffusion. Increase for slower runtime but greater quality.
diffusion_steps = 1000 #@param {type: 'integer'}

#@markdown `timestep_respacing` - less than or equal to `diffusion_steps`.  
#@markdown Map sampling to a lower number of steps. Decrease for faster runtimes with decrease in quality. 
#@markdown (optional) use `ddim` sampling.
timestep_respacing = "ddim250" #@param ["25", "50", "100", "250", "500", "1000", "ddim25", "ddim100", "ddim250", "ddim500"] {allow-input: true}

model_config = model_and_diffusion_defaults()
model_config.update({
    'attention_resolutions': '32, 16, 8',
    'class_cond': False,
    'diffusion_steps': diffusion_steps,
    'rescale_timesteps': True,
    'timestep_respacing': timestep_respacing,
    'image_size': 256,
    'learn_sigma': True,
    'noise_schedule': 'linear',
    'num_channels': 256,
    'num_head_channels': 64,
    'num_res_blocks': 2,
    'resblock_updown': True,
    'use_fp16': True,
    'use_scale_shift_norm': True,
})
#@title Load `guided-diffusion` and `clip` models
#@markdown - `ViT-B/32` is quite good.
#@markdown - `RN50x16` uses a ton of VRAM and is only slightly better than `ViT-B/16`
#@markdown - `ViT-B/16` is a tad slower but higher quality.
clip_model_name = 'RN50x4' #@param ["ViT-B/16", "ViT-B/32", "RN50", "RN101", "RN50x4", "RN50x16"]
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

model, diffusion = create_model_and_diffusion(**model_config)
model.load_state_dict(torch.load('256x256_diffusion_uncond.pt', map_location='cpu'))
model.requires_grad_(False).eval().to(device)
for name, param in model.named_parameters():
    if 'qkv' in name or 'norm' in name or 'proj' in name:
        param.requires_grad_()
if model_config['use_fp16']:
    model.convert_to_fp16()

clip_model = clip.load(clip_model_name, jit=False)[0].eval().requires_grad_(False).to(device)
clip_size = clip_model.visual.input_resolution
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                 std=[0.26862954, 0.26130258, 0.27577711])



In [None]:
#@title Settings for the run
prompt = 'Psychedelic Rock Album Cover' #@param {type: 'string'}
#@markdown `Or...`
random_imagenet_class = False #@param {type: 'boolean'}
if random_imagenet_class is True:
    prompt = imagenet_class_line_to_caption()
#@markdown 
#@markdown `cutn` increasing seems to help for certain prompts but has diminishing returns for many. Uses more VRAM.
cutn =  32 #@param {type: 'integer'}
batch_size = 1 #@param {type: 'integer'}
clip_guidance_scale = 1000 #@param {type: 'integer'}
tv_scale = 100 #@param {type: 'integer'}
seed = 0 #@param {type: 'integer'}

print(f"Using prompt: '{prompt}'")
print(f"batch size: {batch_size}, clip_guidance_scale: {clip_guidance_scale}, tv_scale: {tv_scale}, cutn: {cutn}, seed: {seed}")

In [None]:
#@title Actually do the run
if seed is not None:
    torch.manual_seed(seed)

text_embed = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()

make_cutouts = MakeCutouts(clip_size, cutn)

cur_t = diffusion.num_timesteps - 1

def cond_fn(x, t, y=None):
    with torch.enable_grad():
        x = x.detach().requires_grad_()
        n = x.shape[0]
        my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
        out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=True, model_kwargs={'y': y})
        fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
        x_in = out['pred_xstart'] * fac + x * (1 - fac)
        clip_in = normalize(make_cutouts(x_in.add(1).div(2)))
        image_embeds = clip_model.encode_image(clip_in).float().view([cutn, n, -1])
        dists = spherical_dist_loss(image_embeds, text_embed.unsqueeze(0))
        losses = dists.mean(0)
        tv_losses = tv_loss(x_in)
        loss = losses.sum() * clip_guidance_scale + tv_losses.sum() * tv_scale
        return -torch.autograd.grad(loss, x)[0]

if model_config['timestep_respacing'].startswith('ddim'):
    sample_fn = diffusion.ddim_sample_loop_progressive
else:
    sample_fn = diffusion.p_sample_loop_progressive

samples = sample_fn(
    model,
    (batch_size, 3, model_config['image_size'], model_config['image_size']),
    clip_denoised=True,
    model_kwargs={},
    cond_fn=cond_fn,
    progress=True,
)

for i, sample in enumerate(samples):
    cur_t -= 1
    if i % 100 == 0 or cur_t == -1:
        print()
        for j, image in enumerate(sample['pred_xstart']):
            filename = f'progress_{j:05}.png'
            TF.to_pil_image(image.add(1).div(2).clamp(0, 1)).save(filename)
            tqdm.write(f'Step {i}, output {j}:')
            display.display(display.Image(filename))
