# Train Diffusion Model

In [None]:
# If you are using Google Collab, you can import the following:
# %pip install -U diffusers datasets transformers accelerate ftfy pyarrow wandb pandas numpy

Let's import the libraries we'll be using and define a few convenience functions which we'll use later in the notebook:

In [None]:
from argparse import Namespace

import numpy as np
import pandas as pd

from matplotlib import pyplot as plt

from PIL import Image

from diffusers import DDPMScheduler
from diffusers import UNet2DModel

import torch
from torch.utils.data import Dataset
import torch.nn.functional as F

from torchvision import transforms
import torchvision

import wandb

In [None]:
SEED = 1
# DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
DEVICE = torch.device('mps')

CONFIG = Namespace(
    run_name='diffusion-debug',
    evaluation_strategy='epoch',
    save_strategy='epoch',
    model_name='cat-dataset-10K-baseline-model',
    per_device_train_batch_size=8,
    num_train_epochs=50,
    learning_rate=4e-4,
    seed=SEED,
    num_train_timesteps=1000,
    beta_schedule='squaredcos_cap_v2',
    train_limit=10000,
    )

In [None]:
# NOTE: These are from the HuggingFace Diffusion Model Class (see https://github.com/huggingface/diffusion-models-class/blob/main/unit1/01_introduction_to_diffusers.ipynb)

def show_images(x):
    """
    Given a batch of images x, make a grid and convert to PIL
    """

    x = x * 0.5 + 0.5  # Map from (-1, 1) back to (0, 1)
    grid = torchvision.utils.make_grid(x)
    grid_im = grid.detach().cpu().permute(1, 2, 0).clip(0, 1) * 255
    grid_im = Image.fromarray(np.array(grid_im).astype(np.uint8))
    return grid_im


def make_grid(images, size=64):
    """
    Given a list of PIL images, stack them together into a line for easy viewing
    """

    output_im = Image.new("RGB", (size * len(images), size))
    for i, im in enumerate(images):
        output_im.paste(im.resize((size, size)), (i * size, 0))
    return output_im

In [None]:
def download_data(run):
    """
    Download data from wandb
    """

    split_data_loc = run.use_artifact('cat_dataset:v0')
    table = split_data_loc.get("huggan_cat_dataset")
    return table

def get_df(table, is_test=False):
    """
    Get dataframe from wandb table
    """
    dataframe = pd.DataFrame(data=table.data, columns=table.columns)
    return dataframe

In [None]:
RUN = wandb.init(project='Cat-Generator', entity=None, job_type="training", name=CONFIG.run_name)
WAND_TABLE = download_data(RUN)
DATASET_DF = get_df(WAND_TABLE)

# RUN.finish()

In [None]:
IMAGE_SIZE = 32

## Create Dataset

In [None]:
PREPROCESS = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)

class CatDataset(Dataset):

    def __init__(self, images:list, config) -> None:

        self.images = []
        for image in images:
            image_ = image.image
            image_ = PREPROCESS(image_.convert("RGB"))
            self.images.append(image_)

        self.config = config

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        
        image = self.images[index]
        output = {'images': image}

        return output

dataframe = DATASET_DF
if CONFIG.train_limit > 0:
    dataframe = DATASET_DF.iloc[0:CONFIG.train_limit, :]

dataset = CatDataset(
    dataframe.image.values,
    CONFIG)

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=CONFIG.per_device_train_batch_size,
    shuffle=True)

In [None]:
# Create UNet Model

MODEL = UNet2DModel(
    sample_size=IMAGE_SIZE,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(32, 64, 64, 128),  # More channels -> more parameters
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)
MODEL.to(DEVICE)

## Step 4: Create a Training Loop

In [None]:
# Set the noise scheduler
noise_scheduler = DDPMScheduler(
    num_train_timesteps=CONFIG.num_train_timesteps,
    beta_schedule=CONFIG.beta_schedule
)

# Training loop
optimizer = torch.optim.AdamW(
    MODEL.parameters(),
    lr=CONFIG.learning_rate)

