## Assignment 03

### Part 1: Image Generation and Embedding Extraction

In [1]:
import torch
from torch import clip
from utils import UNet_utils, ddpm_utils

In [2]:
# Use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
IMG_SIZE = 32
IMG_CH = 3
BATCH_SIZE = 128
INPUT_SIZE = (IMG_CH, IMG_SIZE, IMG_SIZE)

T = 400
B_start = 0.0001
B_end = 0.02
B = torch.linspace(B_start, B_end, T).to(device)
ddpm = ddpm_utils.DDPM(B, device)

clip_model, clip_preprocess = clip.load("ViT-B/32")
clip_model.eval()
CLIP_FEATURES = 512

AttributeError: 'builtin_function_or_method' object has no attribute 'load'

In [None]:
# Initialize the UNet model identical to the one in notebook 05.
model = UNet_utils.UNet(
    T, IMG_CH, IMG_SIZE, down_chs=(256, 256, 512), t_embed_dim=8, c_embed_dim=CLIP_FEATURES
)

# Load the pre-trained model weights
model.load_state_dict(torch.load('path_to_your_model.pth'))

model.eval()

In [None]:
# List of Text Prompts to genereate Images From
text_prompts = [
    "A photo of a red rose",
    "A photo of a white daisy",
    "A photo of a yellow sunflower",
]

In [None]:
# --- Embedding Extraction using Hooks ---
embeddings_storage = {}

def get_embedding_hook(name):
    def hook(model, input, output):
        embeddings_storage[name] = output.detach()
    return hook

# Register hooks to capture embeddings from the "down2" layer of the U-Net model
model.down2.register_forward_hook(get_embedding_hook("down2"))

In [None]:
def sample_flowers(text_list):
    text_tokens = clip.tokenize(text_list).to(device)
    c = clip_model.encode_text(text_tokens).float()
    x_gen, x_gen_store = ddpm_utils.sample_w(model, ddpm, INPUT_SIZE, T, c, device)
    return x_gen, x_gen_store

In [None]:
generated_images, _ = sample_flowers(text_prompts)

extracted_embeddings = embeddings_storage['down2']

### Part 2: Evaluation with CLIP Score and Frechet Inception Distance (FID)

#### Metric Calculation

In [None]:
import torch
import open_clip
from PIL import Image
import numpy as np
from scipy.linalg import sqrtm

##### CLIP Score Calculation

In [5]:
def calculate_clip_score(image_path, text_prompt):
    # Load model
    model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')

    # Preprocess inputs
    image = preprocess(Image.open(image_path)).unsqueeze(0)

    tokenizer = open_clip.get_tokenizer('ViT-B-32')
    text = tokenizer([text_prompt])

    # Compute features and similarity
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
        
        # Normalize features
        image_features /= image_features.norm(dim=-1, keepdim=True)
        text_features /= text_features.norm(dim=-1, keepdim=True)

        # Calculate dot product
        score = (image_features @ text_features.T).item()

    return score

##### Frechet Inception Distance (FID) Calculation

In [None]:
def calculate_fid(real_embeddings, gen_embeddings):
    # Calculate mean and covariance
    mu1, sigma1 = real_embeddings.mean(axis=0), np.cov(real_embeddings,rowvar=False)
    mu2, sigma2 = gen_embeddings.mean(axis=0), np.cov(gen_embeddings, rowvar=False)
    
    # Calculate sum squared difference between means
    ssdiff = np.sum((mu1 - mu2)**2)

    # Calculate sqrt of product of covariances
    covmean = sqrtm(sigma1.dot(sigma2))

    # Handle numerical errors
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    # Final FID calculation
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)

    return fid

In [None]:
# Calculate CLIP scores
clip_scores = []
for i, img in enumerate(generated_images):
    score = calculate_clip_score(clip_model, img.unsqueeze(0), text_prompts[i], device)
    clip_scores.append(score)
average_clip_score = sum(clip_scores) / len(clip_scores)
print(f"Average CLIP Score: {average_clip_score}")

# Calculate FID score
# Load real TF-Flowers dataset images here (not shown)
real_images = ...
fid_score = calculate_fid(real_images, generated_images)
print(f"FID Score: {fid_score}")


ModuleNotFoundError: No module named 'open_clip'

### Part 3: Embedding Analysis with FiftyOne Brain

In [4]:
import fiftyone as fo
import fiftyone.brain as fob

In [None]:
# Create new FiftyOne dataset
dataset = fo.Dataset(name="generated_flowers_with_embeddings")

In [None]:
# For each image, create a fiftyone.Sample and add the following metadata:
# - The file path to the saved image.
# - The text prompt (as a `fo.Classification` label).
# - The CLIP score (as a custom field).
# - The extracted U-Net embedding (as a custom field).

from matplotlib import transforms


for i, img in enumerate(generated_images):
    # Save image to disk
    img_path = f"generated_image_{i}.png"
    # Assuming img is a PIL Image or can be converted to one
    img_pil = transforms.ToPILImage()(img.cpu())
    img_pil.save(img_path)
    
    # Create sample
    sample = fo.Sample(
        filepath=img_path,
        prompt=fo.Classification(label=text_prompts[i]),
        clip_score=clip_scores[i],
        unet_embedding=extracted_embeddings[i].cpu().numpy()
    )
    dataset.add_sample(sample)

# Now that the dataset is populated, use FiftyOne Brain to analyze the embeddings.
brain = fob.EmbeddingBrain()
brain.compute_embeddings(dataset, "unet_embedding")
view = brain.cluster()
session = fo.launch_app(view=view)

In [None]:
# Compute uniqueness of the dataset based on U-Net embeddings
fob.compute_uniqueness(dataset)

# Compute representativeness of the dataset based on U-Net embeddings
fob.compute_representativeness(dataset, embeddings="unet_embedding")

session = fo.launch_app(dataset)

### Part 4: Logging with Weights & Biases

In [None]:
import wandb
wandb.login()

In [None]:
run = wandb.init(project="diffusion_model_assessment_v2")

# Log your hyperparameters (e.g., guidance weight `w`, number of steps `T`).
wandb.config.update({
    "guidance_weight": w,
    "num_steps": T,
})

# Log evaluation metrics (CLIP Score and FID).
wandb.log({
    "average_clip_score": average_clip_score,
    "fid_score": fid_score,
})

# Create wandb Table and log results
table = wandb.Table(columns=["image", "prompt", "clip_score", "uniqueness", "representativeness"])
for sample in dataset:
    img = wandb.Image(sample.filepath)
    prompt = sample.prompt.label
    clip_score = sample.clip_score
    uniqueness = sample.metadata["fiftyone"]["uniqueness"]
    representativeness = sample.metadata["fiftyone"]["representativeness"]
    
    table.add_data(img, prompt, clip_score, uniqueness, representativeness)
wandb.log({"generated_flowers_table": table})

# Finish the wandb run
run.finish()