In [10]:
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras import layers, models, datasets, callbacks
import tensorflow.keras.backend as K

# from notebooks.utils import display

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
IMAGE_SIZE = 32
CHANNELS = 1
BATCH_SIZE = 100
BUFFER_SIZE = 1000
VALIDATION_SPLIT = 0.2
EMBEDDING_DIM = 2
EPOCHS = 3

In [12]:
# Load the data
(x_train, y_train), (x_test, y_test) = datasets.fashion_mnist.load_data()

In [13]:
# Preprocess the data


def preprocess(imgs):
    """
    Normalize and reshape the images
    """
    imgs = imgs.astype("float32") / 255.0
    imgs = np.pad(imgs, ((0, 0), (2, 2), (2, 2)), constant_values=0.0)
    imgs = np.expand_dims(imgs, -1)
    return imgs


x_train = preprocess(x_train)
x_test = preprocess(x_test)

In [14]:
x_train.shape

(60000, 32, 32, 1)

In [24]:
# encoder 

encoder_input = layers.Input(shape = (IMAGE_SIZE,IMAGE_SIZE,CHANNELS),name='encoder_input')

x = layers.Conv2D(32, (3,3),strides=2,activation='relu',padding='same')(encoder_input)
x = layers.Conv2D(64, (3, 3), strides=2, activation="relu", padding="same")(x)
x = layers.Conv2D(128, (3, 3), strides=2, activation="relu", padding="same")(x)

shape_before_flatenning = K.int_shape(x)[1:]
print(shape_before_flatenning)

x = layers.Flatten()(x)
encoder_output = layers.Dense(EMBEDDING_DIM, name="encoder_output")(x)
encoder = models.Model(encoder_input, encoder_output)
encoder.summary()


(4, 4, 128)
Model: "model_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_input (InputLayer)  [(None, 32, 32, 1)]       0         
                                                                 
 conv2d_12 (Conv2D)          (None, 16, 16, 32)        320       
                                                                 
 conv2d_13 (Conv2D)          (None, 8, 8, 64)          18496     
                                                                 
 conv2d_14 (Conv2D)          (None, 4, 4, 128)         73856     
                                                                 
 flatten_4 (Flatten)         (None, 2048)              0         
                                                                 
 encoder_output (Dense)      (None, 2)                 4098      
                                                                 
Total params: 96770 (378.01 KB)
Trainable param

In [26]:
# decoder 
decoder_input = layers.Input(shape= (EMBEDDING_DIM,),name='decoder_input')
x = layers.Dense(np.prod(shape_before_flatenning))(decoder_input)
x = layers.Reshape(shape_before_flatenning)(x)
x = layers.Conv2DTranspose(128,(3,3),strides=2,activation='relu',padding='same')(x)
x = layers.Conv2DTranspose(64,(3,3),strides=2,activation='relu',padding='same')(x)
x = layers.Conv2DTranspose(32,(3,3),strides=2,activation='relu',padding='same')(x)

decoder_output = layers.Conv2D(CHANNELS,(3,3),strides=1,
    activation="sigmoid", # we need output between 0-1
    padding="same",
    name="decoder_output")(x)

decoder = models.Model(decoder_input,decoder_output)

decoder.summary()


