In [1]:
import tensorflow as tf
import numpy as np

In [2]:
# 数据获取及预处理
class MNISTLoader():
    def __init__(self):
        mnist = tf.keras.datasets.mnist
        (self.train_data, self.train_label), (self.test_data, self.test_label) = mnist.load_data()
        # MNIST中的图像默认为uint8（0-255的数字）。以下代码将其归一化到0-1之间的浮点数，并在最后增加一维作为颜色通道
        print(self.train_data.shape)
        self.train_data = np.expand_dims(self.train_data.astype(np.float32) / 255.0, axis=-1)      # [60000, 28, 28, 1]
        print(self.train_data.shape)
        self.test_data = np.expand_dims(self.test_data.astype(np.float32) / 255.0, axis=-1)        # [10000, 28, 28, 1]
        self.train_label = self.train_label.astype(np.int32)    # [60000]
        self.test_label = self.test_label.astype(np.int32)      # [10000]
        self.num_train_data, self.num_test_data = self.train_data.shape[0], self.test_data.shape[0]

    def get_batch(self, batch_size):
        # 从数据集中随机取出batch_size个元素并返回
        index = np.random.randint(0, np.shape(self.train_data)[0], batch_size)
        return self.train_data[index, :], self.train_label[index]

In [3]:
num_epochs = 1
batch_size = 100
learning_rate = 0.001

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(100, activation=tf.nn.relu),
    tf.keras.layers.Dense(10),
    tf.keras.layers.Softmax()
])
data_loader = MNISTLoader()
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

(60000, 28, 28)
(60000, 28, 28, 1)


In [4]:
num_batches = int(data_loader.num_train_data // batch_size * num_epochs)
for batch_index in range(num_batches):
    X, y = data_loader.get_batch(batch_size)
    with tf.GradientTape() as tape:
        y_pred = model(X)
        loss = tf.keras.losses.sparse_categorical_crossentropy(y_true=y, y_pred=y_pred)
        loss = tf.reduce_mean(loss)
        print("batch %d: loss %f" % (batch_index, loss.numpy()))
    grads = tape.gradient(loss, model.variables)
    optimizer.apply_gradients(grads_and_vars=zip(grads, model.variables))

batch 0: loss 2.388603
batch 1: loss 2.293674
batch 2: loss 2.186231
batch 3: loss 2.086780
batch 4: loss 2.045081
batch 5: loss 1.998972
batch 6: loss 1.872199
batch 7: loss 1.832567
batch 8: loss 1.783183
batch 9: loss 1.751624
batch 10: loss 1.685742
batch 11: loss 1.533611
batch 12: loss 1.443546
batch 13: loss 1.487695
batch 14: loss 1.356239
batch 15: loss 1.409653
batch 16: loss 1.215031
batch 17: loss 1.214518
batch 18: loss 1.089001
batch 19: loss 1.206711
batch 20: loss 1.163121
batch 21: loss 0.981417
batch 22: loss 1.039429
batch 23: loss 1.138160
batch 24: loss 0.872527
batch 25: loss 0.863287
batch 26: loss 0.899236
batch 27: loss 0.783926
batch 28: loss 0.875778
batch 29: loss 0.787400
batch 30: loss 0.737174
batch 31: loss 0.700911
batch 32: loss 0.807517
batch 33: loss 0.714060
batch 34: loss 0.688951
batch 35: loss 0.769232
batch 36: loss 0.544162
batch 37: loss 0.709931
batch 38: loss 0.692744
batch 39: loss 0.493657
batch 40: loss 0.650259
batch 41: loss 0.660416
ba

batch 357: loss 0.265880
batch 358: loss 0.228265
batch 359: loss 0.348219
batch 360: loss 0.176772
batch 361: loss 0.242972
batch 362: loss 0.268927
batch 363: loss 0.217994
batch 364: loss 0.376223
batch 365: loss 0.203373
batch 366: loss 0.271851
batch 367: loss 0.169179
batch 368: loss 0.436952
batch 369: loss 0.354993
batch 370: loss 0.495745
batch 371: loss 0.203798
batch 372: loss 0.171820
batch 373: loss 0.192341
batch 374: loss 0.369319
batch 375: loss 0.151433
batch 376: loss 0.113618
batch 377: loss 0.187428
batch 378: loss 0.260494
batch 379: loss 0.216251
batch 380: loss 0.143532
batch 381: loss 0.290597
batch 382: loss 0.182007
batch 383: loss 0.230474
batch 384: loss 0.438971
batch 385: loss 0.264026
batch 386: loss 0.166633
batch 387: loss 0.204959
batch 388: loss 0.215557
batch 389: loss 0.168520
batch 390: loss 0.173369
batch 391: loss 0.201375
batch 392: loss 0.167552
batch 393: loss 0.175329
batch 394: loss 0.209475
batch 395: loss 0.333485
batch 396: loss 0.182565


In [5]:
sparse_categorical_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
num_batches = int(data_loader.num_test_data // batch_size)
for batch_index in range(num_batches):
    start_index, end_index = batch_index * batch_size, (batch_index + 1) * batch_size
    y_pred = model.predict(data_loader.test_data[start_index: end_index])
    sparse_categorical_accuracy.update_state(y_true=data_loader.test_label[start_index: end_index], y_pred=y_pred)
print("test accuracy: %f" % sparse_categorical_accuracy.result())

test accuracy: 0.945500
