## Diffusion Models

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

from scipy.spatial.transform import Rotation

np.random.seed(42)


In [6]:
# Loading fashion MNIST

fashion_mnist = tf.keras.datasets.fashion_mnist.load_data()
(X_train_full, y_train_full), (X_test, y_test) = fashion_mnist

X_train_full = X_train_full.astype(np.float32) / 255
X_test = X_test.astype(np.float32) / 255

X_train, X_valid = X_train_full[:-5000], X_train_full[-5000:]
y_train, y_valid = y_train_full[:-5000], y_train_full[-5000:]

In [3]:
# Variance schedule

# This function returns the variance schedule, that is, the variance to apply 
# at each point in time to the input images.

def variance_schedule(T, s=0.008, max_beta=0.999):
  t = np.arange(T + 1)
  f = np.cos((t / T + s) / (1 + s) * np.pi / 2) ** 2
  alpha = np.clip(f[1:] / f[:-1], 1 - max_beta, 1)
  alpha = np.append(1, alpha).astype(np.float32)
  beta = 1 - alpha
  alpha_cumprod = np.cumprod(alpha)

  return alpha, alpha_cumprod, beta

T = 4000
alpha, alpha_cumprod, beta = variance_schedule(T)

In [5]:
# Adding noise to training data

# This function takes a batch of clean images 
# from the dataset and adds noise to them

def prepare_batch(X):
  # Since we're using fashion MNIST, we need to add a channels axis
  X = tf.cast(X[..., tf.newaxis], tf.float32) * 2 - 1
  X_shape = tf.shape(X)

  # t is a vector of random time steps for each image in the batch
  t = tf.random.uniform([X_shape[0]], minval=1, maxval=T + 1, dtype=tf.int32)
  
  # Gets the value of alpha_cumprod for each of the time steps in t.
  # gather(...) is a function that extracts from a tensor according to 
  # given indices
  alpha_cm = tf.gather(alpha_cumprod, t)
  alpha_cm = tf.reshape(alpha_cm, [X_shape[0]] + [1] * (len(X_shape) - 1)) 
  
  noise = tf.random.normal(X_shape)

  # This applies the noise to the images according to the computed alpha.
  # We get back the inputs (noisy images) and the targets (the noise used
  # to generate them). Given a noisy image, the model predicts noise that
  # can be subtracted from the input to get back an image
  return {
    "X_noisy": alpha_cm ** 0.5 * X + (1 - alpha_cm) ** 0.5 * noise,
    "time": t,
  }, noise

In [9]:
# A function to subtract the predicted noise from an input

def subtract_noise(X_noisy, time, noise):
  X_shape = tf.shape(X_noisy)
  alpha_cm = tf.gather(alpha_cumprod, time)
  alpha_cm = tf.reshape(alpha_cm, [X_shape[0]] + [1] * (len(X_shape) - 1))
  return (X_noisy - (1 - alpha_cm) ** 0.5 * noise) / alpha_cm ** 0.5

In [7]:
# Prepare the dataset

# Applying the noise function to each batch in the dataset
def prepare_dataset(X, batch_size=32, shuffle=False): 
  ds = tf.data.Dataset.from_tensor_slices(X)
  if shuffle:
    ds = ds.shuffle(buffer_size=10_000)
  return ds.batch(batch_size).map(prepare_batch).prefetch(1)


train_set = prepare_dataset(X_train, batch_size=32, shuffle=True) 
valid_set = prepare_dataset(X_valid, batch_size=32)

In [10]:
# Building the diffusion model

# extra code – implements a custom time encoding layer

embed_size = 64

