# VGGNet
Training example using MNIST dataset

In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
data_builder = tfds.builder("cifar100")
data_builder.download_and_prepare()

In [3]:
train_dataset = data_builder.as_dataset(split=tfds.Split.TRAIN)
val_dataset = data_builder.as_dataset(split=tfds.Split.TEST)

num_classes = data_builder.info.features['label'].num_classes

num_train = data_builder.info.splits['train'].num_examples
num_val = data_builder.info.splits['test'].num_examples

print('# for train : %d'%(num_train))
print('# for valid : %d'%(num_val))

# for train : 50000
# for valid : 10000


In [4]:
input_shape = [224, 224, 3]

batch_size = 32
num_epochs = 300

In [5]:
def prepare_data_fn(features, input_shape, augment=False):
    
    input_shape = tf.convert_to_tensor(input_shape)
    
    image = features['image']
    label = features['label']
    image = tf.image.convert_image_dtype(image, tf.float32)
    
    if augment:
        image = tf.image.random_flip_left_right(image)
        
        image = tf.image.random_brightness(image, max_delta=0.1)
        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
        image = tf.clip_by_value(image, 0.0, 1.0)
        
        random_scale_factor = tf.random.uniform([1], minval=1., maxval=1.4, dtype=tf.float32)
        scaled_height = tf.cast(tf.cast(input_shape[0], tf.float32) * random_scale_factor, tf.int32)
        scaled_width = tf.cast(tf.cast(input_shape[1], tf.float32) * random_scale_factor, tf.int32)
        scaled_shape = tf.squeeze(tf.stack([scaled_height, scaled_width]))
        image = tf.image.resize(image, scaled_shape)
        image = tf.image.random_crop(image, input_shape)
    else:
        image = tf.image.resize(image, input_shape[:2])
    return image, label

In [6]:
import functools

prepare_data_fn_for_train = functools.partial(prepare_data_fn,
                                             input_shape=input_shape,
                                             augment=True)
prepare_data_fn_for_val = functools.partial(prepare_data_fn,
                                           input_shape=input_shape,
                                           augment=False)

train_dataset = train_dataset.repeat(num_epochs) \
                    .shuffle(10000) \
                    .map(prepare_data_fn_for_train, num_parallel_calls=4) \
                    .batch(batch_size) \
                    .prefetch(1)

val_dataset = val_dataset.repeat() \
                .map(prepare_data_fn_for_val, num_parallel_calls=4) \
                .batch(batch_size) \
                .prefetch(1)