In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa

from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import numpy as np

from resnet import MNIST_ResNet
from train_utils import Model_Trainer

In [3]:
# Load the MNIST Data Set
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)


# Normalize Image to be floats in range [0,1]
def normalize_img(image, label):
    return tf.cast(image, tf.float32) / 255., label

# Shuffle the Data set and set training batch size 
ds_train = ds_train.map(normalize_img)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples, reshuffle_each_iteration=True)
ds_train = ds_train.batch(128)

# Initilialize the Test Training data set
ds_test = ds_test.map(normalize_img)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()

In [28]:
class MNIST_ResNet(Model):
    def __init__(self, input_shape, residual_block_params, n_classes=10, filter_n_0=64):
        '''
            Parameters
            ----------
                input_shape - int
                    Input dimensions of image data
                residual_block_params - dict
                    List of dict specifying # filters, # residual blocks, and stride for each residual block in the network.  
                    e.g. 
                        [{'n_filters': 16, 'block_depth':3, 'stride': 1}, {'n_filters': 32, 'block_depth':3, 'stride': 1}]
                N_Classes - int
                    Output dimension of the final softmax layer.
                filter_n_0 - int
                    Initial number of filters in the network prior to the Residual block layers.
        '''
        
        super(MNIST_ResNet, self).__init__()     
        
        self.n_classes = n_classes
        
        self.conv_1 = Conv2D(filters=filter_n_0, kernel_size=(5, 5), strides=2, padding="same")
        self.batch_norm_1 = BatchNormalization()
        self.maxpool_1 = MaxPool2D(pool_size=(3, 3), strides=2, padding="same")
        
        self.residual_blocks = []
        
        # Initialize the residual block layers using the parameter dictionaries.
        for param_dict in residual_block_params:
            self.residual_blocks.append(
                make_residual_block_layer(**param_dict)
            )

        self.avgpool = tf.keras.layers.GlobalAveragePooling2D()
        self.softmax = tf.keras.layers.Dense(units=self.n_classes, activation=Softmax)
                
        
    def call(self, x, training=None): 
        # Training is used for layers which utilize Batch Normalization.
        
        x = self.conv_1(inputs)
        x = self.batch_norm_1(x, training=training)
        x = tf.nn.relu(x)
        x = self.maxpool_1(x)
        
        # Pass through residual blocks
        for residual_block in self.residual_blocks:
            x = self.layer1(x, training=training)
            
        x = self.avgpool(x)
        output = self.softmax(x)

        return output
    
