In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [None]:
from algorithms.utils import MNISTLoader, ZeroMetric
from algorithms import sigmoid, mlp, cnn, auto_encoder

In [None]:
batch_size = 256
num_epochs = 100
learning_rate = 1e-4
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [None]:
class AutoEncoder(tf.keras.layers.Layer):
    def __init__(self, units, l2_const=1e-4):
        super().__init__()
        self.units = units
        self.l2_const = l2_const

    def build(self, input_shape):
        self.encoder = tf.keras.layers.Dense(self.units,
                        activation='relu', kernel_regularizer=tf.keras.regularizers.l2(self.l2_const))
        self.decoder = tf.keras.layers.Dense(input_shape[-1],
                        activation='relu', kernel_regularizer=tf.keras.regularizers.l2(self.l2_const))

    def encode(self, inputs):
        embedding = self.encoder(inputs)
        return embedding

    def decode(self, embedding):
        outputs = self.decoder(embedding)
        return outputs

    def call(self, inputs):
        embedding = self.encode(inputs)
        outputs = self.decode(embedding)
        return outputs

class StackedAutoEncoder():
    def __init__(self, units_list):
        self.units_list = units_list
        self.auto_encoders = []
        for units in units_list:
            self.auto_encoders.append(AutoEncoder(units))

    def encode(self, inputs, stop_at=None):
        embedding = inputs
        for auto_encoder in self.auto_encoders[: stop_at]:
            embedding = auto_encoder.encode(embedding)
        return embedding

    def decode(self, embedding, stop_at=None):
        outputs = embedding
        for auto_encoder in reversed(self.auto_encoders[: stop_at]):
            outputs = auto_encoder.decode(outputs)
        return outputs
    
    @property
    def variables(self):
        variables = []
        for auto_encoder in self.auto_encoders:
            variables.extend(auto_encoder.variables)
        return variables

    def __call__(self, inputs, stop_at=None):
        embedding = self.encode(inputs, stop_at)
        outputs = self.decode(embedding, stop_at)
        return outputs

