Base model of the architechture


The base model consists of a VQ-VAE and PixelCNN introduced in the Neural Discrete Representation Learning (van den Oord et al, NeurIPS 2017).

The codes in this notebook are mainly adopted from the notebook created by Amélie Royer

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_probability as tfp
import tensorflow as tf
from tensorflow.python.keras.saving.save import load_model
from matplotlib.colors import Normalize
import matplotlib.cm as cm
import random
import os
#os.chdir('C:/Users/ayata/Desktop/Vorlesungen/TNS/VQVAE')


Hyperparameters

In [None]:
# Hyperparameters
NUM_LATENT_K = 20                 # Number of codebook entries
NUM_LATENT_D = 64                 # Dimension of each codebook entries
BETA = 1.0                        # Weight for the commitment loss

INPUT_SHAPE = x_train.shape[1:]
SIZE = None                       # Spatial size of latent embedding
                                  # will be set dynamically in `build_vqvae

VQVAE_BATCH_SIZE = 128            # Batch size for training the VQVAE
VQVAE_NUM_EPOCHS = 20             # Number of epochs
VQVAE_LEARNING_RATE = 3e-4        # Learning rate
VQVAE_LAYERS = [16, 32]           # Number of filters for each layer in the encoder

PIXELCNN_BATCH_SIZE = 128         # Batch size for training the PixelCNN prior
PIXELCNN_NUM_EPOCHS = 10          # Number of epochs
PIXELCNN_LEARNING_RATE = 3e-4     # Learning rate
PIXELCNN_NUM_BLOCKS = 12          # Number of Gated PixelCNN blocks in the architecture
PIXELCNN_NUM_FEATURE_MAPS = 32    # Width of each PixelCNN block

IMG_CLASSIFIER='MNIST'            # Trained classifier used to test the accuracy of the reconstructed images (MNIST/FashionMNIST)
CATEGORY='0-4'                    # Category of numbers to train the model on ('0-9', '0-4', '5-9')


VQ-VAE

In [None]:
class VectorQuantizer(layers.Layer):
    def __init__(self,num_embeddings, embedding_dim, beta=0.25, **kwargs):
        super().__init__(**kwargs)
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = (
            beta  # This parameter is best kept between [0.25, 2] as per the paper.
        )

        # Initialize the embeddings which we will quantize.
        w_init = tf.random_uniform_initializer()
        self.embeddings = tf.Variable(
            initial_value=w_init(
                shape=(self.embedding_dim, self.num_embeddings), dtype="float32"
            ),
            trainable=True,
            name="embeddings_vqvae",
        )

    def call(self, x):
        # Calculate the input shape of the inputs and
        # then flatten the inputs keeping `embedding_dim` intact.
        input_shape = tf.shape(x)
        flattened = tf.reshape(x, [-1, self.embedding_dim])

        # Quantization.
        encoding_indices = self.get_code_indices(flattened)
        encodings = tf.one_hot(encoding_indices, self.num_embeddings)
        quantized = tf.matmul(encodings, self.embeddings, transpose_b=True)
        quantized = tf.reshape(quantized, input_shape)

        # Calculate vector quantization loss and add that to the layer. You can learn more
        # about adding losses to different layers here:
        # https://keras.io/guides/making_new_layers_and_models_via_subclassing/. Check
        # the original paper to get a handle on the formulation of the loss function.
        commitment_loss = self.beta * tf.reduce_mean(
            (tf.stop_gradient(quantized) - x) ** 2
        )
        codebook_loss = tf.reduce_mean((quantized - tf.stop_gradient(x)) ** 2)
        self.add_loss(commitment_loss + codebook_loss)

        # Straight-through estimator.
        quantized = x + tf.stop_gradient(quantized - x)
        return quantized

    def get_code_indices(self, flattened_inputs):
        # Calculate L2-normalized distance between the inputs and the codes.
        similarity = tf.matmul(flattened_inputs, self.embeddings)
        distances = (
            tf.reduce_sum(flattened_inputs ** 2, axis=1, keepdims=True)
            + tf.reduce_sum(self.embeddings ** 2, axis=0)
            - 2 * similarity
        )

        # Derive the indices for minimum distances.
        encoding_indices = tf.argmin(distances, axis=1)
        return encoding_indices


