In [None]:
import tensorflow as tf
import numpy as np

In [None]:
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
    

## Paramemters

In [None]:
epochs = 16
lr = 0.001

## Dataset

In [None]:
batch_size = 256

(mnist_images, mnist_labels), (x_test, y_test) = \
    tf.keras.datasets.mnist.load_data(path='mnist.npz')

print(mnist_images.shape, mnist_labels.shape)
print(x_test.shape, y_test.shape)

In [None]:
dataset = tf.data.Dataset.from_tensor_slices(
    (tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
             tf.cast(mnist_labels, tf.int64))
)
test_dset = tf.data.Dataset.from_tensor_slices(
    (tf.cast(x_test[..., tf.newaxis] / 255.0, tf.float32),
             tf.cast(y_test, tf.int64))
)

nsamples = len(list(dataset))
ntests = len(list(test_dset))

dataset = dataset.repeat().shuffle(10000).batch(batch_size)
test_dset  = test_dset.repeat().batch(batch_size)

In [None]:
import matplotlib.pylab as plt
for x, y in dataset.take(1):
    print(x.shape, y.shape)
    for i in range(8):    
        plt.subplot(181+i)
        plt.imshow(x[i].numpy())
        plt.xticks([]);plt.yticks([])
        plt.title(y[i].numpy(), fontsize=18)
plt.show()

## Model

In [None]:
mnist_model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, [3, 3], activation='relu'),
    tf.keras.layers.Conv2D(64, [3, 3], activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
    tf.keras.layers.Dropout(0.25),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.5),
    tf.keras.layers.Dense(10, activation='softmax')
])

loss = tf.losses.SparseCategoricalCrossentropy()

opt = tf.optimizers.Adam(lr)

checkpoint_dir = './checkpoints/tf2_mnist'

checkpoint = tf.train.Checkpoint(model=mnist_model, optimizer=opt)


In [None]:
for x, y in dataset.take(1):
    probs = mnist_model(x)
    pred = tf.math.argmax(probs, axis=1)
    for i in range(8):    
        plt.subplot(181+i)
        plt.imshow(x[i].numpy())
        plt.xticks([]);plt.yticks([])
        plt.title("%s(%s)"%(pred[i].numpy(),y[i].numpy()), fontsize=18)
plt.show()

## Training 

In [None]:
@tf.function
def training_step(images, labels):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)
        pred = tf.math.argmax(probs, axis=1)
        equality = tf.math.equal(pred, labels)
        accuracy = tf.math.reduce_mean(tf.cast(equality, tf.float32))
    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))
    return loss_value, accuracy

@tf.function
def validation_step(images, labels):
    probs = mnist_model(images, training=False)
    pred = tf.math.argmax(probs, axis=1)
    equality = tf.math.equal(pred, labels)
    accuracy = tf.math.reduce_mean(tf.cast(equality, tf.float32))
    loss_value = loss(labels, probs)
    return loss_value, accuracy

In [None]:
nstep = nsamples//batch_size
ntest_step = ntests//batch_size

print("Number of training step: %d" %nstep)
print("Number of validation step: %d"%ntest_step)

In [None]:
import time
t0 = time.time()

metrics={}
metrics['train_acc'] = []
metrics['valid_acc'] = []
metrics['train_loss'] = []
metrics['valid_loss'] = []
metrics['time_per_epoch'] = []
for ep in range(epochs):
    training_loss = 0.0
    training_acc = 0.0
    tt0 = time.time()
    for batch, (images, labels) in enumerate(dataset.take(nstep)):
        loss_value, acc = training_step(images, labels)
        training_loss += loss_value/nstep
        training_acc += acc/nstep
        
        if batch % 100 == 0:
            checkpoint.save(checkpoint_dir)
            print('Epoch - %d, step #%06d/%06d\tLoss: %.6f' % (ep, batch, nstep, loss_value))
    # Testing                                                                                                                                    
    test_acc = 0.0
    test_loss = 0.0
    for batch, (images, labels) in enumerate(test_dset.take(ntest_step)):
        loss_value, acc = validation_step(images, labels)
        test_acc += acc/ntest_step
        test_loss += loss_value/ntest_step
    tt1 = time.time()
    print('Epoch - %d, train Loss: %.6f, train Acc: %.6f, val loss: %.6f, val Acc: %.6f %.6f seconds' % (ep, training_loss,\
 training_acc, test_loss, test_acc, tt1 - tt0))
    metrics['train_acc'].append(training_acc)
    metrics['train_loss'].append(training_loss)
    metrics['valid_acc'].append(test_acc)
    metrics['valid_loss'].append(test_loss)
    metrics['time_per_epoch'].append(tt1 - tt0)
checkpoint.save(checkpoint_dir)
t1 = time.time()
print("Total training time: %s seconds" %(t1 - t0))
np.savetxt("metrics.dat", np.array([metrics['train_acc'], metrics['train_loss'], metrics['valid_acc'], metrics['valid_loss\
'], metrics['time_per_epoch']]).transpose())

In [None]:
import matplotlib.pylab as plt
plt.plot(metrics['train_loss'], label='Training')
plt.plot(metrics['valid_loss'], label="Validation")
plt.ylabel("loss")
plt.xlabel("Epochs")
plt.legend()
plt.show()

import matplotlib.pylab as plt
plt.plot(metrics['train_acc'], label='Training')
plt.plot(metrics['valid_acc'], label="Validation")
plt.ylabel("Accuracy")
plt.xlabel("Epochs")
plt.legend()
plt.show()

In [None]:
import matplotlib.pylab as plt
for x, y in dataset.take(1):
    print(x.shape, y.shape)
    probs = mnist_model(x)
    pred = tf.math.argmax(probs, axis=1)
    for i in range(8):    
        plt.subplot(181+i)
        plt.imshow(x[i].numpy())
        plt.xticks([]);plt.yticks([])
        plt.title("%s(%s)"%(pred[i].numpy(),y[i].numpy()), fontsize=18)
plt.show()