In [27]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.datasets import mnist

In [28]:
(x_trn, y_trn), (x_tst, y_tst) = mnist.load_data()

In [29]:
num_features = 28 * 28
num_classes = 10

lr = 0.01
epochs = 1000
batch_size = 32
display_epoch = 100

In [30]:
x_train = (np.array(x_trn, np.float32)).reshape([-1, num_features]) / 255.0
x_test = (np.array(x_tst, np.float32)).reshape([-1, num_features]) / 255.0

y_train = tf.one_hot(y_trn, depth = num_classes)
y_test = tf.one_hot(y_tst, depth = num_classes)

In [31]:
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_data = train_data.repeat().shuffle(5000).batch(batch_size).prefetch(1)

In [32]:
all_losses = []
all_accuracy = []
all_epochs = []

In [33]:
class Model(object):
    def __init__(self):
        rnd_norm = tf.initializers.RandomNormal()
        self.W_conv1 = tf.Variable(rnd_norm([3, 3, 1, 32]))
        self.b_conv1 = tf.Variable(tf.constant(0.1, shape = [32]))

        self.W_conv2 = tf.Variable(rnd_norm([3, 3, 32, 64]))
        self.b_conv2 = tf.Variable(tf.constant(0.1, shape = [64]))

        self.W_fc1 = tf.Variable(rnd_norm([7 * 7 * 64, 128]))
        self.b_fc1 = tf.Variable(tf.constant(0.1, shape = [128]))

        self.W_fc2 = tf.Variable(rnd_norm([128, 10]))
        self.b_fc2 = tf.Variable(tf.constant(0.1, shape = [10]))
    
    def forward(self, x):
        x_image = tf.reshape(x, [-1, 28, 28, 1])

        self.h_conv1 = tf.nn.relu(tf.add(tf.nn.conv2d(x_image, self.W_conv1, strides=[1, 1, 1, 1], padding='SAME'), self.b_conv1))
        self.h_pool1 = tf.nn.max_pool(self.h_conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

        self.h_conv2 = tf.nn.relu(tf.add(tf.nn.conv2d(self.h_pool1, self.W_conv2, strides=[1, 1, 1, 1], padding='SAME'), self.b_conv2))
        self.h_pool2 = tf.nn.max_pool(self.h_conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

        self.h_flat = tf.reshape(self.h_pool2, [-1, 7 * 7 * 64])
        self.h_fc1 = tf.nn.relu(tf.add(tf.matmul(self.h_flat, self.W_fc1), self.b_fc1))

        self.output = tf.nn.softmax(tf.add(tf.matmul(self.h_fc1, self.W_fc2), self.b_fc2))

        return self.output

In [34]:
model = Model()
optimizer = tf.optimizers.SGD(lr)

In [35]:
def loss(y_pred, y_target):
    cross_entr = tf.reduce_mean(-tf.reduce_sum(y_target * tf.math.log(y_pred)))
    return cross_entr

In [36]:
def accuracy(y_pred, y_target):
    correct_pred = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y_target, 1))
    acc = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    return acc

In [37]:
def optimization(x, y):
    with tf.GradientTape() as gt:
        p_y = model.forward(x)
        cur_loss = loss(p_y, y)
    trainable_variables = [model.W_conv1, model.W_conv2, model.W_fc1, model.W_fc2, model.b_conv1, model.b_conv2, model.b_fc1, model.b_fc2]
    grad = gt.gradient(cur_loss, trainable_variables)
    optimizer.apply_gradients(zip(grad, trainable_variables))
    

In [38]:
def train():
    for epoch, (batch_x, batch_y) in enumerate(train_data.take(epochs), 1):
        optimization(batch_x, batch_y)
        if epoch % display_epoch == 0:
            y_pr = model.forward(batch_x)
            loss_res = loss(y_pr, batch_y)
            acc_res = accuracy(y_pr, batch_y)
            all_losses.append(loss_res)
            all_accuracy.append(acc_res)
            all_epochs.append(epoch)
            print("Epoch {0} Loss {1} Acc {2}".format(epoch, loss_res, acc_res))
    
    return model

In [39]:
model = train()

y_pr_test = model.forward(x_test)
acc_test = accuracy(y_pr_test, y_test)

print("Test acc = {}".format(acc_test))

Epoch 100 Loss 6.588296890258789 Acc 0.9375
Epoch 200 Loss 1.652270793914795 Acc 1.0
Epoch 300 Loss 1.3469871282577515 Acc 1.0
Epoch 400 Loss 1.588355541229248 Acc 1.0
Epoch 500 Loss 0.3085606098175049 Acc 1.0
Epoch 600 Loss 0.35543015599250793 Acc 1.0
Epoch 700 Loss 0.13824251294136047 Acc 1.0
Epoch 800 Loss 0.31167519092559814 Acc 1.0
Epoch 900 Loss 0.11289635300636292 Acc 1.0
Epoch 1000 Loss 0.22764891386032104 Acc 1.0
Test acc = 0.9794999957084656