In [None]:
def get_encoder(latent_dim=16):
    encoder_inputs = keras.Input(shape=(28, 28, 1))
    x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(
        encoder_inputs
    )
    x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
    encoder_outputs = layers.Conv2D(latent_dim, 1, padding="same")(x)
    return keras.Model(encoder_inputs, encoder_outputs, name="encoder")


def get_decoder(latent_dim=16):
    latent_inputs = keras.Input(shape=get_encoder().output.shape[1:])
    x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(
        latent_inputs
    )
    x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
    decoder_outputs = layers.Conv2DTranspose(1, 3, padding="same")(x)
    return keras.Model(latent_inputs, decoder_outputs, name="decoder")
    

#Standalone VQVAE model
def get_vqvae(latent_dim=16, num_embeddings=64):
    latent_dim=16
    num_embeddings=64
    vq_layer = VectorQuantizer(num_embeddings, latent_dim, name="vector_quantizer")
    encoder = get_encoder(latent_dim)
    decoder = get_decoder(latent_dim)
    inputs = keras.Input(shape=(28, 28, 1))
    encoder_outputs = encoder(inputs)
    quantized_latents = vq_layer(encoder_outputs)
    reconstructions = decoder(quantized_latents)
    return keras.Model(inputs, reconstructions, name="vq_vae")


get_vqvae().summary()

Wrapping up the training loop inside VQVAETrainer

In [None]:
class VQVAETrainer(keras.models.Model):
    def __init__(self, train_variance, latent_dim=16, num_embeddings=128, **kwargs):
        super(VQVAETrainer, self).__init__(**kwargs)
        self.train_variance = train_variance
        self.latent_dim = latent_dim
        self.num_embeddings = num_embeddings

        self.vqvae = get_vqvae(self.latent_dim, self.num_embeddings)

        self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
        self.reconstruction_loss_tracker = keras.metrics.Mean(
            name="reconstruction_loss"
        )
        self.vq_loss_tracker = keras.metrics.Mean(name="vq_loss")

    @property
    def metrics(self):
        return [
            self.total_loss_tracker,
            self.reconstruction_loss_tracker,
            self.vq_loss_tracker,

        ]

    def train_step(self, x):
        with tf.GradientTape() as tape:
            # Outputs from the VQ-VAE.
            reconstructions = self.vqvae(x)
            # Calculate the losses.
            reconstruction_loss = (
                tf.reduce_mean((x - reconstructions) ** 2) / self.train_variance
            )
            total_loss = reconstruction_loss + sum(self.vqvae.losses)
        
        
        # Backpropagation.
        grads = tape.gradient(total_loss, self.vqvae.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.vqvae.trainable_variables))

        # Loss tracking.
        self.total_loss_tracker.update_state(total_loss)
        self.reconstruction_loss_tracker.update_state(reconstruction_loss)
        self.vq_loss_tracker.update_state(sum(self.vqvae.losses))

    # Log results.
        return {
            "loss": self.total_loss_tracker.result(),
            "reconstruction_loss": self.reconstruction_loss_tracker.result(),
            "vqvae_loss": self.vq_loss_tracker.result(),
        }

In [None]:
#A filter to determain the catergory of numbers to train/test the model on
def data_cetegory(category ,x_data, y_data):
    if (category == '0-9'):
        x= x_data
        y= y_data      
    if (category == '0-4'):
        data_filter = np.where((y_data == 0 ) | (y_data == 1 ) | (y_data == 2 ) | (y_data == 3 ) | (y_data == 4 ) )
        x, y = x_data[data_filter], y_data[data_filter]
    if (category == '5-9'):
        data_filter = np.where((y_data == 5 ) | (y_data == 6 ) | (y_data == 7 ) | (y_data == 8 ) | (y_data == 9 ) )
        x, y = x_train[data_filter], y_data[data_filter]
    
    return x, y

Train 10 VQ-VAEs with fixed but different 10 seeds (To make the plots comparable)

