In [1]:
# let's just use a custom training loop to train mnist. instead of our normal procedure

In [2]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow.keras.datasets import mnist #to import our dataset
from tensorflow.keras.models import Model # imports our type of network
from tensorflow.keras.layers import Dense, Flatten, Input,Lambda # imports our layers we want to use

from tensorflow.keras.losses import categorical_crossentropy #loss function
from tensorflow.keras.optimizers import Adam, SGD #optimisers
from tensorflow.keras.utils import to_categorical #some function for data preparation

#from sklearn.decomposition import PCA

In [3]:
batch_size = 128
num_classes = 10
epochs = 20

# input image dimensions
img_rows, img_cols = 28, 28

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()


x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
print("y_train before categorical", np.shape(y_train))
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)
print("y_train after categorical", np.shape(y_train))


x_train shape: (60000, 28, 28)
60000 train samples
10000 test samples
y_train before categorical (60000,)
y_train after categorical (60000, 10)


In [4]:
inputs = Input(shape=(28,28))
x = Flatten()(inputs)
x = Dense(128,activation='relu')(x)
y = Dense(10,activation='softmax')(x)


model1= Model(inputs,outputs=y)
opt = Adam(learning_rate=0.0001)
model1.compile(loss='categorical_crossentropy',optimizer=opt)

model1.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 128)               100480    
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1290      
Total params: 101,770
Trainable params: 101,770
Non-trainable params: 0
_________________________________________________________________


In [5]:
model1.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,validation_data=(x_test,y_test))

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x2be2085da30>

In [6]:
# now let's do the same with a custom training loop

inputs = Input(shape=(28,28))
x = Flatten()(inputs)
x = Dense(128,activation='relu')(x)
y = Dense(10,activation='softmax')(x)


model2= Model(inputs,outputs=y)

# Instantiate an optimizer.
optimizer = Adam(learning_rate=1e-4)
# Instantiate a loss function.
loss_fn = tf.keras.losses.CategoricalCrossentropy(from_logits=True)



In [7]:
# Prepare the training dataset.
batch_size = 64

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=len(x_train)).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
val_dataset = val_dataset.batch(batch_size)

In [8]:
train_dataset

<BatchDataset shapes: ((None, 28, 28), (None, 10)), types: (tf.float32, tf.float32)>

In [9]:
epochs = 20
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):

        # Open a GradientTape to record the operations run
        # during the forward pass, which enables auto-differentiation.
        with tf.GradientTape() as tape:

            # Run the forward pass of the layer.
            # The operations that the layer applies
            # to its inputs are going to be recorded
            # on the GradientTape.
            logits = model2(x_batch_train, training=True)  # Logits for this minibatch

            # Compute the loss value for this minibatch.
            loss_value = loss_fn(y_batch_train, logits)

        # Use the gradient tape to automatically retrieve
        # the gradients of the trainable variables with respect to the loss.
        grads = tape.gradient(loss_value, model2.trainable_weights)

        # Run one step of gradient descent by updating
        # the value of the variables to minimize the loss.
        optimizer.apply_gradients(zip(grads, model2.trainable_weights))

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %s samples" % ((step + 1) * batch_size))


Start of epoch 0
Training loss (for one batch) at step 0: 2.3075
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.8714
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.7691
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.7263
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 1.6369
Seen so far: 51264 samples

Start of epoch 1
Training loss (for one batch) at step 0: 1.6366
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.5516
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.6456
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.6091
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 1.5941
Seen so far: 51264 samples

Start of epoch 2
Training loss (for one batch) at step 0: 1.6324
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.5578
Seen so far: 12864 samples
Training loss (for one batch) at step

In [10]:
# this was relatively slow, let's add an additional line of code to accelerate this with @tf.function

inputs = Input(shape=(28,28))
x = Flatten()(inputs)
x = Dense(128,activation='relu')(x)
y = Dense(10,activation='softmax')(x)


model3= Model(inputs,outputs=y)


train_acc_metric = tf.keras.metrics.CategoricalAccuracy()
val_acc_metric = tf.keras.metrics.CategoricalAccuracy()

@tf.function
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model3(x, training=True)
        loss_value = loss_fn(y, logits)
    grads = tape.gradient(loss_value, model3.trainable_weights)
    optimizer.apply_gradients(zip(grads, model3.trainable_weights))
    train_acc_metric.update_state(y, logits)
    return loss_value

@tf.function
def test_step(x, y):
    val_logits = model3(x, training=False)
    val_acc_metric.update_state(y, val_logits)

In [11]:
import time

epochs = 5
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model3(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model3.trainable_weights)
        optimizer.apply_gradients(zip(grads, model3.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model3(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))


Start of epoch 0
Training loss (for one batch) at step 0: 2.3102
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.5699
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.5915
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.5798
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 1.5846
Seen so far: 51264 samples
Training acc over epoch: 0.8692
Validation acc: 0.9143
Time taken: 17.53s

Start of epoch 1
Training loss (for one batch) at step 0: 1.5955
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.5433
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.5789
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.6146
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 1.5960
Seen so far: 51264 samples
Training acc over epoch: 0.9156
Validation acc: 0.9213
Time taken: 16.42s

Start of epoch 2
Training loss (for one batch) at step 

In [12]:
epochs = 5
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        loss_value = train_step(x_batch_train, y_batch_train)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * batch_size))

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        test_step(x_batch_val, y_batch_val)

    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))


Start of epoch 0
Training loss (for one batch) at step 0: 1.5449
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.5294
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.5340
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.5458
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 1.5145
Seen so far: 51264 samples
Training acc over epoch: 0.9367
Validation acc: 0.9380
Time taken: 6.69s

Start of epoch 1
Training loss (for one batch) at step 0: 1.4987
Seen so far: 64 samples
Training loss (for one batch) at step 200: 1.5480
Seen so far: 12864 samples
Training loss (for one batch) at step 400: 1.5496
Seen so far: 25664 samples
Training loss (for one batch) at step 600: 1.5148
Seen so far: 38464 samples
Training loss (for one batch) at step 800: 1.5158
Seen so far: 51264 samples
Training acc over epoch: 0.9399
Validation acc: 0.9394
Time taken: 3.99s

Start of epoch 2
Training loss (for one batch) at step 0: