# Denoising Probabilistic Diffusion Models

In [None]:
import sys

# Colab-only setup
if "google.colab" in sys.modules:
    print("Running in Google Colab. Setting up repo")

    !git clone https://github.com/MatthiasCr/Diffusion-Models-Assignment.git
    %cd Diffusion-Models-Assignment

In [None]:
import os
from PIL import Image
import torch
import clip
import open_clip
import wandb
import urllib.request
import fiftyone as fo
import fiftyone.brain as fob
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import numpy as np
from scipy.linalg import sqrtm
import matplotlib.pyplot as plt

from utils import UNet_utils, ddpm_utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Part 1

In [None]:
T = 400
B_start = 0.0001
B_end = 0.02
B = torch.linspace(B_start, B_end, T).to(device)

IMG_SIZE = 32
IMG_CH = 3
BATCH_SIZE = 128
INPUT_SIZE = (IMG_CH, IMG_SIZE, IMG_SIZE)
GUIDANCE_WEIGHT = 2.0

ddpm = ddpm_utils.DDPM(B, device)

# Initialize the U-Net model 
model = UNet_utils.UNet(
    T, img_ch=3, img_size=32, down_chs=(256, 256, 512), t_embed_dim=8, c_embed_dim=512
).to(device)

# load clip model to get text embeddings from prompts
clip_model, clip_preprocess = clip.load("ViT-B/32")
clip_model.eval()
CLIP_FEATURES = 512

In [None]:
weights_path = "weights/model.pth"
url = "https://github.com/MatthiasCr/Diffusion-Models-Assignment/releases/download/v1/model_weights.pth"

os.makedirs("weights", exist_ok=True)

if not os.path.exists(weights_path):
    print("Downloading pretrained weights...")
    urllib.request.urlretrieve(url, weights_path)

model.load_state_dict(torch.load(weights_path, map_location=device)) 
model.eval()

In [None]:
# list of text prompts to generate images for.
text_prompts = [
    "A photo of a red rose",
    "An image of a red rose",
    "A picture of a red rose",
    "A red rose",
    "A rose with red petals"
    "A purple rose",
    "A yellow rose",
    "A blue rose",

    "A photo of a white daisy",
    "An image of a white daisy",
    "A picture of a white daisy",
    "A white daisy",
    "A daisy that is white",
    "A yellow daisy",
    "A red daisy",

    "A photo of a yellow sunflower",
    "An image of a yellow sunflower",
    "A picture of a yellow sunflower",
    "A yellow sunflower",
    "A sunflower",
    "A sunflower with orange petals"

    "An orange tulip",
    "A rose tulip",
    "A purple tulip",
    "A photo of a white orchid",
    "A rose orchid",
    "A photo of a purple flower",
    "A photo of a blue flower",
]

In [None]:
# Register a forward hook on the `down2` layer of the U-Net model.
embeddings_storage = {}

def get_embedding_hook(name):
    def hook(model, input, output):
        embeddings_storage[name] = output.detach()
    return hook

model.down2.register_forward_hook(get_embedding_hook('down2'))

# function to generate flower images from prompts
def sample_flowers(text_list):
    text_tokens = clip.tokenize(text_list).to(device)
    c = clip_model.encode_text(text_tokens).float()
    x_gen, x_gen_store = ddpm_utils.sample_w(model, ddpm, INPUT_SIZE, T, c, device, w_tests=[GUIDANCE_WEIGHT])
    return x_gen, x_gen_store


generated_images, _ = sample_flowers(text_prompts)
extracted_embeddings = embeddings_storage['down2']

In [None]:
def show_tensor_image(image):
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: torch.minimum(torch.tensor([1]), t)),
        transforms.Lambda(lambda t: torch.maximum(torch.tensor([0]), t)),
        transforms.ToPILImage(),
    ])
    plt.imshow(reverse_transforms(image[0].detach().cpu()))

grid = make_grid(generated_images.cpu())
show_tensor_image([grid])
plt.show()

## Part 2

In [None]:
def calculate_clip_score(image_path, text_prompt):
    # Load model
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
    
    # Preprocess inputs
    image = preprocess(Image.open(image_path)).unsqueeze(0)
    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    text = tokenizer([text_prompt])

    # Compute features and similarity
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        # Normalize features
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        # Calculate dot product
        score = (image_features @ text_features.T).item()
    return score

In [None]:
def calculate_fid(real_embeddings, gen_embeddings):
    # real_embeddings and gen_embeddings should be Numpy arrays of shape (N, 2048) 
    # extracted from an InceptionV3 model
    # Calculate mean and covariance
    mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings, rowvar=False)
    mu2, sigma2 = gen_embeddings.mean(axis=0), np.cov(gen_embeddings, rowvar=False)
    # Calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2)
    # Calculate sqrt of product of covariances
    covmean = sqrtm(sigma1.dot(sigma2))
    # Handle numerical errors
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    # Final FID calculation
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [None]:
import open_clip

# TODO: Calculate the CLIP score for each generated image against its prompt.

# You can use the `calculate_clip_score` function from the evaluation guide.

# TODO: Calculate the FID score for the set of generated images.

# You will need the `calculate_fid` function and the Inception model from the evaluation guide.

# You will also need to load the real TF-Flowers dataset to compare against.

## Part 3

In [None]:
n_generated_images = generated_images.shape[0]

# Select the embeddings corresponding to the guided pass
embeddings_guided_pass = extracted_embeddings[:n_generated_images]

# Flatten the all dimensions except batch
flattened_embeddings = embeddings_guided_pass.view(n_generated_images, -1)

In [None]:
# Save generated images as .png
images_dir = "generated_images"
os.makedirs(images_dir, exist_ok=True)

image_filepaths = []
for i, img_tensor in enumerate(generated_images):
    img_filename = os.path.join(images_dir, f"generated_image_{i}.png")
    save_image(img_tensor, img_filename)
    image_filepaths.append(img_filename)

In [None]:
dataset = fo.Dataset(name="generated_flowers_with_embeddings")

samples = []
for i in range(n_generated_images):
    sample = fo.Sample(filepath=image_filepaths[i])
    sample["text_prompt"] = fo.Classification(label=text_prompts[i])
    sample["unet_embedding"] = flattened_embeddings[i].tolist()
    samples.append(sample)

dataset.add_samples(samples)

In [None]:
# Compute uniqueness and representativeness.
fob.compute_uniqueness(dataset)
fob.compute_representativeness(dataset, embeddings="unet_embedding")

In [None]:
session = fo.launch_app(dataset, auto=False)
print(session.url)

## Part 4

In [None]:
wandb.login()

run = wandb.init(project="diffusion_model_assessment")

# TODO: Log your hyperparameters (e.g., guidance weight `w`, number of steps `T`).

# TODO: Log your evaluation metrics (CLIP Score and FID).

# TODO: Create a wandb.Table to log your results. The table should include:

# - The generated image.

# - The text prompt.

# - The CLIP score.

# - The uniqueness score.

# - The representativeness score.

run.finish()