In [None]:
#lists for storing data out of the VQ-VAE

img_classifications=[]
val_loss_all=[]
val_loss_all_FashionMNIST=[]
val_loss_all_outDist=[]

#Initiating 10 different VQ-VAE models with fixed random seeds (The same seed used for the VQVAE with and without quantization)
for x in range (10):    

    # Seed value
    # Using a fixed seed value at each loop
    seed_value= x

    # 1. Set the `PYTHONHASHSEED` environment variable at a fixed value
    os.environ['PYTHONHASHSEED']=str(seed_value)

    # 2. Set the `python` built-in pseudo-random generator at a fixed value
    random.seed(seed_value)

    # 3. Set the `numpy` pseudo-random generator at a fixed value
    np.random.seed(seed_value)

    # 4. Set the `tensorflow` pseudo-random generator at a fixed value
    tf.random.set_seed(seed_value)


    #Load the MNIST dataset
    (x_train, y_train), (x_test, y_test)  = keras.datasets.mnist.load_data()
    
    #filter training data according to selected category
    x_train, y_train = data_cetegory(CATEGORY ,x_train, y_train)
    #filter testing data according to selected category
    x_test, y_test = data_cetegory(CATEGORY ,x_test, y_test)

    #Preprocessing the data
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    x_train_scaled = (x_train / 255.0) - 0.5
    x_test_scaled = (x_test / 255.0) - 0.5
    
    data_variance = np.var(x_train / 255.0)
    test_variance= np. var(x_test / 255.0)
    
    #Compile and train the model
    vqvae_trainer = VQVAETrainer(data_variance, latent_dim=16, num_embeddings=20)
    vqvae_trainer.compile(optimizer=keras.optimizers.Adam())
    history= vqvae_trainer.fit(x_train_scaled, epochs=20, batch_size=128)    
    trained_vqvae_model = vqvae_trainer.vqvae
    
    #Saving of weights of the model
    #trained_vqvae_model.save_weights('trainedVQVAE'+str(x))
    
    #Loading the saved model weights
    trained_vqvae_model.load_weights('trainedVQVAE'+str(x))

    #Load pretrained MINST/FashionMNIST classifiers
    if IMG_CLASSIFIER=='MNIST'
        #pretrained classifier on mnist with accuracy 98%
        classifier=load_model('classifier.h5')
        else
        #pretrained classifier on fashion mnist with accuracy 91%
        classifier=load_model('classifierFashionMNIST.h5')
    
    #Load the fixed 10 patterns of noise images
    noise_range=np.load('10noise_ranges'+str(x)+'.npy')

    #add noise gradually from 0% to 100% to the test images
    for i in range (100):   
        test_images = x_test_scaled + noise_range[i]
        #predict the output of the VQ-VAE using the moisy image
        reconstructions_test = trained_vqvae_model.predict(test_images)
        
        #an option to precidt the classification of the VQ-VAE's output
        #img_classification= classifier.predict(reconstructions_test)
        
        #an option to collect the classifications
        #img_classifications+=[np.argmax(img_classification, axis=1)]
        
        #an option calculate the classification accuracy of the noisy image against its ground truth test image
        #val_loss_noisy=classifier.evaluate(test_images,y_test, verbose=0)
        
        #calculate the classification accuracy of the output image against its ground truth test image
        val_loss_noisy=classifier.evaluate(reconstructions_test,y_test, verbose=0)
        
        #collect the classifications accuracies
        val_loss_all+=[val_loss_noisy[1]]
        
        #(optional) plot the reconstructions at noise levels 0%, 20%, 40%, 60%, 80% during testing
        if (i%20==0):
            fig, axs= plt.subplots(5, 1, figsize=(4, 20))
            plt.subplot(5, 1, ((i/20)+1)
            plt.imshow(reconstructions_test3[0].squeeze() + 0.5, cmap='gray')
            plt.axis("off")

            plt.subplots_adjust(wspace=0, hspace=0.03)
            plt.savefig('59digitVQVAE-09digitTest_fixedSeed_noisePattern'+str(x)+'.png',bbox_inches='tight', dpi=300, linewidth=20, edgecolor='#DFB920')
            plt.show()


#collect the classification accuracy of the VQ-VAEs output images agains the test images
np.save('VQVAE_accuracies.npy', np.array(val_loss_all))

#an option to collect the classification accuracy of the noisy images agains the test images
#np.save('noisyImg_accuracies', np.array(val_loss_all))

#an option to collect the discrete classifications of the VQ-VAE's output
#np.save('imgClassifications.npy', np.array(img_classifications))


In [None]:
#plot test images vs. reconstructed images (output images)
zipped=zip(test_images, reconstructions_test)
for idx, x in enumerate(zipped):
    show_subplot(test_images[idx], reconstructions_test[idx], idx)
    

In [None]:
#plot reconstructed images vs. input images
def show_subplot(original, reconstructed, idx):
    fig, axs= plt.subplots(1, 2, figsize=(8, 4), gridspec_kw={'width_ratios': [1, 1]})
    plt.subplot(1, 2, 1)
    plt.imshow(original.squeeze() + 0.5, cmap='gray')
    plt.title("Original")
    plt.axis("off")

    valuesAll=original
    colorsAll=valuesAll.flatten()
    normAll = Normalize()
    normAll.autoscale(colorsAll)
    zm = cm.ScalarMappable(cmap='gray')
    zm.set_array([])

    plt.subplot(1, 2, 2)
    plt.imshow(reconstructed.squeeze() + 0.5, cmap='gray')
    plt.title("Reconstructed")
    plt.axis("off")

    plt.savefig('OutOfSample_FashionMNIST_ReconstructedImg'+str(idx)+'.png',bbox_inches='tight')
    plt.show()

Figure 2 of the paper

In [None]:
a=np.load('VQVAE_MNIST_accuracies.npy')
b=np.load('autoencoder_MNIST_accuracies.npy')
c=np.load('noisytest_MNIST_accuracies.npy')

#aevrage the results of 10 runs
a_mean=np.mean(np.reshape(a, (-1,100)), axis=0)
b_mean=np.mean(np.reshape(b, (-1,100)), axis=0)
c_mean=np.mean(np.reshape(c, (-1,100)), axis=0)

#calculate the standard deviation of the data to plot the variance
a_std=np.std(np.reshape(a,(-1,100)), axis=0, ddof=1)
b_std=np.std(np.reshape(b,(-1,100)), axis=0, ddof=1)
c_std=np.std(np.reshape(c, (-1,100)), axis=0, ddof=1)

x=np.arange(0, 100, 1)

fig = plt.figure( dpi=300)
plt.rcParams["font.family"] = "Arial"
plt.errorbar(x,a_mean,  label= 'VQ-VAE output (with quantization)',color='#3F81C0')
plt.fill_between(x, a_mean-a_std, a_mean+a_std, alpha=0.3, facecolor='#3F81C0')
plt.errorbar(x,b_mean,  label= 'VQ-VAE output (without quantization)',color='#7AA35C')
plt.fill_between(x, b_mean-b_std, b_mean+b_std, alpha=0.3, facecolor='#7AA35C')
plt.errorbar(x,c_mean,  label= 'MNIST 0-9 input images', linestyle='--',color='#3F81C0')
plt.fill_between(x, c_mean-c_std, c_mean+c_std, alpha=0.08, facecolor='#3F81C0')

plt.hlines(0.1,0,100, linestyle='dotted', alpha=0.7, color='#3F81C0', label='MNIST chance level')

plt.ylim(0,1)
x = np.array([0,20,40,60,80,100])
y = np.array([0, 0.2, 0.4, 0.6, 0.8, 1])
ticks = ['0%', '20%', '40%', '60%', '80%', '100%']

plt.xticks(x, ticks, fontsize=14)
plt.yticks(y, ticks, fontsize=14)
plt.ylabel('Classification accuracy', fontsize=16)
plt.xlabel('Noise level', fontsize=16)

#order legend according to the plotted lines
matplotlib.colors.to_rgb
handles, labels = plt.gca().get_legend_handles_labels()
order = [1,2,3,0]
plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order],loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, shadow=False, ncol=1, fontsize=16)
fig.savefig('fig2.tiff',bbox_inches='tight', dpi=300)
plt.show()   

