In [3]:
import os
import pickle
from tensorflow.keras.losses import MSE
from tensorflow.keras import Model
from tensorflow.keras.layers import Input, Conv2D, ReLU, BatchNormalization,Flatten, Dense, Reshape, Conv2DTranspose, Activation, Lambda
from tensorflow.keras import backend as K
from tensorflow.keras.optimizers import Adam
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
import pandas as pd



"""for 3D we will use Conv3D"""

#https://www.youtube.com/watch?v=A6mdOEPGM1E

#GOAL:
#Extract 2D slices and use the VAE
#Then tur the VAE into a 3d VAE and use the entire 3D information

data = np.load('3D_dataset.npy')
#The data contains the density field obtained by applying the Zel'dovich
#approximation on a 3D grid of particles and then interpolating the 
#particle distribution to a density field. 
#There are 32x32x32. Array shape is 8000,32,32,32. 
 
"""Autoencoder-the plot isn't symmetricalaround origin, how do we then
samplea point for generation? 
Some labels are represented over small areas, other over large ones,
so we have a lack of diversity. There are also gaps between coloured points,
so some generated images will be poor. 
Thats why we use VAE.
"""

#data disabling eager execution. The implementation that we will be given for this VAE woesnt
#work foreager execution- its a programming enviroment that tensorflow has that calculate and evaluate operation is possible
#before the graph is really built

tf.compat.v1.disable_eager_execution()


class VAE:
    """
    VAE represents a Deep Convolutional variational autoencoder architecture
    with mirrored encoder and decoder components.
    
    When training the model, we need to be able to calculate the relationship of each parameter in the network
    with respect to the final output loss using a technique known as backpropagation. But in this case
    use the clever trick ''reparameterization trick'' 
    """
    
    def __init__(self,
                 input_shape,
 
                 conv_filters,#tuble/lists, each item represent number of filters for each layer
                 conv_kernels,
                 conv_strides, #strides
                 latent_space_dim):
        self.input_shape = input_shape # [28, 28, 1] #assign all of these argument to instant attributes
        self.conv_filters = conv_filters # [2, 4, 8]
        self.conv_kernels = conv_kernels # [3, 5, 3]
        self.conv_strides = conv_strides # [1, 2, 2]
        self.latent_space_dim = latent_space_dim # 2
        self.reconstruction_loss_weight = 1000

        self.encoder = None
        self.decoder = None
        self.model = None

        self._num_conv_layers = len(conv_filters)
        self._shape_before_bottleneck = None
        self._model_input = None

        self._build()

    def summary(self):
        self.encoder.summary()
        self.decoder.summary()
        self.model.summary()
#####################################################################################################
    def compile(self, learning_rate=0.0001):
        optimizer = Adam(learning_rate=learning_rate)
        self.model.compile(optimizer=optimizer,
                           loss=self._calculate_combined_loss,
                           metrics=[self._calculate_reconstruction_loss,
                                    self._calculate_kl_loss]) #metrics are for so we can see the 
                        # losses  in the meantime as we are training
                            
    def train(self, x_train, batch_size, num_epochs):
        history=self.model.fit(x_train,
                       x_train,validation_split=0.3,
                       batch_size=batch_size,
                       epochs=num_epochs,
                       shuffle=True)
        golden_size = lambda width: (width, 2. * width / (1 + np.sqrt(5)))

        fig, ax = plt.subplots(figsize=golden_size(7))
        
        hist_dff = np.log((pd.DataFrame(history.history)))
        hist_df=hist_dff.loc[:,['loss','val_loss']]

        hist_df.plot(hist_df)
        
        print(hist_dff,hist_df)
        
        ax.set_ylabel('NELBO')
        ax.set_xlabel('Number of epochs')
        
        ax.set_ylim(((.99*hist_df[1:].values.min())), 
                    (1.1*hist_df[1:].values.max()))
        plt.show()
        
