In [None]:
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow import keras
from keras import layers
#from google.colab import drive
#drive.mount('/content/drive')


In [None]:
# data
dataset_repetitions = 5
num_epochs = 500  # train for at least 50 epochs for good results
image_height = 32
image_width = 64
# KID = Kernel Inception Distance, see related section
#kid_image_size = 75
kid_diffusion_steps = 5
plot_diffusion_steps = 20

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# architecture
embedding_dims = 32
emb_size=32
num_classes = 12

widths = [32, 64, 96]
block_depth = 2
attention_levels = [0, 1, 0]

# optimization
batch_size = 16
ema = 0.999
learning_rate = 1e-3
weight_decay = 1e-4

In [None]:
def parse_images(example, img_height=image_height, img_width=image_width):
    feature_description = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "path": tf.io.FixedLenFeature([], tf.string),
    }

    example = tf.io.parse_single_example(example, feature_description)
    image = tf.io.decode_png(example["image"], channels=1)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, [img_height, img_width])
    image = tf.where(image != 0, 1.0, 0.0)


    return image

In [None]:
vehicles_dset = tf.data.TFRecordDataset("vehicles.tfrecord")
img_dset = vehicles_dset.map(parse_images, num_parallel_calls=tf.data.AUTOTUNE)
img_dset = img_dset.batch(batch_size, drop_remainder=True)
img_dset = img_dset.shuffle(1000)
img_dset = img_dset.prefetch(buffer_size=tf.data.AUTOTUNE)

In [None]:
iterator = iter(img_dset)
cur_image = next(iterator)
#print(image.numpy().shape)
plt.imshow(cur_image[0].numpy(), cmap='gray')


In [None]:
def attention(qkv):

    q, k, v = qkv
    # should we scale this?
    s = tf.matmul(k, q, transpose_b=True)  # [bs, h*w, h*w]
    beta = tf.nn.softmax(s)  # attention map
    o = tf.matmul(beta, v)  # [bs, h*w, C]
    return o

