#M.Lytova, M.Spanner, I.Tamblyn. Deep learning and high harmonic generation (2020)
##Codes for Section II.D : Autoencoders and latent space visualization 
###see also (for beta=1) https://github.com/emnajaoua/beta_variational_autoencoders/blob/master/disentangled_vae%20(1).ipynb

##Headers and constants

In [None]:
import numpy as np
import tensorflow as tf
from keras.layers import Input, Dense, Dropout, Lambda
from keras import backend as K
from keras.models import Model
from keras.optimizers import Adam, Nadam
from keras import objectives
from keras.losses import mean_squared_error
from keras.callbacks import TensorBoard
import argparse
import matplotlib.pyplot as plt
import time

In [None]:
PI = 3.14159265359

x_n_points = 512   # number of nodes in input layer
x_n = np.linspace(0, 100, x_n_points) 

# VAE parameters
batch_size = 128
original_dim = x_n_points
input_shape = (original_dim, )
intermediate_dim1 = 128
intermediate_dim2 = 64
intermediate_dim3 = 16
latent_dim = 2
epsilon_std = 1

n_train = batch_size * 4000   
n_test = batch_size *1000

##Training set generation

In [None]:
# Training set
all2_train = np.random.rand(n_train,2)
w0_train = all2_train[:,0]*0.5+0.5    # vector of random frequencies in [0.5, 1]
A0_train = all2_train[:,1]*0.5+0.5    # vector of random amplitudes in [0.5, 1]

y_train = np.zeros((n_train, x_n_points))
for i in range(n_train):
    y_train[i,] = A0_train[i] * (np.sin(w0_train[i]*x_n))

##Testing set generation

In [None]:
# Testing set
all2_test = np.random.rand(n_test,2)
w0_test = all2_test[:,0]*0.5+0.5   # vector of random frequencies in [0.5, 1]
A0_test = all2_test[:,1]*0.5+0.5    # vector of random amplitudes in [0.5, 1]

y_test = np.zeros((n_test, x_n_points))
for i in range(n_test):
    y_test[i,] = A0_test[i] * (np.sin(w0_test[i]*x_n))


##Latent space vectors, encoder and decoder
512 $->$ 128 $->$ 64 $->$ 16 $->$ 2$->$ 16 $->$ 64 $->$ 128 $->$ 512

In [None]:
#Generate the latent representation vectors 
def sampling(args):
    z_mean, z_log_sigma = args
    epsilon = K.random_normal(shape=(batch_size, latent_dim),
                              mean=0., stddev=epsilon_std)
    return z_mean + K.exp(z_log_sigma) * epsilon

In [None]:
# Encoder

inputs = Input(shape = input_shape, name='encoder_input')

x = Dense(intermediate_dim1, activation='tanh')(inputs)
x = Dense(intermediate_dim2, activation='tanh')(x)
x = Dense(intermediate_dim3, activation='tanh')(x)

z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')

print(encoder.summary())

In [None]:
# Decoder

latent_inputs = Input(shape=(latent_dim,), name='z_sampling')

x = Dense(intermediate_dim3, activation='tanh')(latent_inputs)
x = Dense(intermediate_dim2, activation='tanh')(x)
x = Dense(intermediate_dim1, activation='tanh')(x)

outputs = Dense(original_dim, activation='tanh')(x)

decoder = Model(latent_inputs, outputs, name='decoder')

print(decoder.summary())

In [None]:
outputs = decoder(encoder(inputs)[2]) 

vae = Model(inputs, outputs, name='vae_mlp')

In [None]:
reconstruction_loss = mean_squared_error(inputs, outputs)

reconstruction_loss *= original_dim

kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)

kl_loss = K.sum(kl_loss, axis=-1)

kl_loss *= -0.5

vae_loss = K.mean(reconstruction_loss + kl_loss)

vae.add_loss(vae_loss)

opt = Nadam(lr=0.0001)
vae.compile(optimizer=opt)

print(vae.summary())

##Training

In [None]:
tic = time.perf_counter()

