<a href="https://colab.research.google.com/github/Rituraj003/llm-viz/blob/main/diffusion_eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [11]:
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
from diffusers import UNet2DModel, DDPMScheduler
from PIL import Image
import numpy as np

In [12]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_IMAGES = 10
NUM_TIMESTEPS_LIST = [10, 25, 50, 100]
MODEL_PATH = "google/ddpm-cifar10-32"
IMAGE_SIZE = 32
BATCH_SIZE = NUM_IMAGES
CIFAR_PATH = "./cifar10_data"

In [13]:
model = UNet2DModel.from_pretrained(MODEL_PATH).to(DEVICE)
model.eval()

scheduler = DDPMScheduler.from_pretrained(MODEL_PATH)

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [14]:
test_dataset = CIFAR10(root=CIFAR_PATH, train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) # No need to shuffle for evaluation

In [15]:
def evaluate(model, scheduler, num_images, num_timesteps, device):
    """Generates images from noise using the diffusion model and given number of steps.
       Returns the generated images and original images.
    """

    model.eval()

    # Load the first batch from the test data to compare against reconstructed images
    with torch.no_grad():
      test_images, _ = next(iter(test_loader))
      test_images = test_images[:num_images].to(device)  # Take the first num_images

    # Initial random noise
    noise = torch.randn((num_images, model.config.in_channels, model.config.sample_size, model.config.sample_size), device=device)

    # Set the timesteps for the scheduler
    scheduler.set_timesteps(num_timesteps)

    # Diffusion loop (Reverse Process)
    with torch.no_grad():
        for t in scheduler.timesteps:
            # 1. Predict noise residual
            model_output = model(noise, t).sample

            # 2. Compute previous image: x_t -> x_t-1
            noise = scheduler.step(model_output, t, noise).prev_sample

    # Scale the generated images back to [0, 1] range
    generated_images = (noise / 2 + 0.5).clamp(0, 1)

    # Scale original images back to [0, 1] range
    test_images = (test_images / 2 + 0.5).clamp(0, 1)

    return generated_images, test_images

In [16]:
def save_images(generated_images, original_images, num_timesteps, filename_prefix="diffusion_result"):
    """Saves generated images to files."""
    for i in range(generated_images.shape[0]):
        gen_img_np = generated_images[i].cpu().permute(1, 2, 0).numpy() # C, H, W -> H, W, C
        gen_img = Image.fromarray((gen_img_np * 255).astype(np.uint8))

        orig_img_np = original_images[i].cpu().permute(1, 2, 0).numpy()
        orig_img = Image.fromarray((orig_img_np * 255).astype(np.uint8))

        gen_img.save(f"{filename_prefix}_generated_{num_timesteps}_steps_{i}.png")
        orig_img.save(f"{filename_prefix}_original_{i}.png")

        print(f"Saved images {i} with {num_timesteps} steps.")


In [17]:
if __name__ == "__main__":
    for num_timesteps in NUM_TIMESTEPS_LIST:
        print(f"Evaluating with {num_timesteps} diffusion steps...")
        generated_images, original_images = evaluate(model, scheduler, NUM_IMAGES, num_timesteps, DEVICE)
        save_images(generated_images, original_images, num_timesteps)
    print("Evaluation complete.  Images saved.")

Evaluating with 10 diffusion steps...
Saved images 0 with 10 steps.
Saved images 1 with 10 steps.
Saved images 2 with 10 steps.
Saved images 3 with 10 steps.
Saved images 4 with 10 steps.
Saved images 5 with 10 steps.
Saved images 6 with 10 steps.
Saved images 7 with 10 steps.
Saved images 8 with 10 steps.
Saved images 9 with 10 steps.
Evaluating with 25 diffusion steps...
Saved images 0 with 25 steps.
Saved images 1 with 25 steps.
Saved images 2 with 25 steps.
Saved images 3 with 25 steps.
Saved images 4 with 25 steps.
Saved images 5 with 25 steps.
Saved images 6 with 25 steps.
Saved images 7 with 25 steps.
Saved images 8 with 25 steps.
Saved images 9 with 25 steps.
Evaluating with 50 diffusion steps...
Saved images 0 with 50 steps.
Saved images 1 with 50 steps.
Saved images 2 with 50 steps.
Saved images 3 with 50 steps.
Saved images 4 with 50 steps.
Saved images 5 with 50 steps.
Saved images 6 with 50 steps.
Saved images 7 with 50 steps.
Saved images 8 with 50 steps.
Saved images 9 w