In [11]:
import tensorflow as tf
from keras import layers
from keras.datasets import mnist
from keras.layers import Dropout
from tensorflow import keras

import plotly.express as px

In [12]:
def get_mnist_model():
    inputs = keras.Input(shape=(28*28,))
    features = layers.Dense(512, activation='relu')(inputs)
    features = Dropout(0.2)(features)
    outputs = layers.Dense(10, activation='softmax')(features)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

In [13]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images.reshape((60000, 28*28)).astype("float32") / 255
test_images = test_images.reshape((10000, 28*28)).astype("float32") / 255

train_images, val_images = train_images[10000:], train_images[:10000]
train_labels, val_labels = train_labels[10000:], train_labels[:10000]

In [14]:
model = get_mnist_model()

In [15]:
loss_fn = keras.losses.SparseCategoricalCrossentropy()
optimizer = keras.optimizers.Adam()
metrics = [keras.metrics.SparseCategoricalAccuracy()]
loss_tracking_metrics = keras.metrics.Mean()

In [16]:
def train_step(inputs, targets):
    with tf.GradientTape() as tape:
        predictions = model(inputs, training=True)
        loss = loss_fn(targets, predictions)
    gradients = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))
    logs ={}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs[metric.name] = metric.result()

    loss_tracking_metrics.update_state(loss)
    logs['loss'] = loss_tracking_metrics.result()
    return logs

In [17]:
def reset_metrics():
    for metric in metrics:
        metric.reset_states()
    loss_tracking_metrics.reset_states()

In [18]:
training_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
training_dataset = training_dataset.batch(32)
epochs = 3
for epoch in range(epochs):
    reset_metrics()
    for input_batch, targets_batch in training_dataset:
        logs = train_step(input_batch, targets_batch)
    print(f"Results at the end of epoch {epoch}")
    for key, value in logs.items():
        print(f"...{key}: {value:.4f}")



2022-04-08 00:03:58.471523: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 156800000 exceeds 10% of free system memory.
2022-04-08 00:03:58.558391: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 156800000 exceeds 10% of free system memory.


Results at the end of epoch 0
...sparse_categorical_accuracy: 0.9257
...loss: 0.2493


2022-04-08 00:04:20.729567: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 156800000 exceeds 10% of free system memory.


Results at the end of epoch 1
...sparse_categorical_accuracy: 0.9670
...loss: 0.1108


2022-04-08 00:04:42.979819: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 156800000 exceeds 10% of free system memory.


Results at the end of epoch 2
...sparse_categorical_accuracy: 0.9763
...loss: 0.0758


In [22]:
@tf.function
def test_step(inputs, targets):
    predictions = model(inputs, training=False)
    loss = loss_fn(targets, predictions)
    logs = {}
    for metric in metrics:
        metric.update_state(targets, predictions)
        logs["val_" + metric.name] = metric.result()
        loss_tracking_metrics.update_state(loss)
        logs["val_loss"] = loss_tracking_metrics.result()
        return logs

    val_dataset = tf.data.Dataset.from_tensor_slices((val_images, val_labels))
    val_dataset = val_dataset.batch(32)
    reset_metrics()
    for inputs_batch, targets_batch in val_dataset:
        logs = test_step(inputs_batch, targets_batch)
    print("Evaluation results:")
    for key, value in logs.items():
        print(f"...{key}: {value:.4f}")


In [23]:
test_step(val_images, val_labels)

{'val_sparse_categorical_accuracy': <tf.Tensor: shape=(), dtype=float32, numpy=0.9741143>,
 'val_loss': <tf.Tensor: shape=(), dtype=float32, numpy=0.07584991>}