We used the following notebook to re implement the DDPM architecure:

https://colab.research.google.com/drive/1AZ2_BAwXrU8InE_qAE9cFZ0lsIO5a_xp?usp=sharing

## Standard import

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Subset
from tqdm.notebook import tqdm
from torchvision.transforms import Compose, ToTensor, Lambda, Resize
from torchvision.models import inception_v3
from scipy.linalg import sqrtm
import zipfile
import shutil
import os
import pandas as pd

## Model architecture

In [4]:
def sinusoidal_embedding(n, d):
    """
    Returns the sinusoidal positional embedding matrix.

    Args:
        n (int): Length of sequence.
        d (int): Dimension of the embedding.

    Returns:
        torch.Tensor: Sinusoidal positional embedding matrix of shape (n, d).
    """
    positions = torch.arange(0, n).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d, 2) * (-np.log(10000.0) / d))
    embedding = torch.zeros(n, d)
    embedding[:, 0::2] = torch.sin(positions * div_term)
    embedding[:, 1::2] = torch.cos(positions * div_term)
    return embedding


In [5]:
class double_conv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(double_conv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class down_layer(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(down_layer, self).__init__()
        self.pool = nn.MaxPool2d(2, stride=2, padding=0)
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(self.pool(x))
        return x

class up(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(up, self).__init__()
        self.up_scale = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)

    def forward(self, x1, x2): # x1 (bs,out_ch,w1,h1) x2 (bs,in_ch,w2,h2)
        x2 = self.up_scale(x2) # (bs,out_ch,2*w2,2*h2)
        diffY = x1.size()[2] - x2.size()[2]
        diffX = x1.size()[3] - x2.size()[3]

        x2 = F.pad(x2, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2]) # (bs,out_ch,w1,h1)
        x = torch.cat([x2, x1], dim=1) # (bs,2*out_ch,w1,h1)
        return x

class up_layer(nn.Module):
    def __init__(self, in_ch, out_ch): # !! 2*out_ch = in_ch !!
        super(up_layer, self).__init__()
        self.up = up(in_ch, out_ch)
        self.conv = double_conv(in_ch, out_ch)

    def forward(self, x1, x2): # x1 (bs,out_ch,w1,h1) x2 (bs,in_ch,w2,h2)
        a = self.up(x1, x2) # (bs,2*out_ch,w1,h1)
        x = self.conv(a) # (bs,out_ch,w1,h1) because 2*out_ch = in_ch
        return x

In [6]:
class UNet(nn.Module):
    def __init__(self, in_channels=1, n_steps=1000, time_emb_dim=100):
        super(UNet, self).__init__()
        self.conv1 = double_conv(in_channels, 64)
        self.down1 = down_layer(64, 128)
        self.down2 = down_layer(128, 256)
        self.down3 = down_layer(256, 512)
        self.down4 = down_layer(512, 1024)
        self.up1 = up_layer(1024, 512)
        self.up2 = up_layer(512, 256)
        self.up3 = up_layer(256, 128)
        self.up4 = up_layer(128, 64)
        self.last_conv = nn.Conv2d(64, in_channels, 1)

        # Time embedding
        self.time_embed = nn.Embedding(n_steps, time_emb_dim)
        self.time_embed.weight.data = sinusoidal_embedding(n_steps, time_emb_dim)
        self.time_embed.requires_grad_(False)
        self.te1 = self._make_te(time_emb_dim, in_channels)
        self.te2 = self._make_te(time_emb_dim, 64)
        self.te3 = self._make_te(time_emb_dim, 128)
        self.te4 = self._make_te(time_emb_dim, 256)
        self.te5 = self._make_te(time_emb_dim, 512)
        self.te1_up = self._make_te(time_emb_dim, 1024)
        self.te2_up = self._make_te(time_emb_dim, 512)
        self.te3_up = self._make_te(time_emb_dim, 256)
        self.te4_up = self._make_te(time_emb_dim, 128)

    def _make_te(self, dim_in, dim_out):
        return nn.Sequential(nn.Linear(dim_in, dim_out), nn.SiLU(), nn.Linear(dim_out, dim_out))

    def forward(self, x , t): # x (bs,in_channels,w,d)
        bs = x.shape[0]
        t = self.time_embed(t)
        x1 = self.conv1(x+self.te1(t).reshape(bs, -1, 1, 1)) # (bs,64,w,d)
        x2 = self.down1(x1+self.te2(t).reshape(bs, -1, 1, 1)) # (bs,128,w/2,d/2)
        x3 = self.down2(x2+self.te3(t).reshape(bs, -1, 1, 1)) # (bs,256,w/4,d/4)
        x4 = self.down3(x3+self.te4(t).reshape(bs, -1, 1, 1)) # (bs,512,w/8,h/8)
        x5 = self.down4(x4+self.te5(t).reshape(bs, -1, 1, 1)) # (bs,1024,w/16,h/16)
        x1_up = self.up1(x4, x5+self.te1_up(t).reshape(bs, -1, 1, 1)) # (bs,512,w/8,h/8)
        x2_up = self.up2(x3, x1_up+self.te2_up(t).reshape(bs, -1, 1, 1)) # (bs,256,w/4,h/4)
        x3_up = self.up3(x2, x2_up+self.te3_up(t).reshape(bs, -1, 1, 1)) # (bs,128,w/2,h/2)
        x4_up = self.up4(x1, x3_up+self.te4_up(t).reshape(bs, -1, 1, 1)) # (bs,64,w,h)
        output = self.last_conv(x4_up) # (bs,in_channels,w,h)
        return output

In [9]:
class DDPM(nn.Module):
    def __init__(self, network, num_timesteps, beta_start=0.0001, beta_end=0.02, device=device) -> None:
        super(DDPM, self).__init__()
        self.num_timesteps = num_timesteps
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps, dtype=torch.float32).to(device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.network = network
        self.device = device
        self.sqrt_alphas_cumprod = self.alphas_cumprod ** 0.5 # used in add_noise
        self.sqrt_one_minus_alphas_cumprod = (1 - self.alphas_cumprod) ** 0.5 # used in add_noise and step

    def add_noise(self, x_start, x_noise, timesteps):
        # The forward process
        # x_start and x_noise (bs, n_c, w, d)
        # timesteps (bs)
        s1 = self.sqrt_alphas_cumprod[timesteps] # bs
        s2 = self.sqrt_one_minus_alphas_cumprod[timesteps] # bs
        s1 = s1.reshape(-1,1,1,1) # (bs, 1, 1, 1) for broadcasting
        s2 = s2.reshape(-1,1,1,1) # (bs, 1, 1, 1)
        return s1 * x_start + s2 * x_noise

    def reverse(self, x, t):
        # The network return the estimation of the noise we added
        return self.network(x, t)

    def step(self, model_output, timestep, sample):
        # one step of sampling
        # timestep (1)
        t = timestep
        coef_epsilon = (1-self.alphas)/self.sqrt_one_minus_alphas_cumprod
        coef_eps_t = coef_epsilon[t].reshape(-1,1,1,1)
        coef_first = 1/self.alphas ** 0.5
        coef_first_t = coef_first[t].reshape(-1,1,1,1)
        pred_prev_sample = coef_first_t*(sample-coef_eps_t*model_output)

        variance = 0
        if t > 0:
            noise = torch.randn_like(model_output).to(self.device)
            variance = ((self.betas[t] ** 0.5) * noise)

        pred_prev_sample = pred_prev_sample + variance

        return pred_prev_sample

## Training pipeline

In [2]:
def show_and_save_images(images, title="", save_path="output.png"):
    """
    Displays images as subplots in a square grid and saves the figure.

    Args:
        images (list): List of images to be displayed.
        title (str): Title of the figure.
        save_path (str): File path to save the figure.
    """
    images = [np.clip(im.permute(1, 2, 0).numpy(), 0, 1) for im in images]

    num_images = len(images)
    rows = int(np.sqrt(num_images))
    cols = (num_images + rows - 1) // rows  # Ensure all images fit in the grid

    fig, axes = plt.subplots(rows, cols, figsize=(8, 8))

    for ax, image in zip(axes.flatten(), images):
        ax.imshow(image)
        ax.axis('off')

    fig.suptitle(title, fontsize=30)
    plt.savefig(save_path)


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
def training_loop(model, dataloader, optimizer, num_epochs, num_timesteps, gradient_accumulation_steps=1, device=device):
    """Training loop for DDPM"""

    global_step = 0
    losses = []

    for epoch in range(num_epochs):
        model.train()
        progress_bar = tqdm(total=len(dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        accumulated_steps = 0

        for step, batch in enumerate(dataloader):
            batch = batch[0].to(device)
            noise = torch.randn(batch.shape).to(device)
            timesteps = torch.randint(0, num_timesteps, (batch.shape[0],)).long().to(device)

            noisy = model.add_noise(batch, noise, timesteps)
            noise_pred = model.reverse(noisy, timesteps)
            loss = F.mse_loss(noise_pred, noise)
            loss /= gradient_accumulation_steps

            loss.backward()

            accumulated_steps += 1

            if accumulated_steps % gradient_accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            losses.append(loss.detach().item())
            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "step": global_step}
            progress_bar.set_postfix(**logs)

        # Perform a final optimization step for the remaining accumulated gradients
        if accumulated_steps % gradient_accumulation_steps != 0:
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1

        torch.save(model.state_dict(), 'ddpm.pt')

        progress_bar.close()

    return losses


## Dataset loading and  Preprocessing

In [None]:
!mkdir data_faces && wget https://s3-us-west-1.amazonaws.com/udacity-dlnfd/datasets/celeba.zip

In [12]:
with zipfile.ZipFile("celeba.zip","r") as zip_ref:
    zip_ref.extractall("data_faces/")

In [13]:
transform = Compose([
    Resize((64,64)),
    ToTensor()]
)

batch_size = 128
celeba_dataset = datasets.ImageFolder('data_faces/', transform=transform)

In [14]:
celeba_loader = DataLoader(celeba_dataset, batch_size, shuffle=True)

## Training model

In [None]:
learning_rate = 1e-3
num_epochs = 10
num_timesteps = 1000
network = UNet(in_channels=3)
network.to(device)
model = DDPM(network, num_timesteps, beta_start=0.0001, beta_end=0.02, device=device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
training_loop(model, celeba_loader, optimizer, num_epochs, num_timesteps, device=device)

## Sampling images from the target distribution

In [17]:
def generate_image(ddpm, sample_size, channel, size, device):
    """
    Generate the image from the Gaussian noise.

    Args:
        ddpm (torch.nn.Module): DDPM model.
        sample_size (int): Number of samples to generate.
        channel (int): Number of channels.
        size (int): Size of the image.
        device (str): Device to run the generation on.

    Returns:
        tuple: Tuple containing two lists of generated frames.
    """
    frames = []
    frames_mid = []
    ddpm.eval()

    with torch.no_grad():
        timesteps = list(range(ddpm.num_timesteps))[::-1]
        sample = torch.randn(sample_size, channel, size, size).to(device)

        for i, t in enumerate(tqdm(timesteps)):
            time_tensor = torch.full((sample_size,), t, dtype=torch.long).to(device)
            residual = ddpm.reverse(sample, time_tensor)
            sample = ddpm.step(residual, time_tensor[0], sample)

            if t == 500:
                for i in range(sample_size):
                    frames_mid.append(sample[i].detach().cpu())

        for i in range(sample_size):
            frames.append(sample[i].detach().cpu())

    return frames, frames_mid


In [None]:
generated, generated_mid = generate_image(model, 10000, 3, 32)

In [None]:
show_and_save_images(generated, "Final result")

In [None]:
from google.colab import files

# Download the file
files.download('ddpm.pt')

In [None]:
model.load_state_dict(torch.load('ddpm.pt', map_location=device))
model.eval()

## Measuring performances

In [None]:
def calculate_activation_statistics(loader, model, device):
    model.eval()
    features = []
    with torch.no_grad():
        for images in loader:
            images = images[0].to(device)
            activations = model(images)
            features.append(activations)
    features = torch.cat(features, dim=0)
    mu = torch.mean(features, dim=0)
    sigma = torch.matmul((features - mu).T, (features - mu)) / (features.size(0) - 1)
    return mu, sigma

def calculate_activation_statistics_gen(loader, model, device):
    model.eval()
    features = []
    with torch.no_grad():
        for images in loader:
            images = images.to(device)
            activations = model(images)
            features.append(activations)
    features = torch.cat(features, dim=0)
    mu = torch.mean(features, dim=0)
    sigma = torch.matmul((features - mu).T, (features - mu)) / (features.size(0) - 1)
    return mu, sigma

def calculate_frechet_distance(mu1, sigma1, mu2, sigma2):
    eps = 1e-6
    diff = mu1 - mu2
    covmean = sqrtm(sigma1 @ sigma2)
    fid = (diff.dot(diff) + torch.trace(sigma1 + sigma2 - 2 * covmean)).real
    return fid.item()

def compute_fid_score(real_loader, generated_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Load resnet18 model pre-trained on ImageNet
    net = models.resnet18(weights='IMAGENET1K_V1').to(device)
    net.eval()

    # Compute statistics for real and generated images
    real_mu, real_sigma = calculate_activation_statistics(real_loader, net, device)
    generated_mu, generated_sigma = calculate_activation_statistics_gen(generated_loader, net, device)

    # Compute FID score
    fid_score = calculate_frechet_distance(real_mu.to('cpu'), real_sigma.to('cpu'), generated_mu.to('cpu'), generated_sigma.to('cpu'))

    return fid_score


In [None]:
# create a dataloader for the generated images

# Transformations for images
preprocess = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor()
    ])


generated_loader = DataLoader(generated, batch_size, shuffle=True)

# take only the first 1000 images from the real dataset
celeba_dataset = datasets.ImageFolder('data_faces/', transform=preprocess)
celeba_dataset = torch.utils.data.Subset(celeba_dataset, list(range(10000)))
real_loader = DataLoader(celeba_dataset, batch_size, shuffle=True)


fid_score = compute_fid_score(real_loader, generated_loader)

In [None]:
fid_score

## Comparison with baseline

In [None]:
class LinearMLP(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(LinearMLP, self).__init__()

        # Calculate the size of the flattened input
        flattened_size = 64 * 64 * in_channels

        # Define linear layers
        self.linear1 = nn.Linear(flattened_size, hidden_channels)
        self.linear2 = nn.Linear(hidden_channels, hidden_channels)
        self.linear3 = nn.Linear(hidden_channels, flattened_size)

        # Define activation function
        self.activation = nn.ReLU()

    def forward(self, x, t):
        # Flatten the input image
        batch_size = x.size(0)
        x = x.view(batch_size, -1)

        # Pass through linear layers with activation functions
        x = self.activation(self.linear1(x))
        x = self.activation(self.linear2(x))
        x = self.linear3(x)

        # Reshape back to image shape
        x = x.view(batch_size, 3, 64, 64)
        return x

In [None]:
learning_rate = 1e-3
num_epochs = 10
num_timesteps = 1000
network_baseline = LinearMLP(in_channels = 3, hidden_channels = 100, out_channels = 3)
network_baseline.to(device)
model_baseline = DDPM(network_baseline, num_timesteps, beta_start=0.0001, beta_end=0.02, device=device)
model_baseline.train()
optimizer = torch.optim.Adam(model_baseline.parameters(), lr=learning_rate)
training_loop(model_baseline, celeba_loader, optimizer, num_epochs, num_timesteps, device=device)

In [None]:
generated_baseline = generate_image(model_baseline, 10000, 3, 32)

In [None]:
# Transformations for images
preprocess = transforms.Compose([
        transforms.Resize(32),
        transforms.ToTensor()
    ])

generated_loader = DataLoader(generated_baseline, batch_size, shuffle=True)


fid_score_baseline = compute_fid_score(real_loader, generated_loader)

In [None]:
r = fid_score/fid_score_baseline
print(r)

## Interpolation of visage

In [None]:
def interpolate_visages(ddpm, visage1, visage2, num_interpolation_steps, device):
    """
    Interpolate between two visages.

    Args:
        ddpm (torch.nn.Module): DDPM model.
        visage1 (torch.Tensor): First visage (image).
        visage2 (torch.Tensor): Second visage (image).
        num_interpolation_steps (int): Number of interpolation steps between the visages.
        device (str): Device to run the interpolation on.

    Returns:
        list: List of interpolated visages (images).
    """
    ddpm.eval()

    # Encode visages into noise representations
    with torch.no_grad():
        noise1 = ddpm.reverse(visage1.unsqueeze(0).to(device), torch.tensor([0]).to(device))
        noise2 = ddpm.reverse(visage2.unsqueeze(0).to(device), torch.tensor([0]).to(device))

        interpolated_visages = []

        # Interpolate between the noise representations
        for i in range(num_interpolation_steps + 2):  # Including endpoints
            alpha = i / (num_interpolation_steps + 1)
            interpolated_noise = alpha * noise2 + (1 - alpha) * noise1

            # Decode interpolated noise representation back into image
            interpolated_visage = ddpm.step(interpolated_noise, torch.tensor([0]).to(device), visage1.unsqueeze(0).to(device))
            interpolated_visages.append(interpolated_visage.squeeze().detach().cpu())

    return interpolated_visages


visage1 = celeba_dataset[0][0].unsqueeze(0)  # First visage
visage2 = celeba_dataset[1][0].unsqueeze(0)  # Second visage

# Interpolate between the two visages
num_interpolation_steps = 5
interpolated_visages = interpolate_visages(model, visage1, visage2, num_interpolation_steps, device)

# Plot the original visages and the interpolated visages
plt.figure(figsize=(15, 5))

# Plot the first visage
plt.subplot(1, num_interpolation_steps + 2, 1)
plt.imshow(transforms.ToPILImage()(visage1))
plt.title('Visage 1')
plt.axis('off')

# Plot the last visage
plt.subplot(1, num_interpolation_steps + 2, num_interpolation_steps + 2)
plt.imshow(transforms.ToPILImage()(visage1))
plt.title('Visage 2')
plt.axis('off')

# Plot the interpolated visages
for i, interpolated_visage in enumerate(interpolated_visages):
    plt.subplot(1, num_interpolation_steps + 2, i + 2)
    plt.imshow(transforms.ToPILImage()(interpolated_visage))
    plt.title(f'Interpolation {i+1}')
    plt.axis('off')

plt.show()
