based on https://github.com/uygarkurt/DDPM-Image-Generation/blob/main/DDPM_Image_Generartion.ipynb

In [1]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim import SGD
from diffusers import UNet2DModel, DDPMScheduler, DDPMPipeline
from diffusers.optimization import get_cosine_schedule_with_warmup
from datasets import load_dataset
from accelerate import Accelerator
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import numpy as np
import random 
import timeit

import os
import time

In [2]:
RANDOM_SEED = 42
IMG_SIZE = 128
DATASET_PERCENT = 0.1
BATCH_SIZE = 16
LEARNING_RATE = 1e-4
NUM_EPOCHS = 30
NUM_GENERATE_IMAGES = 9
NUM_TIMESTEPS = 1000
MIXED_PRECISION = "fp16"
GRADIENT_ACCUMULATION_STEPS = 1

random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)
# torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
local_dataset_path = f"data/square{IMG_SIZE}_random{str(DATASET_PERCENT)}/"
dataset = load_dataset('imagefolder', data_dir=local_dataset_path)

Resolving data files:   0%|          | 0/30312 [00:00<?, ?it/s]

In [4]:
dataset = dataset['train']

In [5]:
preprocess = transforms.Compose(
[
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

In [6]:
def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform)

In [7]:
train_dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

In [8]:

model_small = UNet2DModel(
    sample_size=IMG_SIZE,
    in_channels=3,
    out_channels=3,
    layers_per_block=1,
    block_out_channels=(64, 64, 128, 128, 256),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D"
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D"
    )
)
model_small = model_small.to(device)

model_mid = UNet2DModel(
    sample_size=IMG_SIZE,
    in_channels=3,
    out_channels=3,
    layers_per_block=2,
    block_out_channels=(128, 128, 256, 256, 512, 512),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D"
    ),
    up_block_types=(
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D"
    )
)
model_mid = model_mid.to(device)



model_big = UNet2DModel(
    sample_size=IMG_SIZE,
    in_channels=3,
    out_channels=3,
    layers_per_block=3,
    block_out_channels=(128, 256, 256, 512, 512, 1024, 1024),
    down_block_types=(
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",
        "DownBlock2D",
        "DownBlock2D"
    ),
    up_block_types=(
        "UpBlock2D",
        "UpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D"
    )
)
model_big = model_big.to(device)


models = [model_small, model_mid, model_big]

In [9]:
noise_scheduler = DDPMScheduler(num_train_timesteps=NUM_TIMESTEPS)

In [10]:
def sample_image_generation(model, noise_scheduler, num_generate_images, random_seed, num_timesteps):
    pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
    
    images = pipeline(
        batch_size=num_generate_images,
        generator=torch.manual_seed(random_seed),
        num_inference_steps=num_timesteps
    ).images
    
    fig = plt.figure()
    for i in range(1, num_generate_images+1):
        fig.add_subplot(3, 3, i)
        plt.imshow(images[i-1])
    plt.show()

In [11]:
import torch
import torch.nn as nn
import numpy as np
from scipy.linalg import sqrtm
from torchvision.models import inception_v3
from torchvision.transforms import ToTensor, Resize, Normalize, Compose
from torch.utils.data import DataLoader, TensorDataset

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2):
    """Calculate the Fréchet Distance between two multivariate Gaussians."""
    covmean, _ = sqrtm(sigma1 @ sigma2, disp=False)
    if np.iscomplexobj(covmean):
        covmean = covmean.real

    diff = mu1 - mu2
    return diff @ diff + np.trace(sigma1 + sigma2 - 2 * covmean)

def get_activations(model, dataloader, device, key):
    """Get activations of the dataset images using the InceptionV3 model."""
    model.eval()
    activations = []

    for batch in tqdm(dataloader, desc="Calculating activations", leave=False):
        images = batch[key].to(device)
        with torch.no_grad():
            preds = model(images)
        activations.append(preds.cpu().numpy())

    activations = np.concatenate(activations, axis=0)
    return activations

def calculate_statistics(activations):
    """Calculate mean and covariance matrix of activations."""
    mu = np.mean(activations, axis=0)
    sigma = np.cov(activations, rowvar=False)
    return mu, sigma

def sample_image_generation(model, noise_scheduler, num_generate_images, random_seed, num_timesteps, device):
    pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)
    generator = torch.manual_seed(random_seed)
    
    images = pipeline(
        batch_size=num_generate_images,
        generator=generator,
        num_inference_steps=num_timesteps
    ).images
    
    # Transform images to tensor and normalize
    transform = preprocess
    
    transformed_images = torch.stack([transform(image) for image in images]).to(device)
    return transformed_images

def calculate_fid(model, dataloader_real, num_generated_images, noise_scheduler, random_seed, num_timesteps, device):
    """Calculate FID score for real and generated images."""
    # Load InceptionV3 model
    inception = inception_v3(pretrained=True, transform_input=False)
    inception.fc = nn.Identity()  # Remove the last fully connected layer
    inception.to(device)

    # Get activations for real images
    real_activations = get_activations(inception, dataloader_real, device, key = 'images')

    # Generate images and get activations for generated images
    generated_images = sample_image_generation(model, noise_scheduler, num_generated_images, random_seed, num_timesteps, device)
    generated_dataset = TensorDataset(generated_images)
    generated_dataloader = DataLoader(generated_dataset, batch_size=32, shuffle=False)
    generated_activations = get_activations(inception, generated_dataloader, device, key = 0)

    # Calculate statistics
    mu_real, sigma_real = calculate_statistics(real_activations)
    mu_generated, sigma_generated = calculate_statistics(generated_activations)

    # Calculate FID
    fid_score = calculate_frechet_distance(mu_real, sigma_real, mu_generated, sigma_generated)
    return fid_score


