In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf
import os
tfds.disable_progress_bar()

## Download dataset

In [3]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)



In [4]:
info

tfds.core.DatasetInfo(
    name='mnist',
    version=1.0.0,
    description='The MNIST database of handwritten digits.',
    urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'],
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)

In [5]:
mnist_train, mnist_test = datasets['train'], datasets['test']

## Define distribution strategy

In [6]:
strategy = tf.distribute.MirroredStrategy()





INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [7]:
strategy.num_replicas_in_sync

1

In [8]:
num_train = info.splits['train'].num_examples
num_test = info.splits['test'].num_examples

In [9]:
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

In [10]:
def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.
    return image, label

In [11]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

## Creat the model

In [25]:
with strategy.scope():
    model = tf.keras.models.Sequential([
        tf.keras.layers.Conv2D(filters=32, kernel_size=3, activation='relu', input_shape=(28,28,1)),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10)
    ])
    
    model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
                  optimizer=tf.keras.optimizers.Adam(),
                  metrics=['accuracy'])


In [26]:
checkpoint_dir = '../logs/train_checkpoint'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt_{epoch}')

In [27]:
def decay(epoch):
    if epoch < 3:
        return 0.001
    elif epoch >= 3 and epoch < 7:
        return 0.0001
    else:
        0.00001

In [28]:
class PrintLR(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f'\nLearn rate for {epoch+1} is {model.optimizer.lr.numpy():.6f}')

In [33]:
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='../logs'),
    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,
                                       save_weights_only=True),
    tf.keras.callbacks.LearningRateScheduler(decay),
    tf.keras.callbacks.EarlyStopping(monitor='loss', min_delta=1e-4, patience=3),
    PrintLR()
]

## Train and Evaluate

In [34]:
model.fit(train_dataset, epochs=12, callbacks=callbacks)

Epoch 1/12
    938/Unknown - 21s 23ms/step - loss: 0.0204 - accuracy: 0.9928
Learn rate for 1 is 0.001000
Epoch 2/12
Learn rate for 2 is 0.001000
Epoch 3/12
Learn rate for 3 is 0.001000
Epoch 4/12
Learn rate for 4 is 0.000100
Epoch 5/12
Learn rate for 5 is 0.000100
Epoch 6/12
Learn rate for 6 is 0.000100
Epoch 7/12
Learn rate for 7 is 0.000100


ValueError: The output of the "schedule" function should be float.