In [1]:
import tensorflow as tf
from datetime import datetime as dt
import numpy as np
from matplotlib import pyplot as plt

random_seed = 42
tf.random.set_seed(random_seed)
np.random.seed(random_seed)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [2]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.cifar10.load_data()

def preprocess_images(images):
    images = images.reshape((images.shape[0], 32, 32, 3)).astype('float32') / 255.
    return images

train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)

length = train_images.shape[1]
width = train_images.shape[2]
channels = train_images.shape[3]

train_size = train_images.shape[0]
test_size = test_images.shape[0]
batch_size = 128

train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(train_size).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(test_size).batch(batch_size)

In [3]:
class Sampling(tf.keras.layers.Layer):
    @tf.function
    def call(self, inputs, training=False):
        z_mean, z_log_var = inputs
        batch = tf.shape(z_mean)[0]
        dim = tf.shape(z_mean)[1]
        epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
LR_SCHEDULE = [
    (25, 0.0001),
    (50, 0.00005),
    (75, 0.00002),
    (100, 0.00001),
]

def lr_schedule(epoch, lr):
    if epoch < LR_SCHEDULE[0][0] or epoch > LR_SCHEDULE[-1][0]:
        return lr
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr

In [None]:
class CustomAAECallbacks(tf.keras.callbacks.Callback):
    def __init__(self, X, schedule, patience=0):
        super(CustomAAECallbacks, self).__init__()
        # Immagini per la ricostruzione
        self.X = X
        self.patience = patience
        self.schedule = schedule

    def on_train_begin(self, logs=None):
        self.wait = 0
        self.stopped_epoch = 0
        self.ae_best = np.Inf
        self.gen_best = np.Inf
        self.dc_best = np.Inf

    def on_epoch_begin(self, epoch, logs=None):
        ae_lr = float(tf.keras.backend.get_value(self.model.ae_optimizer.learning_rate))
        gen_lr = float(tf.keras.backend.get_value(self.model.gen_optimizer.learning_rate))
        dc_lr = float(tf.keras.backend.get_value(self.model.dc_optimizer.learning_rate))
        scheduled_ae_lr = self.schedule(epoch, ae_lr)
        scheduled_dc_lr = self.schedule(epoch, dc_lr)
        scheduled_gen_lr = self.schedule(epoch, gen_lr)
        tf.keras.backend.set_value(self.model.ae_optimizer.lr, scheduled_ae_lr)
        tf.keras.backend.set_value(self.model.gen_optimizer.lr, scheduled_gen_lr)
        tf.keras.backend.set_value(self.model.dc_optimizer.lr, scheduled_dc_lr)
        print("\nEpoch %05d: Learning rate is %6.4f." % (epoch, scheduled_ae_lr))

    def on_epoch_end(self, epoch, logs=None):
        # Save 10 reconstructions every 10 epochs
        if epoch%10==0:
            self.model.save_10_reconstructions(self.X, epoch)

        current_ae_loss = logs.get("val_ae_loss")
        current_gen_loss = logs.get("val_gen_loss")
        current_dc_loss = logs.get("val_dc_loss")
        if np.less(current_ae_loss, self.ae_best) or np.less(current_gen_loss, self.gen_best) or np.less(current_dc_loss, self.dc_best):
            self.ae_best = current_ae_loss
            self.dc_best = current_dc_loss
            self.gen_best = current_gen_loss
            self.wait = 0
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
        
    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d: early stopping" % (self.stopped_epoch + 1))


In [4]:
def get_encoder(input_shape, latent_dim, leaky_alpha, filters, kernel_size, strides, dense_units):
    inputs = tf.keras.Input(shape=input_shape)

    x = tf.keras.layers.Conv2D(filters=filters[0], kernel_size=kernel_size[0], strides=strides[0], padding='same', kernel_regularizer='l2')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(leaky_alpha)(x)

    x = tf.keras.layers.Conv2D(filters=filters[1], kernel_size=kernel_size[1], strides=strides[1], padding='same', kernel_regularizer='l2')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(leaky_alpha)(x)

    x = tf.keras.layers.Conv2D(filters=filters[2], kernel_size=kernel_size[2], strides=strides[2], padding='same', kernel_regularizer='l2')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(leaky_alpha)(x)

    x = tf.keras.layers.Conv2D(filters=filters[3], kernel_size=kernel_size[3], strides=strides[3], padding='same', kernel_regularizer='l2')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(leaky_alpha)(x)

    flatten = tf.keras.layers.Flatten()(x)
    x = tf.keras.layers.Dense(dense_units[0], activation='relu', kernel_regularizer='l2')(flatten)
    x = tf.keras.layers.BatchNormalization()(x)
    z_mean = tf.keras.layers.Dense(latent_dim)(x)
    z_log_var = tf.keras.layers.Dense(latent_dim)(x)

    z = Sampling()([z_mean, z_log_var])

    model = tf.keras.Model(inputs, z, name="Encoder")
    return model