class ResidualBlock(tf.keras.layers.Layer):

    def __init__(self, n_filters, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv_1 = Conv2D(filters=n_filters, kernel_size=(3, 3), strides=stride, padding="same")
        self.batch_norm_1 = BatchNormalization()
        
        self.conv_2 = Conv2D(filters=n_filters, kernel_size=(3, 3), strides=1, padding="same")
        self.batch_norm_2 = BatchNormalization()
        
        # This is done for layers which reduce the filter dimensions. 
        # Use 1x1 convolutions when downsampling, and identity map otherwise.
        if stride != 1:
            self.downsample = tf.keras.Sequential()
            self.downsample.add(Conv2D(filters=n_filters, kernel_size=(1, 1), strides=stride))
            self.downsample.add(BatchNormalization())
        else:
            self.downsample = lambda x: x

    def call(self, inputs, training=None, **kwargs):
        residual = self.downsample(inputs)

        x = self.conv_1(inputs)
        x = self.batch_norm_1(x, training=training)
        x = tf.nn.relu(x)
        
        x = self.conv_2(x)
        x = self.batch_norm_2(x, training=training)

        output = tf.nn.relu(tf.keras.layers.add([residual, x]))
        return output
    
def make_residual_block_layer(n_filters, block_depth, stride=1):
    # Define a miniture network which is composed of sequential residual blocks with the same # of filters 
    res_block = tf.keras.Sequential()
    res_block.add(ResidualBlock(n_filters, stride=stride))

    for _ in range(block_depth):
        res_block.add(ResidualBlock(n_filters, stride=1))

    return res_block
    
    
class Model_Train:
    # Training Wrapper For Tensorflow Models. Allows a predifined model to be easily trained
    # while also tracking parameter and gradient information.
    
    def __init__(self, Model):
                
        self.lr = 5e-4
        self.n_classes = Model.n_classes       
        
        self.model = Model
        self.init_loss()
        self.init_optimizer()
        
        # Used to save the parameters of the model at a given point of time.
        self.checkpoint = tf.train.Checkpoint(self.model)
        self.checkpoint_path = self.model.__class__.__name__ + "/training_checkpoints"
        
        self.gradients
        
    
    #initialize loss function and metrics to track over training
    def init_loss(self):
        self.loss_function = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        self.train_loss = tf.keras.metrics.Mean(name='train_loss')
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
        self.train_confusion = tfa.metrics.MultiLabelConfusionMatrix(num_classes=self.n_classes, name='train_confusion_matrix')

        self.test_loss = tf.keras.metrics.Mean(name='test_loss')
        self.test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
        self.test_confusion = tfa.metrics.MultiLabelConfusionMatrix(num_classes=self.n_classes, name='test_confusion_matrix')
        

    # Initialize Model optimizer
    def init_optimizer(self):
        self.optimizer = tf.keras.optimizers.Adam(learning_rate=self.lr)
    
    # Take a single Training step on the given batch of training data.
    @tf.function
    def train_step(self, images, labels, track_gradient=False):
        with tf.GradientTape() as gtape:
            predictions = self.model(images, training=True)
            loss = self.loss_function(labels, predictions)
            
        gradients = gtape.gradient(loss, self.trainable_variables)
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # Track Gradient Information
        
        # Track model Performance
        self.train_loss(loss)
        self.train_accuracy(labels, predictions)
        self.train_confusion(labels, predictions)
        
        return self.train_loss.result(), self.train_accuracy.result()*100, self.train_confusion.result()
    
    # Evaluate Model on Test Data
    @tf.function
    def test_step(self, images, labels):
        predictions = self.model.predict(images)
        test_loss = self.loss_function(labels, predictions)
        
        self.test_loss(test_loss)
        self.test_accuracy(labels, predictions) 
        self.test_confusion(labels, predictions)
        
        return self.test_loss.result(), self.test_accuracy.result()*100, self.test_confusion.result()
        
    # Reset Metrics 
    @tf.function
    def reset(self):
        self.train_loss.reset_states()
        self.train_accuracy.reset_states()
        self.train_confusion.reset()
        
        self.test_loss.reset_states()
        self.test_accuracy.reset_states()
        self.test_confusion.reset()
        
    # Save a checkpoint instance of the model for later use
    def model_checkpoint(self):
        # Save a checkpoint to self.checkpoint_path-{save_counter}. Every time
        # checkpoint.save is called, the save counter is increased.
        save_path = checkpoint.save(self.checkpoint_path)
        return save_path

In [6]:
image_dim = ds_info.features['image'].shape
n_classes = ds_info.features['label'].num_classes 
filter_n_0 = 16
residual_block_params = [
    {'n_filters': 16, 'block_depth':2, 'stride': 1},
    {'n_filters': 32, 'block_depth':2, 'stride': 1}
]

base_model = MNIST_ResNet(
    input_shape=image_dim,
    residual_block_params=residual_block_params,
    n_classes=n_classes,
    filter_n_0=filter_n_0
)
trainer = Model_Trainer(base_model)

n_epochs = 5

for epoch in tqdm(range(n_epochs)):
    trainer.reset()
    
    for images, labels in ds_train:
        train_loss, train_accuracy, train_confusion = trainer.train_step(images, labels)
    
    for images, labels in ds_test:
        test_loss, test_accuracy, test_confusion = trainer.test_step(images, labels)
    
    template = 'Epoch {} - Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch, train_loss, train_accuracy, test_loss, test_accuracy))
    
    trainer.model_checkpoint()
    trainer.log_metrics(epoch)

In [8]:
base_model.trainable_variables

ValueError: Weights for model sequential have not yet been created. Weights are created when the Model is first called on inputs or `build()` is called with an `input_shape`.