In [None]:
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from datasets import load_dataset
from diffusers import DDIMScheduler, DDPMPipeline
from matplotlib import pyplot as plt
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## Learning is very difficult, but cats save the day.

In [None]:
image_pipe = DDPMPipeline.from_pretrained("google/ddpm-cat-256")
image_pipe.to(device)

## Lets generate 4 cat images

In [None]:
cat_image_list = []
for i in range(0, 4):
    images = image_pipe().images
    cat_image_list.append(images[0])

n_images = len(cat_image_list)

# Create a figure with 1 row and `n_images` columns
fig, axs = plt.subplots(1, n_images, figsize=(20, 5))  # Width 20, Height 5

# Loop through the images and plot them in one row
for idx, image in enumerate(cat_image_list):
    axs[idx].imshow(image)  # Display the image
    axs[idx].axis('off')    # Remove axis for cleaner display

# Display the plots
plt.tight_layout()
plt.show()

1000 steps realy takes very long time. Let's try only use 400 steps.

In [None]:
cat_image_list = []
for i in range(0, 5):
    images = image_pipe(num_inference_steps=400).images
    cat_image_list.append(images[0])

n_images = len(cat_image_list)

# Create a figure with 1 row and `n_images` columns
fig, axs = plt.subplots(1, n_images, figsize=(20, 5))  # Width 20, Height 5

# Loop through the images and plot them in one row
for idx, image in enumerate(cat_image_list):
    axs[idx].imshow(image)  # Display the image
    axs[idx].axis('off')    # Remove axis for cleaner display

# Display the plots
plt.tight_layout()
plt.show()

## What does the model do during different inference steps

Replace the scheduler to see the model behavior

In [None]:
# Replace the default scheduler with DDIMScheduler (if you want DDIM behavior)
scheduler = DDIMScheduler.from_pretrained("google/ddpm-cat-256")
num_inference_steps = 400
scheduler.set_timesteps(num_inference_steps=num_inference_steps)
image_pipe.scheduler = scheduler

# Random starting point (batch of 4 images, 3 channels, 256x256 resolution)
x = torch.randn(4, 3, 256, 256).to(device)  # Batch of 4, 3-channel 256x256 px images

# Loop through the sampling timesteps
for i, t in tqdm(enumerate(image_pipe.scheduler.timesteps)):
    # Prepare model input
    model_input = image_pipe.scheduler.scale_model_input(x, t)
    # Get the prediction from the UNet (predict the noise)
    with torch.no_grad():
        noise_pred = image_pipe.unet(model_input, t)["sample"]
    # Calculate what the updated sample should look like using the scheduler
    scheduler_output = image_pipe.scheduler.step(noise_pred, t, x)
    # Update x for the next iteration
    x = scheduler_output.prev_sample
    # Occasionally display the intermediate results
    if i % 100 == 0 or i == len(image_pipe.scheduler.timesteps) - 1:
        fig, axs = plt.subplots(1, 2, figsize=(12, 5))
        # Display the current `x` (partially denoised image)
        grid = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0)
        axs[0].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)  # Clip and normalize image for display
        axs[0].set_title(f"Current x (step {i})")
        # Display the predicted "clean" image (denoised image)
        pred_x0 = scheduler_output.pred_original_sample  # Original prediction from the scheduler
        if pred_x0 is not None:  # Not all schedulers support pred_original_sample
            grid = torchvision.utils.make_grid(pred_x0, nrow=4).permute(1, 2, 0)
            axs[1].imshow(grid.cpu().clip(-1, 1) * 0.5 + 0.5)
            axs[1].set_title(f"Predicted denoised images (step {i})")
        else:
            axs[1].set_title(f"Predicted denoised images (unavailable)")
        plt.show()


## Use 100 cat images in same breed to fine-tuning the model

In [None]:
dataset_name = "timm/oxford-iiit-pet"
dataset = load_dataset(dataset_name, split="train")

# Choose the cat breed you want to focus on
target_breed = "british_shorthair"  # Change this to any breed name in the dataset
breed_names = dataset.features['label'].names  # Get the list of all breed names

# Check if the breed exists in the dataset
if target_breed not in breed_names:
    raise ValueError(f"Breed '{target_breed}' not found in dataset. Available breeds: {breed_names}")

# Get the breed index for the target breed
breed_index = breed_names.index(target_breed)

# Filter dataset to only include the specific cat breed
filtered_dataset = dataset.filter(lambda example: example['label'] == breed_index)

print(f"Number of images for {target_breed}: {len(filtered_dataset)}")



image_size = 256  # @param
batch_size = 4  # @param

preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}


filtered_dataset.set_transform(transform)

train_dataloader = torch.utils.data.DataLoader(filtered_dataset, batch_size=batch_size, shuffle=True)

print("Previewing batch:")
batch = next(iter(train_dataloader))
grid = torchvision.utils.make_grid(batch["images"], nrow=4)
plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

In [None]:
num_epochs = 5  # @param
lr = 5e-7  # 2param
weight_decay = 1e-5
grad_accumulation_steps = 4  # @param

optimizer = torch.optim.AdamW(image_pipe.unet.parameters(), lr=lr, weight_decay=weight_decay)

losses = []

for epoch in range(num_epochs):
    for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        clean_images = batch["images"].to(device)
        # Sample noise to add to the images
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs = clean_images.shape[0]

        # Sample a random timestep for each image
        timesteps = torch.randint(
            0,
            image_pipe.scheduler.num_train_timesteps,
            (bs,),
            device=clean_images.device,
        ).long()

        # Add noise to the clean images according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_images = image_pipe.scheduler.add_noise(clean_images, noise, timesteps)

        # Get the model prediction for the noise
        noise_pred = image_pipe.unet(noisy_images, timesteps, return_dict=False)[0]

        # Compare the prediction with the actual noise:
        loss = F.mse_loss(
            noise_pred, noise
        )  # NB - trying to predict noise (eps) not (noisy_ims-clean_ims) or just (clean_ims)

        # Store for later plotting
        losses.append(loss.item())

        # Update the model parameters with the optimizer based on this loss
        loss.backward(loss)

        # Gradient accumulation:
        if (step + 1) % grad_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

    print(f"Epoch {epoch} average loss: {sum(losses[-len(train_dataloader):])/len(train_dataloader)}")

# Plot the loss curve:
plt.plot(losses)

## Generate some cat images with our fine-tuned model

In [None]:
x = torch.randn(8, 3, 256, 256).to(device)  # Batch of 8
for i, t in tqdm(enumerate(scheduler.timesteps)):
    model_input = scheduler.scale_model_input(x, t)
    with torch.no_grad():
        noise_pred = image_pipe.unet(model_input, t)["sample"]
    x = scheduler.step(noise_pred, t, x).prev_sample
grid = torchvision.utils.make_grid(x, nrow=4)
plt.imshow(grid.permute(1, 2, 0).cpu().clip(-1, 1) * 0.5 + 0.5)

### Save our model

In [None]:
# torch.save(image_pipe, target_breed+"_FT.pth")