In [1]:
from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
# import PIL
import tensorflow as tf
import tensorflow_probability as tfp
# import tensorflow.contrib as tf_contrib
import time

#imp
tf.keras.backend.set_image_data_format('channels_last')

In [2]:
### A simplified attention block
def hw_flatten(x) :
    # return tf.reshape(x, shape=[x.shape[0], -1, x.shape[-1]])
    return tf.reshape(x, shape=(x.shape[0], -1, x.shape[-1]))

def attention(x, channels=265):

    f = tf.keras.layers.Conv2D(filters=1, kernel_size=1, strides=1, padding='same', use_bias=True)(x) # [bs, h, w, c']
    g = tf.keras.layers.Conv2D(filters=1, kernel_size=1, strides=1, padding='same', use_bias=True)(x) # [bs, h, w, c']
    h = tf.keras.layers.Conv2D(filters=1, kernel_size=1, strides=1, padding='same', use_bias=True)(x) # [bs, h, w, c]
    # print('h', h.shape)
    # N = h * w
    # s = tf.linalg.matmul(hw_flatten(g), hw_flatten(f), transpose_b=True) # # [bs, N, N]
    s = tf.matmul(g, f, transpose_b=True) # # [bs, N, N]
    # s = tf.matmul(tf.keras.layers.Flatten()(g), tf.keras.layers.Flatten()(f), transpose_b=True) # # [bs, N, N]
    # print('s', s.shape)
    beta = tf.nn.softmax(s)  # attention map
    # print('beta', beta.shape)

    # o = tf.linalg.matmul(beta, hw_flatten(h)) # [bs, N, C]
    o = tf.matmul(beta, h) # [bs, N, C]
    # o = tf.linalg.matmul(beta, tf.keras.layers.Flatten()(h)) # [bs, N, C]
    # print('o', o.shape)
    # Unsure if this is correct, see documentation: https://www.tensorflow.org/api_docs/python/tf/compat/v1/get_variable#migrate-to-tf2
    gamma = tf.compat.v1.get_variable("gamma", [1], initializer=tf.constant_initializer(0.0))

    # o = tf.reshape(o, shape=x.shape) # [bs, h, w, C]
    o = tf.keras.layers.Conv2D(filters=channels, kernel_size=1, strides=1, padding='same', use_bias=True)(o)

    x = gamma * o + x

    return x

class Sampling(tf.keras.layers.Layer):
    """Uses (z_mean, z_log_var) to sample z, the vector encoding a digit."""

    def call(self, inputs):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon # this formula is considered best practice


In [3]:
def Conv3D_Block(inp_shape):
        inp = tf.keras.layers.Input(shape=inp_shape)

        # We will construct 4 `ConvLSTM2D` layers with batch normalization,
        # followed by a `Conv3D` layer for the spatiotemporal outputs.
        x = tf.keras.layers.ConvLSTM2D(filters=4, kernel_size=(3), strides=(2,2), padding="same", return_sequences=True, activation="relu", data_format='channels_last', )(inp)
        #x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ConvLSTM2D(filters=8, kernel_size=(3), strides=(2,2), padding="same", return_sequences=True, activation="relu")(inp)
        #x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ConvLSTM2D(filters=16, kernel_size=(3), strides=(2,2), padding="same", return_sequences=True, activation="relu")(x)
        #x = tf.keras.layers.BatchNormalization()(x)
        x = tf.keras.layers.ConvLSTM2D(filters=32, kernel_size=(3), strides=(2,2), padding="same", return_sequences=True, activation="relu")(x)
        #x = tf.keras.layers.BatchNormalization()(x)
        res1 = tf.keras.layers.Conv3D(filters=1, kernel_size=(1,1,1), padding="same")(x)
        res1 = tf.keras.layers.LeakyReLU(alpha=0.05)(res1)
        res1 = tf.keras.layers.MaxPooling3D(pool_size=2, )(res1)
        #res1 = tf.keras.layers.BatchNormalization()(res1)
        # attention
        x = attention(res1, channels=265)
        # residual
        x = tf.keras.layers.Add()([res1, x])
        x = tf.keras.layers.Dense(16,activation='relu')(x)
        x = tf.keras.layers.Dense(32,activation='relu')(x)
        x = tf.keras.layers.Dense(64,activation='relu')(x)
        x = tf.keras.layers.LeakyReLU(alpha=0.05)(x)

        # Next, we will build the complete model and compile it.
        model = tf.keras.Model(inputs=inp, outputs=x)
        return model

