# Diffusion Model on 16x16 MNIST with UNet and Cosine Noise Schedule

In [1]:
import tensorflow as tf
from tensorflow.keras import layers, optimizers, models, losses, callbacks, datasets, metrics
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.preprocessing.image import smart_resize

In [2]:
IMG_SIZE = 16
NUM_CHANNELS = 1

In [3]:
# Load dataset
(x_train, _), (_, _) = datasets.mnist.load_data()
x_train = np.expand_dims(x_train, -1)

# Resize and convert to numpy array, rescale from 0 to 1
x_train = np.array([smart_resize(img, (IMG_SIZE, IMG_SIZE)) for img in x_train])/255.0

# Rescale 
x_train = (x_train.astype(np.float32) - 0.5) * 2.0


In [4]:
class Diffusion(models.Model):
    def __init__(self, img_size=16, img_channels=1, timesteps=200, time_emb_dim=64):
        super(Diffusion, self).__init__()
        self.img_size = img_size
        self.img_channels = img_channels
        self.timesteps = timesteps
        self.emb_dim = time_emb_dim
        self.total_loss = metrics.Mean(name='total_loss')
        self.unet = self.make_unet()
        self.beta, self.alpha, self.alpha_hat = self.cosine_beta_schedule()
        
    def conv_block(self, x, filters, kernel_size=3, activation='relu'):
        x = layers.Conv2D(filters, kernel_size, padding='same', activation=activation)(x)
        x = layers.Conv2D(filters, kernel_size, padding='same', activation=activation)(x)
        return x

    def make_unet(self):
        image_input = layers.Input(shape=(self.img_size, self.img_size, self.img_channels))
        t_input = layers.Input(shape=(), dtype=tf.int32)
    
        # Timestep embedding
        t_emb = self.get_timestep_embedding(t_input)
        t_emb = layers.Dense(128, activation='relu')(t_emb)
        t_emb = layers.Dense(self.img_size * self.img_size * 32)(t_emb)
        t_emb = layers.Reshape((self.img_size, self.img_size, 32))(t_emb)
    
        x = layers.Concatenate()([image_input, t_emb])
    
        # Downsampling
        c1 = self.conv_block(x, 32)
        p1 = layers.MaxPooling2D((2, 2))(c1)
    
        c2 = self.conv_block(p1, 64)
        p2 = layers.MaxPooling2D((2, 2))(c2)
    
        # Bottleneck
        bn = self.conv_block(p2, 128)
    
        # Upsampling
        u1 = layers.UpSampling2D((2, 2))(bn)
        u1 = layers.Concatenate()([u1, c2])
        c3 = self.conv_block(u1, 64)
    
        u2 = layers.UpSampling2D((2, 2))(c3)
        u2 = layers.Concatenate()([u2, c1])
        c4 = self.conv_block(u2, 32)
    
        output = layers.Conv2D(self.img_channels, 1, padding='same')(c4)
    
        return models.Model([image_input, t_input], output)

    def cosine_beta_schedule(self):
        steps = self.timesteps + 1
        x = np.linspace(0, self.timesteps, steps)
        alphas_cumprod = np.cos(((x / self.timesteps) + 0.008) / (1 + 0.008) * np.pi * 0.5) ** 2
        alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
        betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
        beta = np.clip(betas, 0.0001, 0.9999).astype(np.float32)
        alpha = 1.0 - beta
        alpha_hat = np.cumprod(alpha)
        return beta, alpha, alpha_hat

    def get_timestep_embedding(self, timesteps):
        half_dim = self.emb_dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = tf.cast(timesteps, tf.float32)[:, None] * tf.exp(-emb * tf.range(half_dim, dtype=tf.float32)[None, :])
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        return emb

    def add_noise(self, x, t):
        noise = tf.random.normal(shape=tf.shape(x))
        alpha_hat_t = tf.gather(self.alpha_hat, t)
        sqrt_alpha_hat = tf.sqrt(alpha_hat_t)[:, None, None, None]
        sqrt_one_minus_alpha_hat = tf.sqrt(1.0 - alpha_hat_t)[:, None, None, None]
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * noise, noise

    def compile(self, optimizer, loss):
        super().compile()
        self.optimizer = optimizer
        self.loss_fn = loss

    @property
    def metrics(self):
        return [self.total_loss]
        
    def train_step(self, data):
        x = data[0]

        # Make sure x is 4D
        if len(x.shape) != 4:
            x = tf.expand_dims(x, 0)

        batch_size = tf.shape(x)[0]
        t = tf.random.uniform((batch_size,), minval=0, maxval=self.timesteps, dtype=tf.int32)
        noisy_x, noise = self.add_noise(x, t)

        with tf.GradientTape() as tape:
            pred = self.unet([noisy_x, t], training=True)
            loss = self.loss_fn(noise, pred)

        grads = tape.gradient(loss, self.unet.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.unet.trainable_variables))
        self.total_loss.update_state(loss)

        return {x.name: x.result() for x in self.metrics}
        
    def sample(self, n=10):
        x = tf.random.normal((n, self.img_size, self.img_size, self.img_channels))
        for t in reversed(range(self.timesteps)):
            t_tensor = tf.constant([t] * n, dtype=tf.int32)
            alpha_t = tf.constant(self.alpha[t], dtype=tf.float32)
            alpha_hat_t = tf.constant(self.alpha_hat[t], dtype=tf.float32)
            beta_t = tf.constant(self.beta[t], dtype=tf.float32)
            pred_noise = self.unet([x, t_tensor], training=False)
            coef1 = 1 / tf.sqrt(alpha_t)
            coef2 = (1 - alpha_t) / tf.sqrt(1 - alpha_hat_t)
            mean = coef1 * (x - coef2 * pred_noise)
            if t > 0:
                z = tf.random.normal(shape=x.shape)
                x = mean + tf.sqrt(beta_t) * z
            else:
                x = mean
        return (x + 1.0) / 2.0  # Rescale to [0, 1]

In [7]:
diffusion = Diffusion()
optimizer = optimizers.legacy.Adam(1e-4)
mse = losses.MeanSquaredError()
diffusion.compile(optimizer, mse)


ValueError: A KerasTensor cannot be used as input to a TensorFlow function. A KerasTensor is a symbolic placeholder for a shape and dtype, used when constructing Keras Functional models or Keras Functions. You can only use it as input to a Keras layer or a Keras operation (from the namespaces `keras.layers` and `keras.ops`). You are likely doing something like:

```
x = Input(...)
...
tf_fn(x)  # Invalid.
```

What you should do instead is wrap `tf_fn` in a layer:

```
class MyLayer(Layer):
    def call(self, x):
        return tf_fn(x)

x = MyLayer()(x)
```


In [None]:
batch_size = 128
epochs = 3000

diffusion.fit(x_train, epochs = epochs, batch_size = batch_size)

In [None]:
samples = diffusion.sample()

plt.figure(figsize=(10, 2))
for i in range(samples.shape[0]):
    plt.subplot(1, samples.shape[0], i + 1)
    plt.imshow(samples[i, ..., 0], cmap='gray')
    plt.axis('off')
plt.suptitle('Generated Samples')
plt.tight_layout()
plt.show()
