# VGGNet
Training example using MNIST dataset

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
data_builder = tfds.builder("cifar10")
data_builder.download_and_prepare()

In [3]:
train_dataset = data_builder.as_dataset(split=tfds.Split.TRAIN)
val_dataset = data_builder.as_dataset(split=tfds.Split.TEST)

num_classes = data_builder.info.features['label'].num_classes

num_train = data_builder.info.splits['train'].num_examples
num_val = data_builder.info.splits['test'].num_examples

print('# for train : %d'%(num_train))
print('# for valid : %d'%(num_val))

# for train : 50000
# for valid : 10000


In [4]:
input_shape = [224, 224, 3]

batch_size = 32
num_epochs = 300

## Create model

In [5]:
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

import functools

In [6]:
Conv3_64 = functools.partial(Conv2D,
                            filters=64,
                            kernel_size=(3, 3),
                            padding='same',
                            activation='relu')

Conv3_128 = functools.partial(Conv2D,
                             filters=128,
                             kernel_size=(3, 3),
                             padding='same',
                             activation='relu')

Conv3_256 = functools.partial(Conv2D,
                             filters=256,
                             kernel_size=(3, 3),
                             padding='same',
                             activation='relu')

Conv3_512 = functools.partial(Conv2D,
                             filters=512,
                             kernel_size=(3, 3),
                             padding='same',
                             activation='relu')

In [7]:
def VGGNet(input_shape, num_classes, model_type: 16 or 19 =16):
    input = Input(shape=input_shape, name='Input')
    
    x = Conv3_64(name='block1_conv1')(input)
    x = Conv3_64(name='block1_conv2')(x)
    x = MaxPooling2D(pool_size=2, padding='same', name='block1_pool')(x)
    
    x = Conv3_128(name='block2_conv1')(x)
    x = Conv3_128(name='block2_conv2')(x)
    x = MaxPooling2D(pool_size=2, padding='same', name='block2_pool')(x)
    
    x = Conv3_256(name='block3_conv1')(x)
    x = Conv3_256(name='block3_conv2')(x)
    x = Conv3_256(name='block3_conv3')(x)
    if model_type == 19:
        x = Conv3_256(name='block3_conv4')(x)
    x = MaxPooling2D(pool_size=2, padding='same', name='block3_pool')(x)
    
    x = Conv3_512(name='block4_conv1')(x)
    x = Conv3_512(name='block4_conv2')(x)
    x = Conv3_512(name='block4_conv3')(x)
    if model_type == 19:
        x = Conv3_512(name='block4_conv4')(x)
    x = MaxPooling2D(pool_size=2, padding='same', name='block4_pool')(x)
    
    x = Conv3_512(name='block5_conv1')(x)
    x = Conv3_512(name='block5_conv2')(x)
    x = Conv3_512(name='block5_conv3')(x)
    if model_type == 19:
        x = Conv3_512(name='block5_conv4')(x)
    x = MaxPooling2D(pool_size=2, padding='same', name='block5_pool')(x)
    
    x = Flatten(name='flatten')(x)
    x = Dense(4096, activation='relu', name='fc1')(x)
    x = Dense(4096, activation='relu', name='fc2')(x)    
    output = Dense(num_classes, activation='softmax', name='ouput')(x)

    model = Model(inputs=input, outputs=output)
    if model_type == 19:
        model._name = 'VGG19'
    else:
        model._name = 'VGG16'
    
    return model

In [8]:
def VGG16(input_shape, num_classes):
    return VGGNet(input_shape=input_shape,
                 num_classes=num_classes,
                 model_type=16)

def VGG19(input_shape, num_classes):
    return VGGNet(input_shape=input_shape,
                 num_classes=num_classes,
                 model_type=19)

In [9]:
batch_input_shape = tf.TensorShape((None, *input_shape))

In [10]:
model = VGG16(input_shape, num_classes)
model.build(input_shape=batch_input_shape)
model.summary()

