In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.reshape(-1, 784)
x_test = x_test.reshape(-1, 784)

y_train_oh = tf.one_hot(y_train, depth=10)
y_test_oh = tf.one_hot(y_test, depth=10)

def create_model():
    model = models.Sequential([
        layers.Input(shape=(784,)),
        layers.Dense(512, activation='relu'),
        layers.Dense(256, activation='relu'),
        layers.Dense(128, activation='relu'),
        layers.Dense(10)
    ])
    return model


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


Training with tf.GradientTape

In [3]:
model = create_model()
optimizer = tf.keras.optimizers.Adam()
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train_oh)).batch(64)

for epoch in range(5):
    print(f"Epoch {epoch+1}")
    for step, (x_batch, y_batch) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch, training=True)
            loss = loss_fn(y_batch, logits)
        grads = tape.gradient(loss, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

    test_logits = model(x_test, training=False)
    test_preds = tf.argmax(test_logits, axis=1)
    acc = tf.reduce_mean(tf.cast(test_preds == y_test, tf.float32))
    print(f"Test Accuracy: {acc.numpy():.4f}")


Epoch 1
Test Accuracy: 0.9552
Epoch 2
Test Accuracy: 0.9663
Epoch 3
Test Accuracy: 0.9631
Epoch 4
Test Accuracy: 0.9709
Epoch 5
Test Accuracy: 0.9664


Training with model.fit

In [4]:
model2 = create_model()
model2.compile(optimizer='adam',
               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
               metrics=['accuracy'])

model2.fit(x_train, y_train, epochs=5, batch_size=64, validation_split=0.1)
test_loss, test_acc = model2.evaluate(x_test, y_test)
print(f"Test Accuracy: {test_acc:.4f}")


Epoch 1/5
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 12ms/step - accuracy: 0.8802 - loss: 0.3901 - val_accuracy: 0.9637 - val_loss: 0.1084
Epoch 2/5
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 11ms/step - accuracy: 0.9730 - loss: 0.0893 - val_accuracy: 0.9795 - val_loss: 0.0734
Epoch 3/5
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 12ms/step - accuracy: 0.9828 - loss: 0.0547 - val_accuracy: 0.9688 - val_loss: 0.1052
Epoch 4/5
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 12ms/step - accuracy: 0.9866 - loss: 0.0421 - val_accuracy: 0.9757 - val_loss: 0.0854
Epoch 5/5
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 12ms/step - accuracy: 0.9895 - loss: 0.0330 - val_accuracy: 0.9773 - val_loss: 0.0907
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.9719 - loss: 0.0970
Test Accuracy: 0.9772
