# DDPM with cats dataset

## import nesessary modules 

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os

from torch.utils.tensorboard import SummaryWriter
from modules.dcgan import Discriminator, Generator, initialize_weights

## set agnostic

In [2]:
# agnostic code
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## load data set

In [3]:
# set image size
img_size = 64

# Set the path to the dataset
dataset_path = 'dataset/cats/'

# Set the number of images to transform
NUM_IMAGES = 15747

In [4]:
# Get the list of image filenames
image_filenames = os.listdir(dataset_path)

# Define the image transformations
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    # transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), # convert PIL image to tensor and scales data into [0,1] 
    # transforms.Lambda(lambda t: (t * 2) - 1) # Scale between [-1, 1] 
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # Scale between [-1, 1] by (input[channel] - mean[channel]) / std[channel]
])

# Create a list to store the transformed images
transformed_images = []

# Iterate over the first num_images filenames and transform the corresponding images
for i, filename in enumerate(image_filenames[:NUM_IMAGES]):
    # Load the image
    img_path = os.path.join(dataset_path, filename)
    image = Image.open(img_path)

    # Apply the transformations
    transformed_image = transform(image)

    # Append the transformed image to the list
    transformed_images.append(transformed_image)

# Convert the list of transformed images to a PyTorch tensor
transformed_images = torch.stack(transformed_images)

print(f'Loaded data: {transformed_images.shape}')

Loaded data: torch.Size([15747, 3, 64, 64])


## split batch size 

In [5]:
# set batch size
batch_size = 16

data_loader = DataLoader(transformed_images, batch_size=batch_size, shuffle=True, drop_last=True)

data_iter = iter(data_loader)
print(next(data_iter).shape)

torch.Size([16, 3, 64, 64])


## import the model

base on DDPM and unet papers 
https://arxiv.org/pdf/1505.04597v1.pdf

In [6]:
from modules.ddpm import Diffusion

from modules.modules import UNet

## Set hyperparameter before training iteration

In [7]:
# base on the paper
LEARNING_RATE = 1e-4  #0.0001
BATCH_SIZE = 16
IMAGE_SIZE = 64
CHANNELS_IMG = 3
LATENT_DIM = 100

NUM_EPOCHS = 200



## Train model and save weight and log 

## training process
parameter base on DDPM paper 

In [23]:
import logging
import time

model = UNet().to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
mse = nn.MSELoss()
diffusion = Diffusion(img_size=IMAGE_SIZE, device=device)
writer = SummaryWriter(os.path.join("logs/cats","DDPM"))

l = len(data_loader)

time_use = 0

for epoch in range(NUM_EPOCHS):

    # Shuffle the dataset at the beginning of each epoch
    data_loader = DataLoader(transformed_images, batch_size=batch_size, shuffle=True, drop_last=True)

    lossMean = 0

    # logging.info(f"Starting epoch {epoch}:")

    # use time for time measurement
    start_time = time.time()

    model.train()
    for batch_idx, images in enumerate(data_loader):
        images = images.to(device)
        t = diffusion.sample_timesteps(images.shape[0]).to(device)
        x_t, noise = diffusion.noise_images(images, t)
        predicted_noise = model(x_t, t)
        loss = mse(noise, predicted_noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # sum lose
        lossMean += loss

    end_time = time.time()
    epoch_time = end_time - start_time

    time_use += epoch_time 

    epoch_index = epoch+1

    # calculate mean value
    lossMean = lossMean / len(data_loader.dataset)

    print(
        f"Epoch [{epoch_index}/{NUM_EPOCHS}] Batch {batch_idx+1}/{len(data_loader)} Using Time: {epoch_time:.4f}\
            Loss: {lossMean:.4f}"
    )

    # tensorboard
    writer.add_scalar("MSE", lossMean, global_step=epoch+1)
    writer.add_scalar("traing time", time_use, epoch_index)

    torch.save(model.state_dict(), os.path.join("weights", "cats", "DDPM", f"{epoch_index}.pt"))

    model.eval()
    with torch.no_grad():
        x = diffusion.sample(model, n=32).type(dtype=torch.float32)
        img_grid = torchvision.utils.make_grid(x[:32], normalize=True)
        writer.add_image("All/Generate Images", img_grid, global_step=epoch_index)


06:53:26 - INFO: Starting epoch 0:


Epoch [1/200] Batch 984/984 Using Time: 255.3108            Loss: 0.0069


06:57:41 - INFO: Sampling 32 new images....
999it [02:46,  6.00it/s]
07:00:28 - INFO: Starting epoch 1:


# Measure FID

use this implementation: https://github.com/mseitzer/pytorch-fid/tree/master

## define FID measurement function 

In [None]:
import subprocess
import torch
from torchvision.utils import save_image
import re

# Create a function to run the FID script
def run_fid(real_path, gen_path, epoch):
    command = ["python", "-m", "pytorch_fid", real_path, gen_path]

    output = subprocess.run(command, capture_output=True, text=True)

    # Extract the FID score using regular expressions
    output = output.stdout
    fid_score_match = re.search(r"FID:\s+(\d+\.\d+)", output)
    
    if fid_score_match:
        fid_score = float(fid_score_match.group(1))
        print("FID score:", fid_score)

        # Write the FID score to the log file
        # gen_log(log_path="logs/cats/fid/DCGAN.log", message=f"Epoch: {epoch}, FID score: {fid_score}")
        
    else:
        print("FID score not found in the output.")

    return fid_score

def FID_measure(model, sample_size=100*8, batch_size=8, device="cpu", real_path="dataset/cats/", gen_path="generated_images", epoch=int):
    # Set the model to evaluation mode
    model.eval()

    # Create a folder for generated images if it doesn't exist
    os.makedirs(gen_path, exist_ok=True)
    torch.manual_seed(2023)

    # Generate images
    with torch.no_grad():
        for i in range(0, sample_size, batch_size):
            
            # Random value from the normal distribution
            z = torch.randn(batch_size, LATENT_DIM, 1, 1).to(device)

            # Random from uniform distribution in range [-2, 2]
            # z = torch.FloatTensor(batch_size, LATENT_DIM, 1, 1).uniform_(-2, 2).to(device)

            # Generate images
            gen_imgs = model(z)

            # Save images
            for j in range(batch_size):
                save_image(gen_imgs[j], f"{gen_path}/{i+j}.png", normalize=True)
        
    # Measure FID score between the real and generated images
    fid_score = run_fid(real_path, gen_path, epoch)
    return fid_score


## Load weight of model to measure FID score

In [None]:
def load_weight( generator, discriminator=None, weight_path="", index=int):

    # Load the saved weights
    generator.load_state_dict(torch.load(f'{weight_path}/G{index}.pt'))
    # discriminator.load_state_dict(torch.load(f'{weight_path}/D{index}.pt'))
    

In [None]:
import glob
from pathlib import Path
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(f'logs/cats/DCGAN/')

# Specify the directory path
directory_path = Path('weights/cats/DCGAN/')

# Get the list of files ending with ".pth"
# file_list = list(directory_path.glob('*.pt'))


gen = Generator(LATENT_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)

sample_size = 10000

for i in range(NUM_EPOCHS):
    index_weight = i+1

    load_weight(generator=gen, weight_path=directory_path, index=index_weight)

    print('Epoch: ',index_weight)
    fid_score = FID_measure(gen, sample_size, batch_size=16, device=f"{device}", gen_path=f"gen_image/cats/DCGAN/{index_weight}/", epoch=index_weight)   
    
    writer.add_scalar("Metrics/FID Score", fid_score, index_weight)
