In [1]:
import torch

from dataclasses import dataclass
from lightning.fabric.utilities import AttributeDict
from tqdm import tqdm

import lightning as L

from diffusers import DDPMScheduler
from datasets import load_dataset
from torchvision import transforms
from diffusers import UNet2DModel
from diffusers.optimization import get_cosine_schedule_with_warmup

In [4]:
@dataclass
class Arguments:

    seed: int = 42

    image_size: int = 32

    train_batch_size: int = 16
    learning_rate: float = 1e-4
    lr_warmup_fraction: float = 0.1

    start_timestep: int = 900

    num_epochs: int = 10
    sample_prediction: bool = False

args = Arguments()

In [5]:
# set seed
L.seed_everything(args.seed)

# set fabric
fabric = L.Fabric(
    accelerator="auto", 
    devices="auto",
    strategy="auto",
    precision="32-true",
)

# launch fabric
fabric.launch()

Seed set to 42


In [6]:
# set dataset
train_dataset = load_dataset(
    "huggan/smithsonian_butterflies_subset", 
    split="train", 
    cache_dir="./cache"
)

# set transform
preprocess = transforms.Compose(
    [
        transforms.Resize((args.image_size, args.image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

train_dataset.set_transform(transform)
num_training_samples = len(train_dataset)

# set dataloader
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_size=args.train_batch_size, 
    shuffle=True,
    pin_memory=True,
)
num_training_batches = len(train_dataloader)

print ("Number of training batches:", num_training_batches)

Repo card metadata block was not found. Setting CardData to empty.


Number of training batches: 63


In [7]:
# set model
model = UNet2DModel(
    sample_size=args.image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 64, 128, 128, 256, 256),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

num_trainable_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_trainable_params}")

Number of parameters: 28447235


In [17]:
# set noise scheduler
noise_scheduler = DDPMScheduler()
print ("Number of training timesteps:", noise_scheduler.config.num_train_timesteps)

def add_noise_from_timestep(    
    original_samples: torch.Tensor,
    noise: torch.Tensor,
    timesteps: torch.IntTensor,
) -> torch.Tensor:
    
    # calculate the cumulative product of reversed alphas
    alphas_reversed_cumprod = torch.cumprod(torch.flip(noise_scheduler.alphas, [0]), dim=0)
    # each prod starts from the current timestep to the end
    alphas_reversed_cumprod = torch.flip(alphas_reversed_cumprod, [0])

    alphas_reversed_cumprod = alphas_reversed_cumprod.to(device=original_samples.device)
    alphas_reversed_cumprod = alphas_reversed_cumprod.to(dtype=original_samples.dtype)
    timesteps = timesteps.to(original_samples.device)

    sqrt_alpha_reversed_cumprod = alphas_reversed_cumprod[timesteps] ** 0.5
    sqrt_alpha_reversed_cumprod = sqrt_alpha_reversed_cumprod.flatten()
    while len(sqrt_alpha_reversed_cumprod.shape) < len(original_samples.shape):
        sqrt_alpha_reversed_cumprod = sqrt_alpha_reversed_cumprod.unsqueeze(-1)

    sqrt_one_minus_alpha_reversed_cumprod = (1 - alphas_reversed_cumprod[timesteps]) ** 0.5
    sqrt_one_minus_alpha_reversed_cumprod = sqrt_one_minus_alpha_reversed_cumprod.flatten()
    while len(sqrt_one_minus_alpha_reversed_cumprod.shape) < len(original_samples.shape):
        sqrt_one_minus_alpha_reversed_cumprod = sqrt_one_minus_alpha_reversed_cumprod.unsqueeze(-1)

    noisy_samples = sqrt_alpha_reversed_cumprod * original_samples + sqrt_one_minus_alpha_reversed_cumprod * noise
    return noisy_samples

Number of training timesteps: 1000


In [9]:
# set optimizer
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=args.learning_rate
)

# set lr scheduler
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=int(args.lr_warmup_fraction * args.num_epochs * num_training_batches),
    num_training_steps=args.num_epochs * num_training_batches
)

In [10]:
# set up objects
model, optimizer = fabric.setup(model, optimizer)
train_dataloader = fabric.setup_dataloaders(train_dataloader)

In [18]:
def train_step(model, batch, start_timestep):

    assert start_timestep >= -1 and start_timestep <= noise_scheduler.config.num_train_timesteps-2, "start_timestep must be in the range of -2 to noise_scheduler.config.num_train_timesteps-1, including the boundary values!"

    # get the clean images
    clean_images = batch["images"]

    # add noise to the clean images to generate noisy images
    if start_timestep == -1:
        noisy_images = clean_images
    else:
        noisy_images = noise_scheduler.add_noise(clean_images, torch.randn_like(clean_images), torch.tensor(start_timestep, device=clean_images.device, dtype=torch.int))

    # draw epsilon from the standard normal distribution
    epsilon = torch.randn_like(clean_images)

    # add epsilon to the noisy images to generate pesudo white (standard normal) noise from there
    if start_timestep == -1:
        timestep = torch.tensor(0, device=clean_images.device, dtype=torch.int)
    else:
        timestep = torch.tensor(start_timestep+1, device=clean_images.device, dtype=torch.int)
    pesudo_white_noise = add_noise_from_timestep(noisy_images, epsilon, timestep)

    # predict the epsilon
    predicted_noisy_images = model(pesudo_white_noise, timestep).sample

    # calculate the mean squared error loss
    mse_loss = torch.nn.functional.mse_loss(noisy_images, predicted_noisy_images)
    
    return mse_loss


In [19]:
# train
model.train()

# Calculate total iterations
total_iterations = args.num_epochs * len(train_dataloader)

progress_bar = tqdm(total=total_iterations, desc="Training")

t = 0

for epoch in range(args.num_epochs):

    avg_train_loss_per_epoch = 0

    for _, batch in enumerate(train_dataloader):

        optimizer.zero_grad()

        loss = train_step(model, batch, args.start_timestep)

        fabric.backward(loss)

        optimizer.step()

        lr_scheduler.step()

        avg_train_loss_per_epoch += loss.item()
        progress_bar.update(1)
        progress_bar.set_postfix({
            'epoch': f'{epoch+1}/{args.num_epochs}',
            'loss': f'{loss.item():.4f}',
            'lr': f'{lr_scheduler.get_last_lr()[0]:.6f}'
        })
    
    avg_train_loss_per_epoch /= num_training_batches
    #print(f"Epoch {epoch+1}/{args.num_epochs}, Average train loss: {avg_train_loss_per_epoch:.4f}")


Training:   0%|          | 0/630 [00:39<?, ?it/s]


UnboundLocalError: cannot access local variable 'alphas_reversed_cumprod' where it is not associated with a value