In [1]:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models

#加载MNIST数据集
mnist = datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

#数据归一化并调整形状
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = tf.reshape(x_train, [-1, 28, 28, 1])
x_test = tf.reshape(x_test, [-1, 28, 28, 1])

y_train = tf.one_hot(y_train, depth=10)
y_test = tf.one_hot(y_test, depth=10)

#构建CNN模型
class CNNModel(tf.keras.Model):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.conv1 = layers.Conv2D(32, (7, 7), activation='relu', padding='same')
        self.pool1 = layers.MaxPooling2D((2, 2), strides=2, padding='same')
        
        self.conv2 = layers.Conv2D(64, (5, 5), activation='relu', padding='same')
        self.pool2 = layers.MaxPooling2D((2, 2), strides=2, padding='same')
        
        self.flatten = layers.Flatten()
        self.fc1 = layers.Dense(1024, activation='relu')
        self.dropout = layers.Dropout(0.3)  #对应keep_prob_rate=0.7
        self.fc2 = layers.Dense(10, activation='softmax')

    def call(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.dropout(x)
        return self.fc2(x)

#创建模型
model = CNNModel()
model.build(input_shape=(None, 28, 28, 1))
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

#计算准确率（保持 2-3 位数字输出格式）
def compute_accuracy():
    loss, acc = model.evaluate(x_test, y_test, verbose=0)
    return round(acc, 3)  # 确保 2-3 位数字输出

#创建数据集
batch_size = 100
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(60000).batch(batch_size)

#训练
max_epoch = 2000
train_iter = iter(train_dataset.repeat())

for i in range(max_epoch):
    batch_xs, batch_ys = next(train_iter) 
    model.train_on_batch(batch_xs, batch_ys)

    if i % 100 == 0:
        print(compute_accuracy())


0.184
0.895
0.932
0.946
0.956
0.965
0.968
0.972
0.976
0.974
0.978
0.979
0.982
0.981
0.982
0.984
0.982
0.985
0.987
0.983