Figure 3 of the paper

In [None]:
    #04 fashion model - 04 test digits
    a=np.load('halfMNIST_inSample.npy')
    b=np.load('halfMNIST_outOfSample.npy')
    c=np.load('halfMNIST_outOfDist.npy')
    f=np.load('halfMNIST_noisyTest.npy')

    #04 fashion model - 59 test digits
    g=np.load('halfMNIST59_inSample.npy')    
    h=np.load('halfMNIST04_outOfSample.npy')    
    k=np.load('halfMNISTfashion04_outOfDist.npy')    
    l=np.load('halfMNIST59_noisyTest.npy')
    
    a_concatenated=np.concatenate((a,g), axis=None)
    b_concatenated=np.concatenate((b,h), axis=None)
    c_concatenated=np.concatenate((c,k), axis=None)
    f_concatenated=np.concatenate((f,l), axis=None)
    
    #aevrage the results of 10 runs
    a_mean=np.mean(np.reshape(a_concatenated, (-1,100)), axis=0)
    b_mean=np.mean(np.reshape(b_concatenated, (-1,100)), axis=0)
    c_mean=np.mean(np.reshape(c_concatenated, (-1,100)), axis=0)
    f_mean=np.mean(np.reshape(f_concatenated, (-1,100)), axis=0)
    
    #calculate the standard deviation of the data to plot the variance
    a_std=np.std(np.reshape(a_concatenated,(-1,100)), axis=0, ddof=1)
    b_std=np.std(np.reshape(b_concatenated,(-1,100)), axis=0, ddof=1)
    c_std=np.std(np.reshape(c_concatenated, (-1,100)), axis=0, ddof=1)
    f_std=np.std(np.reshape(f_concatenated, (-1,100)), axis=0, ddof=1)
    
    x=np.arange(0, 100, 1)
    fig = plt.figure( dpi=300)
    plt.rcParams["font.family"] = "Arial"    
    plt.errorbar(x,f_mean,  label= 'Digits 0-4  and 5-9 input images', linestyle='--',color='#3F81C0')
    plt.fill_between(x, f_mean-f_std, f_mean+f_std, alpha=0.08, facecolor='#3F81C0')
    
    plt.hlines(0.1,0,100, linestyle='dotted', alpha=0.7, color='#3F81C0', label='Digits chance level')
    plt.hlines(0.01, 0,100, linestyle='None', alpha=0.7,color='white', label=' ')
    
    plt.errorbar(x,a_mean,  label= 'In-sample digit VQ-VAE',color='#3F81C0')
    plt.fill_between(x, a_mean-a_std, a_mean+a_std, alpha=0.25, facecolor='#3F81C0')
    plt.errorbar(x,b_mean,  label= 'Out-of-sample digit VQ-VAE',color='#7AA35C')
    plt.fill_between(x, b_mean-b_std, b_mean+b_std, alpha=0.25, facecolor='#7AA35C')
    plt.errorbar(x,c_mean,  label= 'Out-of-distribution fashion VQ-VAE',zorder=3,color='#DFB920')
    plt.fill_between(x, c_mean-c_std, c_mean+c_std, alpha=0.25, facecolor='#DFB920')

    plt.ylim(0,1)
    x = np.array([0, 20, 40, 60, 80, 100])
    y = np.array([0, 0.2, 0.4, 0.6, 0.8, 1])
    ticks = ['0%', '20%', '40%', '60%', '80%', '100%']
    plt.xticks(x, ticks, fontsize=14)
    plt.yticks(y, ticks, fontsize=14)
    plt.ylabel('Classification accuracy', fontsize=16)
    plt.xlabel('Noise level', fontsize=16)

    #order legend according to the plotted lines
    handles, labels = plt.gca().get_legend_handles_labels()
    order = [2,0,1,3,4,5]
    plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order],loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, shadow=False, ncol=2)#, fontsize=14)
    matplotlib.colors.to_rgb
    fig.savefig('fig3.tiff',bbox_inches='tight', dpi=300)
    plt.show()

