In [9]:
import tensorflow as tf

# Workaround for Pylance
keras = tf.keras

In [10]:
# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype("float32")
x_test = x_test.reshape(x_test.shape[0], 28, 28, 1).astype("float32")

# Reserve 10,000 samples for validation.
x_val = x_train[-10000:]
y_val = y_train[-10000:]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

In [11]:
def make_classifier():
    # Get model
    image_input = keras.layers.Input(shape=(28, 28, 1), name="image")
    x = keras.layers.Flatten()(image_input)
    x = keras.layers.Dense(64, activation="relu", name="dense_1")(x)
    x = keras.layers.Dense(64, activation="relu", name="dense_2")(x)
    outputs = keras.layers.Dense(10, name="predictions")(x)
    model = keras.Model(inputs=image_input, outputs=outputs)
    return model


# Instantiate an optimizer to train the model.
optimizer = keras.optimizers.SGD(learning_rate=1e-3)
# Instantiate a loss function.
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

# Prepare the metrics.
train_acc_metric = keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = keras.metrics.SparseCategoricalAccuracy()

In [12]:
classifier = make_classifier()

In [13]:
@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = classifier(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, classifier.trainable_weights)
    optimizer.apply_gradients(zip(grads, classifier.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value


@tf.function
def test_step(x, y):
    val_logits = classifier(x, training=False)
    val_acc_metric.update_state(y, val_logits)

In [14]:
import time

epochs = 3
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("------------------------------------")
    print("Training acc over epoch: %.4f" % (float(train_acc),))
    print("------------------------------------")

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))
    print("------------------------------------")


Start of epoch 0
------------------------------------
Training acc over epoch: 0.7599
------------------------------------
Validation acc: 0.8375
Time taken: 1.38s
------------------------------------

Start of epoch 1
------------------------------------
Training acc over epoch: 0.8678
------------------------------------
Validation acc: 0.8830
Time taken: 0.84s
------------------------------------

Start of epoch 2
------------------------------------
Training acc over epoch: 0.8944
------------------------------------
Validation acc: 0.9043
Time taken: 0.83s
------------------------------------


In [15]:
classifier.compile()
classifier.save("./models/standard_classifier")



INFO:tensorflow:Assets written to: ./models/standard_classifier\assets


INFO:tensorflow:Assets written to: ./models/standard_classifier\assets
