In [7]:
import math
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds

from tensorflow import keras
from keras import layers
import numpy as np
import diffusion
from diffusion import DiffusionModel
import unet
from vqvae import VQVAETrainer
from train_vqvae import create_image_dataset

In [8]:
# data
dataset_repetitions = 5
num_epochs = 2000  # train for at least 50 epochs for good results
image_height = 128
image_width = 256
latent_height = 32
latent_width = 64

diffusion_steps = 5
plot_diffusion_steps = 20

# sampling
min_signal_rate = 0.02
max_signal_rate = 0.95

# architecture
embedding_dims = 32
emb_size=32
#num_classes = 12

widths = [32, 64, 96]
block_depth = 2
attention_levels = [0, 1, 0]
latent_dim = 16

# optimization
batch_size = 64
ema = 0.999
learning_rate = 1e-3
weight_decay = 1e-4

In [9]:
img_dset= create_image_dataset("combined_rgb.tfrecord", batch_size=64, buffer_size=640)
scaled_images = np.concatenate(list(img_dset.as_numpy_iterator()), axis=0)
data_variance = np.var(scaled_images)

In [10]:
class PlotImagesCallback(keras.callbacks.Callback):
    def __init__(self, model, img_dset, num_rows, num_cols):
        super().__init__()
        self.model = model
        self.img_dset = img_dset
        self.num_rows = num_rows
        self.num_cols = num_cols

    def on_epoch_end(self, epoch, logs=None):
        # Get a single batch from the dataset
        batch_images = next(iter(self.img_dset))
        #print(batch_images)

        # Choose 2 random indices from the batch
        num_samples = batch_images.shape[0]
        random_indices = np.random.choice(num_samples, 2)
        random_images = tf.gather(batch_images, random_indices)
        #print(random_images.shape)
        # Plot images before forward diffusion and after reverse diffusion
        self.model.plot_images(
            batch_images=random_images,
            epoch=epoch,
            num_rows=self.num_rows,
            num_cols=self.num_cols
        )

In [11]:
vqvae = VQVAETrainer(data_variance, latent_dim=16, num_embeddings=128)
vqvae.vqvae.load_weights("vqvae_weights.h5")



# create and compile the model
model = DiffusionModel(widths,
                       block_depth,
                       attention_levels,
                       vqvae)

plot_images_callback = PlotImagesCallback(model, img_dset, num_rows=2, num_cols=2)

# below tensorflow 2.9:
# pip install tensorflow_addons
'''
model.compile(
    optimizer=keras.optimizers.experimental.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    loss=keras.losses.mean_absolute_error,
)
'''
import tensorflow_addons as tfa
model.compile(
    optimizer=tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    loss=keras.losses.mean_absolute_error,
)

#plot model
#keras.utils.plot_model(model.network, "model.png")

# pixelwise mean absolute error is used as loss

# save the best model based on the validation KID metric
checkpoint_path = "checkpoints/diffusion_model"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
    monitor="n_loss",
    mode="min",
    save_best_only=True,
)

# calculate mean and variance of training dataset for normalization
#model.normalizer.adapt(img_dset)

# run training and plot generated images periodically
model.fit(
    img_dset,
    epochs=num_epochs,
    #steps_per_epoch=100,
    #validation_data=val_dataset,
    callbacks=[
        plot_images_callback,
        checkpoint_callback,
    ],
)

1
2
2
1
Epoch 1/2000
     19/Unknown - 11s 325ms/step - n_loss: 0.73842
Epoch 2/2000
Epoch 3/2000
Epoch 4/2000
Epoch 5/2000
Epoch 6/2000
Epoch 7/2000
Epoch 8/2000
Epoch 9/2000
Epoch 10/2000
Epoch 11/2000
Epoch 12/2000
Epoch 13/2000
Epoch 14/2000
Epoch 15/2000
Epoch 16/2000
Epoch 17/2000
Epoch 18/2000
Epoch 19/2000
Epoch 20/2000
Epoch 21/2000
Epoch 22/2000
Epoch 23/2000
Epoch 24/2000
Epoch 25/2000
Epoch 26/2000
Epoch 27/2000
Epoch 28/2000
Epoch 29/2000
Epoch 30/2000
Epoch 31/2000
Epoch 32/2000
Epoch 33/2000
Epoch 34/2000
Epoch 35/2000
Epoch 36/2000
Epoch 37/2000
Epoch 38/2000
Epoch 39/2000
Epoch 40/2000
Epoch 41/2000
Epoch 42/2000
Epoch 43/2000
Epoch 44/2000
Epoch 45/2000
Epoch 46/2000
Epoch 47/2000
Epoch 48/2000
Epoch 49/2000
Epoch 50/2000
Epoch 51/2000
Epoch 52/2000
Epoch 53/2000

KeyboardInterrupt: 