In [4]:
def encoder_model(latent_dim, inp_shape):
    """ Adapted from Laurence Moroney's Coursera course on VAEs: https://www.coursera.org/lecture/generative-deep-learning-with-tensorflow/sampling-layer-and-encoder-G2mJr"""
    demand_model = Conv3D_Block(inp_shape[0])
    ex_f_model = Conv3D_Block(inp_shape[1])
    combined = tf.keras.layers.concatenate([demand_model.output, ex_f_model.output], axis=-1)
    x = tf.keras.layers.Dense(16, activation='relu')(combined)
    x = tf.keras.layers.Dense(32, activation='relu')(x)
    x = tf.keras.layers.Dense(64, activation='relu')(x)
    x = tf.keras.layers.Flatten()(x)
    z_mean = tf.keras.layers.Dense(latent_dim, name="z_mean")(x)
    z_log_var = tf.keras.layers.Dense(latent_dim, name="z_log_var")(x)
    z = Sampling()([z_mean, z_log_var])
    encoder = tf.keras.Model(inputs=[demand_model.input, ex_f_model.input], outputs=[z_mean, z_log_var, z], name="encoder")
    return encoder

In [5]:
def generator_model(latent_dim):
    
    latent_inputs = tf.keras.Input(shape=(latent_dim,))
    x = tf.keras.layers.Dense(34*34*32, activation='relu')(latent_inputs)
    x = tf.keras.layers.Reshape(target_shape=(34,34,32))(x)
    x = tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=(3), strides=2, padding="same", activation="relu")(x)
    x = tf.keras.layers.Conv2DTranspose(filters=16, kernel_size=(3), strides=2, padding="same", activation="relu")(x)
    x = tf.keras.layers.Conv2DTranspose(filters=8, kernel_size=(3), strides=2, padding="same", activation="relu")(x)
    x = tf.keras.layers.Conv2D(filters=1, kernel_size=1, strides=1)(x)
    x = tf.keras.layers.Reshape(target_shape=(1,272,272,1))(x)
    x = tf.keras.layers.Conv3DTranspose(filters=1, kernel_size=(3), padding="same", activation="relu")(x)
    x = tf.keras.layers.LeakyReLU(alpha=0.05)(x)

    #generator = tf.keras.Model(latent_inputs, outputs=[x_real, x_enc, x_fake], name="generator")
    generator = tf.keras.Model(latent_inputs, outputs=x, name="generator")
    return generator

In [6]:
class CVAE(tf.keras.Model):
    def __init__(self, encoder, generator, **kwargs):
        super(CVAE, self).__init__(**kwargs)
        self.encoder = encoder
        self.generator = generator
        self.total_loss_tracker = tf.keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = tf.keras.metrics.Mean(name="reconstruction_loss")
        self.kl_loss_tracker = tf.keras.metrics.Mean(name="kl_loss")

    @property
    def metrics(self):
        return[self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker]

    def train_step(self,data):
        with tf.GradientTape() as tape:
            z_mean, z_log_variance, z = self.encoder([data[0], data[1]])
            reconstruction = self.generator(z)
            reconstruction_loss = tf.reduce_mean(tf.reduce_sum(tf.keras.losses.binary_crossentropy(data[0], reconstruction), axis=(1,2)))
            kl_loss = -0.5 * (1 + z_log_variance - tf.square(z_mean) - tf.exp(z_log_variance))
            kl_loss = tf.reduce_mean(tf.reduce_sum(kl_loss, axis=1))
            total_loss =reconstruction_loss + kl_loss
        grads = tape.gradient(total_loss, self.trainable_weights)
        self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.kl_loss_tracker.update_state(kl_loss)

        return { 
            "loss":self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "kl_loss": self.kl_loss_tracker.result()
        }


Note, it's common practice to avoid using batch normalization when training VAEs, since the additional stochasticity due to using mini-batches may aggravate instability on top of the stochasticity from sampling.

In [12]:
latent_dim = 4
batch_size = 1 # Paper: 32
channel = 1
height = 272
width = 272
depth = 6
inp_shape = [(depth, height, width, channel), (depth, height, width, channel)]
encoder = encoder_model(latent_dim, inp_shape)
generator = generator_model(latent_dim)
vae = CVAE(encoder, generator)
vae.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.01))

In [13]:
train_data = tf.random.normal(shape = (batch_size, depth, height, width, channel))

In [14]:
z_mean, z_log_variance, z = encoder([train_data, train_data])

In [15]:
data = tf.stack([train_data, train_data], axis = 0)

In [16]:
vae.fit(data, epochs=1)



<keras.callbacks.History at 0x2989183bc08>