# DDPM with cats dataset

## import necessary modules 

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os

from torch.utils.tensorboard import SummaryWriter

## set agnostic

In [2]:
# agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## load data set

In [3]:
# set image size
img_size = 64

# Set the path to the dataset
dataset_path = 'dataset/cats/'

# Set the number of images to transform
NUM_IMAGES = 15747

In [4]:
# Get the list of image filenames
image_filenames = os.listdir(dataset_path)

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # convert PIL image to tensor and scales data into [0,1] 
    # transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # Scale between [-1, 1] by (input[channel] - mean[channel]) / std[channel]
])

# Create a list to store the transformed images
transformed_images = []

# Iterate over the first num_images filenames and transform the corresponding images
for i, filename in enumerate(image_filenames[:NUM_IMAGES]):
    # Load the image
    img_path = os.path.join(dataset_path, filename)
    image = Image.open(img_path)

    # Apply the transformations
    transformed_image = transform(image)

    # Append the transformed image to the list
    transformed_images.append(transformed_image)

# Convert the list of transformed images to a PyTorch tensor
transformed_images = torch.stack(transformed_images)

print(f'Loaded data: {transformed_images.shape}')

Loaded data: torch.Size([15747, 3, 64, 64])


## split batch size 

In [5]:
# set batch size
batch_size = 16

data_loader = DataLoader(transformed_images, batch_size=batch_size, shuffle=True, drop_last=True)

data_iter = iter(data_loader)
print(next(data_iter).shape)

torch.Size([16, 3, 64, 64])


## import the model

base on DDPM and unet papers 
https://arxiv.org/pdf/1505.04597v1.pdf

In [6]:
from modules.ddpm import Diffusion

from modules.modules import UNet

## Set hyperparameter before training iteration

In [7]:
# base on the paper
LEARNING_RATE = 1e-4  #0.0001
# LEARNING_RATE = 1e-5  #0.00001

BATCH_SIZE = 16
IMAGE_SIZE = 64
CHANNELS_IMG = 3
NUM_EPOCHS = 100



## Train model and save weight and log 

## training process
parameter base on DDPM paper 

In [8]:
import time

model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=IMAGE_SIZE, device=device)
writer = SummaryWriter(os.path.join("logs/cats","DDPM"))

l = len(data_loader)

time_use = 0

