<a href="https://colab.research.google.com/github/TaiDuc1001/Flower-Diffusion-Model/blob/main/Flower_(64x64)_Generator_with_Diffusion_UNET.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## WandB and dependencies

In [None]:
!pip install wandb --quiet
!wandb login

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras import utils, callbacks, metrics, losses, activations

import wandb
import tarfile
import numpy as np
import os
import matplotlib.pyplot as plt
import math

## Extract

In [None]:
ZIP_PATH = "/content/drive/MyDrive/Datasets/102flowers/102flowers.tgz"
if not os.path.exists("/content/data"):
    os.mkdir("/content/data")

with tarfile.open(ZIP_PATH, "r") as tar:
    tar.extractall("/content/data")

In [None]:
ROOT_DIR = "/content/data/jpg"
images_names = os.listdir(ROOT_DIR)

## Config hyperparameters

In [None]:
config = {
    "ARCHITECTURE": "UNET",
    "IMAGE_SIZE": 64,
    "BATCH_SIZE": 64,
    "REPETITION": 5,
    "NOISE_EMBEDDING_SIZE": 32,
    "LEARNING_RATE": 1e-3,
    "WEIGHT_DECAY": 1e-4,
    "EMA": 0.99,
    "PLOT_DIFFUSION_STEP": 20,
    "EPOCHS": 50
}
wandb.init(project="Flower_Generator_with_Diffusion_UNET", config=config)
config = wandb.config

## Initialize dataset

In [None]:
train_dataset = utils.image_dataset_from_directory(
    ROOT_DIR,
    labels=None,
    image_size=(config.IMAGE_SIZE, config.IMAGE_SIZE),
    batch_size=None,
    shuffle=True,
    seed=42,
    interpolation="bilinear"
)

In [None]:
def preprocess_image(image):
    image = tf.cast(image, "float32") / 255.0
    return image

train = train_dataset.map(lambda x: preprocess_image(x))
train = train.repeat(config.REPETITION)
train = train.batch(config.BATCH_SIZE, drop_remainder=True)

In [None]:
batch = train.take(1).get_single_element().numpy()
plt.imshow(batch[0])
plt.axis("off")

## Diffusion

### Diffusion schedules

In [None]:
def linear_diffusion_schedule(diffusion_times):
    min_rate = 1e-4
    max_rate = 2e-2
    betas = min_rate + diffusion_times * (max_rate - min_rate)
    alphas = 1 - betas
    alpha_bars = tf.math.cumprod(alphas)
    signal_rates = tf.sqrt(alpha_bars)
    noise_rates = tf.sqrt(1-alpha_bars)
    return signal_rates, noise_rates

def cosine_diffusion_schedule(diffusion_times):
    signal_rates = tf.cos(diffusion_times * math.pi / 2)
    noise_rates = tf.sin(diffusion_times * math.pi / 2)
    return signal_rates, noise_rates

def offset_cosine_diffusion_schedule(diffusion_times):
    min_signal_rate = 2e-2
    max_signal_rate = 0.95
    start_angle = tf.acos(max_signal_rate)
    end_angle = tf.acos(min_signal_rate)
    diffusion_angle = start_angle + diffusion_times * (end_angle - start_angle)
    signal_rates = tf.cos(diffusion_angle)
    noise_rates = tf.sin(diffusion_angle)
    return signal_rates, noise_rates

In [None]:
T = 1000
diffusion_times = tf.convert_to_tensor([x/T for x in range(T)])
linear_signal_rates, linear_noise_rates = linear_diffusion_schedule(diffusion_times)
cosine_signal_rates, cosine_noise_rates = cosine_diffusion_schedule(diffusion_times)
offset_cosine_signal_rates, offset_cosine_noise_rates = offset_cosine_diffusion_schedule(diffusion_times)

In [None]:
plt.plot(diffusion_times, linear_signal_rates**2, linewidth=1.5, label="linear")
plt.plot(diffusion_times, cosine_signal_rates**2, linewidth=1.5, label="cosine")
plt.plot(diffusion_times, offset_cosine_signal_rates**2, linewidth=1.5, label="offset_cosine")

