In [1]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist
import tensorflow as tf
import numpy as np
import time
import sys

In [2]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
assert len(physical_devices) > 0, "Not enough GPU hardware devices available"
config = tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [3]:
# building our model
def build_model(width, height, depth, classes):
    input_shape = (height, width, depth)
    chan_dim = -1
    
    model = Sequential([
        # CONV => RELU => BN => POOL layer set
        Conv2D(16, (3, 3), padding="same", input_shape=input_shape),
        Activation("relu"),
        BatchNormalization(axis=chan_dim),
        MaxPooling2D(pool_size=(2, 2)),
        
        # (CONV => RELU => BN) * 2 => POOL layer set
        Conv2D(32, (3, 3), padding="same"),
        Activation("relu"),
        BatchNormalization(axis=chan_dim),
        Conv2D(32, (3, 3), padding="same"),
        Activation("relu"),
        BatchNormalization(axis=chan_dim),
        MaxPooling2D(pool_size=(2, 2)),
        
        # (CONV => RELU => BN) * 3 => POOL layer set
        Conv2D(64, (3, 3), padding="same"),
        Activation("relu"),
        BatchNormalization(axis=chan_dim),
        Conv2D(64, (3, 3), padding="same"),
        Activation("relu"),
        BatchNormalization(axis=chan_dim),
        Conv2D(64, (3, 3), padding="same"),
        Activation("relu"),
        BatchNormalization(axis=chan_dim),
        MaxPooling2D(pool_size=(2, 2)),
        
        # first (and only) set of FC => RELU layers
        Flatten(),
        Dense(256),
        Activation("relu"),
        BatchNormalization(),
        Dropout(0.5),
        
        # softmax classifier
        Dense(classes),
        Activation("softmax")
    ])
    
    return model

Our model is representative of VGG-esque (inspired by variants of VGGNet).
Dropout of 50% to improve generalization.

In [4]:
def step(X, y):
    # keep track of our gradients
    with tf.GradientTape() as tape:
        # make a prediction using the model and then calculate the loss
        pred = model(X)
        loss = categorical_crossentropy(y, pred)
    
    # calculate the gradients using our tape and then update the model weights
    grads = tape.gradient(loss, model.trainable_variables)
    opt.apply_gradients(zip(grads, model.trainable_variables))  # optimizer to update the weights

## Lets prepare the data

In [5]:
EPOCHS = 25
BS = 64  # batch size
INIT_LR = 1e-3  # initial learning rate

In [13]:
((trainX, trainY), (testX, testY)) = mnist.load_data()

In [16]:
# add a channel dimension to every image in the dataset, then scale the pixel intensities to the range [0, 1]
trainX = np.expand_dims(trainX, axis=-1)
testX = np.expand_dims(testX, axis=-1)
trainX = trainX.astype("float32") / 255.0
testX = testX.astype("float32") / 255.0

In [17]:
# one-hot encoding the labels
trainY = to_categorical(trainY, 10)
testY = to_categorical(testY, 10)

In [9]:
trainY.shape

(60000, 10)

In [19]:
# build our model and init our optimizer
model = build_model(28, 28, 1, 10)
opt = Adam(lr=INIT_LR, decay=INIT_LR / EPOCHS)

In [11]:
# train our model with our GradientTape
# compute the number of batch updates per epoch
num_updates = int(trainX.shape[0] / BS)

# loop over the number of epochs
for epoch in range(0, EPOCHS):
    # show the current epoch number
    print("[INFO] starting epoch {}/{}...".format(epoch + 1, EPOCHS), end="")
    sys.stdout.flush()
    epoch_start = time.time()
    
    # loop over data in batch size increments
    for i in range(0, num_updates):
        # start and end slice indexes
        start = i * BS
        end = start + BS
        
        # take a step
        step(trainX[start: end], trainY[start: end])
        
    # show timing information for the epoch
    epoch_end = time.time()
    elapsed = (epoch_end - epoch_start) / 60.0
    print("took {:.4} minutes".format(elapsed))

[INFO] starting epoch 1/25...took 0.3437 minutes
[INFO] starting epoch 2/25...took 0.3046 minutes
[INFO] starting epoch 3/25...took 0.3222 minutes
[INFO] starting epoch 4/25...took 0.3158 minutes
[INFO] starting epoch 5/25...took 0.3329 minutes
[INFO] starting epoch 6/25...took 0.3286 minutes
[INFO] starting epoch 7/25...took 0.35 minutes
[INFO] starting epoch 8/25...took 0.348 minutes
[INFO] starting epoch 9/25...took 0.3278 minutes
[INFO] starting epoch 10/25...took 0.3141 minutes
[INFO] starting epoch 11/25...took 0.3058 minutes
[INFO] starting epoch 12/25...took 0.3251 minutes
[INFO] starting epoch 13/25...took 0.3014 minutes
[INFO] starting epoch 14/25...took 0.3026 minutes
[INFO] starting epoch 15/25...took 0.3023 minutes
[INFO] starting epoch 16/25...took 0.3038 minutes
[INFO] starting epoch 17/25...took 0.3031 minutes
[INFO] starting epoch 18/25...took 0.3028 minutes
[INFO] starting epoch 19/25...took 0.3051 minutes
[INFO] starting epoch 20/25...took 0.3022 minutes
[INFO] start

In [18]:
# in order to calculate accuracy using Keras' functions we first need
# to compile the model
model.compile(optimizer=opt, loss=categorical_crossentropy,
    metrics=["acc"])
# now that the model is compiled we can compute the accuracy
(loss, acc) = model.evaluate(testX, testY)
print("[INFO] test accuracy: {:.4f}".format(acc))

   32/10000 [..............................] - ETA: 1:25

InvalidArgumentError:  logits and labels must be broadcastable: logits_size=[32,10] labels_size=[320,10]
	 [[node loss/activation_7_loss/softmax_cross_entropy_with_logits (defined at <ipython-input-18-a98ed7df0446>:6) ]] [Op:__inference_distributed_function_7825829]

Function call stack:
distributed_function
