# DDPM with cats dataset

## import nesessary modules 

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os

from torch.utils.tensorboard import SummaryWriter
from modules.dcgan import Discriminator, Generator, initialize_weights

## set agnostic

In [None]:
# agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## load data set

In [None]:
# set image size
img_size = 64

# Set the path to the dataset
dataset_path = 'dataset/cats/'

# Set the number of images to transform
NUM_IMAGES = 15747

In [None]:
# Get the list of image filenames
image_filenames = os.listdir(dataset_path)

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # convert PIL image to tensor and scales data into [0,1] 
    # transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # Scale between [-1, 1] by (input[channel] - mean[channel]) / std[channel]
])

# Create a list to store the transformed images
transformed_images = []

# Iterate over the first num_images filenames and transform the corresponding images
for i, filename in enumerate(image_filenames[:NUM_IMAGES]):
    # Load the image
    img_path = os.path.join(dataset_path, filename)
    image = Image.open(img_path)

    # Apply the transformations
    transformed_image = transform(image)

    # Append the transformed image to the list
    transformed_images.append(transformed_image)

# Convert the list of transformed images to a PyTorch tensor
transformed_images = torch.stack(transformed_images)

print(f'Loaded data: {transformed_images.shape}')

## split batch size 

In [None]:
# set batch size
batch_size = 16

data_loader = DataLoader(transformed_images, batch_size=batch_size, shuffle=True, drop_last=True)

data_iter = iter(data_loader)
print(next(data_iter).shape)

## import the model

base on DDPM and unet papers 
https://arxiv.org/pdf/1505.04597v1.pdf

In [None]:
from modules.ddpm import Diffusion, initialize_weights
from .modules.unet import UNet

## Set hyperparameter before training iteration

In [None]:
# base on the paper
LEARNING_RATE = 1e-4  #0.0001
BATCH_SIZE = 16
IMAGE_SIZE = 64
CHANNELS_IMG = 3
LATENT_DIM = 100

NUM_EPOCHS = 200



## Train model and save weight and log 

## training process
parameter base on DDPM paper 

In [None]:
import logging

import tqdm


model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=IMAGE_SIZE, device=device)
writer = SummaryWriter(os.path.join("logs/cats","DDPM"))

l = len(data_loader)

for epoch in range(NUM_EPOCHS):
    logging.info(f"Starting epoch {epoch}:")


    model.train()
    pbar = tqdm(data_loader)
    for i, (images, _) in enumerate(pbar):
        images = images.to(device)
        t = diffusion.sample_timesteps(images.shape[0]).to(device)
        x_t, noise = diffusion.noise_images(images, t)
        predicted_noise = model(x_t, t)
        loss = mse(noise, predicted_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix(MSE=loss.item())
        writer.add_scalar("MSE", loss.item(), global_step=epoch * l + i)

    sampled_images = diffusion.sample(model, n=images.shape[0])
    torch.save(model.state_dict(), os.path.join("weights", "cats", "DDPM", f"{epoch+1}.pt"))

    model.eval()
    with torch.no_grad():
        x = diffusion.sample(model, 8)


# Measure FID

use this implementation: https://github.com/mseitzer/pytorch-fid/tree/master

You need to import model first  
and masure define gen_log function