In [None]:
runpod 
pip install genaibook

In [None]:
import torch
from diffusers import DDPMPipeline
from genaibook.core import get_device

# We can set the device to use our GPU or CPU
device = get_device()

# Load the pipeline
image_pipe = DDPMPipeline.from_pretrained("google/ddpm-celebahq-256")
image_pipe.to(device)

# Sample an image
image_pipe().images[0]

In [None]:
from genaibook.core import plot_noise_and_denoise
image = torch.randn(4, 3, 256, 256).to(device)
image_pipe.scheduler.set_timesteps(num_inference_steps=100)
for i, t in enumerate(image_pipe.scheduler.timesteps):
    with torch.inference_mode():
        noise_pred = image_pipe.unet(image, t)["sample"]
    scheduler_output = image_pipe.scheduler.step(noise_pred, t, image)
    image = scheduler_output.prev_sample
    if i % 10 == 0 or i == len(image_pipe.scheduler.timesteps) - 1:
        plot_noise_and_denoise(scheduler_output, i)

In [None]:
from datasets import load_dataset
dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")

from torchvision import transforms
image_size = 384
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)

In [None]:
def transform(examples):
    examples = [preprocess(image) for image in examples["image"]]
    return {"images": examples}

dataset.set_transform(transform)
batch_size = 16

train_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)

In [None]:
from genaibook.core import show_images
show_images(batch["images"][:8] * 0.5 + 0.5)
batch = next(iter(train_dataloader))
from diffusers import DDPMScheduler

scheduler = DDPMScheduler(
    num_train_timesteps=1000, beta_start=0.001, beta_end=0.02
)

timesteps = torch.linspace(0, 999, 8).long()
x = batch["images"][:8]
noise = torch.rand_like(x)
noised_x = scheduler.add_noise(x, noise, timesteps)
show_images((noised_x * 0.5 + 0.5).clip(0, 1))

In [None]:
from diffusers import UNet2DModel

model = UNet2DModel(
    in_channels=3,  # 3 channels for RGB images
    sample_size=64,  # Specify our input size
    # The number of channels per block affects the model size
    block_out_channels=(64, 128, 256, 512),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "AttnDownBlock2D",
    ),
    up_block_types=("AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D", "UpBlock2D"),
).to(device)

In [None]:
with torch.inference_mode():
    out = model(noised_x.to(device), timestep=timesteps.to(device)).sample

from torchvision import transforms
batch_size = 32
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)
def transform(examples):
    examples = [preprocess(image) for image in examples["image"]]
    return {"images": examples}

dataset.set_transform(transform)

train_dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)

from torch.nn import functional as F
num_epochs = 50
lr = 1e-4 
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
losses = [] 

In [None]:
for epoch in range(num_epochs):
    for batch in train_dataloader:
        clean_images = batch["images"].to(device)
        noise = torch.randn(clean_images.shape).to(device)
        timesteps = torch.randint(
                    0,
                    scheduler.config.num_train_timesteps,
                    (clean_images.shape[0],),
                    device=device,
                ).long()
        noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
        noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
        loss = F.mse_loss(noise_pred, noise)
        losses.append(loss.item())
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
    print(
        f"Finished epoch {epoch}. Average loss for this epoch: {avg_loss:05f}"
    )

In [None]:
from matplotlib import pyplot as plt
plt.subplots(1, 2, figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(losses)
plt.title("Training loss")
plt.xlabel("Training step")

plt.subplot(1, 2, 2)
plt.plot(range(400, len(losses)), losses[400:])
plt.title("Training loss from step 400")
plt.xlabel("Training step");

sample = torch.randn(4, 3, 64, 64).to(device)
show_images(sample.clip(-1, 1) * 0.5 + 0.5, nrows=1)

for t in scheduler.timesteps:
    with torch.inference_mode():
        noise_pred = model(sample, t)["sample"]
        
    sample = scheduler.step(noise_pred, t, sample).prev_sample
show_images(sample.clip(-1, 1) * 0.5 + 0.5, nrows=1)