####################################################################################################
    def save(self, save_folder="."):
        self._create_folder_if_it_doesnt_exist(save_folder)
        self._save_parameters(save_folder)
        self._save_weights(save_folder)

    def load_weights(self, weights_path):
        self.model.load_weights(weights_path)

    def reconstruct(self, images):
        latent_representations = self.encoder.predict(images)
        reconstructed_images = self.decoder.predict(latent_representations)
        return reconstructed_images, latent_representations

    @classmethod
    def load(cls, save_folder="."):
        parameters_path = os.path.join(save_folder, "parameters.pkl")
        with open(parameters_path, "rb") as f:
            parameters = pickle.load(f)
        autoencoder = VAE(*parameters) #*parameters stands for the lists of parameters as positional arguments
        weights_path = os.path.join(save_folder, "weights.h5")
        autoencoder.load_weights(weights_path)
        return autoencoder

    def _calculate_combined_loss(self, y_target, y_predicted): #custom implementation of the loss function. 
        reconstruction_loss = self._calculate_reconstruction_loss(y_target, y_predicted)
        kl_loss = self._calculate_kl_loss(y_target, y_predicted)
        combined_loss = self.reconstruction_loss_weight * reconstruction_loss\
                                                         + kl_loss
        return combined_loss

    def _calculate_reconstruction_loss(self, y_target, y_predicted):
        error = y_target - y_predicted
        reconstruction_loss = tf.reduce_mean(tf.reduce_sum((K.mean(K.square(error), axis=[1, 2,3]))))
        return reconstruction_loss

    def _calculate_kl_loss(self, y_target, y_predicted):#In this case we want the difference between
        #our Gaussian distribution from the standaard multivariate normal distribution. We use this 'distance' as a loss
        #because we want to pull our Gaussian ditr towards the standard Gaussian Distr. 
        kl_loss = tf.reduce_mean((-0.5*K.sum(1 + self.log_variance - tf.square(self.mu) -tf.exp(self.log_variance), axis=1))) #which axis we want to sum 
        return kl_loss 

    def _create_folder_if_it_doesnt_exist(self, folder):
        if not os.path.exists(folder):
            os.makedirs(folder)

    def _save_parameters(self, save_folder):
        parameters = [
            self.input_shape,
            self.conv_filters,
            self.conv_kernels,
            self.conv_strides,
            self.latent_space_dim
        ]
        save_path = os.path.join(save_folder, "parameters.pkl")
        with open(save_path, "wb") as f: #writing mode, binary file and f is file
            pickle.dump(parameters, f)

    def _save_weights(self, save_folder):
        save_path = os.path.join(save_folder, "weights.h5") #h5 is a frmat that comes with keras.
        #Used for storing weights using the keras API.
        self.model.save_weights(save_path)
 
    def _build(self): #architecture of the model
        self._build_encoder()
        self._build_decoder()
        self._build_autoencoder()

    def _build_autoencoder(self):
        model_input = self._model_input
        model_output = self.decoder(self.encoder(model_input))
        self.model = Model(model_input, model_output, name="autoencoder")
############################################################################################################################3
    def _build_decoder(self):
        decoder_input = self._add_decoder_input()
        dense_layer = self._add_dense_layer(decoder_input)
        reshape_layer = self._add_reshape_layer(dense_layer)
        conv_transpose_layers = self._add_conv_transpose_layers(reshape_layer)
        decoder_output = self._add_decoder_output(conv_transpose_layers)
        self.decoder = Model(decoder_input, decoder_output, name="decoder")

    def _add_decoder_input(self):
        return Input(shape=self.latent_space_dim, name="decoder_input")

    def _add_dense_layer(self, decoder_input):
        num_neurons = np.prod(self._shape_before_bottleneck) # [1, 2, 4] -> 8
        dense_layer = Dense(num_neurons, name="decoder_dense")(decoder_input)
        return dense_layer

    def _add_reshape_layer(self, dense_layer):
        return Reshape(self._shape_before_bottleneck)(dense_layer)

    def _add_conv_transpose_layers(self, x):
        """Add conv transpose blocks."""
        # loop through all the conv layers in reverse order and stop at the
        # first layer
        for layer_index in reversed(range(1, self._num_conv_layers)):
            x = self._add_conv_transpose_layer(layer_index, x)
        return x

    def _add_conv_transpose_layer(self, layer_index, x):
        layer_num = self._num_conv_layers - layer_index
        conv_transpose_layer = Conv2DTranspose(
            filters=self.conv_filters[layer_index],
            kernel_size=self.conv_kernels[layer_index],
            strides=self.conv_strides[layer_index],
            padding="same",
            name=f"decoder_conv_transpose_layer_{layer_num}"
        )
        x = conv_transpose_layer(x)
        x = ReLU(name=f"decoder_relu_{layer_num}")(x)
        x = BatchNormalization(name=f"decoder_bn_{layer_num}")(x)
        return x
