In [1]:
import tensorflow as tf

In [2]:
(train_image, train_label), (test_image, test_label) = tf.keras.datasets.mnist.load_data()

In [3]:
train_image.shape, test_image.shape

((60000, 28, 28), (10000, 28, 28))

In [4]:
train_image = tf.expand_dims(train_image, -1)
test_image = tf.expand_dims(test_image, -1)

In [5]:
train_image = tf.cast(train_image/255, tf.float32)
test_image = tf.cast(test_image/255, tf.float32)

In [6]:
train_label = tf.cast(train_label, tf.int64)
test_label = tf.cast(test_label, tf.int64)

In [7]:
dataset_train = tf.data.Dataset.from_tensor_slices((train_image, train_label))
dataset_test = tf.data.Dataset.from_tensor_slices((test_image, test_label))

In [8]:
dataset_train, dataset_test

(<TensorSliceDataset shapes: ((28, 28, 1), ()), types: (tf.float32, tf.int64)>,
 <TensorSliceDataset shapes: ((28, 28, 1), ()), types: (tf.float32, tf.int64)>)

In [9]:
dataset_train = dataset_train.shuffle(60000).batch(128).repeat(1)
dataset_test = dataset_test.batch(128).repeat(1)

In [10]:
dataset_train, dataset_test

(<RepeatDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>,
 <RepeatDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float32, tf.int64)>)

In [11]:
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(16, [3,3], activation='relu', input_shape=(None, None, 1)),
    tf.keras.layers.Conv2D(32, [3,3], activation='relu'),
    tf.keras.layers.GlobalMaxPooling2D(),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [12]:
optimizer = tf.keras.optimizers.Adam() 

In [13]:
loss_func = tf.keras.losses.SparseCategoricalCrossentropy()

In [14]:
train_loss = tf.keras.metrics.Mean('train_loss')
train_acc = tf.keras.metrics.SparseCategoricalAccuracy('train_acc')
test_loss = tf.keras.metrics.Mean('test_loss')
test_acc = tf.keras.metrics.SparseCategoricalAccuracy('test_acc')

In [15]:
def train_step(model, image, label):
    with tf.GradientTape() as t:
        pred = model(image)
        loss_step = loss_func(label, pred)
    grad = t.gradient(loss_step, model.trainable_variables)
    optimizer.apply_gradients(zip(grad, model.trainable_variables))
    train_loss(loss_step)
    train_acc(label, pred)

In [16]:
def test_step(model, image, label):
    pred = model(image)
    loss_step = loss_func(label, pred)
    test_loss(loss_step)
    test_acc(label, pred)

In [17]:
import datetime

In [18]:
cur_time = datetime.datetime.now().strftime('%Y%m%d-%H%M%S')

In [19]:
train_log_dir = './logs/gradient_tape' + cur_time + '/train'
test_log_dir = './logs/gradient_tape' + cur_time + '/test'
train_writer = tf.summary.create_file_writer(train_log_dir)
test_writer = tf.summary.create_file_writer(test_log_dir)

In [20]:
def train(num):
    for epoch in range(1, num+1):
        for (batch, (image, label)) in enumerate(dataset_train):
            train_step(model, image, label)
        with train_writer.as_default():
            tf.summary.scalar('loss', train_loss.result(), step=epoch)
            tf.summary.scalar('acc', train_acc.result(), step=epoch)
            
        for (batch, (image, label)) in enumerate(dataset_test):
            test_step(model, image, label)
            print('*', end=' ')
        with test_writer.as_default():
            tf.summary.scalar('loss', test_loss.result(), step=epoch)
            tf.summary.scalar('acc', test_acc.result(), step=epoch)
            
        print('Epoch %d : loss is %f , accuracy is %f , test_loss is %f , test_accuracy is %f .' 
                  % (epoch, train_loss.result(), train_acc.result()*100, test_loss.result(), test_acc.result()*100))
        
        train_acc.reset_states()
        train_loss.reset_states()
        test_acc.reset_states()
        test_loss.reset_states()

In [21]:
train(10)

* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * Epoch 1 : loss is 1.416043 , accuracy is 57.665001 , test_loss is 0.818973 , test_accuracy is 74.070000 .
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * Epoch 2 : loss is 0.678305 , accuracy is 78.636665 , test_loss is 0.539241 , test_accuracy is 83.080002 .
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * Epoch 3 : loss is 0.526124 , accuracy is 83.350006 , test_loss is 0.474566 , test_accuracy is 84.400002 .
* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * Epoch 4 : loss is 0.454493 , accuracy is 85.368332