# Denoising diffusion probabilistic model (DDPM)

For our last notebook we will train a denoising diffusion probabilistic model, arguably the first big success of the diffusion-type models. Other variants exist too, but are still based on the same concept, of gradually adding noise to an image while trying to learn the inverse process.

This notebook follows closely the text of chapter 17 in the textbook, so it's a good idea to read it at the same time as going through the code. Again we present this just as a tutorial, without any extra exercises.

First the usual imports:

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib
matplotlib.rcParams.update({'font.size': 22})
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

## Load the data

For this test we pick the usual MNIST images. While they are very small, both training the diffusion model and generating new images is computationally heavy, so using a GPU is strongly recommended.

Optionally, you can give it a go on the Fashion MNIST data instead, by just uncommenting the relevant lines below.

In [None]:
# Fashion MNIST data
#mnist_data = tf.keras.datasets.fashion_mnist.load_data()

# Handwritten digits
mnist_data = tf.keras.datasets.mnist.load_data()

# Download
(X_train_full, y_train_full), (X_test, y_test) = mnist_data
X_train_full = X_train_full.astype(np.float32) / 255
X_test = X_test.astype(np.float32) / 255
X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]

## Variance scheduling

Plot figure 17-21, showing how we schedule the noise variance.

In [None]:
def variance_schedule(T, s=0.008, max_beta=0.999):
    t = np.arange(T + 1)
    f = np.cos((t / T + s) / (1 + s) * np.pi / 2) ** 2
    alpha = np.clip(f[1:] / f[:-1], 1 - max_beta, 1)
    alpha = np.append(1, alpha).astype(np.float32)  # add α₀ = 1
    beta = 1 - alpha
    alpha_cumprod = np.cumprod(alpha)
    print('alpha:', alpha)
    return alpha, alpha_cumprod, beta  # αₜ , α̅ₜ , βₜ for t = 0 to T

def linear_schedule(T, s=0.008, max_beta=0.999):
    step_scale = 0.005
    t = np.arange(T + 1)
    f = 1 - step_scale*t / T
    alpha = np.clip(f, 1 - max_beta, 1)
    alpha = np.append(1, alpha).astype(np.float32)  # add α₀ = 1
    beta = 1 - alpha
    alpha_cumprod = np.cumprod(alpha)
    return alpha, alpha_cumprod, beta  # αₜ , α̅ₜ , βₜ for t = 0 to T

np.random.seed(42)  # For reproducibility
T = 4000

alpha, alpha_cumprod, beta = variance_schedule(T)
lin_alpha, lin_alpha_cumprod, lin_beta = linear_schedule(T)

In [None]:
plt.figure(figsize=(8, 4))
plt.plot(alpha_cumprod, "b", label=r"$\bar{\alpha}_t$ (cosine)")
plt.plot(lin_alpha_cumprod, "r", label=r"$\bar{\alpha}_t$ (linear)")
plt.axis([0, T, 0, 1])
plt.xlabel(r"t")
plt.legend()
plt.show()

## Prepare data for training

Let's define the function for preparing a single batch of noisy training images. Note that because we were clever and wrote out the shortcut in equation 17-6 in the textbook, we don't need to sequentially add the noise up to a given time step $t$, but we can compute it directly from the original input image. Hence we can just pick time steps at random.

In [None]:
def prepare_batch(X):

    # Scale the images to have pixel values from -1 to 1
    X = tf.cast(X[..., tf.newaxis], tf.float32) * 2 - 1
    X_shape = tf.shape(X)

    # Select time steps at random
    t = tf.random.uniform([X_shape[0]], minval=1, maxval=T + 1, dtype=tf.int32)

    # Apply the alpha (noise variance) schedule we computed before
    alpha_cm = tf.gather(alpha_cumprod, t)
    alpha_cm = tf.reshape(alpha_cm, [X_shape[0]] + [1] * (len(X_shape) - 1))

    # Sample noise from a normal distribution
    noise = tf.random.normal(X_shape)

    # Return a dict with noisy images, and the time step t
    return {
        "X_noisy": alpha_cm ** 0.5 * X + (1 - alpha_cm) ** 0.5 * noise,
        "time": t,
    }, noise

Here we create TensorFlow datasets containing training and validation images.

