In [1]:
#由于tensorflow2.x版本删除了tensorflow.examples模块，因此对下方读取数据集的模块进行修改。
import tensorflow as tf
import numpy as np

# 加载MNIST数据集
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# 数据预处理
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

learning_rate = 1e-4
keep_prob_rate = 0.7
max_epoch = 700

# 定义计算准确率的函数
def compute_accuracy(model, x, y):
    y_pre = model.predict(x, verbose=0)
    correct_prediction = np.equal(np.argmax(y_pre, 1), np.argmax(y, 1))
    accuracy = np.mean(correct_prediction.astype('float32'))
    return accuracy

# 定义模型
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, (7, 7), activation='relu', input_shape=(28, 28, 1), padding='same'),
        tf.keras.layers.MaxPooling2D((2, 2), padding='same'),
        tf.keras.layers.Conv2D(64, (5, 5), activation='relu', padding='same'),
        tf.keras.layers.MaxPooling2D((2, 2), padding='same'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(1024, activation='relu'),
        tf.keras.layers.Dropout(1 - keep_prob_rate),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    return model

# 创建并训练模型
model = create_model()

for i in range(max_epoch):
    # 随机选择100个样本进行训练
    indices = np.random.choice(len(x_train), 100, replace=False)
    batch_xs = x_train[indices]
    batch_ys = y_train[indices]
    
    model.train_on_batch(batch_xs, batch_ys)
    
    if i % 100 == 0:
        acc = compute_accuracy(model, x_test[:1000], y_test[:1000])
        print(f"Epoch: {i}, Accuracy: {acc:.4f}")

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch: 0, Accuracy: 0.1630
Epoch: 100, Accuracy: 0.8730
Epoch: 200, Accuracy: 0.9170
Epoch: 300, Accuracy: 0.9400
Epoch: 400, Accuracy: 0.9510
Epoch: 500, Accuracy: 0.9590
Epoch: 600, Accuracy: 0.9650
