# VAE

In [None]:
import os
os.environ['KERAS_BACKEND'] = 'torch'

import keras
from keras import layers

import numpy as np

from modules.models    import VAE
from modules.layers    import SamplingLayer
from modules.callbacks import ImagesCallback
from modules.datagen   import MNIST


import matplotlib.pyplot as plt
import scipy.stats
import sys

import fidle

# Init Fidle environment
run_id, run_dir, datasets_dir = fidle.init('K3VAE2')

VAE.about()

In [None]:
latent_dim    = 6
loss_weights  = [1,.06]

scale         = .2
seed          = 123

batch_size    = 64
epochs        = 4
fit_verbosity = 1

### Encoder

In [None]:
inputs    = keras.Input(shape=(28, 28, 1))
x         = layers.Conv2D(32, 3, strides=1, padding="same", activation="relu")(inputs)
x         = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x         = layers.Conv2D(64, 3, strides=2, padding="same", activation="relu")(x)
x         = layers.Conv2D(64, 3, strides=1, padding="same", activation="relu")(x)
x         = layers.Flatten()(x)
x         = layers.Dense(16, activation="relu")(x)

z_mean    = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z         = SamplingLayer()([z_mean, z_log_var])

encoder = keras.Model(inputs, [z_mean, z_log_var, z], name="encoder")
encoder.compile()




### Decoder

In [None]:
inputs  = keras.Input(shape=(latent_dim,))
x       = layers.Dense(7 * 7 * 64, activation="relu")(inputs)
x       = layers.Reshape((7, 7, 64))(x)
x       = layers.Conv2DTranspose(64, 3, strides=1, padding="same", activation="relu")(x)
x       = layers.Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
x       = layers.Conv2DTranspose(32, 3, strides=2, padding="same", activation="relu")(x)
outputs = layers.Conv2DTranspose(1,  3, padding="same", activation="sigmoid")(x)

decoder = keras.Model(inputs, outputs, name="decoder")
decoder.compile()

### VAE

VAE is a custom model with a specific train_step - See : VAE.py

In [None]:
vae = VAE(encoder, decoder, loss_weights)

vae.compile(optimizer='adam')