In [None]:
from dataclasses import dataclass
from tqdm.auto import tqdm
from PIL import x
import torch
import torchvision

# TODO adapt these parameters such that they work for your setup
@dataclass
class TrainingConfig:
    x_size = 28  # the generated x resolution
    num_channels = 1  # the number of channels in the generated x
    train_batch_size = 5
    eval_batch_size = 2  # how many xs to sample during evaluation
    num_epochs = 10
    learning_rate = 1e-4
    output_dir = "samples"

config = TrainingConfig()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import torchvision.transforms as transforms
mnist_dataset = torchvision.datasets.MNIST(root='datasets/mnist', train=True, download=True, transform=transforms.ToTensor())
train_dataloader = torch.utils.data.DataLoader(mnist_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=1)

In [None]:
def get_alpha(t):
    return 1 - t

In [None]:
def forward_diffusion(clean_x, noise, t):
    # it takes the clean xs, the noise and the timesteps as input and returns the noisy xs
    alpha = get_alpha(t).to(clean_x.device)
    for _ in range(len(clean_x.shape) - 1):
        alpha = alpha.unsqueeze(-1)

    noisy_x = clean_x * torch.sqrt(alpha)  +  noise * torch.sqrt(1 - alpha)
    return noisy_x

In [None]:
def save_and_show(batch, name, nrow=1):
    x_grid = torchvision.utils.make_grid(batch, nrow)
    torchvision.utils.save_x(x_grid, name)
    display(x.open(name))

sample_batch = next(iter(train_dataloader))[0]
noise = torch.randn_like(sample_batch)
noise_levels = []
for i in range(11):
    current_batch = forward_diffusion(sample_batch, noise, torch.tensor(i / 10))
    noise_levels.append(current_batch)

save_and_show(torch.cat(noise_levels), f'forward_diffusion.png', nrow=noise_levels[0].shape[0])

In [None]:
from diffusers import UNet2DModel

model = UNet2DModel(
    sample_size=config.x_size,  # the target x resolution
    in_channels=1,  # the number of input channels, 3 for RGB xs
    out_channels=1,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128),  # the number of output channels for each UNet block
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock2D",  # a regular ResNet downsampling block
    ),
    up_block_types=(
        "UpBlock2D",  # a regular ResNet upsampling block
        "UpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)

In [None]:
import torch.nn.functional as F
import torch

optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)

def train_loop(config, model, forward_diffusion, optimizer, train_dataloader, device):
    model.to(device)
    global_step = 0
    sample_batch = next(iter(train_dataloader))[0].to(device)
    test_noise = torch.randn_like(sample_batch)

    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            if epoch * len(train_dataloader) + step in [0, 100, 200, 1000, 5000, 10000, 20000]:
                noisy_xs_list = []
                sample_batch_reconstructed = []
                for noise_level in range(11):
                    t = torch.tensor(noise_level / 10)
                    alpha = get_alpha(t)
                    noisy_xs = forward_diffusion(sample_batch, test_noise, t)
                    noisy_xs_list.append(noisy_xs)
                    noise_pred = model(noisy_xs, t.to(device), return_dict=False)[0].detach()
                    sample_batch_reconstructed.append((noisy_xs - torch.sqrt(1 - alpha) * noise_pred) / torch.sqrt(alpha))
                    
                save_and_show(torch.cat(sample_batch_reconstructed, 0), f'reconstruction_{epoch}_{step}.png', nrow=sample_batch.shape[0])

            clean_xs = batch[0].to(device)
            # Sample noise to add to the xs
            noise = torch.randn(clean_xs.shape).to(device)
            bs = clean_xs.shape[0]

            # Sample a random timestep for each x
            t = torch.abs(1.0 - torch.rand(bs)).to(device)

            # Add noise to the clean xs according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_xs = forward_diffusion(clean_xs, noise, t)

            # Predict the noise residual
            noise_pred = model(noisy_xs, t, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "step": global_step}
            progress_bar.set_postfix(**logs)
            global_step += 1

In [None]:
train_loop(config, model, forward_diffusion, optimizer, train_dataloader, device)

In [None]:
torch.save(model, 'unet.pt')

In [None]:
model = torch.load('unet.pt', map_location=torch.device('cpu')).to(device)

In [None]:
def evaluate(config, img_name, reverse_diffusion_process, model, timesteps, device, train_timesteps):
    #
    noise = torch.randn([config.eval_batch_size, config.num_channels, config.x_size, config.x_size])
    xs = reverse_diffusion_process(model, noise, timesteps, device, train_timesteps)
    # Make a grid out of the xs
    x_grid = torchvision.utils.make_grid(xs) #make_grid(xs, rows=4, cols=4)

    # Save the x grid
    torchvision.utils.save_x(x_grid, img_name)

In [None]:
def get_next_x(model, x_current, t_current, t_next):
    alpha_current = get_alpha(torch.tensor(t_current)).to(x_current)
    alpha_next = get_alpha(torch.tensor(t_next)).to(x_current)
    for _ in range(len(x_current.shape) - 1):
        alpha_current = alpha_current.unsqueeze(-1)
        alpha_next = alpha_next.unsqueeze(-1)
        
    noise_pred_current = model(x_current, t_current, return_dict=False)[0].detach()
    next_image = torch.sqrt(alpha_next) * (x_current - torch.sqrt(1 - alpha_current) * noise_pred_current) / torch.sqrt(alpha_current)
    next_image += torch.sqrt(1 - alpha_next) * noise_pred_current
    #next_image = alpha_next.sqrt() * (x_current / alpha_current.sqrt() + (((1 - alpha_next) / alpha_next).sqrt() - ((1 - alpha_current) / alpha_current).sqrt()) * noise_pred_current)
    return next_image

