In [43]:
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

In [71]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train,x_test = x_train / 255.0, x_test / 255.0

In [72]:
# Add a channels dimension
x_train = x_train[...,tf.newaxis]
x_test = x_test[...,tf.newaxis]

In [73]:
train_ds = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)).shuffle(10000).batch(32)
test_ds = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(32)

In [74]:
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.conv1 = Conv2D(32, 3, activation='relu')
        self.flatten = Flatten()
        self.d1 = Dense(128, activation='relu')
        self.d2 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.flatten(x)
        x = self.d1(x)
        return self.d2(x)

# Create an instance of the model
model = MyModel()

In [75]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

In [76]:
# 衡量损失和精度
train_loss = tf.keras.metrics.Mean(name = 'train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name = 'train_accuracy')

test_loss = tf.keras.metrics.Mean(name = 'test_loss')
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name ='test_accuracy')

In [77]:
@tf.function
def train_step(images, labels):
    with tf.GradientTape() as tape:
        predictions = model(images)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss,model.trainable_variables)
    optimizer.apply_gradients(zip(gradients,model.trainable_variables))
    
    train_loss(loss)
    train_accuracy(labels,predictions)
 

In [78]:
@tf.function
def test_step(images, labels):
    predictions = model(images)
    t_loss = loss_object(labels,predictions)
    
    test_loss(t_loss)
    test_accuracy(labels, predictions)

In [80]:
EPOCHS = 5

for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(images,labels)
        
    for test_images, test_labels in test_ds:
        test_step(test_images,test_labels)
        
    template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
    print(template.format(epoch+1,
                         train_loss.result(),
                         train_accuracy.result()*100,
                         test_loss.result(),
                         test_accuracy.result()*100))
    
    train_loss.reset_states()
    train_accuracy.reset_states()
    test_loss.reset_states()
    test_accuracy.reset_states()

Epoch 1, Loss: 0.09387125074863434, Accuracy: 97.1630630493164, Test Loss: 0.049626827239990234, Test Accuracy: 98.4000015258789
Epoch 2, Loss: 0.022229477763175964, Accuracy: 99.32500457763672, Test Loss: 0.06010788679122925, Test Accuracy: 98.0999984741211
Epoch 3, Loss: 0.013563843443989754, Accuracy: 99.5616683959961, Test Loss: 0.05504993721842766, Test Accuracy: 98.41999816894531
Epoch 4, Loss: 0.009449983946979046, Accuracy: 99.68167114257812, Test Loss: 0.06335621327161789, Test Accuracy: 98.30999755859375
Epoch 5, Loss: 0.00841258279979229, Accuracy: 99.70832824707031, Test Loss: 0.06248476728796959, Test Accuracy: 98.48999786376953