plt.xlabel("t/T", fontsize=12)
plt.ylabel(r"$\bar{\alpha_t}$ (signal)", fontsize=12)
plt.legend()
plt.show()

=> Cosine is outperforming the linear one. \
Note: Why offset looks like "smoother" than cosine is because max_rate is 0.95 not 1. If it is set to 1. then the 2 curves will look the same.

### Sinusoidal embedding (positional encoding in Transformer)

In [None]:
def sinusoidal_embedding(x):
    frequencies = tf.exp(tf.linspace(
        tf.math.log(1.0),
        tf.math.log(1000.0),
        config.NOISE_EMBEDDING_SIZE // 2 # 32 // 2 = 16
    ))
    angular_speed = 2.0 * math.pi * frequencies
    embeddings = tf.concat(
        [tf.sin(angular_speed * x), tf.cos(angular_speed * x)], axis=3
    )
    return embeddings

In [None]:
embedding_list = []
for y in np.arange(0, 1, 0.01):
    embedding_list.append(sinusoidal_embedding(np.array([[[[y]]]]))[0][0][0])
embedding_array = np.array(np.transpose(embedding_list))
fig, ax = plt.subplots()

ax.set_xticks(np.arange(0, 100, 10), labels=np.round(np.arange(0, 1, 0.1), 1))
ax.set_xlabel("Noise variance", fontsize=8)
ax.set_ylabel("Embedding dimension", fontsize=8)
plt.pcolor(embedding_array, cmap="coolwarm")
plt.colorbar(orientation="horizontal", label="embedding values")
ax.imshow(embedding_array, interpolation="nearest", origin="lower")
plt.show()

### Architecture

In [None]:
K.clear_session()
def ResidualBlock(width):
    def apply(x):
        input_width = x.shape[3] # channels
        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(width, kernel_size=1)(x)
        x = layers.BatchNormalization(center=False, scale=False)(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", activation=activations.swish
        )(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownBlock(width, block_depth):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(width)(x)
            skips.append(x)
        x = layers.AveragePooling2D(pool_size=2)(x)
        return x

    return apply


def UpBlock(width, block_depth):
    def apply(x):
        x, skips = x
        x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
        for _ in range(block_depth):
            x = layers.Concatenate()([x, skips.pop()])
            x = ResidualBlock(width)(x)
        return x

    return apply

noisy_images = layers.Input(shape=(config.IMAGE_SIZE, config.IMAGE_SIZE, 3))
x = layers.Conv2D(32, kernel_size=1)(noisy_images)

noise_variances = layers.Input(shape=(1, 1, 1))
noise_embedding = layers.Lambda(sinusoidal_embedding)(noise_variances)
noise_embedding = layers.UpSampling2D(size=config.IMAGE_SIZE, interpolation="nearest")(
    noise_embedding
)

x = layers.Concatenate()([x, noise_embedding])

skips = []

x = DownBlock(32, block_depth=2)([x, skips])
x = DownBlock(64, block_depth=2)([x, skips])
x = DownBlock(96, block_depth=2)([x, skips])

x = ResidualBlock(128)(x)
x = ResidualBlock(128)(x)

x = UpBlock(96, block_depth=2)([x, skips])
x = UpBlock(64, block_depth=2)([x, skips])
x = UpBlock(32, block_depth=2)([x, skips])

x = layers.Conv2D(3, kernel_size=1, kernel_initializer="zeros")(x)

unet = models.Model([noisy_images, noise_variances], x, name="unet")

### Diffusion Class

In [None]:
class DiffusionModel(models.Model):
    def __init__(self):
        super().__init__()
        self.normalizer = layers.Normalization()
        self.network = unet
        self.ema_network = models.clone_model(self.network)
        self.diffusion_schedule = offset_cosine_diffusion_schedule

    def compile(self, **kwargs):
        super().compile(**kwargs)
        self.noise_loss_tracker = metrics.Mean(name="n_loss")

    @property
    def metrics(self):
        return [self.noise_loss_tracker]

    def denormalize(self, images):
        images = self.normalizer.mean + self.normalizer.variance**0.5 * images
        images = tf.clip_by_value(images, 0.0, 1.0)
        return images

    def denoise(self, noisy_images, signal_rates, noise_rates, training):
        if training:
            network = self.network
        else:
            network = self.ema_network

        pred_noises = network([noisy_images, noise_rates**2], training=training)
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates
        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, diffusion_steps):
        num_images = initial_noise.shape[0]
        step_size = 1.0 / diffusion_steps
        current_images = initial_noise
        for step in range(diffusion_steps):
            diffusion_times = tf.ones((num_images, 1, 1, 1)) - step * step_size
            signal_rates, noise_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(
                current_images, signal_rates, noise_rates, training=False
            )
            next_diffusion_times = diffusion_times - step_size
            # Calculate sqrt(\bar{\alpha}_{t-1})
            next_signal_rates, next_noise_rates = self.diffusion_schedule(next_diffusion_times)
            current_images = (next_signal_rates * pred_images + next_noise_rates * pred_noises)
        return pred_images

    def generate(self, num_images, diffusion_steps, initial_noise=None):
        if initial_noise is None:
            initial_noise = tf.random.normal(
                shape=(num_images, config.IMAGE_SIZE, config.IMAGE_SIZE, 3)
            )
        generated_images = self.reverse_diffusion(initial_noise, diffusion_steps)
        generated_images = self.denormalize(generated_images)
        return generated_images

    def train_step(self, images):
        images = self.normalizer(images, training=True)
        noises = tf.random.normal((config.BATCH_SIZE, config.IMAGE_SIZE, config.IMAGE_SIZE, 3))
        diffusion_times = tf.random.uniform(
            shape=(config.BATCH_SIZE, 1, 1, 1), minval=0.0, maxval=1.0
        )
        signal_rates, noise_rates = self.diffusion_schedule(diffusion_times)
        noisy_images = (signal_rates * images + noise_rates * noises)
        with tf.GradientTape() as tape:
            pred_noises, pred_images = self.denoise(
                noisy_images, signal_rates, noise_rates, training=True
            )
            noise_loss = self.loss(noises, pred_noises)
        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))
        self.noise_loss_tracker.update_state(noise_loss)

        # Update EMA model
        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(config.EMA * ema_weight + (1-config.EMA) * weight)
        return {m.name: m.result() for m in self.metrics}

