# Deep Learning: Image Generation using DDPM

Authors: Jakub Borek, Bartosz Dybowski

## Install dependencies

In [None]:
!pip install torch torchvision diffusers accelerate scipy scikit-image matplotlib imageio kaggle

## Import dependencies

In [None]:
import os
import json
import random
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image, make_grid
from torchvision.models import inception_v3
from diffusers import UNet2DModel, DDPMScheduler
from PIL import Image
from tqdm import tqdm
import glob
import numpy as np
from scipy.linalg import sqrtm
from skimage.transform import resize
import matplotlib.pyplot as plt
import imageio

## Config

In [None]:
# Dataset
dataset_choice = "cats"  # change to "cats" or "cats_and_dogs"

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Model parameters
image_size = 64
batch_size = 128
num_workers = 4
num_epochs = 100
learning_rate = 1e-4
early_stopping_patience = 3
save_dir = "ddpm_outputs_cats"
os.makedirs(save_dir, exist_ok=True)
checkpoint_path = os.path.join(save_dir, "ddpm_model.pt")
fid_log = []

## Set seed

In [None]:
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)

## Choose dataset and download

In [None]:
with open("kaggle.json", "r") as f:
    kaggle_token = json.load(f)

os.makedirs(os.path.expanduser("~/.kaggle"), exist_ok=True)
with open(os.path.expanduser("~/.kaggle/kaggle.json"), "w") as f:
    json.dump(kaggle_token, f)
os.chmod(os.path.expanduser("~/.kaggle/kaggle.json"), 0o600)

if dataset_choice == "cats":
    dataset_path = "./cat-dataset/cats/Data"
    if not os.path.exists(dataset_path):
        !kaggle datasets download -d borhanitrash/cat-dataset
        !unzip -q cat-dataset.zip -d cat-dataset

elif dataset_choice == "cats_and_dogs":
    dataset_path = "dogs-vs-cats/train"
    if not os.path.exists(dataset_path):
        !kaggle competitions download -c dogs-vs-cats
        !unzip -q dogs-vs-cats.zip
        !unzip -q train.zip -d dogs-vs-cats

print("Dataset path:", dataset_path)
print("Files in dataset:", len(os.listdir(dataset_path)))

## Define dataset

In [None]:
class ImageFolderDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_paths = glob.glob(os.path.join(image_folder, "*"))
        self.transform = transform

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

    def __len__(self):
        return len(self.image_paths)

## Transform images

In [None]:
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

## Choose dataset

In [None]:
valid_ext = {".jpg", ".png", ".jpeg"}
dataset = ImageFolderDataset(dataset_path, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)

## Define model

In [None]:
model = UNet2DModel(
    sample_size=image_size,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 256, 256),
    down_block_types=("DownBlock2D", "AttnDownBlock2D", "DownBlock2D"),
    up_block_types=("UpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
    norm_num_groups=32,
    act_fn="silu"
).to(device)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

## Define image generation function

In [None]:
def generate_images(model, n_samples=100, seed=None):
    if seed is not None:
        torch.manual_seed(seed)
    model.eval()
    latents = torch.randn((n_samples, 3, image_size, image_size)).to(device)
    for t in reversed(range(noise_scheduler.config.num_train_timesteps)):
        ts = torch.full((n_samples,), t, device=device, dtype=torch.long)
        with torch.no_grad():
            noise_pred = model(latents, ts).sample
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
    images = (latents.clamp(-1, 1) + 1) / 2
    return images

## Define FID calculation function

In [None]:
def calculate_fid(real_imgs, fake_imgs):
    inception = inception_v3(pretrained=True, transform_input=False).to(device)
    inception.eval()
    def get_activations(imgs):
        with torch.no_grad():
            resized = nn.functional.interpolate(imgs, size=(299, 299), mode='bilinear')
            preds = inception(resized).detach().cpu().numpy()
        return preds
    real_act = get_activations(real_imgs)
    fake_act = get_activations(fake_imgs)

    mu1, sigma1 = real_act.mean(axis=0), np.cov(real_act, rowvar=False)
    mu2, sigma2 = fake_act.mean(axis=0), np.cov(fake_act, rowvar=False)
    ssdiff = np.sum((mu1 - mu2)**2)
    covmean = sqrtm(sigma1.dot(sigma2)).real
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2 * covmean)
    return fid

## Define interpolation

In [None]:
def interpolate_and_generate(model, z1, z2, steps=8):
    model.eval()
    latents = [z1 * (1 - alpha) + z2 * alpha for alpha in torch.linspace(0, 1, steps + 2)]
    imgs = []
    with torch.no_grad():
        for z in latents:
            x = z.unsqueeze(0).to(device)
            for t in reversed(range(noise_scheduler.config.num_train_timesteps)):
                ts = torch.tensor([t], device=device)
                pred = model(x, ts).sample
                x = noise_scheduler.step(pred, t, x).prev_sample
            imgs.append((x.clamp(-1, 1) + 1) / 2)
    return torch.cat(imgs), imgs