for epoch in range(NUM_EPOCHS):

    lossMean = 0

    # use time for time measurement
    start_time = time.time()

    model.train()
    for batch_idx, images in enumerate(data_loader):
        images = images.to(device)
        t = diffusion.sample_timesteps(images.shape[0]).to(device)
        x_t, noise = diffusion.noise_images(images, t)
        predicted_noise = model(x_t, t)
        loss = mse(noise, predicted_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # sum lose
        lossMean += loss

        # add log at first 
        if epoch == 0 and batch_idx == 0:
            writer.add_scalar("MSE", lossMean, global_step=epoch+1)
            writer.add_scalar("traing time", time_use, epoch_index)

    end_time = time.time()
    epoch_time = end_time - start_time

    time_use += epoch_time 

    epoch_index = epoch+1

    # calculate mean value
    lossMean = lossMean / len(data_loader.dataset)

    print(
        f"Epoch [{epoch_index}/{NUM_EPOCHS}] Batch {batch_idx+1}/{len(data_loader)} Using Time: {epoch_time:.4f}\
            Loss: {lossMean:.4f}"
    )

    # tensorboard
    writer.add_scalar("MSE", lossMean, global_step=epoch+1)
    writer.add_scalar("traing time", time_use, epoch_index)

    torch.save(model.state_dict(), os.path.join("weights", "cats", "DDPM", f"{epoch_index}.pt"))

    model.eval()
    with torch.no_grad():
        x = diffusion.sample(model, n=32).type(dtype=torch.float32)
        img_grid = torchvision.utils.make_grid(x[:32], normalize=True)
        writer.add_image("All/Generate Images", img_grid, global_step=epoch_index)


  return F.conv2d(input, weight, bias, self.stride,


Epoch [1/100] Batch 984/984 Using Time: 252.2996            Loss: 0.0066


999it [02:44,  6.07it/s]


Epoch [2/100] Batch 984/984 Using Time: 264.0821            Loss: 0.0031


999it [02:48,  5.91it/s]


Epoch [3/100] Batch 984/984 Using Time: 266.6037            Loss: 0.0027


999it [02:49,  5.88it/s]


Epoch [4/100] Batch 984/984 Using Time: 267.6915            Loss: 0.0025


999it [02:50,  5.87it/s]


Epoch [5/100] Batch 984/984 Using Time: 267.9238            Loss: 0.0024


999it [02:49,  5.88it/s]


Epoch [6/100] Batch 984/984 Using Time: 267.2298            Loss: 0.0023


999it [02:50,  5.86it/s]


Epoch [7/100] Batch 984/984 Using Time: 267.8265            Loss: 0.0022


999it [02:50,  5.87it/s]


Epoch [8/100] Batch 984/984 Using Time: 267.0831            Loss: 0.0021


999it [02:49,  5.88it/s]


Epoch [9/100] Batch 984/984 Using Time: 267.4691            Loss: 0.0021


999it [02:50,  5.85it/s]


Epoch [10/100] Batch 984/984 Using Time: 268.5200            Loss: 0.0021


999it [02:52,  5.79it/s]


Epoch [11/100] Batch 984/984 Using Time: 266.9136            Loss: 0.0021


999it [02:47,  5.97it/s]


Epoch [12/100] Batch 984/984 Using Time: 264.1223            Loss: 0.0019


999it [02:47,  5.97it/s]


Epoch [13/100] Batch 984/984 Using Time: 264.2991            Loss: 0.0020


999it [02:47,  5.96it/s]


Epoch [14/100] Batch 984/984 Using Time: 264.2016            Loss: 0.0019


999it [02:47,  5.95it/s]


Epoch [15/100] Batch 984/984 Using Time: 264.2225            Loss: 0.0019


999it [02:48,  5.94it/s]


Epoch [16/100] Batch 984/984 Using Time: 266.6485            Loss: 0.0020


999it [02:49,  5.89it/s]


Epoch [17/100] Batch 984/984 Using Time: 268.1866            Loss: 0.0019


999it [02:50,  5.87it/s]


Epoch [18/100] Batch 984/984 Using Time: 267.0108            Loss: 0.0018


999it [02:49,  5.89it/s]


Epoch [19/100] Batch 984/984 Using Time: 266.0725            Loss: 0.0019


999it [02:48,  5.92it/s]


Epoch [20/100] Batch 984/984 Using Time: 266.0160            Loss: 0.0018


999it [02:50,  5.86it/s]


Epoch [21/100] Batch 984/984 Using Time: 269.0245            Loss: 0.0019


999it [02:51,  5.82it/s]


Epoch [22/100] Batch 984/984 Using Time: 274.4324            Loss: 0.0019


999it [02:55,  5.68it/s]


Epoch [23/100] Batch 984/984 Using Time: 267.9397            Loss: 0.0018


999it [02:49,  5.88it/s]


Epoch [24/100] Batch 984/984 Using Time: 265.5469            Loss: 0.0019


999it [02:47,  5.98it/s]


Epoch [25/100] Batch 984/984 Using Time: 262.9820            Loss: 0.0018


999it [02:46,  5.99it/s]


Epoch [26/100] Batch 984/984 Using Time: 262.2998            Loss: 0.0019


999it [02:46,  5.99it/s]


Epoch [27/100] Batch 984/984 Using Time: 262.6237            Loss: 0.0019


999it [02:46,  6.01it/s]


Epoch [28/100] Batch 984/984 Using Time: 261.5062            Loss: 0.0018


999it [02:46,  6.01it/s]


Epoch [29/100] Batch 984/984 Using Time: 261.9300            Loss: 0.0019


999it [02:46,  6.02it/s]


Epoch [30/100] Batch 984/984 Using Time: 261.5533            Loss: 0.0018


999it [02:45,  6.02it/s]


Epoch [31/100] Batch 984/984 Using Time: 260.7541            Loss: 0.0018


999it [02:45,  6.03it/s]


Epoch [32/100] Batch 984/984 Using Time: 260.4207            Loss: 0.0019


999it [02:45,  6.04it/s]


Epoch [33/100] Batch 984/984 Using Time: 261.1978            Loss: 0.0018


999it [02:45,  6.04it/s]


Epoch [34/100] Batch 984/984 Using Time: 260.8551            Loss: 0.0018


999it [02:45,  6.05it/s]


Epoch [35/100] Batch 984/984 Using Time: 260.5156            Loss: 0.0018


999it [02:45,  6.04it/s]


Epoch [36/100] Batch 984/984 Using Time: 260.6349            Loss: 0.0019


999it [02:45,  6.03it/s]


Epoch [37/100] Batch 984/984 Using Time: 261.3359            Loss: 0.0018


999it [02:45,  6.02it/s]


Epoch [38/100] Batch 984/984 Using Time: 261.3278            Loss: 0.0018


999it [02:45,  6.03it/s]


Epoch [39/100] Batch 984/984 Using Time: 261.0544            Loss: 0.0018


999it [02:45,  6.05it/s]


Epoch [40/100] Batch 984/984 Using Time: 261.7021            Loss: 0.0018


999it [02:46,  5.99it/s]


Epoch [41/100] Batch 984/984 Using Time: 262.2472            Loss: 0.0018


999it [02:46,  5.99it/s]


Epoch [42/100] Batch 984/984 Using Time: 262.3739            Loss: 0.0018


999it [02:45,  6.03it/s]


Epoch [43/100] Batch 984/984 Using Time: 261.0179            Loss: 0.0017


999it [02:45,  6.04it/s]


Epoch [44/100] Batch 984/984 Using Time: 260.6831            Loss: 0.0018


999it [02:45,  6.04it/s]


Epoch [45/100] Batch 984/984 Using Time: 260.6659            Loss: 0.0018


999it [02:45,  6.04it/s]


Epoch [46/100] Batch 984/984 Using Time: 260.6827            Loss: 0.0018


999it [02:45,  6.03it/s]


Epoch [47/100] Batch 984/984 Using Time: 261.1944            Loss: 0.0018


999it [02:45,  6.03it/s]


Epoch [48/100] Batch 984/984 Using Time: 261.4617            Loss: 0.0018


999it [02:46,  6.01it/s]


Epoch [49/100] Batch 984/984 Using Time: 261.5864            Loss: 0.0017


999it [02:45,  6.02it/s]


Epoch [50/100] Batch 984/984 Using Time: 262.0904            Loss: 0.0018


999it [02:45,  6.02it/s]


Epoch [51/100] Batch 984/984 Using Time: 262.0709            Loss: 0.0017


999it [02:46,  6.00it/s]


Epoch [52/100] Batch 984/984 Using Time: 261.7911            Loss: 0.0018


999it [02:46,  6.01it/s]


Epoch [53/100] Batch 984/984 Using Time: 262.1972            Loss: 0.0017


999it [02:46,  6.00it/s]


Epoch [54/100] Batch 984/984 Using Time: 262.4919            Loss: 0.0018


999it [02:46,  5.99it/s]


Epoch [55/100] Batch 984/984 Using Time: 262.4278            Loss: 0.0018


999it [02:46,  5.99it/s]


Epoch [56/100] Batch 984/984 Using Time: 262.6061            Loss: 0.0017


999it [02:46,  5.98it/s]


Epoch [57/100] Batch 984/984 Using Time: 262.9850            Loss: 0.0018


999it [02:46,  5.98it/s]


Epoch [58/100] Batch 984/984 Using Time: 265.9537            Loss: 0.0018


999it [02:51,  5.84it/s]


Epoch [59/100] Batch 984/984 Using Time: 267.9586            Loss: 0.0017


999it [02:50,  5.86it/s]


Epoch [60/100] Batch 984/984 Using Time: 268.3881            Loss: 0.0017


999it [02:52,  5.80it/s]


Epoch [61/100] Batch 984/984 Using Time: 271.0644            Loss: 0.0018


999it [02:53,  5.75it/s]


Epoch [62/100] Batch 984/984 Using Time: 270.1576            Loss: 0.0017


999it [02:51,  5.81it/s]


Epoch [63/100] Batch 984/984 Using Time: 267.9355            Loss: 0.0017


999it [02:52,  5.78it/s]


KeyboardInterrupt: 

# Measure FID

use this implementation: https://github.com/mseitzer/pytorch-fid/tree/master

## define FID measurement function 

In [43]:
import subprocess
import torch
from torchvision.utils import save_image
import re

# Create a function to run the FID script
def run_fid(real_path, gen_path, epoch):
    command = ["python", "-m", "pytorch_fid", real_path, gen_path]

    output = subprocess.run(command, capture_output=True, text=True)

    # Extract the FID score using regular expressions
    output = output.stdout
    fid_score_match = re.search(r"FID:\s+(\d+\.\d+)", output)
    
    if fid_score_match:
        fid_score = float(fid_score_match.group(1))
        print("FID score:", fid_score)

        # Write the FID score to the log file
        # gen_log(log_path="logs/cats/fid/DCGAN.log", message=f"Epoch: {epoch}, FID score: {fid_score}")
        
    else:
        print("FID score not found in the output.")

    return fid_score

def FID_measure(model, sample_n=100*8, batch_size=8, device="cpu", real_path="dataset/cats/", gen_path="generated_images", epoch=int):
    # Set the model to evaluation mode
    model.eval()

    # Create a folder for generated images if it doesn't exist
    os.makedirs(gen_path, exist_ok=True)
    torch.manual_seed(2023)

    # Generate images
    with torch.no_grad():
        for i in range(0, sample_n, batch_size):
            
            x = diffusion.sample(model, n=batch_size).type(dtype=torch.float32)

            for j in range(batch_size):
                save_image(x[j], f"{gen_path}/{i+j}.png", normalize=True)
        
    # Measure FID score between the real and generated images
    fid_score = run_fid(real_path, gen_path, epoch)
    return fid_score


## Load weight of model to measure FID score

In [44]:
def load_weight( model, weight_path="", index=int):

    # Load the saved weights
    model.load_state_dict(torch.load(f'{weight_path}/{index}.pt'))
    

In [49]:
import glob
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'logs/cats/DDPM/')

# Specify the directory path
directory_path = Path('weights/cats/DDPM/')

# Get the list of files ending with ".pth"
# file_list = list(directory_path.glob('*.pt'))


model = UNet().to(device)

sample_size = 10000

for i in range(NUM_EPOCHS):
    index_weight = i+1

    load_weight(model, weight_path=directory_path, index=index_weight)

    print('Epoch: ',index_weight)
    fid_score = FID_measure(model, sample_size, batch_size=64, device=f"{device}", gen_path=f"gen_image/cats/DDPM/{index_weight}/", epoch=index_weight)   
    
    writer.add_scalar("Metrics/FID Score", fid_score, index_weight)


11:38:55 - INFO: Sampling 64 new images....


Epoch:  1


999it [06:15,  2.66it/s]
11:45:10 - INFO: Sampling 64 new images....
563it [03:33,  2.64it/s]


KeyboardInterrupt: 

# Count learnable parameter in model

In [3]:
model = UNet().to("cuda")


pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

In [4]:
print(pytorch_total_params)

22291587
