In this assignment we implement a simple version of generating MNIST with the help of HuggingFace's Diffusers library. We will use the MNIST dataset to train a simple model and then use the trained model to generate new images. We will also use the Diffuser library to generate images from a random noise vector.

In [None]:
# install the hugging face diffusers library
!pip install diffusers

Now we create a training configuration for the model.

In [None]:
from dataclasses import dataclass

# TODO adapt these parameters such that they work for your setup
@dataclass
class TrainingConfig:
    image_size = 32  # the generated image resolution
    num_channels = 1  # the number of channels in the generated image
    train_batch_size = 5
    eval_batch_size = 5  # how many images to sample during evaluation
    num_epochs = 50
    learning_rate = 1e-4
    output_dir = "samples"

config = TrainingConfig()

Next we will create the MNIST dataset and the dataloader from it

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
mnist_dataset = torchvision.datasets.MNIST(root='mnist_data', train=True, download=True, transform=transforms.ToTensor())
train_dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=1)

Moreover, we create the actual U-Net model:

In [None]:
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=1,  # the number of input channels, 3 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",  # a regular ResNet downsampling block
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "UpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)

Now we create the main train loop:

In [None]:
from tqdm.auto import tqdm
import os
import torch.nn.functional as F
import torch
from PIL import Image

num_train_timesteps=1000
timesteps = torch.LongTensor([50])

optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

def train_loop(config, model, forward_diffusion, optimizer, train_dataloader):

    global_step = 0

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images = batch[0]
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(
                0, num_train_timesteps, (bs,), device=clean_images.device
            ).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = forward_diffusion(clean_images, noise, timesteps)

            # Predict the noise residual
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "step": global_step}
            progress_bar.set_postfix(**logs)
            global_step += 1

Next you will implement the forward diffusion process:

In [None]:
def forward_diffusion(clean_images, noise, timesteps):
    # TODO implement the forward diffusion process - 10 points
    # it takes the clean images, the noise and the timesteps as input and returns the noisy images
    pass

Now we can train the model:

In [None]:
train_loop(config, model, forward_diffusion, optimizer, train_dataloader)

Now we create the utils for evaluating the model:

In [None]:
def evaluate(config, img_name, reverse_diffusion_process, model, timesteps):
    #
    noise = torch.randn([config.eval_batch_size, config.num_channels, config.image_size, config.image_size])
    images = reverse_diffusion_process(model, noise, timesteps)
    # Make a grid out of the images
    image_grid = torchvision.utils.make_grid(images) #make_grid(images, rows=4, cols=4)

    # Save the image grid
    torchvision.utils.save_image(image_grid, img_name)

Now you will implement the reverse diffusion process with DDPM:

In [None]:
def reverse_diffusion_ddpm(model, noise, num_timesteps):
    # TODO implement the reverse diffusion process with DDPM - 15 points
    # it should take noise, the model and the number of timesteps as input and return the generated images
    pass


And sample from it:

In [None]:
num_inference_timesteps = num_train_timesteps
evaluate(config, 'ddpm.png', reverse_diffusion_ddpm, model, num_inference_timesteps)
Image.open('ddpm.png')

# Here the generated MNIST digits should be printed if you implemented it correctly!

Now you will implement the reverse diffusion process with DDIM:

In [None]:
def reverse_diffusion_ddim(model, noise, num_timesteps):
    # TODO implement the reverse diffusion process with DDIM - 15 points
    # it should take noise, the model and the number of timesteps as input and return the generated images
    pass

And evaluate with it:

In [None]:
num_inference_timesteps = num_train_timesteps
evaluate(config, 'ddim.png', reverse_diffusion_ddim, model, num_inference_timesteps)
Image.open('ddim.png')

# Here the generated MNIST digits should be printed if you implemented it correctly!

In [None]:
# Bonus exercise:
# Compare different num_inference_timesteps and different models (DDPM vs DDIM) and discuss the results.
