### DiffusionCLIP


In [None]:

# core libs + diffusers + CLIP + metrics
!pip install torch torchvision diffusers accelerate \
            torchmetrics tqdm matplotlib ftfy \
            git+https://github.com/openai/CLIP.git


Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-9tncnr0p
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-9tncnr0p
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torchmetrics
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014

In [None]:
!pip install diffusers transformers ftfy torchmetrics
!pip install clean-fid --quiet

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from diffusers import UNet2DModel, DDPMScheduler
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
from torchvision.utils import make_grid
import os
from cleanfid import fid



In [None]:
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip
!unzip -q tiny-imagenet-200.zip
!rm tiny-imagenet-200.zip

In [14]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])

dataset = datasets.ImageFolder('/content/datasets/train', transform=transform)
train_len = int(0.8 * len(dataset))
val_len = len(dataset) - train_len
train_data, val_data = random_split(dataset, [train_len, val_len])

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False)


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
unet = UNet2DModel(sample_size=64, in_channels=3, out_channels=3, layers_per_block=2, block_out_channels=(64, 128,128,256)).to(device)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

def compute_clip_loss(images, texts):
    inputs = clip_processor(text=texts, images=images, return_tensors="pt", padding=True).to(device)
    outputs = clip_model(**inputs)
    return 1 - F.cosine_similarity(outputs.image_embeds, outputs.text_embeds).mean()


fid_score = fid.compute_fid(
    real_path="/content/tiny-imagenet-200/val/images",  # adjust path if needed
    fake_path="generated",
    mode="clean"
)
print(f"FID Score: {fid_score:.2f}")


In [None]:
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-4)
epochs = 5
clip_text = "picture of an alien"

train_losses, val_losses = [], []

for epoch in range(epochs):
    unet.train()
    epoch_loss = 0
    for x, _ in tqdm(train_loader):
        x = x.to(device)
        t = torch.randint(0, noise_scheduler.num_train_timesteps, (x.size(0),), device=device).long()
        noise = torch.randn_like(x).to(device)
        noisy_x = noise_scheduler.add_noise(x, noise, t)
        noise_pred = unet(noisy_x, t).sample
        loss_diff = F.mse_loss(noise_pred, noise)

        clip_loss = compute_clip_loss(x, [clip_text]*x.size(0))
        loss = loss_diff + clip_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    train_losses.append(epoch_loss / len(train_loader))

    # Validation loss
    unet.eval()
    val_loss = 0
    with torch.no_grad():
        for x, _ in val_loader:
            x = x.to(device)
            t = torch.randint(0, noise_scheduler.num_train_timesteps, (x.size(0),), device=device).long()
            noise = torch.randn_like(x).to(device)
            noisy_x = noise_scheduler.add_noise(x, noise, t)
            noise_pred = unet(noisy_x, t).sample
            loss_diff = F.mse_loss(noise_pred, noise)
            clip_loss = compute_clip_loss(x, [clip_text]*x.size(0))
            loss = loss_diff + clip_loss
            val_loss += loss.item()
    val_losses.append(val_loss / len(val_loader))


  deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
  5%|▌         | 126/2500 [03:00<56:48,  1.44s/it]

In [None]:
plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.title("Loss vs Epochs")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()


In [None]:
def generate_images(n=32):
    unet.eval()
    samples = []
    for _ in range(n // 32):
        x = torch.randn(32, 3, 64, 64).to(device)
        for t in reversed(range(0, 1000, 100)):
            with torch.no_grad():
                model_output = unet(x, torch.tensor([t]*32).to(device)).sample
                x = noise_scheduler.step(model_output, t, x).prev_sample
        samples.append(x.cpu())
    return torch.cat(samples, dim=0)

samples = generate_images(32)
def show_image_grid(images, nrow=8, title="Generated Images"):
    grid = make_grid(images.clamp(0, 1), nrow=nrow)

    plt.figure(figsize=(12, 12))
    plt.imshow(grid.permute(1, 2, 0).cpu())
    plt.title(title, fontsize=16)
    plt.axis('off')
    plt.show()
    for i, img in enumerate(samples):
        grid = make_grid(images.clamp(0, 1), nrow=nrow)
# Example usage
show_image_grid(samples, title="CLIP-Guided Diffusion Results")
save_path = "generated"
os.makedirs(save_path, exist_ok=True)



In [None]:
clip_texts = ["picture of an alien"] * len(samples)
clip_loss_value = compute_clip_loss(samples, clip_texts).item()
print(f"CLIP Loss: {clip_loss_value:.4f}")
print(f"FID Score: {fid_score:.2f}")