#
    def _add_decoder_output(self, x):
        conv_transpose_layer = Conv2DTranspose(
            filters=1, #channel 1 because we are working with greyscale
            kernel_size=self.conv_kernels[0], #
            strides=self.conv_strides[0], #all the data that we pass for 1:st conv layer in terms of kernels and strides. 
            padding="same",#calculates and adds the padding required to the input image to ensure that the putput has the same shape as the input
            name=f"decoder_conv_transpose_layer_{self._num_conv_layers}"
        )
        x = conv_transpose_layer(x)
        output_layer = Activation("sigmoid", name="linear_layer")(x)

        return output_layer
###################################################################################################################################
    def _build_encoder(self):
        encoder_input = self._add_encoder_input()
        conv_layers = self._add_conv_layers(encoder_input)
        bottleneck = self._add_bottleneck(conv_layers)
        self._model_input = encoder_input
        self.encoder = Model(encoder_input, bottleneck, name="encoder")

    def _add_encoder_input(self):
        return Input(shape=self.input_shape, name="encoder_input")

    def _add_conv_layers(self, encoder_input):
        "go through all layers and add to the graph of layers each conv layer"
        """Create all convolutional blocks in encoder."""
        x = encoder_input
        for layer_index in range(self._num_conv_layers):
            x = self._add_conv_layer(layer_index, x)
        return x

    def _add_conv_layer(self, layer_index, x):
        """Add a convolutional block to a graph of layers, consisting of
        conv 2d + ReLU + batch normalization.
        """
        layer_number = layer_index + 1
        conv_layer = Conv2D(
            filters=self.conv_filters[layer_index],
            kernel_size=self.conv_kernels[layer_index],
            strides=self.conv_strides[layer_index],
            padding="same",
            name=f"encoder_conv_layer_{layer_number}"
        )
        x = conv_layer(x)
        x = ReLU(name=f"encoder_relu_{layer_number}")(x)
        x = BatchNormalization(name=f"encoder_bn_{layer_number}")(x)
        return x

    def _add_bottleneck(self, x): #VANILLA AUTOENCODER
        """Flatten data and add bottleneck with Guassian sampling (Dense
        layer).
        """
        self._shape_before_bottleneck = K.int_shape(x)[1:]
        x = Flatten()(x)
        self.mu = Dense(self.latent_space_dim, name="mu")(x) #attribute calles self.mu, apply this new dense layer to the graph.
        self.log_variance = Dense(self.latent_space_dim,
                                  name="log_variance")(x)

        def sample_point_from_normal_distribution(args): #data sampling a datapoint from our gaussian dist that is parametrised through log variance 
            #and mu. Wrap that funciton within our graph. Keras have a ''lamda layer'' for that.
            mu, log_variance = args #We know that we are passing mu and log through lamda layer.
            epsilon = K.random_normal(shape=K.shape(self.mu), mean=0.,
                                      stddev=1.)  #Epsilon if a sampled point from a standard normal dist. Applying mu and zigma
            # will give us a point that is sampled from our gaussian distribution defined by mu and log_var. 
             
            sampled_point = mu + K.exp(log_variance / 2) * epsilon
            return sampled_point 

        x = Lambda(sample_point_from_normal_distribution,
                   name="encoder_output")([self.mu, self.log_variance])
        return x
###############################################################################################################################################################
"""
The statements executed by the top-level invocation (anropande)  of the inerpreter, either read from 
a script file or interactively, are considered part of a model called main
"""
if __name__ == "__main__":
    vae = VAE(
        input_shape=(32,32,1), 
        conv_filters=(32, 64,64,64), #filters are in same dim as input with same nr. of channels, but fewer rows and columns
        conv_kernels=(3,3,3,3),
        conv_strides=(1,2,2,1), #The amount of movement between applications of the filter to the input image.Default in 2D is (1,1) for the height and the width movement. 
        latent_space_dim=2
    )
    vae.summary()

Model: "encoder"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
encoder_input (InputLayer)      [(None, 32, 32, 1)]  0                                            
__________________________________________________________________________________________________
encoder_conv_layer_1 (Conv2D)   (None, 32, 32, 32)   320         encoder_input[0][0]              
__________________________________________________________________________________________________
encoder_relu_1 (ReLU)           (None, 32, 32, 32)   0           encoder_conv_layer_1[0][0]       
__________________________________________________________________________________________________
encoder_bn_1 (BatchNormalizatio (None, 32, 32, 32)   128         encoder_relu_1[0][0]             
____________________________________________________________________________________________