In [1]:
import os 
import tensorflow as tf 
import tensorflow_datasets as tfds 
tfds.disable_progress_bar()

In [4]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)

In [5]:
datasets

{'test': <DatasetV1Adapter shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>,
 'train': <DatasetV1Adapter shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>}

In [6]:
mnist_train, mnist_test = datasets['train'], datasets['test']

In [7]:
mnist_train

<DatasetV1Adapter shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>

In [8]:
info

tfds.core.DatasetInfo(
    name='mnist',
    version=3.0.0,
    description='The MNIST database of handwritten digits.',
    homepage='http://yann.lecun.com/exdb/mnist/',
    features=FeaturesDict({
        'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
        'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
    }),
    total_num_examples=70000,
    splits={
        'test': 10000,
        'train': 60000,
    },
    supervised_keys=('image', 'label'),
    citation="""@article{lecun2010mnist,
      title={MNIST handwritten digit database},
      author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
      journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
      volume={2},
      year={2010}
    }""",
    redistribution_info=,
)

In [9]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


In [10]:
strategy.num_replicas_in_sync

1

In [11]:
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples

BUFFER_SIZE = 10000

BATCH_SIZE_PER_REPLICA = 64
BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

In [12]:
BATCH_SIZE

64

In [18]:
def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255

    return image, label


In [19]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

In [22]:
next(iter(train_dataset))

(<tf.Tensor: shape=(64, 28, 28, 1), dtype=float32, numpy=
 array([[[[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         ...,
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]]],
 
 
        [[[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          ...,
          [0.],
          [0.],
          [0.]],
 
         [[

In [24]:
with strategy.scope():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),
        tf.keras.layers.MaxPool2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(
        loss='sparse_categorical_crossentropy',
        optimizer = tf.keras.optimizers.Adam(), 
        metrics=['accuracy']
    )

In [25]:
checkpoint_dir = './training_checkpoints'

checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt_{epoch}')

In [26]:
checkpoint_prefix

'./training_checkpoints/ckpt_{epoch}'

In [27]:
def decay(epoch):
    if epoch < 3:
        return 1e-3
    elif epoch >=3 and epoch < 7:
        return 1e-4
    else:
        return 1e-5

In [28]:
class PrintLR(tf.keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print('\nLearning rate for epoch {} is {}'.format(
            epoch + 1,
            model.optimizer.lr.numpy()
        ))

In [29]:
callbacks = [
    tf.keras.callbacks.TensorBoard(log_dir='./logs'),
    tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_prefix,
        save_weights_only=True
    ),
    tf.keras.callbacks.LearningRateScheduler(decay),
    PrintLR()
]

In [30]:
model.fit(train_dataset, epochs=12, callbacks=callbacks)

Epoch 1/12
    938/Unknown - 10s 11ms/step - loss: 0.2062 - accuracy: 0.9398
Learning rate for epoch 1 is 0.0010000000474974513
Epoch 2/12
Learning rate for epoch 2 is 0.0010000000474974513
Epoch 3/12
Learning rate for epoch 3 is 0.0010000000474974513
Epoch 4/12
Learning rate for epoch 4 is 9.999999747378752e-05
Epoch 5/12
Learning rate for epoch 5 is 9.999999747378752e-05
Epoch 6/12
Learning rate for epoch 6 is 9.999999747378752e-05
Epoch 7/12
Learning rate for epoch 7 is 9.999999747378752e-05
Epoch 8/12
Learning rate for epoch 8 is 9.999999747378752e-06
Epoch 9/12
Learning rate for epoch 9 is 9.999999747378752e-06
Epoch 10/12
Learning rate for epoch 10 is 9.999999747378752e-06
Epoch 11/12
Learning rate for epoch 11 is 9.999999747378752e-06
Epoch 12/12
Learning rate for epoch 12 is 9.999999747378752e-06


<tensorflow.python.keras.callbacks.History at 0x7f1faa6dd2e8>

In [31]:
!ls {checkpoint_dir}

checkpoint		     ckpt_4.data-00000-of-00001
ckpt_10.data-00000-of-00001  ckpt_4.index
ckpt_10.index		     ckpt_5.data-00000-of-00001
ckpt_11.data-00000-of-00001  ckpt_5.index
ckpt_11.index		     ckpt_6.data-00000-of-00001
ckpt_12.data-00000-of-00001  ckpt_6.index
ckpt_12.index		     ckpt_7.data-00000-of-00001
ckpt_1.data-00000-of-00001   ckpt_7.index
ckpt_1.index		     ckpt_8.data-00000-of-00001
ckpt_2.data-00000-of-00001   ckpt_8.index
ckpt_2.index		     ckpt_9.data-00000-of-00001
ckpt_3.data-00000-of-00001   ckpt_9.index
ckpt_3.index


In [32]:
model.load_weights(tf.train.latest_checkpoint(checkpoint_dir))

<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f1fa83017f0>

In [35]:
eval_loss, eval_acc = model.evaluate(eval_dataset)

    157/Unknown - 1s 5ms/step - loss: 0.0386 - accuracy: 0.9869

In [36]:
path = 'saved_model/'

In [37]:
tf.keras.experimental.export_saved_model(model, path)

AttributeError: module 'tensorflow_core.keras.experimental' has no attribute 'export_saved_model'