transform = preprocess




In [12]:
# ignore UserWarning
import warnings
warnings.filterwarnings("ignore")

In [13]:
NUM_GENERATE_IMAGES_FID = len(train_dataloader)  

for model in models:
    for learning_rate in [LEARNING_RATE / 10, LEARNING_RATE, LEARNING_RATE * 10]:
        for optimizer in ["Adam", "SGD"]:
            if optimizer == "Adam":
                optimizer = AdamW(model.parameters(), lr=learning_rate)
            elif optimizer == "SGD":
                optimizer = SGD(model.parameters(), lr=learning_rate)

            training_loss = []
            frechet_inception_distance = []


            lr_scheduler = get_cosine_schedule_with_warmup(

                optimizer=optimizer,

                num_warmup_steps=500,

                num_training_steps=len(train_dataloader) * NUM_EPOCHS,
            )


            accelerator = Accelerator(

                mixed_precision=MIXED_PRECISION,

                gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
            )


            model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
                model, optimizer, train_dataloader, lr_scheduler
            )


            start = timeit.default_timer()

            for epoch in tqdm(range(NUM_EPOCHS), position=0, leave=True, desc="EPOCHS"):

                model.train()

                train_running_loss = 0

                for idx, batch in enumerate(
                    tqdm(train_dataloader, position=0, desc="BATCHES", leave=False)
                ):

                    clean_images = batch["images"].to(device)

                    noise = torch.randn(clean_images.shape).to(device)

                    last_batch_size = len(clean_images)


                    timesteps = torch.randint(
                        0,
                        noise_scheduler.config.num_train_timesteps,
                        (last_batch_size,),
                    ).to(device)

                    noisy_images = noise_scheduler.add_noise(
                        clean_images, noise, timesteps
                    )


                    with accelerator.accumulate(model):

                        noise_pred = model(noisy_images, timesteps, return_dict=False)[
                            0
                        ]
                        loss = F.mse_loss(noise_pred, noise)
                        accelerator.backward(loss)

                        accelerator.clip_grad_norm_(model.parameters(), 1.0)
                        optimizer.step()
                        lr_scheduler.step()
                        optimizer.zero_grad()

                    train_running_loss += loss.item()
                train_loss = train_running_loss / (idx + 1)

                training_loss.append(train_loss)
                fid_score = calculate_fid(
                    model,
                    train_dataloader,
                    NUM_GENERATE_IMAGES_FID,
                    noise_scheduler,
                    RANDOM_SEED,
                    NUM_TIMESTEPS,
                    device,
                )

                frechet_inception_distance.append(fid_score)

                train_learning_rate = lr_scheduler.get_last_lr()[0]

                print("-" * 30)

                print(f"Train Loss EPOCH: {epoch+1}: {train_loss:.4f}")

                print(f"Train Learning Rate EPOCH: {epoch+1}: {train_learning_rate}")

                print(f"FID Score EPOCH: {epoch+1}: {fid_score:.4f}")

                print("-" * 30)

            stop = timeit.default_timer()

            print(f"Training Time: {stop-start:.2f}s")

            # save model with date and time in a folder
            os.makedirs("models", exist_ok=True)
            time = time.strftime("%Y-%m-%d_%H-%M-%S")
            model_path = f"models/{time}"
            os.makedirs(model_path, exist_ok=True)
            torch.save(model.state_dict(), f"{model_path}/model.pth")

            # save optimizer
            torch.save(optimizer.state_dict(), f"{model_path}/optimizer.pth")

            # save lr_scheduler
            torch.save(lr_scheduler.state_dict(), f"{model_path}/lr_scheduler.pth")

            # save noise_scheduler
            torch.save(noise_scheduler, f"{model_path}/noise_scheduler.pth")

            # save metadata
            metadata = {
                "IMG_SIZE": IMG_SIZE,
                "BATCH_SIZE": BATCH_SIZE,
                "LEARNING_RATE": LEARNING_RATE,
                "NUM_EPOCHS": NUM_EPOCHS,
                "NUM_GENERATE_IMAGES": NUM_GENERATE_IMAGES,
                "NUM_TIMESTEPS": NUM_TIMESTEPS,
                "MIXED_PRECISION": MIXED_PRECISION,
                "GRADIENT_ACCUMULATION_STEPS": GRADIENT_ACCUMULATION_STEPS,
                "losses": training_loss,
                "fid_scores": frechet_inception_distance,
                "dataset": f"square{IMG_SIZE}_random{str(DATASET_PERCENT)}",
            }

EPOCHS:   0%|          | 0/30 [00:00<?, ?it/s]

BATCHES:   0%|          | 0/1895 [00:00<?, ?it/s]