def get_decoder(input_shape, latent_dim, leaky_alpha, filters, kernel_size, strides, dense_units, stride_reduction):
    inputs = tf.keras.Input(shape=(latent_dim,))

    x = tf.keras.layers.Dense(dense_units[0], activation='relu', kernel_regularizer='l2')(inputs)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(int(input_shape[0]*input_shape[1]*filters[3]/stride_reduction**2), activation='relu', kernel_regularizer='l2')(x)
    
    reshaped = tf.keras.layers.Reshape((int(input_shape[0]/stride_reduction), int(input_shape[1]/stride_reduction), filters[3]))(x)
    x = tf.keras.layers.Conv2DTranspose(filters=filters[3], kernel_size=kernel_size[3], strides=strides[3], padding='same', kernel_regularizer='l2')(reshaped)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(leaky_alpha)(x)

    x = tf.keras.layers.Conv2DTranspose(filters=filters[2], kernel_size=kernel_size[2], strides=strides[2], padding='same', kernel_regularizer='l2')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(leaky_alpha)(x)

    x = tf.keras.layers.Conv2DTranspose(filters=filters[1], kernel_size=kernel_size[1], strides=strides[1], padding='same', kernel_regularizer='l2')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(leaky_alpha)(x)

    x = tf.keras.layers.Conv2DTranspose(filters=filters[0], kernel_size=kernel_size[0], strides=strides[0], padding='same', kernel_regularizer='l2')(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.LeakyReLU(leaky_alpha)(x)

    image = tf.keras.layers.Conv2DTranspose(filters=input_shape[2], kernel_size=kernel_size[0], strides=1, padding='same', kernel_regularizer='l2')(x)

    model = tf.keras.Model(inputs, image, name="Decoder")
    return model

def get_discriminator(latent_dim, discriminator_units):
    inputs = tf.keras.Input(shape=(latent_dim,))

    x = tf.keras.layers.Dense(discriminator_units[0], activation='relu', kernel_regularizer='l2')(inputs)
    x = tf.keras.layers.Dense(discriminator_units[1], activation='relu', kernel_regularizer='l2')(x)

    vote = tf.keras.layers.Dense(1)(x)

    model = tf.keras.Model(inputs, vote, name="Discriminator")
    return model


    

In [5]:
class AAE(tf.keras.Model):
    def __init__(self, input_shape, latent_dim, leaky_alpha, filters, kernel_size, strides, dense_units, discriminator_units, base_lr, max_lr, step_size, gen_coef, batch_size):
        super(AAE, self).__init__()

        self.batch_size = batch_size
        # Calculate the stride factor of downsampling
        self.stride_reduction = 1
        for i, stride in enumerate(strides):
            self.stride_reduction = self.stride_reduction * stride
        
        # Latent dimension
        self.latent_dim = latent_dim
        # Define losses and accuracies
        self.mse = tf.keras.losses.MeanSquaredError()
        self.cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        self.dc_accuracy = tf.keras.metrics.BinaryAccuracy()
        # Define the learning rates for cyclic learning rate (not used)
        self.base_lr = base_lr
        self.max_lr = max_lr
        self.step_size = step_size
        self.gen_coef = gen_coef

        # Encoder Net
        self.encoder = get_encoder(input_shape, latent_dim, leaky_alpha, filters, kernel_size, strides, dense_units)

        # Decoder Net
        self.decoder = get_decoder(input_shape, latent_dim, leaky_alpha, filters, kernel_size, strides, dense_units, self.stride_reduction)

        # Discriminator Net
        self.discriminator = get_discriminator(latent_dim, discriminator_units)

    def compile(self, ae_opt, dc_opt, gen_opt):
        super(AAE, self).compile()
        # Set optimizers
        self.ae_optimizer = ae_opt
        self.gen_optimizer = gen_opt
        self.dc_optimizer = dc_opt
        # Set loss functions
        self.ae_loss_fn = tf.keras.losses.MeanSquaredError()
        self.binCe_loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
        # Set metrics and accuracies
        self.dc_acc_fn = tf.keras.metrics.BinaryAccuracy(name='dc_accuracy')
        self.ae_metrics = tf.keras.metrics.MeanSquaredError(name='ae_loss')
        self.dc_metrics = tf.keras.metrics.BinaryCrossentropy(from_logits=True, name='dc_loss')
        self.gen_metrics = tf.keras.metrics.BinaryCrossentropy(from_logits=True, name='gen_loss')
        # Compile internal models
        self.encoder.compile()
        self.decoder.compile()
        self.discriminator.compile()

    # Define the metrics
    @property
    def metrics(self):
        return [self.dc_acc_fn, self.ae_metrics, self.dc_metrics, self.gen_metrics]

    # Encoding function
    def encode(self, x, training=False):
        return self.encoder(x, training=training)

    # Decoding function
    def decode(self, z, apply_sigmoid=False, training=False):
        logits = self.decoder(z, training=training)
        if apply_sigmoid:
            probs = tf.sigmoid(logits)
            return probs
        return logits
    
    # Save 10 reconstructions
    def save_10_reconstructions(self, X, epoch):
        fig, ax = plt.subplots(2,5,figsize=(15,5))
        images = X[:10,:,:,:].reshape((10,32,32,3))
        z = self.encode(images)
        image_reconstruction = self.decode(z, apply_sigmoid=True).numpy()
        for i in range(10):
            ax.ravel()[i].imshow(image_reconstruction[i,:])
            ax.ravel()[i].axis(False)
        fig.savefig(f'reconstructed_epoch_{epoch+1:03d}.png',bbox_inches='tight')

    @tf.function
    def test_step(self, batch_x):
        generated_noise = tf.random.normal([self.batch_size, self.latent_dim], mean=0.0, stddev=1.0)
        z_generated = self.encode(batch_x)
        X_reconstructed = self.decode(z_generated, apply_sigmoid=True)
                
        dc_fake = self.discriminator(z_generated)
        dc_real = self.discriminator(generated_noise)

        self.ae_metrics.update_state(batch_x, X_reconstructed)
        self.dc_metrics.update_state(tf.concat([tf.zeros_like(dc_fake), tf.ones_like(dc_real)], axis=0), tf.concat([dc_fake, dc_real], axis=0))
        self.gen_metrics.update_state(tf.ones_like(dc_fake), dc_fake)
        self.dc_acc_fn.update_state(tf.concat([tf.zeros_like(dc_fake), tf.ones_like(dc_real)], axis=0), tf.concat([dc_fake, dc_real], axis=0))
        return {m.name: m.result() for m in self.metrics}


    # Function for the train step
    @tf.function
    def train_step(self, batch_x):
        # Autoencoder training
        with tf.GradientTape() as ae_tape:
            z_generated = self.encode(batch_x, training=True)
            X_reconstructed = self.decode(z_generated, apply_sigmoid=True, training=True)
            ae_loss = self.ae_loss_fn(batch_x, X_reconstructed)

        # Apply the gradients
        ae_grads = ae_tape.gradient(ae_loss, self.encoder.trainable_variables + self.decoder.trainable_variables)
        self.ae_optimizer.apply_gradients(zip(ae_grads, self.encoder.trainable_variables + self.decoder.trainable_variables))

        # Discriminator training with normal prior
        generated_noise = tf.random.normal([self.batch_size, self.latent_dim], mean=0.0, stddev=1.0)
        with tf.GradientTape() as dc_tape:
            encoder_output = self.encode(batch_x, training=False)
            dc_fake = self.discriminator(encoder_output, training=True)
            dc_real = self.discriminator(generated_noise, training=True)

            real_loss = self.binCe_loss_fn(tf.ones_like(dc_real), dc_real)
            fake_loss = self.binCe_loss_fn(tf.zeros_like(dc_fake), dc_fake)
            dc_loss = real_loss + fake_loss

            dc_acc = self.dc_acc_fn(tf.concat([tf.ones_like(dc_real), tf.zeros_like(dc_fake)], axis=0),
                        tf.concat([dc_real, dc_fake], axis=0))

        # Apply the gradients
        dc_grads = dc_tape.gradient(dc_loss, self.discriminator.trainable_variables)
        self.dc_optimizer.apply_gradients(zip(dc_grads, self.discriminator.trainable_variables))

        # Generator training (Encoder)
        with tf.GradientTape() as gen_tape:
            encoder_output = self.encode(batch_x, training=True)
            dc_fake = self.discriminator(encoder_output, training=False)
            gen_loss = self.binCe_loss_fn(tf.ones_like(dc_fake),dc_fake)

        # Apply the gradients
        gen_grads = gen_tape.gradient(gen_loss, self.encoder.trainable_variables)
        self.gen_optimizer.apply_gradients(zip(gen_grads, self.encoder.trainable_variables))

        # Update the metrics
        self.ae_metrics.update_state(batch_x, X_reconstructed)
        self.dc_metrics.update_state(tf.concat([tf.zeros_like(dc_fake), tf.ones_like(dc_real)], axis=0), tf.concat([dc_fake, dc_real], axis=0))
        self.gen_metrics.update_state(tf.ones_like(dc_fake), dc_fake)
        self.dc_acc_fn.update_state(tf.concat([tf.zeros_like(dc_fake), tf.ones_like(dc_real)], axis=0), tf.concat([dc_fake, dc_real], axis=0))
        return {m.name: m.result() for m in self.metrics}
    



In [6]:
base_lr = 0.0001
max_lr = 0.0025
step_size = 2 * np.ceil(train_images.shape[0] / batch_size)
epochs = 350

latent_dim = 192
alpha_leaky = 0.2
filters = [64,128,256,512]
kernel_size = [4,4,3,3]
strides = [2,2,2,2]
dense_units = [1000,300]
discriminator_units = [200, 200]
keep_prob = 0.5
gen_coef = 2.

steps_per_epoch = train_images.shape[0] / batch_size

aae = AAE((length, width, channels), latent_dim, alpha_leaky, filters, kernel_size, strides, dense_units, discriminator_units, base_lr, max_lr, step_size, gen_coef, batch_size)
aae.compile(tf.keras.optimizers.Adam(learning_rate = base_lr), tf.keras.optimizers.Adam(learning_rate=base_lr), tf.keras.optimizers.Adam(learning_rate=base_lr*gen_coef))

ResourceExhaustedError: failed to allocate memory [Op:Mul]

In [None]:
tensorboard_callback = tf.keras.callbacks.TensorBoard('logs', histogram_freq=1)
customAAECallback = CustomAAECallbacks(X=test_images[:10], schedule=lr_schedule, patience=10)

In [None]:
history = aae.fit(x=train_dataset, epochs=epochs, validation_data=test_dataset, callbacks=[tensorboard_callback, customAAECallback])

Epoch 1/350
Epoch 2/350
Epoch 3/350
Epoch 4/350
Epoch 5/350
Epoch 6/350
Epoch 7/350
Epoch 8/350
Epoch 9/350
Epoch 10/350
Epoch 11/350
Epoch 12/350
Epoch 13/350

In [None]:
aae.encoder.save('/content/drive/MyDrive/Blackboxes/enc_model')
aae.decoder.save('/content/drive/MyDrive/Blackboxes/dec_model')
aae.discriminator.save('/content/drive/MyDrive/Blackboxes/dc_model')

In [None]:
aae.encoder.save('enc_model')
aae.decoder.save('dec_model')
aae.discriminator.save('dc_model')

In [None]:
image = train_images[0,:,:,:].reshape((1,32,32,3))
z = aae.encode(image)
image_recon = aae.decode(z, apply_sigmoid=True)
plt.figure()
plt.imshow(image[0,:,:,:])
plt.show()
plt.figure()
plt.imshow(image_recon[0,:,:,:])
plt.show()

image = test_images[10,:,:,:].reshape((1,32,32,3))
z = aae.encode(image)
image_recon = aae.decode(z, apply_sigmoid=True)
plt.figure()
plt.imshow(image[0,:,:,:])
plt.show()
plt.figure()
plt.imshow(image_recon[0,:,:,:])
plt.show()

In [None]:
test_dc_accuracy, test_ae_loss, test_dc_loss, test_gen_loss = aae.evaluate(test_dataset)

In [None]:
z = np.random.randn(1,latent_dim)
image = aae.decode(z, apply_sigmoid=True).numpy()
image = image * 255.
image=image.astype('int16')
image.reshape((32,32,3))
plt.figure()
plt.imshow(image[0,:,:,:])
plt.show()
print(aae.discriminator(z))

In [None]:
def plot_latent(aae, data, labels):
    latent_data = aae.encode(data)
    plt.scatter(latent_data[:,0], latent_data[:,1], c=labels)

In [None]:
plot_latent(aae, train_images[0:1000], train_labels[0:1000].reshape((1000)))