### ImageGenerator Callback

In [None]:
def display(images, n=8, title="Generated Images"):
    images = images[:n]
    grid = [wandb.Image(img) for img in images]
    wandb.log({title: grid})

class ImageGenerator(callbacks.Callback):
    def __init__(self, num_images):
        self.num_images = num_images
    def on_epoch_end(self, epoch, logs=None):
        generated_images = self.model.generate(
            num_images=self.num_images,
            diffusion_steps=config.PLOT_DIFFUSION_STEP
        ).numpy()
        display(generated_images)

class NoiseLossCallback(callbacks.Callback):
    def __init__(self, wandb):
        super().__init__()
        self.wandb = wandb
    def on_epoch_end(self, epoch, logs=None):
        if logs is not None:
            self.wandb.log(logs)

image_generator = ImageGenerator(num_images=8)
noise_loss_callback = NoiseLossCallback(wandb)

## Main training

In [None]:
ddm = DiffusionModel()
ddm.normalizer.adapt(train)

In [None]:
# wandb.init(project="Flower_Generator_with_Diffusion_UNET", config=config)
K.clear_session()
ddm.compile(
    optimizer=optimizers.AdamW(learning_rate=config.LEARNING_RATE, weight_decay=config.WEIGHT_DECAY),
    loss = losses.mean_absolute_error
)
for var in ddm.optimizer.variables():
    var.assign(tf.zeros_like(var))
ddm.fit(
    train,
    epochs=config.EPOCHS,
    use_multiprocessing = True,
    callbacks=[image_generator, noise_loss_callback]
)

## Analysis

In [None]:
for diffusion_steps in list(range(1, 6, 1)) + [20] + [100]:
    tf.random.set_seed(42)
    generated_images = ddm.generate(
        num_images=8,
        diffusion_steps=diffusion_steps
    ).numpy()
    display(generated_images, title="Changes through diffusion steps")

In [None]:
ddm.save("/content/ddm_64x64_50e", save_format="h5")