<a href="https://colab.research.google.com/github/MaschinenNah/MachineLearningKursCdV/blob/main/AutoEncoderStarter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import os

resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])

tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
print("All devices: ", tf.config.list_logical_devices('TPU'))

strategy = tf.distribute.TPUStrategy(resolver)

KeyError: ignored

In [None]:
import numpy as np

from keras.layers import Input, Conv2D, Flatten, Dense, Conv2DTranspose, Reshape, Lambda, Activation, BatchNormalization, LeakyReLU, Dropout
from keras.models import Model
from keras import backend as K
from keras.optimizers import Adam
from keras.datasets import mnist

import matplotlib.pyplot as plt

(x_train, _), (x_test, _) = mnist.load_data()

x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255

x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

N_PARAMS = 20

In [None]:
encoder_input = Input(shape=(28,28,1))

x = Conv2D(filters = 32, 
           kernel_size = (3,3), 
           strides = 1, 
           padding = 'same')(encoder_input)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Dropout(rate = 0.25)(x)

x = Conv2D(filters = 64, 
           kernel_size = (3,3), 
           strides = 2, 
           padding = 'same')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Dropout(rate = 0.25)(x)

x = Conv2D(filters = 64, 
           kernel_size = (3,3), 
           strides = 2, 
           padding = 'same')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Dropout(rate = 0.25)(x)

x = Conv2D(filters = 64, 
           kernel_size = (3,3), 
           strides = 1, 
           padding = 'same')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Dropout(rate = 0.25)(x)

shape_before_flatten = K.int_shape(x)[1:]  

x = Flatten()(x)
encoder_output = Dense(N_PARAMS)(x)

encoder = Model(encoder_input, encoder_output)

encoder.summary()

In [None]:
decoder_input = Input((N_PARAMS,))

x = Dense(np.prod(shape_before_flatten))(decoder_input)
x = Reshape(shape_before_flatten)(x)

x = Conv2DTranspose(filters = 64, 
                    kernel_size = (3, 3), 
                    strides = 1, 
                    padding = 'same')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Dropout(rate = 0.25)(x)

x = Conv2DTranspose(filters = 64, 
                    kernel_size = (3, 3), 
                    strides = 2, 
                    padding = 'same')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Dropout(rate = 0.25)(x)

x = Conv2DTranspose(filters = 32, 
                    kernel_size = (3, 3), 
                    strides = 2, 
                    padding = 'same')(x)
x = LeakyReLU()(x)
x = BatchNormalization()(x)
x = Dropout(rate = 0.25)(x)

x = Conv2DTranspose(filters = 1, 
                    kernel_size = (3, 3), 
                    strides = 1, 
                    padding = 'same')(x)
x = Activation('sigmoid')(x)

decoder_output = x

decoder = Model(decoder_input, decoder_output)

decoder.summary()

In [None]:
model_input = encoder_input
model_output = decoder(encoder_output)
model = Model(model_input, model_output)
model.summary()

In [None]:
optimizer = Adam(lr=0.0005)

def r_loss(y_true, y_pred):
    return K.mean(K.square(y_true - y_pred), axis = [1,2,3])



model.compile(loss=r_loss, optimizer=optimizer)

In [None]:
model.fit(x_train[:2000], 
          x_train[:2000],
          batch_size = 64,
          shuffle = True,
          epochs = 20,
          validation_data = (x_test[:2000], x_test[:2000]))

In [None]:
model.evaluate(x_test, x_test)

In [None]:
n_to_show = 10
example_idx = np.random.choice(range(len(x_test)), n_to_show)
example_images = x_test[example_idx]

z_points = encoder.predict(example_images)
#print(z_points)

reconst_images = decoder.predict(z_points)

fig = plt.figure(figsize=(15, 3))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

for i in range(n_to_show):
    img = example_images[i].squeeze()
    ax = fig.add_subplot(2, n_to_show, i+1)
    ax.axis('off')
    #ax.text(0.5, -0.35, str(np.round(z_points[i],1)), fontsize=10, ha='center', transform=ax.transAxes)   
    ax.imshow(img, cmap='gray_r')

for i in range(n_to_show):
    img = reconst_images[i].squeeze()
    ax = fig.add_subplot(2, n_to_show, i+n_to_show+1)
    ax.axis('off')
    ax.imshow(img, cmap='gray_r')