In [None]:
def reverse_diffusion_ddim(model, noise, num_timesteps = 100):
    noisy_xs_list = []
    x_pred_list = []
    current_x = torch.clone(noise)
    
    # Perform reverse diffusion for the specified number of timesteps
    progress_bar = tqdm(range(num_timesteps - 1))
    for timestep in range(1, num_timesteps)[::-1]:
        # Generate the noise for the current timestep
        progress_bar.set_description(f"T {timestep} / {num_timesteps}")
        timesteps = torch.ones([current_x.shape[0]], dtype=torch.int32).to(device) * timestep
        if timestep % (num_timesteps / 10) == 0:
            noisy_xs_list.append(torch.clone(current_x))
            noise_pred = torch.clone(model(current_x, timestep / num_timesteps, return_dict=False)[0].detach())
            alpha = get_alpha(torch.tensor(timesteps / num_timesteps)).to(current_x)[:,None,None,None]
            x_pred_list.append((current_x - torch.sqrt(1 - alpha) * noise_pred) / torch.sqrt(alpha))
            
        current_x = get_next_x(model, current_x, timesteps / num_timesteps, (timesteps - 1) / num_timesteps)
    
    noisy_xs_list.append(torch.clone(current_x))
    x_pred_list.append(torch.clone(current_x))
    save_and_show(torch.cat(noisy_xs_list, 0), f'z_to_x.png', nrow=current_x.shape[0])
    save_and_show(torch.cat(x_pred_list, 0), f'z_to_x2.png', nrow=current_x.shape[0])
    # Return the generated xs
    return current_x

In [None]:
def reverse_diffusion_ddim(model, noise, num_timesteps, device = 'cuda'):
    # TODO implement the reverse diffusion process with DDPM - 15 points
    # it should take noise, the model and the number of timesteps as input and return the generated images
    # Generate the initial image from the noise
    noisy_images_list = []
    current_image = noise.to(device)
    
    # Perform reverse diffusion for the specified number of timesteps
    progress_bar = tqdm(range(num_timesteps - 1))
    #for t in range(num_timesteps - 1, -1, -1):
    for t in range(1, num_timesteps)[::-1]:
        # Generate the noise for the current timestep
        progress_bar.set_description(f"T {t}")
        timesteps = torch.ones([current_image.shape[0]], dtype=torch.int32).to(device) * t
        #current_noise = model(current_image, t / num_timesteps, return_dict=False)[0]
        current_noise = model(current_image, timesteps / num_timesteps, return_dict=False)[0]

        #
        #alpha_previous = get_alpha(torch.tensor([(t - 1) / num_timesteps]))[:,None,None,None].to(noise.device)
        #alpha = get_alpha(torch.tensor([t / num_timesteps]))[:,None,None,None].to(noise.device)
        alpha_previous = get_alpha((timesteps - 1) / num_timesteps)[:,None,None,None].to(noise.device)
        alpha = get_alpha(timesteps / num_timesteps)[:,None,None,None].to(noise.device)

        #
        next_image = torch.sqrt(alpha_previous) * (current_image - torch.sqrt(1 - alpha) * current_noise) / torch.sqrt(alpha)
        next_image += torch.sqrt(1 - alpha_previous) * current_noise
        current_image = next_image
        current_image = current_image.detach()
        if t % (num_timesteps / 10) == 0:
            noisy_images_list.append(current_image)
    
    noisy_images_list.append(current_image)    
    save_and_show(torch.cat(noisy_images_list, 0), f'ddim_collage.png', nrow=current_image.shape[0])
    # Return the generated images
    return current_image

In [None]:
samples = reverse_diffusion_ddim(model, torch.randn_like(sample_batch).to(device), 1000)

In [None]:
def forward_diffusion_ddim(model, x, num_timesteps = 100):
    noisy_xs_list = []
    current_z = torch.clone(x)
    
    # Perform reverse diffusion for the specified number of timesteps
    progress_bar = tqdm(range(num_timesteps - 1))
    for timestep in range(1, num_timesteps):
        # Generate the noise for the current timestep
        progress_bar.set_description(f"T {timestep} / {num_timesteps}")
        current_z = get_next_x(model, current_z, (timestep - 1) / num_timesteps, timestep / num_timesteps)
        if timestep % (num_timesteps / 10) == 0:
            noisy_xs_list.append(current_z)
    
    noisy_xs_list.append(current_z)
    save_and_show(torch.cat(noisy_xs_list, 0), f'x_to_z.png', nrow=current_z.shape[0])
    # Return the generated z
    return current_z

In [None]:
# sanity check inversion capabilities
reconstruction = reverse_diffusion_ddim(model, forward_diffusion_ddim(model, sample_batch.to(device), 1000), 1000)
print(torch.mean(torch.abs(reconstruction.cpu() - sample_batch)))

In [None]:
def reverse_diffusion_ddim_conditioned(model, noise, downstream_network, num_timesteps = 100):
    noisy_xs_list = []
    current_x = torch.clone(noise)
    
    # Perform reverse diffusion for the specified number of timesteps
    progress_bar = tqdm(range(num_timesteps - 1))
    for timestep in range(1, num_timesteps)[::-1]:
        # Generate the noise for the current timestep
        progress_bar.set_description(f"T {timestep} / {num_timesteps}")
        current_x = get_next_x(model, current_x, timestep / num_timesteps, (timestep - 1) / num_timesteps)
        if timestep % (num_timesteps / 10) == 0:
            noisy_xs_list.append(current_x)
    
    save_and_show(torch.cat(noisy_xs_list, 0), f'z_to_x.png', nrow=current_x.shape[0])
    # Return the generated xs
    return current_x