Model: "model_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 decoder_input (InputLayer)  [(None, 2)]               0         
                                                                 
 dense_2 (Dense)             (None, 2048)              6144      
                                                                 
 reshape (Reshape)           (None, 4, 4, 128)         0         
                                                                 
 conv2d_transpose_1 (Conv2D  (None, 8, 8, 128)         147584    
 Transpose)                                                      
                                                                 
 conv2d_transpose_2 (Conv2D  (None, 16, 16, 64)        73792     
 Transpose)                                                      
                                                                 
 conv2d_transpose_3 (Conv2D  (None, 32, 32, 32)        1846

In [27]:
# Autoencoder 

autoencoder = models.Model(encoder_input,decoder(encoder_output))
autoencoder.summary()

Model: "model_6"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 encoder_input (InputLayer)  [(None, 32, 32, 1)]       0         
                                                                 
 conv2d_12 (Conv2D)          (None, 16, 16, 32)        320       
                                                                 
 conv2d_13 (Conv2D)          (None, 8, 8, 64)          18496     
                                                                 
 conv2d_14 (Conv2D)          (None, 4, 4, 128)         73856     
                                                                 
 flatten_4 (Flatten)         (None, 2048)              0         
                                                                 
 encoder_output (Dense)      (None, 2)                 4098      
                                                                 
 model_5 (Functional)        (None, 32, 32, 1)         2462

In [28]:
# Compile the autoencoder
autoencoder.compile(optimizer="adam", loss="binary_crossentropy")

In [30]:
# Create a model save checkpoint

model_checkpoint_callback = callbacks.ModelCheckpoint(
    filepath = './checkpoint',
    save_weights_only = False,
    save_freq = 'epoch',
    monitor = 'loss',
    model = "min",
    save_best_only = True,
    verbose = 0 
)
tensorboard_callback = callbacks.TensorBoard(log_dir = './logs')


In [31]:
autoencoder.fit(
    x_train,
    x_train,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    shuffle=True,
    validation_data=(x_test, x_test),
    callbacks=[model_checkpoint_callback, tensorboard_callback],
)

Epoch 1/3


INFO:tensorflow:Assets written to: ./checkpoint/assets


Epoch 2/3


INFO:tensorflow:Assets written to: ./checkpoint/assets


Epoch 3/3


INFO:tensorflow:Assets written to: ./checkpoint/assets




<keras.src.callbacks.History at 0x173558f70>

In [32]:
# Save the final models
autoencoder.save("./models/autoencoder")
encoder.save("./models/encoder")
decoder.save("./models/decoder")

INFO:tensorflow:Assets written to: ./models/autoencoder/assets


INFO:tensorflow:Assets written to: ./models/autoencoder/assets






INFO:tensorflow:Assets written to: ./models/encoder/assets


INFO:tensorflow:Assets written to: ./models/encoder/assets






INFO:tensorflow:Assets written to: ./models/decoder/assets


INFO:tensorflow:Assets written to: ./models/decoder/assets


In [33]:
# Reconstruct using the autoencoder

n_to_predict = 5000
example_images = x_test[:n_to_predict]
example_labels = y_test[:n_to_predict]

In [34]:
predictions = autoencoder.predict(example_images)

print("Example real clothing items")
display(example_images)
print("Reconstructions")
display(predictions)

Example real clothing items


array([[[[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        ...,

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]],

        [[0.        ],
         [0.        ],
         [0.        ],
         ...,
         [0.        ],
         [0.        ],
         [0.        ]]],


       [[[0.        ],
         [0.        ],
         [0.  

Reconstructions


array([[[[5.86573378e-07],
         [2.04943351e-08],
         [1.47156300e-07],
         ...,
         [1.57079627e-10],
         [6.28633146e-10],
         [1.85924125e-06]],

        [[2.31734232e-09],
         [5.33234984e-11],
         [1.01102158e-08],
         ...,
         [5.36310015e-12],
         [4.71431678e-11],
         [1.31456943e-07]],

        [[1.28151232e-08],
         [2.33798203e-09],
         [2.52283115e-07],
         ...,
         [4.52844489e-07],
         [5.32804734e-08],
         [1.10749961e-05]],

        ...,

        [[3.40189299e-09],
         [2.83089496e-10],
         [5.82660071e-08],
         ...,
         [5.01419066e-08],
         [1.37367726e-08],
         [2.67128917e-06]],

        [[6.23709511e-08],
         [5.60824898e-09],
         [7.80456020e-08],
         ...,
         [1.53218394e-07],
         [1.67417671e-08],
         [7.18203410e-06]],

        [[7.52206915e-06],
         [6.52058077e-07],
         [5.62379637e-06],
         ...,
 