In [None]:
def prepare_dataset(X, batch_size=32, shuffle=False):
    ds = tf.data.Dataset.from_tensor_slices(X)
    if shuffle:
        ds = ds.shuffle(10_000)
    return ds.batch(batch_size).map(prepare_batch).prefetch(1)

train_set = prepare_dataset(X_train, batch_size=32, shuffle=True)
valid_set = prepare_dataset(X_valid, batch_size=32)

Before we get started with the training, let's plot the noisy images to see what the model will be predicting.

In [None]:
def plot_multiple_images(images, n_cols=None, update=False):

    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1
    if images.shape[-1] == 1:
        images = images.squeeze(axis=-1)

    fig = plt.figure(figsize=(n_cols, n_rows))

    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1)
        plt.imshow(image, cmap="binary")
        plt.axis("off")

    if update:
        clear_output(wait=True)
        display(fig)

In [None]:
def subtract_noise(X_noisy, time, noise):
    X_shape = tf.shape(X_noisy)
    alpha_cm = tf.gather(alpha_cumprod, time)
    alpha_cm = tf.reshape(alpha_cm, [X_shape[0]] + [1] * (len(X_shape) - 1))
    return (X_noisy - (1 - alpha_cm) ** 0.5 * noise) / alpha_cm ** 0.5

X_dict, Y_noise = list(train_set.take(1))[0]  # get the first batch
X_original = subtract_noise(X_dict["X_noisy"], X_dict["time"], Y_noise)

print("Original images")
plot_multiple_images(X_original[:8].numpy())
plt.show()
print("Time steps:", X_dict["time"].numpy()[:8])
print("Noisy images")
plot_multiple_images(X_dict["X_noisy"][:8].numpy())
plt.show()
print("Noise to predict")
plot_multiple_images(Y_noise[:8].numpy())
plt.show()

## Build the model

The model predicting the reverse diffusion process can basically be anything, as long as the input and output images are the same dimensions. But, for best  results, we also take as input the timestep number, to allow the model to apply different transformations depending on how far in the denoising process we have come.

The original DDPM papers use an U-Net architecture, which we already know from the image segmentation model in notebook 11. Before we get to the U-Net part, let's encode the timestep number. Following the original DDPM implementation, we use the same _sinusoidal encodings_ that were used in the transformer paper, _Attention is all you need_. Basically we build a fixed (i.e. not learned) embedding space using sin- and cos-values of the time step number.

In [None]:
embed_size = 64

