In [20]:
import torch
from PIL import Image
from diffusers import DDPMScheduler
from datasets import load_dataset
from torchvision import transforms
import matplotlib.pyplot as plt
from diffusers import UNet2DModel
from diffusers.optimization import get_cosine_schedule_with_warmup
from TL.generation import TemperingGeneration

In [19]:
image_size = 32
train_batch_size = 16
T = 500
num_epochs = T + 1
learning_rate = 1e-4
lr_warmup_steps = 500

In [8]:
dataset = load_dataset(
    "huggan/smithsonian_butterflies_subset", 
    split="train", 
    cache_dir="./cache"
)

preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5]),
    ]
)

def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform)

sample_image = dataset[0]["images"].unsqueeze(0)

Repo card metadata block was not found. Setting CardData to empty.


In [14]:
train_dataloader = torch.utils.data.DataLoader(
    dataset, 
    batch_size=train_batch_size, 
    shuffle=True,
    drop_last=True
)

In [13]:
model = UNet2DModel(
    sample_size=image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(32, 32, 64, 64, 128, 128),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "DownBlock2D",
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
        "UpBlock2D",
    ),
)

num_params = sum(p.numel() for p in model.parameters())
print(f"Number of parameters: {num_params}")

Number of parameters: 7126275


In [16]:
noise_scheduler = DDPMScheduler(num_train_timesteps=T)

In [21]:
optimizer = torch.optim.AdamW(
    model.parameters(), 
    lr=learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=lr_warmup_steps,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [None]:
TemperingGeneration(
    model=model,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    noise_scheduler=noise_scheduler,
    num_training_batches=len(train_dataloader),
)