In [12]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.datasets import mnist
import numpy as np

**Preparing the image data**


In [2]:
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype("float32") / 255
test_images = test_images.reshape((10000, 28 * 28))
test_images = test_images.astype("float32") / 255

**The network architecture**

In [3]:
model = keras.Sequential([
    layers.Dense(512, activation="relu"),
    layers.Dense(10, activation="softmax")
])

**Model weights!**

In [4]:
model(train_images[0:32])
weights_and_bias = model.trainable_weights
for item in weights_and_bias:
    print(item.name, item.shape)

dense/kernel:0 (784, 512)
dense/bias:0 (512,)
dense_1/kernel:0 (512, 10)
dense_1/bias:0 (10,)


**Running one training step**

In [5]:
def one_training_step(model, images_batch, labels_batch):
    with tf.GradientTape() as tape:
        predictions = model(images_batch)
        per_sample_losses = tf.keras.losses.sparse_categorical_crossentropy(
            labels_batch, predictions)
        average_loss = tf.reduce_mean(per_sample_losses)
    gradients = tape.gradient(average_loss, model.trainable_weights)
    update_weights(gradients, model.trainable_weights)
    return average_loss

In [6]:
learning_rate = 1e-3

def update_weights(gradients, weights):
    for g, w in zip(gradients, weights):
        w.assign_sub(g * learning_rate)

In [7]:
from tensorflow.keras import optimizers

optimizer = optimizers.SGD(learning_rate=1e-3)

def update_weights(gradients, weights):
    optimizer.apply_gradients(zip(gradients, weights))

**A batch generator**

In [8]:
import math

class BatchGenerator:
    def __init__(self, images, labels, batch_size=128):
        assert len(images) == len(labels)
        self.index = 0
        self.images = images
        self.labels = labels
        self.batch_size = batch_size
        self.num_batches = math.ceil(len(images) / batch_size)

    def next(self):
        images = self.images[self.index : self.index + self.batch_size]
        labels = self.labels[self.index : self.index + self.batch_size]
        self.index += self.batch_size
        return images, labels

**The full training loop**

In [9]:
def fit(model, images, labels, epochs, batch_size=128):
    for epoch_counter in range(epochs):
        print(f"Epoch {epoch_counter}")
        batch_generator = BatchGenerator(images, labels)
        for batch_counter in range(batch_generator.num_batches):
            images_batch, labels_batch = batch_generator.next()
            loss = one_training_step(model, images_batch, labels_batch)
            if batch_counter % 100 == 0:
                print(f"loss at batch {batch_counter}: {loss:.2f}")

**fit the model with custom train loop!**

In [10]:
fit(model, train_images, train_labels, epochs=10, batch_size=128)

Epoch 0
loss at batch 0: 2.40
loss at batch 100: 2.18
loss at batch 200: 2.06
loss at batch 300: 1.98
loss at batch 400: 1.89
Epoch 1
loss at batch 0: 1.82
loss at batch 100: 1.74
loss at batch 200: 1.58
loss at batch 300: 1.56
loss at batch 400: 1.52
Epoch 2
loss at batch 0: 1.42
loss at batch 100: 1.41
loss at batch 200: 1.24
loss at batch 300: 1.26
loss at batch 400: 1.26
Epoch 3
loss at batch 0: 1.15
loss at batch 100: 1.17
loss at batch 200: 1.00
loss at batch 300: 1.05
loss at batch 400: 1.08
Epoch 4
loss at batch 0: 0.97
loss at batch 100: 1.00
loss at batch 200: 0.84
loss at batch 300: 0.91
loss at batch 400: 0.95
Epoch 5
loss at batch 0: 0.84
loss at batch 100: 0.87
loss at batch 200: 0.73
loss at batch 300: 0.81
loss at batch 400: 0.86
Epoch 6
loss at batch 0: 0.75
loss at batch 100: 0.78
loss at batch 200: 0.65
loss at batch 300: 0.74
loss at batch 400: 0.80
Epoch 7
loss at batch 0: 0.69
loss at batch 100: 0.71
loss at batch 200: 0.59
loss at batch 300: 0.68
loss at batch 40

**Evaluating the model on new data**

In [13]:
predictions = model(test_images)
predictions = predictions.numpy()
predicted_labels = np.argmax(predictions, axis=1)
matches = predicted_labels == test_labels
print(f"accuracy: {matches.mean():.2f}")

accuracy: 0.87
