In [5]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.optimizers.schedules import PolynomialDecay

from resnet import MNIST_ResNet
from train_utils import Model_Trainer

import matplotlib.pyplot as plt
import numpy as np

In [7]:
# Load the MNIST Data Set
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)


# Normalize Image to be floats in range [0,1]
def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255., label

# Shuffle the Data set and set training batch size 
ds_train = ds_train.map(normalize_img)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples, reshuffle_each_iteration=True)
ds_train = ds_train.batch(128)

# Initilialize the Test Training data set
ds_test = ds_test.map(normalize_img)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()

In [8]:
image_dim = ds_info.features['image'].shape
n_classes = ds_info.features['label'].num_classes 
filter_n_0 = 16
block_depth = 2
# k is a width scaling parameter - when we down sample (1 < stride), increase number of features by a factor of 2k
k = 1 

residual_block_params = [
    {'n_filters': 16, 'block_depth': block_depth, 'stride': 1},
    {'n_filters': 32, 'block_depth': block_depth, 'stride': 2}
]

depth = block_depth * len(residual_block_params) + 1

resnet = MNIST_ResNet(
    input_shape=image_dim,
    residual_block_params=residual_block_params,
    n_classes=n_classes,
    filter_n_0=filter_n_0
)

# Model_id should identify the the model set up. 
# i.e. effective depth, width scaling, initial layer width 
model_id = 'convolution_depth_{depth}_width_scale_{k}_filters_{filter_n_0}'


starter_learning_rate = 1e-3
end_learning_rate = 1e-4
decay_steps = 500
lr_schedule = PolynomialDecay(
    starter_learning_rate,
    decay_steps,
    end_learning_rate,
    power=0.5)

trainer = Model_Trainer(
    model=resnet, 
    lr=lr_schedule,
    model_id=model_id
)

n_epochs = 5
steps = 0

for epoch in range(n_epochs):
    trainer.reset()
    batch_count = 0
    
    for images, labels in ds_train:
        train_loss, train_accuracy = trainer.train_step(images, labels)
        
        if batch_count % 25 == 0:
            template = 'Epoch {}, Batch {} - Train Loss: {:.4f}, Train Accuracy: {:.4f}'
            print(template.format(epoch + 1, batch_count, train_loss, train_accuracy))
        batch_count += 1
        steps += 1
    
    for images, labels in ds_test:
        test_loss, test_accuracy = trainer.test_step(images, labels)
    
    template = 'Epoch {} - Train Loss: {:.4f}, Train Accuracy: {:.4f}, Test Loss: {:.4f}, Test Accuracy: {:.4f}'
    print(template.format(epoch + 1, train_loss, train_accuracy, test_loss, test_accuracy))
    
    trainer.model_checkpoint()
    trainer.log_metrics()

Epoch 1, Batch 0 - Train Loss: 2.3207, Train Accuracy: 8.5938
Epoch 1, Batch 25 - Train Loss: 2.1428, Train Accuracy: 31.5805
Epoch 1, Batch 50 - Train Loss: 1.9781, Train Accuracy: 52.5889
Epoch 1, Batch 75 - Train Loss: 1.8558, Train Accuracy: 65.6044
Epoch 1, Batch 100 - Train Loss: 1.7810, Train Accuracy: 72.7336
Epoch 1, Batch 125 - Train Loss: 1.7296, Train Accuracy: 77.4616
Epoch 1, Batch 150 - Train Loss: 1.6941, Train Accuracy: 80.5878
Epoch 1, Batch 175 - Train Loss: 1.6682, Train Accuracy: 82.8569
Epoch 1, Batch 200 - Train Loss: 1.6478, Train Accuracy: 84.6743
Epoch 1, Batch 225 - Train Loss: 1.6311, Train Accuracy: 86.1069
Epoch 1, Batch 250 - Train Loss: 1.6181, Train Accuracy: 87.2012
Epoch 1, Batch 275 - Train Loss: 1.6070, Train Accuracy: 88.1425
Epoch 1, Batch 300 - Train Loss: 1.5976, Train Accuracy: 88.9457
Epoch 1, Batch 325 - Train Loss: 1.5895, Train Accuracy: 89.6137
Epoch 1, Batch 350 - Train Loss: 1.5821, Train Accuracy: 90.2444
Epoch 1, Batch 375 - Train Loss

In [11]:
base_model.summary()

Model: "mnist__res_net_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d_42 (Conv2D)           multiple                  416       
_________________________________________________________________
batch_normalization_42 (Batc multiple                  64        
_________________________________________________________________
max_pooling2d_3 (MaxPooling2 multiple                  0         
_________________________________________________________________
sequential_9 (Sequential)    (None, 14, 14, 16)        14304     
_________________________________________________________________
sequential_10 (Sequential)   (None, 7, 7, 32)          52320     
_________________________________________________________________
global_average_pooling2d_3 ( multiple                  0         
_________________________________________________________________
flatten_3 (Flatten)          multiple             