In [None]:
import tensorflow as tf

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

# 数据归一化
train_images = train_images / 255.0
test_images = test_images / 255.0

# 对数据进行了分批次的处理，批次的大小维 64 
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels))
# 我们对训练数据进行了乱序处理
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(64)

valid_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels))
valid_dataset = valid_dataset.batch(64)

'''
2. 自定义模型
    由于进行 Mnist 图像分类的任务比较简单，因此我们可以定义一个较为简单的模型，这里的模型的结构包含四层：
    -Flattern 层：对二维数据进行展开；
    -第一个 Dense 层：包含 128 个神经元；
    -第二个 Dense 层：包含 64 个神经元；
    -最后一个 Dense 分类层；包含 10 个神经元，对应于我们的十个分类。
'''
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.l1 = tf.keras.layers.Flatten()
        self.l2 = tf.keras.layers.Dense(128, activation='relu')
        self.l3 = tf.keras.layers.Dense(64, activation='relu')
        self.l4 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, inputs, training=True):
        x = self.l1(inputs)
        x = self.l2(x)
        x = self.l3(x)
        y = self.l4(x)
        return y
model = MyModel()


In [None]:
# 损失函数
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
# 优化器
optimizer = tf.keras.optimizers.Adam()
# 监控验证机上的准确率
val_acc = tf.keras.metrics.SparseCategoricalAccuracy()

In [None]:
epochs = 3
for epoch in range(epochs):
    print("Start Training epoch " + str(epoch))
    
    # 取出每一个批次的数据
    for batch_i, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        # 在梯度带内进行操作
        with tf.GradientTape() as tape:
          outputs = model(x_batch_train, training=True)
          loss_value = loss_fn(y_batch_train, outputs)

        # 求取梯度
        grads = tape.gradient(loss_value, model.trainable_weights)
        # 使用Optimizer进行优化
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Log
        if batch_i % 100 == 0:
            print("Loss at batch %d: %.4f" % (batch_i, float(loss_value)))

    # 在验证集合上测试
    for batch_i, (x_batch_train, y_batch_train) in enumerate(valid_dataset):
        outputs = model(x_batch_train, training=False)
        # 更新追踪器的状态
        val_acc.update_state(y_batch_train, outputs)
    print("Validation acc: %.4f" % (float(val_acc.result()),))

    # 重置追踪器
    val_acc.reset_states()
