In [25]:
import numpy as np
import torch
import torch.nn.functional as F

import datasets, diffusers, torchvision

from datasets import load_dataset
from diffusers import DDIMScheduler, DDPMPipeline
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm


device = ("mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)

In [27]:
image_generation_pipeline = diffusers.DDPMPipeline.from_pretrained(pretrained_model_name_or_path = "johnowhitaker/sd-class-wikiart-from-bedrooms", )
image_generation_pipeline.to(device);

# Create new scheduler and set num inference steps
scheduler = diffusers.DDIMScheduler.from_pretrained("johnowhitaker/sd-class-wikiart-from-bedrooms")

Fetching 4 files: 100%|██████████| 4/4 [00:24<00:00,  6.09s/it]
Loading pipeline components...: 100%|██████████| 2/2 [00:00<00:00,  6.78it/s]


## Color Guidance

In [None]:
def color_loss(images, target_color=(0.1, 0.9, 0.5)):
    """Given a target color (R, G, B) return a loss for how far away on average
    the images' pixels are from that color. Defaults to a light teal: (0.1, 0.9, 0.5)"""

    target = torch.tensor(target_color).to(images.device) * 2 - 1   # Map target color to (-1, 1)
    target = target[None, :, None, None]                            # Get shape right to work with the images (b, c, h, w)
    error  = torch.abs(images - target).mean()                      # Mean absolute difference between the image pixels and the target color
    
    return error

In [None]:
# The guidance scale determines the strength of the effect
guidance_loss_scale = 40  # Explore changing this to 5, or 100

x = torch.randn(8, 3, 256, 256).to(device)

for i, t in tqdm(enumerate(scheduler.timesteps)):

    # Prepare the model input
    model_input = scheduler.scale_model_input(x, t)

    # predict the noise residual
    with torch.no_grad():
        noise_pred = image_generation_pipeline.unet(model_input, t)["sample"]

    # Set x.requires_grad to True
    x = x.detach().requires_grad_()

    # Get the predicted x0
    x0 = scheduler.step(noise_pred, t, x).pred_original_sample

    # Calculate loss
    loss = color_loss(x0) * guidance_loss_scale
    if i % 10 == 0:
        print(i, "loss:", loss.item())

    # Get gradient
    cond_grad = -torch.autograd.grad(loss, x)[0]

    # Modify x based on this gradient
    x = x.detach() + cond_grad

    # Now step with scheduler
    x = scheduler.step(noise_pred, t, x).prev_sample

# View the output
grid = torchvision.utils.make_grid(x, nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))

## Advanced Guidance, CLIP(Text & Image) Guidance

In [29]:
image_generation_pipeline = diffusers.DDPMPipeline.from_pretrained(pretrained_model_name_or_path = "johnowhitaker/sd-class-wikiart-from-bedrooms", )
image_generation_pipeline.to(device);

# Create new scheduler and set num inference steps
scheduler = diffusers.DDIMScheduler.from_pretrained("johnowhitaker/sd-class-wikiart-from-bedrooms")

Loading pipeline components...: 100%|██████████| 2/2 [00:00<00:00,  2.45it/s]


In [30]:
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
import requests

clip        = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor   = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

text_descriptions  = [' Red Rose (still life), red flower painting ']
url         = "http://images.cocodataset.org/val2017/000000039769.jpg"
images       = Image.open(requests.get(url, stream=True).raw)

inputs  = processor(text= text_descriptions, images= images, return_tensors="pt", padding=True)
outputs = clip(**inputs)

outputs.text_embeds, outputs.image_embeds

def clip_loss(image_embeds, text_embeds):
    dists = image_embeds.sub(text_embeds).norm(dim=2).div(2).arcsin().pow(2).mul(2)  # Squared Great Circle Distance
    return dists.mean()

**Pay attention to `repeat_guidance_counter`**

- `repeat_guidance_counter` provides a knob to fine-tune the trade-off between stability and the strength of CLIP's influence on the generated image. 
- Higher values of `repeat_guidance_counter` can lead to smoother gradients and potentially better image quality, but at the cost of increased computation time.

In [None]:
GUIDANCE_PROMPT = "Red Rose (still life), red flower painting"  # @param

# Explore changing this
guidance_scale          = 8  # @param
repeat_guidance_counter = 4  # @param

# More steps -> more time for the guidance to have an effect
scheduler.set_timesteps(50)

x = torch.randn(4, 3, 256, 256).to(device)  # RAM usage is high, you may want only 1 image at a time

for i, t in tqdm(enumerate(scheduler.timesteps)):

    model_input = scheduler.scale_model_input(x, t)

    # predict the noise residual
    with torch.no_grad():
        noise_pred = image_generation_pipeline.unet(model_input, t)["sample"]

    cond_grad = 0

    for iteration_number in range(repeat_guidance_counter):

        # Set requires grad on x
        x = x.detach().requires_grad_()

        # Get the predicted x0:
        x0 = scheduler.step(noise_pred, t, x).pred_original_sample

        # EXTERNAL CLIP MODEL BEING USED HERE
        inputs      = processor( text = GUIDANCE_PROMPT , images = x0 , return_tensors = 'pt', padding=True )
        outputs     = clip(**inputs)
        # USE THESE -> outputs.image_embeds & outputs.text_embeds
    
        # Calculate loss
        loss = clip_loss(outputs.image_embeds, outputs.text_embeds) * guidance_scale

        # Get gradient (scale by n_cuts since we want the average)
        cond_grad -= torch.autograd.grad(loss, x)[0] / iteration_number

    if i % 25 == 0:
        print("Step:", i, ", Guidance loss:", loss.item())

    # Modify x based on this gradient
    alpha_bar = scheduler.alphas_cumprod[i]
    x = x.detach() + cond_grad * alpha_bar.sqrt()  # Note the additional scaling factor here!

    # Now step with scheduler
    x = scheduler.step(noise_pred, t, x).prev_sample


grid = torchvision.utils.make_grid(x.detach(), nrow=4)
im = grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5
Image.fromarray(np.array(im * 255).astype(np.uint8))