In [None]:
import torch
from diffusers import StableDiffusionPipeline
from transformers import CLIPProcessor, CLIPModel, BlipModel, BlipProcessor
from PIL import Image
import numpy as np
from tqdm.notebook import tqdm
from torchvision.transforms import functional as TF
import torchvision
from torchvision.transforms import (CenterCrop, Compose, InterpolationMode, Normalize, Resize)

In [4]:
import os

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

In [5]:
def compute_clip_loss(clip_model, tokenizer, prompt, image):
    image_features = clip_model.get_image_features(image)

    prompt_token = tokenizer(
        prompt, return_tensors="pt", padding=True, max_length=77, truncation=True
    ).to(image.device)
    text_features = clip_model.get_text_features(**prompt_token)

    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    return 100 - (image_features @ text_features.T).mean() * clip_model.logit_scale.exp()

def compute_blip_loss(blip_model, tokenizer, prompt, image):
    image_features = blip_model.get_image_features(image)

    prompt_token = tokenizer(
        text=prompt, return_tensors="pt", padding=True, max_length=77, truncation=True
    ).to(image.device)
    text_features = blip_model.get_text_features(**prompt_token)

    image_features = image_features / image_features.norm(dim=-1, keepdim=True)
    text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    return 100 - (image_features @ text_features.T).mean() * blip_model.logit_scale.exp()

In [6]:
def optimize_noise(pipe, metric_model, tokenizer, prompt, noise, iterations=50, lr=0.05, compute_metric_loss=compute_clip_loss):

    noise.requires_grad_(True)
    optimizer = torch.optim.Adam([noise], lr=lr)
    preprocess = Compose(
        [
            Resize(224, interpolation=InterpolationMode.BICUBIC),
            CenterCrop(224),
            Normalize(
                (0.48145466, 0.4578275, 0.40821073),
                (0.26862954, 0.26130258, 0.27577711),
            ),
        ]
    )

    tqdm_bar = tqdm(range(iterations))
    for i in tqdm_bar:
        optimizer.zero_grad()

        latents = 1 / 0.18215 * noise
        decoded_latents = pipe.vae.decode(latents).sample
        generated_images = (decoded_latents / 2 + 0.5).clamp(0, 1)

        images = torch.nn.functional.interpolate(generated_images, size=(224, 224), mode="bilinear", align_corners=False)
        images = preprocess(images)
        torch.cuda.empty_cache()

        loss = compute_metric_loss(metric_model, tokenizer, prompt, images)
        torch.cuda.empty_cache()
        tqdm_bar.set_description(f"loss: {loss.item()}")
        # print(f"Iteration {i}, Loss: {loss.item()}")

        loss.backward()
        optimizer.step()


    return noise


In [None]:
pipe = StableDiffusionPipeline.from_pretrained(
    "stabilityai/sd-turbo",
    variant="fp16",
    # torch_dtype=torch.float16
)
pipe = pipe.to("cuda")

In [8]:
def freeze_params(params):
    for param in params:
        param.requires_grad = False

freeze_params(pipe.vae.parameters())
freeze_params(pipe.unet.parameters())
freeze_params(pipe.text_encoder.parameters())

In [None]:
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
tokenizer = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
freeze_params(clip_model.parameters())

In [None]:
blip_model = BlipModel.from_pretrained("Salesforce/blip-image-captioning-base").to("cuda")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
freeze_params(blip_model.parameters())

In [10]:
def set_seed(n):
    torch.manual_seed(n)
    torch.cuda.manual_seed(n)
    torch.cuda.manual_seed_all(n)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [11]:
def save_image(name, noise_matrix, prompt):
    with torch.no_grad():
        image_clip = pipe(latents=noise_matrix, prompt=prompt).images[0]
        image_clip.save(name)

In [12]:
def optimize_using_blip(noise, prompt):
    return optimize_noise(
            pipe=pipe,
            metric_model=blip_model,
            tokenizer=blip_processor,
            prompt=prompt,
            noise=noise,
            iterations=50,
            lr=0.01,
            compute_metric_loss=compute_blip_loss
        )

In [13]:
def optimize_using_clip(noise, prompt):
    return optimize_noise(
            pipe=pipe,
            metric_model=clip_model,
            tokenizer=tokenizer,
            prompt=prompt,
            noise=noise,
            iterations=50,
            lr=0.01
        )

In [22]:
def test_with_blip(prompt):
    set_seed(42)
    noise = torch.randn((1, pipe.unet.config.in_channels, 64, 64), requires_grad=True, device="cuda")
    save_image(f"{prompt} - normal.png", noise, prompt)
    optimized_noise = optimize_using_blip(noise, prompt)
    save_image(f"{prompt} - optimized with blip.png", optimized_noise, prompt)

In [14]:
def test_with_clip(prompt):
    set_seed(42)
    noise = torch.randn((1, pipe.unet.config.in_channels, 64, 64), requires_grad=True, device="cuda")
    save_image(f"{prompt} - normal.png", noise, prompt)
    optimized_noise = optimize_using_clip(noise, prompt)
    save_image(f"{prompt} - optimized with clip.png", optimized_noise, prompt)

In [15]:
all_prompts = [
    "a red cow and a brown dog",
    "a red book and a brown cat",
    "a black dog and a brown cat",
    "a red bowl and a blue chair",
    "a green cup and a blue vase",
    "a blue dove and a white sky",
    "a brown car and a red giraffe",
    "a green bench and a blue book",
    "a gold bench and a green clock",
    "a blue cake and a brown giraffe",
    "a white candle and a red holder",
    "a blue elephant and a brown vase",
    "a yellow apple and a green banana",
    "a green tomato and a red cucumber",
    "a silver watch and a gold bracelet",
    "a brown giraffe and a red suitcase",
    "a red backpack and a blue suitcase",
    "a green smoothie and a purple straw",
    "a pink elephant and a brown giraffe",
    "a white cat is inside a black toilet",
    "a white bathroom has a red towel on the bar",
    "a black truck has a red dog in the drivers chair",
    "a white kitchen counter with a big, brown bowl on it",
    "the kitchen has silver cabinets and a brown refrigerator",
    "a woman wearing a black shirt and white vest looks into a mirror while holding a camera",
]

In [27]:
for prompt in tqdm(all_prompts, desc="all prompts"):
    test_with_blip(prompt)

all prompts:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [16]:
for prompt in tqdm(all_prompts, desc="all prompts"):
    test_with_clip(prompt)

all prompts:   0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]