In [123]:
from math import e
import random
from base64 import b64encode
import numpy
import torch
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
from huggingface_hub import notebook_login

# For video display:
from IPython.display import HTML
from matplotlib import pyplot as plt
from pathlib import Path
from PIL import Image
from torch import autocast
from torchvision import transforms as tfms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, logging
import os

torch.manual_seed(1)

# Supress some unnecessary warnings when loading the CLIPTextModel
logging.set_verbosity_error()

# Set device
torch_device = "cuda" if torch.cuda.is_available() else "cpu"

## Initialize Component

We first initialize components from CLIP Model and Stable Diffusion

In [124]:
# Load the autoencoder model which will be used to decode the latents into image space.
vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")

# Load the tokenizer and text encoder to tokenize and encode the text.
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")

# The UNet model for generating the latents.
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")

# The noise scheduler
scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

# To the GPU we go!
vae = vae.to(torch_device)
text_encoder = text_encoder.to(torch_device)
unet = unet.to(torch_device);

In [125]:
# scale factor that diffusion model uses
DIFFUSION_SCALE_FACTOR = 0.18215

## Utility Functions

In [126]:
def pil_to_latent(input_im):
    # Single image -> single latent in a batch (so size 1, 4, 64, 64)
    with torch.no_grad():
        latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
    return DIFFUSION_SCALE_FACTOR * latent.latent_dist.sample()

def latents_to_pil(latents):
    # bath of latents -> list of images
    latents = (1 / DIFFUSION_SCALE_FACTOR) * latents
    with torch.no_grad():
        image = vae.decode(latents).sample
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images

### Noise Latents

Based on the timestep index $t$, we retrieve a sigma and then sample a random noise. The noise is added to the latent at timestep $t$

In [127]:
def noise_latent(latent, timestep_idx):
    # Get the sigma value for the randomly selected timestep
    start_sigma = scheduler.sigmas[timestep_idx]

    # Generate random noise with the same shape as the latent
    noise = torch.randn_like(latent)

    timestep = scheduler.timesteps[timestep_idx]

    # Add noise to the latent using the scheduler at the randomly selected timestep
    noised_latent = scheduler.add_noise(latent, noise, timesteps=torch.tensor([timestep]))

    return noised_latent

### Denoise Latent

After we noise a latent on timestep $t$, we then try to denoise the image.
We first predict the noise added using the `unet` along with textual embeddings
The latent is then denoised conditioned on textual embeddings

In [128]:
guidance_scale = 7.5

def denoise_latent_conditioned(latent, timestep_idx, text_embeddings):
    timestep_idx = min(timestep_idx, len(scheduler.timesteps) - 2)
    # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
    latent_model_input = torch.cat([latent] * 2)
    sigma = scheduler.sigmas[timestep_idx]
    timestep = scheduler.timesteps[timestep_idx]
    latent_model_input = scheduler.scale_model_input(latent_model_input, timestep)

    # predict the noise residual
    with torch.no_grad():
        noise_pred = unet(latent_model_input, timestep, encoder_hidden_states=text_embeddings)["sample"]

    # perform guidance
    noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
    noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    # compute the previous noisy sample x_t -> x_t-1
    latents = scheduler.step(noise_pred, timestep, latent).prev_sample

    return latents[0]

## Text Embeddings

We generate text embeddings by first tokening the prompt/classname using tokenizer from CLIP and then encode it using CLIP Text Encoder.
We also pad it with unconditional embeddings for guidance free classiciation

In [129]:
batch_size = 1
def genenerate_text_embeddings(prompt: str):
    # Prep text (same as before)
    text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
        [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
    )
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
    return text_embeddings

In [130]:
# w_t := exp(−7t)
def timestep_weight(t: int):
    return e ** (-7*t)

In [131]:
min_scores = 20
max_scores = 2000
cutoff_pval = 2 * e**-3
sd_model = 'CompVis/stable-diffusion-v1-4'

# weight function used
# score = w_t * l2_loss

def diffuser_classifier(image_path: str, classes: list[str]):
    k = len(classes)

    if k == 0:
        return []
    if k == 1:
        return [1.0]
        
    scores = {y_i: [] for y_i in classes}
    for i in tqdm(range(max_scores)):
        
        image = Image.open(image_path)
        image_latent = pil_to_latent(image)

        """
        Noise the Image
        """
        # t ~ U([0, 1]) - t can take any value from 0 to 1 and each value is equally likely
        t = random.uniform(0, 1)
        num_steps = len(scheduler.timesteps)
        timestep_idx = int(t * (num_steps - 1))

        # x_t ~ q(x_t|x)
        x_t = noise_latent(
            latent=image_latent,
            timestep_idx=timestep_idx,
        )

        """
        Score against remaining classes
        """
        for y_i in scores:
            embeddings = genenerate_text_embeddings(y_i)
            
            denoised_latent = denoise_latent_conditioned(
                latent=x_t,
                timestep_idx=timestep_idx,
                text_embeddings=embeddings,
            )

            # calculate L2 loss
            squared_loss = torch.mean((image_latent - denoised_latent) ** 2)
            w_t = timestep_weight(t)
            weighted_loss = w_t * squared_loss
            
            scores[y_i].append(weighted_loss)

    means = {key: torch.mean(torch.tensor(scores[key])).item() for key in scores}
    min_class = min(means, key=means.get)

    print(f"label: {y_i}, t:{t}, loss:", means)

In [None]:
diffuser_classifier(
    image_path='bear.png', 
    classes=['apple', 'god', 'cat', 'bear', 'duck']
)