In [7]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)

# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

# Convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

# Define the model
model = keras.Sequential(
    [
        keras.Input(shape=input_shape),
        layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        layers.MaxPooling2D(pool_size=(2, 2)),
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(num_classes, activation="softmax"),
    ]
)

model.summary()

# Custom training step function with optimized gradient direction
@tf.function
def train_step(model, x_batch, y_batch, learning_rate=0.001):
    gradients_per_sample = []

    # Loop over each sample in the batch
    for i in range(len(x_batch)):
        with tf.GradientTape() as tape:
            # Expand dimensions of the input sample to match batch format
            x_sample = tf.expand_dims(x_batch[i], axis=0)
            y_sample = tf.expand_dims(y_batch[i], axis=0)

            # Forward pass
            predictions = model(x_sample, training=True)
            loss = tf.keras.losses.categorical_crossentropy(y_sample, predictions)

        # Compute gradients for this sample
        grads = tape.gradient(loss, model.trainable_variables)
        grads = [g for g in grads if g is not None]  # Filter out None gradients
        if grads:
            grads_flattened = tf.concat([tf.reshape(g, [-1]) for g in grads], axis=0)
            gradients_per_sample.append(grads_flattened)

    if not gradients_per_sample:
        return  # 如果没有有效的梯度，则跳过这个批次

    gradients_per_sample = tf.stack(gradients_per_sample)

    # 样本梯度均值 d
    d = tf.reduce_mean(gradients_per_sample, axis=0)
    d = d / tf.norm(d)  # 归一化 d 以保持稳定性

    # 用找到的 d 方向更新权重
    start_idx = 0
    for var in model.trainable_variables:
        shape = tf.shape(var)
        size = tf.reduce_prod(shape)
        var_grad = tf.reshape(d[start_idx:start_idx + size], shape)
        var.assign_sub(learning_rate * var_grad)  # 直接更新变量
        start_idx += size

# Custom training loop
batch_size = 128
epochs = 15

# Training loop
for epoch in range(epochs):
    print(f"Epoch {epoch + 1}/{epochs}")
    num_batches = x_train.shape[0] // batch_size
    for i in range(0, x_train.shape[0], batch_size):
        x_batch = x_train[i:i+batch_size]
        y_batch = y_train[i:i+batch_size]
        train_step(model, x_batch, y_batch)

    # 在每个epoch结束时进行验证
    model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
    val_loss, val_acc = model.evaluate(x_test, y_test, verbose=0)
    print(f"Validation loss: {val_loss:.4f}, Validation accuracy: {val_acc:.4f}")

# 最后评估模型在测试数据上的表现
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=0)
print(f"Test loss: {test_loss:.4f}, Test accuracy: {test_acc:.4f}")



x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples
Model: "sequential_5"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_10 (Conv2D)          (None, 26, 26, 32)        320       
                                                                 
 max_pooling2d_10 (MaxPooli  (None, 13, 13, 32)        0         
 ng2D)                                                           
                                                                 
 conv2d_11 (Conv2D)          (None, 11, 11, 64)        18496     
                                                                 
 max_pooling2d_11 (MaxPooli  (None, 5, 5, 64)          0         
 ng2D)                                                           
                                                                 
 flatten_5 (Flatten)         (None, 1600)              0         
                                               