In [None]:
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

In [None]:
from algorithms.utils import MNISTLoader, ZeroMetric
from algorithms import softmax, mlp, cnn, auto_encoder

In [None]:
batch_size = 256
num_epochs = 100
learning_rate = 1e-4
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [None]:
model = softmax.Softmax() # 92.48%
# model = mlp.MLP() # 97.60%
# model = cnn.CNN() # 99.20%
# model = auto_encoder.AutoEncoder(1000)
data_loader = MNISTLoader()

In [None]:
X, y = next(data_loader.batch_loader())

In [None]:
idx = 19
plt.imshow(X[idx].squeeze(), cmap='gray')
plt.title(y[idx])
plt.show()

In [None]:
@tf.function
def train_on_batch(X_batch, y_batch):
    with tf.GradientTape() as tape:
        y_pred = model(X_batch)
        loss = loss_object(y_true=y_batch, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

    train_loss(loss)
    train_accuracy(y_batch, y_pred)
    return loss

@tf.function
def test_on_batch(X_batch, y_batch):
    y_pred = model(X_batch)
    t_loss = loss_object(y_batch, y_pred)

    test_loss(t_loss)
    test_accuracy(y_batch, y_pred)
    return t_loss

for epoch in range(num_epochs):

    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    # Training
    for batch_index, (X_batch, y_batch) in enumerate(data_loader.batch_loader(batch_size=batch_size, data_type='train')):
        loss = train_on_batch(X_batch, y_batch)
        template = '[Training] Epoch {}, Batch {}/{}, Loss: {}, Accuracy: {:.2%} '
        print(template.format(epoch+1,
                              batch_index,
                              data_loader.train_size // batch_size,
                              loss,
                              train_accuracy.result()),
             end='\r')

    # Testing
    for batch_index, (X_batch, y_batch) in enumerate(data_loader.batch_loader(batch_size=batch_size, data_type='test')):
        loss = test_on_batch(X_batch, y_batch)
        template = '[Testing] Epoch {}, Batch {}/{}, Loss: {}, Accuracy: {:.2%} '
        print(template.format(epoch+1,
                              batch_index,
                              data_loader.test_size // batch_size,
                              loss,
                              test_accuracy.result()),
             end='\r')

    template = 'Epoch {}, Loss: {}, Accuracy: {:.2%}, Test Loss: {}, Test Accuracy: {:.2%} '
    print(template.format(epoch+1,
                         train_loss.result(),
                         train_accuracy.result(),
                         test_loss.result(),
                         test_accuracy.result()))