Imports

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import numpy as np
import torchvision
from torchvision import transforms
import diffusers
from diffusers import DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
import sys
from PIL import Image
sys.path.append("../")

import utils


Training Config

In [1]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    run_name = "DDPM-FFHQ-TEST"
    num_train_timestamps = 1000
    dataset = "FFHQ" # "CIFAR10" / "FFHQ"
    image_size = 32  # the generated image resolution
    train_batch_size = 6
    eval_batch_size = 6  # how many images to sample during evaluation
    num_epochs = 50
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 1
    save_model_epochs = 1
    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'ddpm-ffhq-128'  # the model namy locally and on the HF Hub
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0
    down_dropout = 0.0
    mid_dropout = 0.0
    up_dropout = 0.0
    bayesian_avg_samples = 1
    bayesian_avg_range = (0, 1000)

config = TrainingConfig()

Create Dataloader

In [None]:
#transforms
image_transforms = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

if config.dataset == "FFHQ":
    # FFHQ dataset
    dataset = utils.ffhq_Dataset("../dataset/ffhq/thumbnails128x128/", image_transforms)
    config.image_size = 128
elif config.dataset == "CIFAR10":
    #cifar dataset
    dataset = torchvision.datasets.CIFAR10(root= "../dataset/", download=True, image_transform=transforms)
    config.image_size = 32
else:
    raise ValueError("Invalid Dataset supplied")

train_loader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=4, pin_memory=True)

Setup for Training

In [None]:
model = utils.get_default_unet(config)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_loader) * config.num_epochs),
)
noise_scheduler = DDPMScheduler(num_train_timesteps=config.num_train_timestamps)

#load pretrained
# pipeline = diffusers.DDPMPipeline(model, noise_scheduler).from_pretrained("ddpm-ffhq-128")
# model = pipeline.unet

Start Training

In [5]:
from accelerate import notebook_launcher
args = (config, model, noise_scheduler, optimizer, train_loader, lr_scheduler)

notebook_launcher(utils.ddpm_train_loop, args, num_processes=torch.cuda.device_count())

Visualize Outputs

In [None]:
import glob

sample_images = sorted(glob.glob(f"{config.output_dir}/samples/*.png"))
Image.open(sample_images[-1])