num_steps = 0
for epoch in range(CONFIG.num_train_epochs):
    epoch_loss = []
    for step, batch in enumerate(dataloader):
        clean_images = batch["images"].to(DEVICE)
        # print(f"Clean image shape: {clean_images.shape}")

        # Sample noise to add to the images
        noise = torch.randn(clean_images.shape).to(clean_images.device)
        bs = clean_images.shape[0]

        # print(f"Noise shape: {noise.shape}")
        # Sample a random timestep for each image
        timesteps = torch.randint(
            0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device
        ).long()

        # print(f"Timesteps shape: {timesteps.shape}")
        # raise

        # Add noise to the clean images according to the noise magnitude at each timestep
        noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

        # Get the model prediction
        noise_pred = MODEL(noisy_images, timesteps, return_dict=False)[0]

        # Calculate the loss
        loss = F.mse_loss(noise_pred, noise)
        loss.backward(loss)

        epoch_loss.append(loss.item())
        RUN.log({'loss-1': loss.item()}, commit=False, step=num_steps)
        num_steps += 1

        # Update the model parameters with the optimizer
        optimizer.step()
        optimizer.zero_grad()

    RUN.log({'epoch-loss': sum(epoch_loss)/len(epoch_loss)})

Plotting the loss, we see that the model rapidly improves initially and then continues to get better at a slower rate (which is more obvious if we use a log scale as shown on the right):

## Check model training

In [None]:
# Random starting point (8 random images):
sample = torch.randn(8, 3, IMAGE_SIZE, IMAGE_SIZE).to(DEVICE)

for i, t in enumerate(noise_scheduler.timesteps):

    # Get model pred
    with torch.no_grad():
        residual = MODEL(sample, t).sample

    # Update sample with step
    sample = noise_scheduler.step(residual, t, sample).prev_sample

# show_images(sample)
show_images(sample).resize((8 * IMAGE_SIZE, IMAGE_SIZE), resample=Image.NEAREST)

## Save model to W&Bs

In [None]:
model_art = wandb.Artifact(CONFIG.model_name, type='model')
torch.save(MODEL.state_dict(), 'model.pt')
model_art.add_file('model.pt')
RUN.log_artifact(model_art)

In [None]:
RUN.finish()

# Scaling up with 🤗 Accelerate

This notebook was made for learning purposes, and as such I tried to keep the code as minimal and clean as possible. Because of this, we omitted some of the things you might want if you were to try training a larger model on much more data, such as multi-GPU support, logging of progress and example images, gradient checkpointing to support larger batch sizes, automatic uploading of models and so on. Thankfully most of these features are available in the example training script [here](https://github.com/huggingface/diffusers/raw/main/examples/unconditional_image_generation/train_unconditional.py).

You can download the file like so:

**Exercise:** See if you can find training/model settings that give good results in as little time as possible, and share your findings with the community. Dig around in the script to see if you can understand the code, and ask for clarification on anything that looks confusing.

# Avenues for Further Exploration

Hopefully this has given you a taste of what you can do with the 🤗 Diffusers library! Some possible next steps:

- Try training an unconditional diffusion model on a new dataset - bonus points if you [create one yourself](https://huggingface.co/docs/datasets/image_dataset). You can find some great image datasets for this task in the [HugGan organization](https://huggingface.co/huggan) on the Hub. Just make sure you downsample them if you don't want to wait a very long time for the model to train!
- Try out DreamBooth to create your own customized Stable Diffusion pipeline using either [this Space](https://huggingface.co/spaces/multimodalart/dreambooth-training) or [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
- Modify the training script to explore different UNet hyperparameters (number of layers, channels etc), different noise schedules etc.
- Check out the [Diffusion Models from Scratch](https://github.com/huggingface/diffusion-models-class/blob/main/unit1/02_diffusion_models_from_scratch.ipynb) notebook for a different take on the core ideas we've covered in this unit

Good luck, and stay tuned for Unit 2!