In [None]:
def train_auto_encoder(stacked_auto_encoder, data_loader, train_layer=-1, num_epochs=5):
    @tf.function
    def train_on_batch(auto_encoder, X_batch, y_batch):
        with tf.GradientTape() as tape:
            y_pred = auto_encoder(X_batch)
            # loss = loss_object(y_true=y_batch, y_pred=y_pred)
            mean_squared_error = tf.reduce_mean((y_pred - y_batch) ** 2, axis=1)
            loss = mean_squared_error# + kl_volatility
            loss = tf.reduce_mean(loss)
        grads = tape.gradient(loss, auto_encoder.variables)
        optimizer.apply_gradients(grads_and_vars=zip(grads, auto_encoder.variables))
        return loss

    for epoch in range(num_epochs):

        # Training
        for batch_index, (X_batch, y_batch) in enumerate(data_loader.batch_loader(batch_size=batch_size, data_type='train')):
            X_batch = X_batch.reshape(X_batch.shape[0], X_batch.shape[1]*X_batch.shape[2])
            X_batch = stacked_auto_encoder.encode(X_batch, stop_at=train_layer)
            auto_encoder = stacked_auto_encoder.auto_encoders[train_layer]
            loss = train_on_batch(auto_encoder, X_batch, X_batch)
            template = '[Training] Layer {}, Epoch {}, Batch {}/{} '
            print(template.format(train_layer,
                                epoch+1,
                                batch_index,
                                data_loader.train_size // batch_size),
                end='\r')

In [None]:
stacked_auto_encoder = StackedAutoEncoder([1024, 512, 128])
train_auto_encoder(stacked_auto_encoder, data_loader, train_layer=0, num_epochs=5)
train_auto_encoder(stacked_auto_encoder, data_loader, train_layer=1, num_epochs=10)
train_auto_encoder(stacked_auto_encoder, data_loader, train_layer=2, num_epochs=15)

In [None]:
img = data_loader.train_data[4]
plt.imshow(img.squeeze(), cmap='gray')
plt.show()
img_rebuild_1 = stacked_auto_encoder(img.reshape(1, 784), stop_at=1).numpy().reshape([28, 28])
plt.imshow(img_rebuild_1, cmap='gray')
plt.show()
img_rebuild_2 = stacked_auto_encoder(img.reshape(1, 784), stop_at=2).numpy().reshape([28, 28])
plt.imshow(img_rebuild_2, cmap='gray')
plt.show()
img_rebuild_3 = stacked_auto_encoder(img.reshape(1, 784), stop_at=3).numpy().reshape([28, 28])
plt.imshow(img_rebuild_3, cmap='gray')
plt.show()

In [None]:
def train_model(model, num_epochs=5, preprocess=lambda _:_, extra_variables=[]):
    @tf.function
    def train_on_batch(X_batch, y_batch):
        with tf.GradientTape() as tape:
            y_pred = model(X_batch)
            loss = loss_object(y_true=y_batch, y_pred=y_pred)
            loss = tf.reduce_mean(loss)
        variables = model.variables + extra_variables
        grads = tape.gradient(loss, variables)
        optimizer.apply_gradients(grads_and_vars=zip(grads, variables))

        train_loss(loss)
        train_accuracy(y_batch, y_pred)
        return loss

    @tf.function
    def test_on_batch(X_batch, y_batch):
        y_pred = model(X_batch)
        t_loss = loss_object(y_batch, y_pred)

        test_loss(t_loss)
        test_accuracy(y_batch, y_pred)
        return t_loss

    for epoch in range(num_epochs):

        train_loss.reset_states()
        train_accuracy.reset_states()
        test_loss.reset_states()
        test_accuracy.reset_states()

        # Training
        for batch_index, (X_batch, y_batch) in enumerate(data_loader.batch_loader(batch_size=batch_size, data_type='train')):
            X_batch = preprocess(X_batch)
            loss = train_on_batch(X_batch, y_batch)
            template = '[Training] Epoch {}, Batch {}/{}, Loss: {}, Accuracy: {:.2%} '
            print(template.format(epoch+1,
                                batch_index,
                                data_loader.train_size // batch_size,
                                loss,
                                train_accuracy.result()),
                end='\r')

        # Testing
        for batch_index, (X_batch, y_batch) in enumerate(data_loader.batch_loader(batch_size=batch_size, data_type='test')):
            X_batch = preprocess(X_batch)
            loss = test_on_batch(X_batch, y_batch)
            template = '[Testing] Epoch {}, Batch {}/{}, Loss: {}, Accuracy: {:.2%} '
            print(template.format(epoch+1,
                                batch_index,
                                data_loader.test_size // batch_size,
                                loss,
                                test_accuracy.result()),
                end='\r')

        template = 'Epoch {}, Loss: {}, Accuracy: {:.2%}, Test Loss: {}, Test Accuracy: {:.2%} '
        print(template.format(epoch+1,
                            train_loss.result(),
                            train_accuracy.result(),
                            test_loss.result(),
                            test_accuracy.result()))

In [None]:
def preprocess(X_batch):
    X_batch = X_batch.reshape(X_batch.shape[0], X_batch.shape[1]*X_batch.shape[2])
    X_batch = stacked_auto_encoder.encode(X_batch)
    return X_batch

softmax_layer = tf.keras.Sequential([
    tf.keras.layers.Dense(10),
    tf.keras.layers.Softmax()
])
softmax_layer.build(input_shape=(None, 128))
softmax_layer.summary()
train_model(softmax_layer, num_epochs=20, preprocess=preprocess)

In [None]:
def preprocess(X_batch):
    X_batch = X_batch.reshape(X_batch.shape[0], X_batch.shape[1]*X_batch.shape[2])
    return X_batch
    
inputs = tf.keras.Input(shape=(784, ))
x = inputs
# x = stacked_auto_encoder.auto_encoders[0].encode(x)
# x = stacked_auto_encoder.auto_encoders[1].encode(x)
# x = stacked_auto_encoder.auto_encoders[2].encode(x)
x = stacked_auto_encoder.encode(x)
x = softmax_layer(x)
outputs = x
model = tf.keras.Model(inputs=inputs, outputs=outputs)
model.summary()

train_model(model, num_epochs=num_epochs, preprocess=preprocess, extra_variables=stacked_auto_encoder.variables)