Figure 4 of the paper

In [None]:
    #04 digit model - 04 fashion test
    a=np.load('halfFashionMNIST_inSample.npy')
    b=np.load('halfFashionMNIST_outOfSample.npy')
    c=np.load('halfFashionMNIST_outOfDist.npy')
    e=np.load('halfFashionMNIST04_noisyTest.npy')

    #04 digit model - 59 fashion model
    h=np.load('halfFashionMNIST59_inSample.npy')
    g=np.load('halfFashionMNIST04_outOfSample.npy')
    k=np.load('halfMNIST04_outOfDist.npy')
    l=np.load('halfFashionMNIST59_noisyTest.npy')

    a_concatenated=np.concatenate((a,g), axis=None)
    b_concatenated=np.concatenate((b,h), axis=None)
    c_concatenated=np.concatenate((c,k), axis=None)
    e_concatenated=np.concatenate((e,l), axis=None)
    
    a_mean=np.mean(np.reshape(a_concatenated, (-1,100)), axis=0)
    b_mean=np.mean(np.reshape(b_concatenated, (-1,100)), axis=0)
    c_mean=np.mean(np.reshape(c_concatenated, (-1,100)), axis=0)
    e_mean=np.mean(np.reshape(e_concatenated, (-1,100)), axis=0)
    
    a_std=np.std(np.reshape(a_concatenated,(-1,100)), axis=0, ddof=1)
    b_std=np.std(np.reshape(b_concatenated,(-1,100)), axis=0, ddof=1)
    c_std=np.std(np.reshape(c_concatenated, (-1,100)), axis=0, ddof=1)
    e_std=np.std(np.reshape(e_concatenated, (-1,100)), axis=0, ddof=1)

    x=np.arange(0, 100, 1)
    fig = plt.figure( dpi=300)
 
    plt.rcParams["font.family"] = "Arial"
    
    plt.errorbar(x,e_mean,  label= 'Fashion 0-4 and 5-9 input images', linestyle='--',color='#3F81C0')
    plt.fill_between(x, e_mean-e_std, e_mean+e_std, alpha=0.08, facecolor='#3F81C0')

    plt.hlines(0.1,0,100, linestyle='dotted', alpha=0.7, color='#3F81C0', label='Fashion chance level')
    plt.hlines(0.01, 0,100, linestyle='None', alpha=0.7,color='white', label=' ')
    
    plt.errorbar(x,a_mean,  label= 'In-sample fashion VQ-VAE',color='#3F81C0')
    plt.fill_between(x, a_mean-a_std, a_mean+a_std, alpha=0.25, facecolor='#3F81C0')
    plt.errorbar(x,b_mean,  label= 'Out-of-sample fashion VQ-VAE',color='#7AA35C')
    plt.fill_between(x, b_mean-b_std, b_mean+b_std, alpha=0.25, facecolor='#7AA35C')
    plt.errorbar(x,c_mean,  label= 'Out-of-distribution digit VQ-VAE',zorder=3,color='#DFB920')
    plt.fill_between(x, c_mean-c_std, c_mean+c_std, alpha=0.25, facecolor='#DFB920')

    plt.ylim(0,1)
    x = np.array([0, 20, 40, 60, 80, 100])
    y = np.array([0, 0.2, 0.4, 0.6, 0.8, 1])
    ticks = ['0%', '20%', '40%', '60%', '80%', '100%']
    plt.xticks(x, ticks)
    plt.yticks(y, ticks)
    plt.ylabel('Classification accuracy')
    plt.xlabel('Noise level')
 
    #order legend according to the plotted lines
    handles, labels = plt.gca().get_legend_handles_labels()
    order = [2,0,1,3,4,5]
    plt.legend([handles[idx] for idx in order],[labels[idx] for idx in order],loc='upper center', bbox_to_anchor=(0.5, -0.15), fancybox=True, shadow=False, ncol=2)
    matplotlib.colors.to_rgb
    fig.savefig('fig4.tiff',bbox_inches='tight', dpi=300)
    plt.show()
    