class TimeEncoding(tf.keras.layers.Layer):
    def __init__(self, T, embed_size, dtype=tf.float32, **kwargs):
        super().__init__(dtype=dtype, **kwargs)
        assert embed_size % 2 == 0, "embed_size must be even"
        p, i = np.meshgrid(np.arange(T + 1), 2 * np.arange(embed_size // 2))
        t_emb = np.empty((T + 1, embed_size))
        t_emb[:, ::2] = np.sin(p / 10_000 ** (i / embed_size)).T
        t_emb[:, 1::2] = np.cos(p / 10_000 ** (i / embed_size)).T
        self.time_encodings = tf.constant(t_emb.astype(self.dtype))

    def call(self, inputs):
        return tf.gather(self.time_encodings, inputs)

Now for the model itself. Recall that the U-Net architecture consists of a downsampling part followed by an upsampling part, with residual connections between each downsampling layer and the corresponding upsampling layer.

In [None]:
def build_diffusion_model():
    X_noisy = tf.keras.layers.Input(shape=[28, 28, 1], name="X_noisy")
    time_input = tf.keras.layers.Input(shape=[], dtype=tf.int32, name="time")
    time_enc = TimeEncoding(T, embed_size)(time_input)

    dim = 16
    Z = tf.keras.layers.ZeroPadding2D((3, 3))(X_noisy)
    Z = tf.keras.layers.Conv2D(dim, 3)(Z)
    Z = tf.keras.layers.BatchNormalization()(Z)
    Z = tf.keras.layers.Activation("relu")(Z)

    time = tf.keras.layers.Dense(dim)(time_enc)  # adapt time encoding
    Z = time[:, tf.newaxis, tf.newaxis, :] + Z  # add time data to every pixel

    skip = Z
    cross_skips = []  # skip connections across the down and up parts of the UNet

    for dim in (32, 64, 128):
        Z = tf.keras.layers.Activation("relu")(Z)
        Z = tf.keras.layers.SeparableConv2D(dim, 3, padding="same")(Z)
        Z = tf.keras.layers.BatchNormalization()(Z)

        Z = tf.keras.layers.Activation("relu")(Z)
        Z = tf.keras.layers.SeparableConv2D(dim, 3, padding="same")(Z)
        Z = tf.keras.layers.BatchNormalization()(Z)

        cross_skips.append(Z)
        Z = tf.keras.layers.MaxPooling2D(3, strides=2, padding="same")(Z)
        skip_link = tf.keras.layers.Conv2D(dim, 1, strides=2,
                                           padding="same")(skip)
        Z = tf.keras.layers.add([Z, skip_link])

        time = tf.keras.layers.Dense(dim)(time_enc)
        Z = time[:, tf.newaxis, tf.newaxis, :] + Z
        skip = Z

    for dim in (64, 32, 16):
        Z = tf.keras.layers.Activation("relu")(Z)
        Z = tf.keras.layers.Conv2DTranspose(dim, 3, padding="same")(Z)
        Z = tf.keras.layers.BatchNormalization()(Z)

        Z = tf.keras.layers.Activation("relu")(Z)
        Z = tf.keras.layers.Conv2DTranspose(dim, 3, padding="same")(Z)
        Z = tf.keras.layers.BatchNormalization()(Z)

        Z = tf.keras.layers.UpSampling2D(2)(Z)

        skip_link = tf.keras.layers.UpSampling2D(2)(skip)
        skip_link = tf.keras.layers.Conv2D(dim, 1, padding="same")(skip_link)
        Z = tf.keras.layers.add([Z, skip_link])

        time = tf.keras.layers.Dense(dim)(time_enc)
        Z = time[:, tf.newaxis, tf.newaxis, :] + Z
        Z = tf.keras.layers.concatenate([Z, cross_skips.pop()], axis=-1)
        skip = Z

    outputs = tf.keras.layers.Conv2D(1, 3, padding="same")(Z)[:, 2:-2, 2:-2]

    return tf.keras.Model(inputs=[X_noisy, time_input], outputs=[outputs])

Train the model!

The loss function should compare the pixel-wise differences between the true and predicted denoised image. We could choose mean squared error (MSE), mean absolute error (MAE, which according to the DDPM paper works better), or we could do like the textbook and use a combination, which would be the Huber loss.

In [None]:

model = build_diffusion_model()
model.compile(
    loss=tf.keras.losses.Huber(),
    optimizer=tf.keras.optimizers.Nadam()
)

# Add a ModelCheckpoint callback
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "my_diffusion_model.keras",
    save_best_only=True
)

history = model.fit(
    train_set,
    validation_data=valid_set,
    epochs=50,
    callbacks=[checkpoint_cb]
)



## Generate new images

With the trained model in place, we can finally generate new images.

We sample random values as a starting point, and then sequentially denoise them. Note that this takes a while, since we are running the model 4000 times to generate a single image!

In [None]:
def generate(model, batch_size=32, show_step_interval=200):
    X = tf.random.normal([batch_size, 28, 28, 1])
    for t in range(T - 1, 0, -1):
        print(f"\rt = {t}", end=" ")  # extra code – show progress
        noise = (tf.random.normal if t > 1 else tf.zeros)(tf.shape(X))
        X_noise = model({"X_noisy": X, "time": tf.constant([t] * batch_size)})
        X = (
            1 / alpha[t] ** 0.5
            * (X - beta[t] / (1 - alpha_cumprod[t]) ** 0.5 * X_noise)
            + (1 - alpha[t]) ** 0.5 * noise
        )

        # Show the denoising process
        if (t % show_step_interval == 0) or (t == 1):
            plot_multiple_images(X.numpy(), 8, update=True)
            plt.show()

    return X

tf.random.set_seed(42)  # extra code – ensures reproducibility on the CPU
X_gen = generate(model)  # generated images



Plot the final output:

In [None]:
plot_multiple_images(X_gen.numpy(), 8)
plt.show()

Note that the quality of the results are very dependent on how long we train the model for, so if you want to improve it, try running for more epochs.