## Train

In [None]:
best_fid = float('inf')
epochs_without_improvement = 0
loss_log = []

for epoch in range(num_epochs):
    model.train()
    loop = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
    best_loss, worst_loss = float('inf'), -float('inf')
    best_img, worst_img = None, None
    epoch_loss_total = 0
    epoch_batches = 0

    for x in loop:
        x = x.to(device)
        noise = torch.randn_like(x)
        t = torch.randint(0, noise_scheduler.num_train_timesteps, (x.size(0),), device=device).long()
        noisy_x = noise_scheduler.add_noise(x, noise, t)

        pred = model(noisy_x, t).sample
        loss = nn.MSELoss()(pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            batch_loss = loss.item()
            epoch_loss_total += batch_loss
            epoch_batches += 1

            if batch_loss < best_loss:
                best_loss = batch_loss
                best_img = x
            if batch_loss > worst_loss:
                worst_loss = batch_loss
                worst_img = x

        loop.set_postfix(loss=loss.item())

    # Mean loss
    epoch_avg_loss = epoch_loss_total / epoch_batches
    loss_log.append(epoch_avg_loss)

    # Save checkpoint
    torch.save(model.state_dict(), checkpoint_path)
    print(f"Model saved to {checkpoint_path}")

    # Save best/worst image examples
    if best_img is not None:
        save_image((best_img + 1) / 2, os.path.join(save_dir, f"best_epoch{epoch+1}.png"))
    if worst_img is not None:
        save_image((worst_img + 1) / 2, os.path.join(save_dir, f"worst_epoch{epoch+1}.png"))

    # Evaluate FID and save
    real_batch = next(iter(dataloader))[:128].to(device)
    fake_batch = generate_images(model, n_samples=128, seed=epoch)
    fid_score = calculate_fid(real_batch, fake_batch)
    fid_log.append(fid_score)
    print(f"FID Epoch {epoch+1}: {fid_score:.2f}")

    save_image((real_batch[:100] + 1) / 2, os.path.join(save_dir, f"real_batch_epoch{epoch+1}.png"), nrow=10)
    save_image((fake_batch[:100] + 1) / 2, os.path.join(save_dir, f"fake_batch_epoch{epoch+1}.png"), nrow=10)

    # Early stopping
    if fid_score < best_fid:
        best_fid = fid_score
        epochs_without_improvement = 0
        torch.save(model.state_dict(), os.path.join(save_dir, "ddpm_best_fid.pt"))
        print("Best FID improved, model saved.")
    else:
        epochs_without_improvement += 1
        print(f"No FID improvement for {epochs_without_improvement} epoch(s).")

    if epochs_without_improvement >= early_stopping_patience:
        print("Early stopping triggered.")
        break

## Save results

In [None]:
# === SAVE INTERPOLATION ===
z1 = torch.randn((3, image_size, image_size)).to(device)
z2 = torch.randn((3, image_size, image_size)).to(device)
interpolated, interpolated_list = interpolate_and_generate(model, z1, z2, steps=8)
save_image(interpolated, os.path.join(save_dir, "interpolation_final.png"), nrow=5)
print("Interpolation image saved.")

# === SAVE INTERPOLATION AS GIF ===
imgs = [img.squeeze().permute(1, 2, 0).cpu().numpy() for img in interpolated_list]
imgs = [(np.clip(img * 255, 0, 255)).astype(np.uint8) for img in imgs]
imageio.mimsave(os.path.join(save_dir, "interpolation.gif"), imgs, duration=0.4)
print("Interpolation GIF saved.")

# === SAVE FID LOG ===
fid_txt = os.path.join(save_dir, "fid_scores.txt")
with open(fid_txt, "w") as f:
    for i, score in enumerate(fid_log):
        f.write(f"Epoch {i+1}: FID = {score:.4f}\n")

# === PLOT FID ===
plt.figure()
plt.plot(range(1, len(fid_log)+1), fid_log)
plt.xlabel("Epoch")
plt.ylabel("FID Score")
plt.title("FID over Epochs")
plt.grid(True)
plt.savefig(os.path.join(save_dir, "fid_plot.png"))
plt.close()
print("FID plot saved.")

# === SAVE LOSS LOG ===
loss_txt = os.path.join(save_dir, "loss_log.txt")
with open(loss_txt, "w") as f:
    for i, l in enumerate(loss_log):
        f.write(f"Epoch {i+1}: Loss = {l:.6f}\n")

# === PLOT LOSS ===
plt.figure()
plt.plot(range(1, len(loss_log)+1), loss_log, color='orange')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss over Epochs")
plt.grid(True)
plt.savefig(os.path.join(save_dir, "loss_plot.png"))
plt.close()
print("Loss plot saved.")