Model: "VGG16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 Input (InputLayer)          [(None, 224, 224, 3)]     0         
                                                                 
 block1_conv1 (Conv2D)       (None, 224, 224, 64)      1792      
                                                                 
 block1_conv2 (Conv2D)       (None, 224, 224, 64)      36928     
                                                                 
 block1_pool (MaxPooling2D)  (None, 112, 112, 64)      0         
                                                                 
 block2_conv1 (Conv2D)       (None, 112, 112, 128)     73856     
                                                                 
 block2_conv2 (Conv2D)       (None, 112, 112, 128)     147584    
                                                                 
 block2_pool (MaxPooling2D)  (None, 56, 56, 128)       0     

In [11]:
# model = VGG19(input_shape, num_classes)
# model.build(input_shape=batch_input_shape)
# model.summary()

## Prepare training dataset

In [12]:
def prepare_data_fn(features, input_shape, augment=False):
    
    input_shape = tf.convert_to_tensor(input_shape)
    
    image = features['image']
    label = features['label']
    image = tf.image.convert_image_dtype(image, tf.float32)
    
    if augment:
        image = tf.image.random_flip_left_right(image)
        
        image = tf.image.random_brightness(image, max_delta=0.1)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.clip_by_value(image, 0.0, 1.0)
        
        random_scale_factor = tf.random.uniform([1], minval=1., maxval=1.4, dtype=tf.float32)
        scaled_height = tf.cast(tf.cast(input_shape[0], tf.float32) * random_scale_factor, tf.int32)
        scaled_width = tf.cast(tf.cast(input_shape[1], tf.float32) * random_scale_factor, tf.int32)
        scaled_shape = tf.squeeze(tf.stack([scaled_height, scaled_width]))
        image = tf.image.resize(image, scaled_shape)
        image = tf.image.random_crop(image, input_shape)
    else:
        image = tf.image.resize(image, input_shape[:2])
    return image, label

In [13]:
import functools

prepare_data_fn_for_train = functools.partial(prepare_data_fn,
                                             input_shape=input_shape,
                                             augment=True)
prepare_data_fn_for_val = functools.partial(prepare_data_fn,
                                           input_shape=input_shape,
                                           augment=False)

train_dataset = train_dataset.repeat(num_epochs) \
                    .shuffle(10000) \
                    .map(prepare_data_fn_for_train, num_parallel_calls=4) \
                    .batch(batch_size) \
                    .prefetch(1)

val_dataset = val_dataset.repeat() \
                .map(prepare_data_fn_for_val, num_parallel_calls=4) \
                .batch(batch_size) \
                .prefetch(1)

## Training

In [16]:
import tensorflow_addons as tfa

In [18]:
# optimizer = tf.keras.optimizers.Adam()
optimizer = tfa.optimizers.SGDW(learning_rate=0.01,momentum=0.9,
            weight_decay=0.0005,nesterov=True)

accuracy_metric = tf.metrics.SparseCategoricalAccuracy(name='acc')
top5_accuracy_metric = tf.metrics.SparseTopKCategoricalAccuracy(k=5, name='top5_acc')

In [19]:
model.compile(optimizer=optimizer,
             loss='sparse_categorical_crossentropy',
             metrics=[accuracy_metric, top5_accuracy_metric])

In [20]:
import os

In [21]:
model_dir = './models/vggnet'

callbacks = [
    tf.keras.callbacks.EarlyStopping(patience=8, monitor='val_loss',
                                    restore_best_weights=True),
    
    tf.keras.callbacks.TensorBoard(log_dir=model_dir, histogram_freq=0, write_graph=True),
    
    tf.keras.callbacks.ModelCheckpoint(
        os.path.join(model_dir, 'weights-epoch{epoch:02d}.h5'))
]

In [None]:
import math

train_steps_per_epoch = math.ceil(num_train / batch_size)
val_steps_per_epoch = math.ceil(num_val / batch_size)

history = model.fit(train_dataset,
                        epochs=num_epochs, steps_per_epoch=train_steps_per_epoch,
                        validation_data=(val_dataset),
                        validation_steps=val_steps_per_epoch,
                        verbose=1, callbacks=callbacks)

Epoch 1/300
  54/1563 [>.............................] - ETA: 8:12 - loss: 2.3022 - acc: 0.1042 - top5_acc: 0.5023