history = vae.fit(y_train, 
                epochs=150,
                batch_size=batch_size,
                shuffle=True,
                validation_data=(y_test, None))

toc = time.perf_counter()
print(f"Execution time {toc - tic:0.4f} seconds")

def plot_losses():
    plt.figure(figsize=(8,4))
    plt.plot(np.log10(history.history['loss']),color='blue')
    plt.plot(np.log10(history.history['val_loss']),color='red')
    plt.title('Model loss')
    plt.ylabel('loss')
    plt.xlabel('epoch')
    plt.legend(['train', 'test'], loc='upper right')
    plt.show()

##Training and validation losses

In [None]:
plot_losses()

##Prediction

In [None]:
z_mean, _, _ = encoder.predict(y_test, batch_size=batch_size)
decoded_sin = decoder.predict(z_mean)

## Drawing the latent space

In [None]:
def plot_latent(z0, z1):
    fig1 = plt.subplots(1,2,figsize=(12,4), constrained_layout=False)
    plt.suptitle ('Colorbar wrt $w_0$', fontsize=16)
    plt.subplot(121)
    plt.scatter(z0, z1, c = w0_test, cmap='jet')
    plt.colorbar()
    plt.xlabel('$z_0$', fontsize=14)
    plt.ylabel('$z_1$', fontsize=14)
    plt.title('In Cartesian coordinates', fontsize=14)
    plt.subplot(122)
    plt.scatter(np.sqrt(z0**2+z1**2), np.arctan2(z1, z0), c = w0_test, cmap='jet')
    plt.colorbar()
    plt.xlabel('$r$', fontsize=14)
    plt.ylabel(r'$\theta$', fontsize=14)
    plt.title('In polar coordinates', fontsize=14)

    fig2 = plt.subplots(1,2,figsize=(12,4), constrained_layout=False)
    plt.suptitle ('Colorbar wrt $A_0$', fontsize=16)
    plt.subplot(121)
    plt.scatter(z0, z1, c = A0_test, cmap='viridis')
    plt.colorbar()
    plt.xlabel('$z_0$', fontsize=14)
    plt.ylabel('$z_1$', fontsize=14)
    plt.title('In Cartesian coordinates$', fontsize=14)
    plt.subplot(122)
    plt.scatter(np.sqrt(z0**2+z1**2), np.arctan2(z1, z0), c = A0_test, cmap='viridis')
    plt.colorbar()
    plt.xlabel('$r$', fontsize=14)
    plt.ylabel(r'$\theta$', fontsize=14)
    plt.title('In polar coordinates', fontsize=14)
    plt.show()  
    plt.close()    

In [None]:
plot_latent(z_mean[:, 0], z_mean[:,1])

##Function to draw the test and reconstructed examples

In [None]:
def plot_examples(i1, i2):    
    fig = plt.subplots(2,1,figsize=(12,8),constrained_layout=False)
    plt.suptitle('Examples: Test points and prediction', fontsize=16)
    plt.subplot(211)
    plt.title('w0 = ' + str(round(w0_test[i1],2)) + ",  A = " + str(round(A0_test[i1],2)), fontsize=16)
    plt.scatter(x_n, y_test[i1], color="blue", s = 1)
    plt.plot(x_n, decoded_sin[i1], color="red", linewidth = 1)
    plt.grid()
    plt.subplot(212)
    plt.title('w0 = ' + str(round(w0_test[i2],2)) + ",  A = " + str(round(A0_test[i2],2)), fontsize=16)
    plt.scatter(x_n, y_test[i2], color="blue", s = 1)
    plt.plot(x_n, decoded_sin[i2], color="red", linewidth = 1)  
    plt.xlabel('$t$, fs', fontsize=16)
    plt.grid()  
    plt.show() 
    plt.close()        

## Comparison of arbitrary $y_{test}$ (blue) and $y_{reconstructed}$ (red)

In [None]:
i_show1 = np.random.randint(0, n_test-1)
i_show2 = np.random.randint(0, n_test-1)

plot_examples(i_show1, i_show2)