This tutorial is largely based on the official TensorFlow quickstart for experts [`tutorial`](https://www.tensorflow.org/tutorials/quickstart/advanced). The official website provides links to all relevant documentations. 

For a more complete implementation including visualizations, feel free to reference code used [`here`](https://www.kaggle.com/code/amyjang/tensorflow-mnist-cnn-tutorial).

If you are interested in the math side of things, talk to any one of the leads or check out this [`link`](https://calvinfeng.gitbook.io/machine-learning-notebook/supervised-learning/old-stuff/mnist_tutorial).

In [None]:
# Use pip install to get the packages if you don't already have them
import tensorflow as tf

# Keras builds upon tensorflow and provides high level functionalities
# It saves you some manual computations
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

In [None]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Normalizing the pixels so they range from 0 to 1
x_train, x_test = x_train / 255.0, x_test / 255.0

In [None]:
# The input images are 2-Dimensional but TensorFlow prefers 3D input
# Add a channels dimension
x_train = x_train[..., tf.newaxis].astype("float32")
x_test = x_test[..., tf.newaxis].astype("float32")

In [None]:
# Shuffle the training data used for each epoch to prevent overfitting.
# This also speeds up data retrieval.
train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(32)

test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [None]:
# The following should be structured similarly as the pytorch tutorial
# Feel free to add more layers but be mindful of input / output shapes
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(filters=32, kernel_size=3, activation='relu')
        self.flatten = Flatten()

        # Dense layers in TensorFlow is analogous to linear layers in PyTorch
        self.d1 = Dense(128, activation='relu')
        # This is the output layer
        self.d2 = Dense(10)

    # Analogous to forward()
    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

# Create an instance of the model
model = MyModel()

In [None]:
# Define the optimizer and the loss function
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

optimizer = tf.keras.optimizers.Adam()

In [None]:
# Define some metrics
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

test_loss = tf.keras.metrics.Mean(name='test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [None]:
# Implement training
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        # training=True is only needed if there are layers with different
        # behavior during training versus inference (e.g. Dropout).
        predictions = model(images, training=True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    train_loss(loss)
    train_accuracy(labels, predictions)

# Testing code
@tf.function
def test_step(images, labels):
    # training=False is only needed if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    predictions = model(images, training=False)
    t_loss = loss_object(labels, predictions)

    test_loss(t_loss)
    test_accuracy(labels, predictions)

Woah you say, what is the @ sign above the function header?

Sigh, welcome to the confusion word of Python decorators. 

The canonical explanation is that decorators add functionality to functions without you needing to rewrite the code inside that function.

In reality, they are kind of difficult to wrap your head around. 

This [`video`](https://www.youtube.com/watch?v=MYAEv3JoenI&ab_channel=howCode) is a very quick introduction about what they do.

For a more in-depth explanation, see this [`video`](https://www.youtube.com/watch?v=r7Dtus7N4pI&t=2s&ab_channel=Kite).

In [None]:
# Feel free to play around with differnet number of epochs
EPOCHS = 5

for epoch in range(EPOCHS):
    # Reset the metrics at the start of the next epoch
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

    for images, labels in train_ds:
        train_step(images, labels)

    for test_images, test_labels in test_ds:
        test_step(test_images, test_labels)

    print(
        f'Epoch {epoch + 1}, '
        f'Loss: {train_loss.result()}, '
        f'Accuracy: {train_accuracy.result() * 100}, '
        f'Test Loss: {test_loss.result()}, '
        f'Test Accuracy: {test_accuracy.result() * 100}'
    )