def spatial_attention(img):

    filters = img.shape[3]
    orig_shape = ((img.shape[1], img.shape[2], img.shape[3]))
    #print(orig_shape)
    img = layers.BatchNormalization()(img)

    # projections:
    q = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(img)
    k = layers.Conv2D(filters // 8, kernel_size=1, padding="same")(img)
    v = layers.Conv2D(filters, kernel_size=1, padding="same")(img)
    k = layers.Reshape((k.shape[1] * k.shape[2], k.shape[3],))(k)

    q = layers.Reshape((q.shape[1] * q.shape[2], q.shape[3]))(q)
    v = layers.Reshape((v.shape[1] * v.shape[2], v.shape[3],))(v)

    img = layers.Lambda(attention)([q, k, v])
    img = layers.Reshape(orig_shape)(img)

    # out_projection:
    img = layers.Conv2D(filters, kernel_size=1, padding="same")(img)
    img = layers.BatchNormalization()(img)

    return img

def sinusoidal_embedding(x):
    embedding_min_frequency = 1.0
    embedding_max_frequency = 1000.0
    frequencies = tf.exp(
        tf.linspace(
            tf.math.log(embedding_min_frequency),
            tf.math.log(embedding_max_frequency),
            embedding_dims // 2,
        )
    )
    angular_speeds = 2.0 * math.pi * frequencies
    embeddings = tf.concat(
        [tf.sin(angular_speeds * x), tf.cos(angular_speeds * x)], axis=3
    )
    return embeddings


def ResidualBlock(width):
    def apply(x):
        input_width = x.shape[3]
        if input_width == width:
            residual = x
        else:
            residual = layers.Conv2D(width, kernel_size=1)(x)
        x = layers.BatchNormalization(center=False, scale=False)(x)
        x = layers.Conv2D(
            width, kernel_size=3, padding="same", activation=keras.activations.swish
        )(x)
        x = layers.Conv2D(width, kernel_size=3, padding="same")(x)
        x = layers.Add()([x, residual])
        return x

    return apply


def DownBlock(width, block_depth, use_self_attention=False):
    def apply(x):
        x, skips = x
        for _ in range(block_depth):
            x = ResidualBlock(width)(x)

            if use_self_attention:
                o = spatial_attention(x)
                x = layers.Add()([x, o])


            skips.append(x)
        x = layers.AveragePooling2D(pool_size=2)(x)
        return x

    return apply


def UpBlock(width, block_depth, use_self_attention=False):
    def apply(x):
        x, skips = x
        x = layers.UpSampling2D(size=2, interpolation="bilinear")(x)
        for _ in range(block_depth):
            x = layers.Concatenate()([x, skips.pop()])
            x = ResidualBlock(width)(x)

            if use_self_attention:
                o = spatial_attention(x)
                x = layers.Add()([x, o])


        return x

    return apply


def get_network(image_height, image_width, widths, block_depth, num_classes,emb_size, attention_levels, precomputed_embedding=False):
    noisy_images = keras.Input(shape=(image_height, image_width, 1))
    noise_variances = keras.Input(shape=(1, 1, 1))

    e = layers.Lambda(sinusoidal_embedding)(noise_variances)
    e = layers.UpSampling2D(size=(image_height, image_width), interpolation="nearest")(e)

    x = layers.Conv2D(widths[0], kernel_size=1)(noisy_images)
    x = layers.Concatenate()([x, e])


    skips = []
    level = 0

    for width in widths[:-1]:
        use_self_attention = attention_levels[level]
        x = DownBlock(width, block_depth, use_self_attention)([x, skips, ])#emb_and_noise])
        #emb_and_noise = layers.AveragePooling2D()(emb_and_noise)
        level += 1
        print(level)

    for _ in range(block_depth):
        x = ResidualBlock(widths[-1])(x)
        #print(attention_levels[level])
        if bool(attention_levels[level]):
            o = spatial_attention(x)
            x = layers.Add()([x, o])
           # cross_att = cross_attention(x, emb_and_noise)
           # x = layers.Add()([x, cross_att])

    for width in reversed(widths[:-1]):
        print(level)
        level -= 1
        #emb_and_noise = layers.UpSampling2D(size=2, interpolation="bilinear")(emb_and_noise)
        use_self_attention = bool(attention_levels[level])
        x = UpBlock(width, block_depth, use_self_attention)([x, skips])

    x = layers.Conv2D(1, kernel_size=1, kernel_initializer="zeros", activation="linear")(x)

    return keras.Model([noisy_images, noise_variances], x, name="residual_unet")


In [None]:

class DiffusionModel(keras.Model):
    def __init__(self, image_height, image_width, widths, block_depth, attention_levels):
        super().__init__()

        self.normalizer = layers.Normalization()
        self.network = get_network(image_height, image_width, widths, block_depth, num_classes,emb_size, attention_levels, precomputed_embedding=False)
        self.ema_network = keras.models.clone_model(self.network)
        self.attention_leves = attention_levels


    def compile(self, **kwargs):
        super().compile(**kwargs)

        self.noise_loss_tracker = keras.metrics.Mean(name="n_loss")
        self.image_loss_tracker = keras.metrics.Mean(name="i_loss")
        #self.kid = KID(name="kid")

    @property
    def metrics(self):
        return [self.noise_loss_tracker, self.image_loss_tracker]#, self.kid]

    def denormalize(self, images):
        # convert the pixel values back to 0-1 range
        images = self.normalizer.mean + images * self.normalizer.variance**0.5
        return tf.clip_by_value(images, 0.0, 1.0)

    def diffusion_schedule(self, diffusion_times):
        # diffusion times -> angles
        start_angle = tf.acos(max_signal_rate)
        end_angle = tf.acos(min_signal_rate)

        diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)

        # angles -> signal and noise rates
        signal_rates = tf.cos(diffusion_angles)
        noise_rates = tf.sin(diffusion_angles)
        # note that their squared sum is always: sin^2(x) + cos^2(x) = 1

        return noise_rates, signal_rates

    def denoise(self, noisy_images, noise_rates, signal_rates, training):
        # the exponential moving average weights are used at evaluation
        if training:
            network = self.network
        else:
            network = self.ema_network

        # predict noise component and calculate the image component using it
        #save input for plotting model
        self.inputs = [noisy_images, noise_rates**2]
        pred_noises = network([noisy_images, noise_rates**2], training=training)
        self.outputs = pred_noises
        pred_images = (noisy_images - noise_rates * pred_noises) / signal_rates

        return pred_noises, pred_images

    def reverse_diffusion(self, initial_noise, diffusion_steps):
        # reverse diffusion = sampling
        num_images = initial_noise.shape[0]
        step_size = 1.0 / diffusion_steps

        # important line:
        # at the first sampling step, the "noisy image" is pure noise
        # but its signal rate is assumed to be nonzero (min_signal_rate)
        next_noisy_images = initial_noise
        for step in range(diffusion_steps):
            noisy_images = next_noisy_images

            # separate the current noisy image to its components
            diffusion_times = tf.ones((num_images, 1, 1, 1)) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=False
            )
            # network used in eval mode

            # remix the predicted components using the next signal and noise rates
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(
                next_diffusion_times
            )
            next_noisy_images = (
                next_signal_rates * pred_images + next_noise_rates * pred_noises
            )
            # this new noisy image will be used in the next step

        return pred_images

    def generate(self, num_images, diffusion_steps):
        # noise -> images -> denormalized images
        initial_noise = tf.random.normal(shape=(num_images, image_height, image_width, 1))
        generated_images = self.reverse_diffusion(initial_noise, diffusion_steps)
        generated_images = self.denormalize(generated_images)
        return generated_images

    def train_step(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = self.normalizer(images, training=True)
        noises = tf.random.normal(shape=(batch_size, image_height, image_width, 1))

        # sample uniform random diffusion times
        diffusion_times = tf.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)

       # print("signal_rates shape:", signal_rates.shape)
       # print("images shape:", images.shape)
       # print("noises shape:", noises.shape)
        # mix the images with noises accordingly
        noisy_images = signal_rates * images + noise_rates * noises

        with tf.GradientTape() as tape:
            # train the network to separate noisy images to their components
            pred_noises, pred_images = self.denoise(
                noisy_images, noise_rates, signal_rates, training=True
            )

            noise_loss = self.loss(noises, pred_noises)  # used for training
            image_loss = self.loss(images, pred_images)  # only used as metric

        gradients = tape.gradient(noise_loss, self.network.trainable_weights)
        self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights))

        self.noise_loss_tracker.update_state(noise_loss)
        self.image_loss_tracker.update_state(image_loss)

        # track the exponential moving averages of weights
        for weight, ema_weight in zip(self.network.weights, self.ema_network.weights):
            ema_weight.assign(ema * ema_weight + (1 - ema) * weight)

        # KID is not measured during the training phase for computational efficiency
        return {m.name: m.result() for m in self.metrics[:-1]}

    def test_step(self, images):
        # normalize images to have standard deviation of 1, like the noises
        images = self.normalizer(images, training=False)
        noises = tf.random.normal(shape=(batch_size, image_height, image_width, 1))

        # sample uniform random diffusion times
        diffusion_times = tf.random.uniform(
            shape=(batch_size, 1, 1, 1), minval=0.0, maxval=1.0
        )
        noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
        # mix the images with noises accordingly
        noisy_images = signal_rates * images + noise_rates * noises

        # use the network to separate noisy images to their components
        pred_noises, pred_images = self.denoise(
            noisy_images, noise_rates, signal_rates, training=False
        )

        noise_loss = self.loss(noises, pred_noises)
        image_loss = self.loss(images, pred_images)

        self.image_loss_tracker.update_state(image_loss)
        self.noise_loss_tracker.update_state(noise_loss)

        # measure KID between real and generated images
        # this is computationally demanding, kid_diffusion_steps has to be small
        images = self.denormalize(images)
        generated_images = self.generate(
            num_images=batch_size, diffusion_steps=kid_diffusion_steps
        )
        #self.kid.update_state(images, generated_images)

        return {m.name: m.result() for m in self.metrics}

    def plot_images(self, epoch=None, logs=None, num_rows=3, num_cols=6):
        # plot random generated images for visual evaluation of generation quality
        generated_images = self.generate(
            num_images=num_rows * num_cols,
            diffusion_steps=plot_diffusion_steps,
        )

        plt.figure(figsize=(num_cols * 2.0, num_rows * 2.0))
        for row in range(num_rows):
            for col in range(num_cols):
                index = row * num_cols + col
                plt.subplot(num_rows, num_cols, index + 1)
                plt.imshow(generated_images[index], cmap='gray')
                plt.axis("off")
        plt.tight_layout()
        plt.show()
        plt.close()


