In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, losses
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras.models import Model
import seaborn as sns

print(tf.__version__)

device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  raise SystemError('GPU device not found')
print('Found GPU at: {}'.format(device_name))


In [None]:
#%% Load data
DIM_LATENT = 25
EPOCHS = 100

(x_train, _), (x_test, _) = fashion_mnist.load_data()

x_train = np.pad(x_train, ((0,0),(2,2),(2,2)), 'constant')
x_test = np.pad(x_test, ((0,0),(2,2),(2,2)), 'constant')
x_train = x_train[..., tf.newaxis]
x_test = x_test[..., tf.newaxis]


x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.


In [None]:
#%% define model
class AutoEncoder(Model):
  def __init__(self, dim_latent, channels):
    super(AutoEncoder, self).__init__()

    self.encoder = tf.keras.Sequential([
      layers.Input(shape=(32, 32, channels)),
      layers.Conv2D(32, 4, padding='same', strides=2),
      layers.BatchNormalization(axis=-1),
      layers.Activation("relu"),
      layers.Conv2D(64, 4, padding='same', strides=2),
      layers.BatchNormalization(axis=-1),
      layers.Activation("relu"),
      layers.Conv2D(128, 4, padding='same', strides=2),
      layers.BatchNormalization(axis=-1),
      layers.Activation("relu"),
      layers.Flatten(),
      layers.Dense(256),
      layers.Activation("relu"),
      layers.Dense(dim_latent)])


    self.decoder = tf.keras.Sequential([
      layers.Dense(256*4*4),
      layers.Activation("relu"),
      layers.Reshape((4,4,256)),
      layers.Conv2DTranspose(128, 4, padding='same', strides=2),
      layers.BatchNormalization(axis=-1),
      layers.Activation("relu"),
      layers.Conv2DTranspose(64, 4, padding='same', strides=2),
      layers.BatchNormalization(axis=-1),
      layers.Activation("relu"),
      layers.Conv2DTranspose(32, 4, padding='same', strides=2),
      layers.BatchNormalization(axis=-1),
      layers.Activation("relu"),
      layers.Conv2DTranspose(channels, 1, padding='valid', strides=1),
      layers.Activation("sigmoid")])


  def call(self, x):
    encoded = self.encoder(x)
    decoded = self.decoder(encoded)
    return decoded

opt="adam"
autoencoder = AutoEncoder(dim_latent=DIM_LATENT, channels=1)
autoencoder.compile(optimizer=opt, loss=losses.BinaryCrossentropy())



In [None]:
hist=autoencoder.fit(x_train, x_train, epochs=EPOCHS, shuffle=True, validation_data=(x_test, x_test), batch_size=32)

In [None]:
autoencoder.encoder.summary()
autoencoder.decoder.summary()

In [None]:
pd.DataFrame(hist.history).plot()

In [None]:
# show some images

size=10000

encoded_imgs = []
for i in range(int(np.ceil(x_train.shape[0]/size))):
    encoded_imgs.append(autoencoder.encoder(x_train[i*size:(i+1)*size]).numpy())  
encoded_imgs = np.concatenate(encoded_imgs, axis=0)

decoded_imgs = []
for i in range(int(np.ceil(encoded_imgs.shape[0]/size))):
    decoded_imgs.append(autoencoder.decoder(encoded_imgs[i*size:(i+1)*size]).numpy())  
decoded_imgs = np.concatenate(decoded_imgs, axis=0)


n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
    # display original
    ax = plt.subplot(2, n, i + 1)
    plt.title("original")
    plt.imshow(tf.squeeze(x_train[i]))
    plt.gray()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # display reconstruction
    bx = plt.subplot(2, n, i + n + 1)
    plt.title("reconstructed")
    plt.imshow(tf.squeeze(decoded_imgs[i]))
    plt.gray()
    bx.get_xaxis().set_visible(False)
    bx.get_yaxis().set_visible(False)


In [None]:
# let's take a look at the latent space
for i in range(int(np.minimum(np.ceil(encoded_imgs.shape[1]/5), 4))):
    g = sns.PairGrid(pd.DataFrame(encoded_imgs[0:10000,(i*5):(i+1)*5]), despine=True)
    g = g.map_upper(sns.scatterplot, s=1.0)
    g = g.map_lower(sns.histplot, cmap="viridis", stat="density", bins=25)
    g = g.map_diag(sns.histplot, bins=100)
    g.fig.suptitle(f"Latent Space Distribution {i+1}")
    g.fig.subplots_adjust(top=0.95)

In [None]:
# save model
# path = "/content/drive/MyDrive/neurips21/"
#model_name = "autoencoder_fashion_mnist"
#tf.keras.models.save_model(autoencoder, path+model_name)