In [1]:
from dataclasses import dataclass
from tqdm.auto import tqdm
from PIL import Image
import torch
import torchvision

# TODO adapt these parameters such that they work for your setup
@dataclass
class TrainingConfig:
    image_size = 28  # the generated image resolution
    num_channels = 1  # the number of channels in the generated image
    train_batch_size = 5
    eval_batch_size = 2  # how many images to sample during evaluation
    num_epochs = 10
    learning_rate = 1e-4
    output_dir = "samples"

config = TrainingConfig()

num_train_timesteps=1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
import torchvision.transforms as transforms
mnist_dataset = torchvision.datasets.MNIST(root='mnist_data', train=True, download=True, transform=transforms.ToTensor())
train_dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=1)

In [3]:
def get_alpha(t):
    return 1 - t

In [7]:
def forward_diffusion(clean_x, noise, t):
    # it takes the clean images, the noise and the timesteps as input and returns the noisy images
    #alpha = get_alpha(t)[:, None, None, None].to(clean_x.device)
    alpha = get_alpha(t).to(clean_x.device)
    return torch.sqrt(alpha) * clean_x + torch.sqrt(1 - alpha) * noise

In [9]:
def save_and_show(batch, name, nrow=1):
    image_grid = torchvision.utils.make_grid(batch, nrow)
    torchvision.utils.save_image(image_grid, name)
    Image.open(name)

sample_batch = next(iter(train_dataloader))[0]
noise = torch.randn_like(sample_batch)
save_and_show(sample_batch, 'original.png')
for i in range(num_train_timesteps + 1):
    if i % 100 == 0:
        current_batch = forward_diffusion(sample_batch, noise, torch.tensor(i / num_train_timesteps))
        save_and_show(current_batch, f'forward_diffusion_{i}.png')

In [10]:
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=1,  # the number of input channels, 3 for RGB images
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",  # a regular ResNet downsampling block
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "UpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)

In [None]:
import torch.nn.functional as F
import torch

optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

def train_loop(config, model, forward_diffusion, optimizer, train_dataloader, device, num_train_timesteps):
    model.to(device)
    global_step = 0
    sample_batch = next(iter(train_dataloader))[0].to(device)
    save_and_show(sample_batch, f'reconstruction_original.png')
    test_noise = torch.randn_like(sample_batch)

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            if epoch * len(train_dataloader) + step in [0, 100, 200, 1000, 5000, 10000, 20000]:
                noisy_images_list = []
                sample_batch_reconstructed = []
                for noise_level in range(11):
                    t = noise_level / 10 * torch.ones([sample_batch.shape[0]], dtype=torch.int32)
                    alpha = get_alpha(t, num_train_timesteps)[:, None, None, None].to(device)
                    noisy_images = forward_diffusion(sample_batch, test_noise, t, num_train_timesteps)
                    noisy_images_list.append(noisy_images)
                    noise_pred = model(noisy_images, t.to(device), return_dict=False)[0].detach()
                    sample_batch_reconstructed.append((noisy_images - torch.sqrt(1 - alpha) * noise_pred) / torch.sqrt(alpha))
                    
                save_and_show(torch.cat(noisy_images_list, 0), f'noisy_images_{epoch}_{step}.png', nrow=sample_batch.shape[0])
                save_and_show(torch.cat(sample_batch_reconstructed, 0), f'reconstruction_{epoch}_{step}.png', nrow=sample_batch.shape[0])

            clean_images = batch[0].to(device)
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape).to(device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            t = torch.randint(
                1, num_train_timesteps, (bs,), device=clean_images.device
            ).to(device) / num_train_timesteps

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = forward_diffusion(clean_images, noise, t, num_train_timesteps)

            # Predict the noise residual
            noise_pred = model(noisy_images, t, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "step": global_step}
            progress_bar.set_postfix(**logs)
            global_step += 1

In [None]:
train_loop(config, model, forward_diffusion, optimizer, train_dataloader, device, num_train_timesteps)

In [None]:
torch.save(model, 'unet.pt')

In [None]:
model = torch.load('unet.pt', map_location=torch.device('cpu')).to(device)

In [None]:
def evaluate(config, img_name, reverse_diffusion_process, model, timesteps, device, train_timesteps):
    #
    noise = torch.randn([config.eval_batch_size, config.num_channels, config.image_size, config.image_size])
    images = reverse_diffusion_process(model, noise, timesteps, device, train_timesteps)
    # Make a grid out of the images
    image_grid = torchvision.utils.make_grid(images) #make_grid(images, rows=4, cols=4)

    # Save the image grid
    torchvision.utils.save_image(image_grid, img_name)

In [None]:
def get_next_x(model, x_current, t_current, t_next):
    alpha_current = get_alpha(t_current)[:,None,None,None]
    alpha_next = get_alpha(t_next)[:,None,None,None]
    noise_pred_current = model(x_current, t_current, return_dict=False)[0].detach()
    return alpha_next.sqrt() * (x_current / alpha_current.sqrt() + (((1 - alpha_next) / alpha_next).sqrt() - ((1 - alpha_current) / alpha_current).sqrt()) * noise_pred_current)

In [None]:
def reverse_diffusion_ddim(model, noise, num_timesteps, device):
    noisy_images_list = []
    current_image = noise.to(device)
    
    # Perform reverse diffusion for the specified number of timesteps
    progress_bar = tqdm(range(num_timesteps - 1))
    #for t in range(num_timesteps - 1, -1, -1):
    for timestep in range(1, num_timesteps)[::-1]:
        # Generate the noise for the current timestep
        progress_bar.set_description(f"T {t}")
        current_image = get_next_x(model, current_image, timestep / num_timesteps, (timestep - 1) / num_timesteps)
        if timestep % (num_timesteps / 10) == 0:
            noisy_images_list.append(current_image)
    
    save_and_show(torch.cat(noisy_images_list, 0), f'ddim_collage.png', nrow=current_image.shape[0])
    # Return the generated images
    return current_image