In [None]:
import tensorflow_addons as tfa
# create and compile the model
model = DiffusionModel(image_height, image_width, widths, block_depth, attention_levels)

# below tensorflow 2.9:
# pip install tensorflow_addons
# import tensorflow_addons as tfa
# optimizer=tfa.optimizers.AdamW
model.compile(
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    loss=keras.losses.mean_absolute_error,
)

#plot model
#keras.utils.plot_model(model.network, "model.png")

# pixelwise mean absolute error is used as loss

# save the best model based on the validation KID metric
checkpoint_path = "checkpoints/diffusion_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="i_loss",
    mode="min",
    save_best_only=True,
)

# calculate mean and variance of training dataset for normalization
model.normalizer.adapt(img_dset)

# run training and plot generated images periodically
model.fit(
    img_dset,
    epochs=num_epochs,
    #steps_per_epoch=100,
    #validation_data=val_dataset,
    callbacks=[
        keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        checkpoint_callback,
    ],
)

In [None]:
model.fit(
    img_dset,
    epochs=num_epochs,
    #steps_per_epoch=100,
    #validation_data=val_dataset,
    callbacks=[
        keras.callbacks.LambdaCallback(on_epoch_end=model.plot_images),
        checkpoint_callback,
    ],
)

In [None]:
# load the best model and generate images
model.load_weights(checkpoint_path)
model.plot_images()