class TimeEncoding(tf.keras.layers.Layer):
    def __init__(self, T, embed_size, dtype=tf.float32, **kwargs):
        super().__init__(dtype=dtype, **kwargs)
        assert embed_size % 2 == 0, "embed_size must be even"
        p, i = np.meshgrid(np.arange(T + 1), 2 * np.arange(embed_size // 2))
        t_emb = np.empty((T + 1, embed_size))
        t_emb[:, ::2] = np.sin(p / 10_000 ** (i / embed_size)).T
        t_emb[:, 1::2] = np.cos(p / 10_000 ** (i / embed_size)).T
        self.time_encodings = tf.constant(t_emb.astype(self.dtype))

    def call(self, inputs):
        return tf.gather(self.time_encodings, inputs)
    
def build_diffusion_model():
    X_noisy = tf.keras.layers.Input(shape=[28, 28, 1], name="X_noisy")
    time_input = tf.keras.layers.Input(shape=[], dtype=tf.int32, name="time")
    time_enc = TimeEncoding(T, embed_size)(time_input)

    dim = 16
    Z = tf.keras.layers.ZeroPadding2D((3, 3))(X_noisy)
    Z = tf.keras.layers.Conv2D(dim, 3)(Z)
    Z = tf.keras.layers.BatchNormalization()(Z)
    Z = tf.keras.layers.Activation("relu")(Z)

    time = tf.keras.layers.Dense(dim)(time_enc)  # adapt time encoding
    Z = time[:, tf.newaxis, tf.newaxis, :] + Z  # add time data to every pixel

    skip = Z
    cross_skips = []  # skip connections across the down & up parts of the UNet

    for dim in (32, 64, 128):
        Z = tf.keras.layers.Activation("relu")(Z)
        Z = tf.keras.layers.SeparableConv2D(dim, 3, padding="same")(Z)
        Z = tf.keras.layers.BatchNormalization()(Z)

        Z = tf.keras.layers.Activation("relu")(Z)
        Z = tf.keras.layers.SeparableConv2D(dim, 3, padding="same")(Z)
        Z = tf.keras.layers.BatchNormalization()(Z)

        cross_skips.append(Z)
        Z = tf.keras.layers.MaxPooling2D(3, strides=2, padding="same")(Z)
        skip_link = tf.keras.layers.Conv2D(dim, 1, strides=2,
                                           padding="same")(skip)
        Z = tf.keras.layers.add([Z, skip_link])

        time = tf.keras.layers.Dense(dim)(time_enc)
        Z = time[:, tf.newaxis, tf.newaxis, :] + Z
        skip = Z

    for dim in (64, 32, 16):
        Z = tf.keras.layers.Activation("relu")(Z)
        Z = tf.keras.layers.Conv2DTranspose(dim, 3, padding="same")(Z)
        Z = tf.keras.layers.BatchNormalization()(Z)

        Z = tf.keras.layers.Activation("relu")(Z)
        Z = tf.keras.layers.Conv2DTranspose(dim, 3, padding="same")(Z)
        Z = tf.keras.layers.BatchNormalization()(Z)

        Z = tf.keras.layers.UpSampling2D(2)(Z)

        skip_link = tf.keras.layers.UpSampling2D(2)(skip)
        skip_link = tf.keras.layers.Conv2D(dim, 1, padding="same")(skip_link)
        Z = tf.keras.layers.add([Z, skip_link])

        time = tf.keras.layers.Dense(dim)(time_enc)
        Z = time[:, tf.newaxis, tf.newaxis, :] + Z
        Z = tf.keras.layers.concatenate([Z, cross_skips.pop()], axis=-1)
        skip = Z

    outputs = tf.keras.layers.Conv2D(1, 3, padding="same")(Z)[:, 2:-2, 2:-2]
    return tf.keras.Model(inputs=[X_noisy, time_input], outputs=[outputs])

In [None]:
model = build_diffusion_model()
model.compile(loss=tf.keras.losses.Huber(), optimizer="nadam")
history = model.fit(train_set, validation_data=valid_set, epochs=10)

In [None]:
# Using the model to generate images

def generate(model, batch_size=32):
  X = tf.random.normal([batch_size, 28, 28, 1]) 
  for t in range(T, 0, -1):
    noise = (tf.random.normal if t > 1 else tf.zeros)(tf.shape(X)) 
    X_noise = model({"X_noisy": X, "time": tf.constant([t] * batch_size)}) 
    X=(
        1 / alpha[t] ** 0.5
        * (X - beta[t] / (1 - alpha_cumprod[t]) ** 0.5 * X_noise)
        + (1 - alpha[t]) ** 0.5 * noise
    )
    return X

X_gen = generate(model)