In [17]:
import numpy as np
import tensorflow as tf


In [18]:
# encoder
encoder_input = tf.keras.layers.Input(shape=(28, 28, 1))

x = encoder_input
conv_layer = tf.keras.layers.Conv2D(32, kernel_size=(3,3), strides=1, padding='same')
x = conv_layer(x)
x = tf.keras.layers.LeakyReLU()(x)

conv_layer = tf.keras.layers.Conv2D(64, kernel_size=(3,3), strides=2, padding='same')
x = conv_layer(x)
x = tf.keras.layers.LeakyReLU()(x)

conv_layer = tf.keras.layers.Conv2D(64, kernel_size=(3,3), strides=2, padding='same')
x = conv_layer(x)
x = tf.keras.layers.LeakyReLU()(x)

conv_layer = tf.keras.layers.Conv2D(128, kernel_size=(3,3), strides=1, padding='same')
x = conv_layer(x)
x = tf.keras.layers.LeakyReLU()(x)

shape_before_flattening = tf.keras.backend.int_shape(x)[1:]
x = tf.keras.layers.Flatten()(x)

encoder_output = tf.keras.layers.Dense(2)(x)

encoder_network = tf.keras.Model(encoder_input, encoder_output)


In [19]:
# encoder
decoder_input = tf.keras.layers.Input(shape=(2,))

x = tf.keras.layers.Dense(np.prod(shape_before_flattening))(decoder_input)
x = tf.keras.layers.Reshape(shape_before_flattening)(x)

conv_t_layer = tf.keras.layers.Conv2DTranspose(128, kernel_size=(3,3), strides=1, padding='same')
x = conv_t_layer(x)
x = tf.keras.layers.LeakyReLU()(x)

conv_t_layer = tf.keras.layers.Conv2DTranspose(64, kernel_size=(3,3), strides=2, padding='same')
x = conv_t_layer(x)
x = tf.keras.layers.LeakyReLU()(x)

conv_t_layer = tf.keras.layers.Conv2DTranspose(64, kernel_size=(3,3), strides=2, padding='same')
x = conv_t_layer(x)
x = tf.keras.layers.LeakyReLU()(x)

conv_t_layer = tf.keras.layers.Conv2DTranspose(1, kernel_size=(3,3), strides=1, padding='same')
x = conv_t_layer(x)
x = tf.keras.layers.Activation('sigmoid')(x)

decoder_output = x

decoder_network = tf.keras.Model(decoder_input, decoder_output)

In [20]:
decoder_network.summary()

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 2)]               0         
_________________________________________________________________
dense_5 (Dense)              (None, 6272)              18816     
_________________________________________________________________
reshape_2 (Reshape)          (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_8 (Conv2DTr (None, 7, 7, 128)         147584    
_________________________________________________________________
leaky_re_lu_18 (LeakyReLU)   (None, 7, 7, 128)         0         
_________________________________________________________________
conv2d_transpose_9 (Conv2DTr (None, 14, 14, 64)        73792     
_________________________________________________________________
leaky_re_lu_19 (LeakyReLU)   (None, 14, 14, 64)        0   

In [21]:
model_input = encoder_input
model_output = decoder_network(encoder_output)

network = tf.keras.Model(model_input, model_output)

def r_loss(y_true, y_pred):
    return tf.keras.backend.mean(tf.keras.backend.square(y_true - y_pred), axis=[1,2])

network.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001), loss=r_loss)
network.summary()

Model: "model_9"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_5 (InputLayer)         [(None, 28, 28, 1)]       0         
_________________________________________________________________
conv2d_8 (Conv2D)            (None, 28, 28, 32)        320       
_________________________________________________________________
leaky_re_lu_14 (LeakyReLU)   (None, 28, 28, 32)        0         
_________________________________________________________________
conv2d_9 (Conv2D)            (None, 14, 14, 64)        18496     
_________________________________________________________________
leaky_re_lu_15 (LeakyReLU)   (None, 14, 14, 64)        0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 7, 7, 64)          36928     
_________________________________________________________________
leaky_re_lu_16 (LeakyReLU)   (None, 7, 7, 64)          0   

In [22]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

X_train = tf.cast(train_images, tf.float32) / 32

In [23]:
history = network.fit(X_train, X_train, batch_size=32, epochs=10)
history.history

Epoch 1/10
Epoch 2/10

KeyboardInterrupt: 