In [21]:
using Flux
using MLDatasets
using Statistics
using Random
using Optimisers

# 设置随机种子以确保结果可重复
Random.seed!(42)

# 1. 加载MNIST数据集
println("加载MNIST数据集...")
train_data = MLDatasets.MNIST(:train)
test_data = MLDatasets.MNIST(:test)

# 获取训练和测试数据
X_train = reshape(train_data.features, :, 60000)
y_train = Flux.onehotbatch(train_data.targets, 0:9)
X_test = reshape(test_data.features, :, 10000)
y_test = Flux.onehotbatch(test_data.targets, 0:9)

# 2. 数据预处理
println("验证原始训练数据：")
println("X_train shape: ", size(X_train), ", type: ", eltype(X_train))
println("X_train min: ", minimum(X_train), ", max: ", maximum(X_train))
println("验证原始测试数据：")
println("X_test shape: ", size(X_test), ", type: ", eltype(X_test))
println("X_test min: ", minimum(X_test), ", max: ", maximum(X_test))

# 仅转换类型，不重复归一化
X_train = Float32.(X_train)
X_test = Float32.(X_test)

# 验证预处理后的数据
println("预处理后训练数据：")
println("X_train shape: ", size(X_train), ", type: ", eltype(X_train))
println("X_train min: ", minimum(X_train), ", max: ", maximum(X_train))
println("预处理后测试数据：")
println("X_test shape: ", size(X_test), ", type: ", eltype(X_test))
println("X_test min: ", minimum(X_test), ", max: ", maximum(X_test))

# 3. 构建简化的MLP模型（移除softmax）
model = Chain(
    Dense(784, 128, relu, init=Flux.glorot_uniform),
    Dense(128, 64, relu, init=Flux.glorot_uniform),
    Dense(64, 10, init=Flux.glorot_uniform)
)

# 4. 定义损失函数和优化器
loss(x, y) = Flux.crossentropy(softmax(model(x)), y)  # 显式softmax
opt = Optimisers.Adam(0.01)

# 设置优化器状态
state = Optimisers.setup(opt, model)

# 5. 训练模型
println("开始训练模型...")
num_epochs = 10
batch_size = 128
for epoch in 1:num_epochs
    # 打乱训练数据
    perm = randperm(60000)
    X_train = X_train[:, perm]
    y_train = y_train[:, perm]
    
    # 按小批量训练
    total_loss = 0.0
    num_batches = 0
    for i in 1:batch_size:60000
        x_batch = X_train[:, i:min(i+batch_size-1, 60000)]
        y_batch = y_train[:, i:min(i+batch_size-1, 60000)]
        
        # 调试：检查批次数据和模型输出
        if i == 1
            println("Batch 1: x_batch shape: ", size(x_batch), ", y_batch shape: ", size(y_batch))
            println("x_batch min: ", minimum(x_batch), ", max: ", maximum(x_batch))
            output = model(x_batch)
            println("Model output shape: ", size(output), ", min: ", minimum(output), ", max: ", maximum(output))
        end
        
        # 计算梯度和损失
        (l, grads) = Flux.withgradient(model) do m
            Flux.crossentropy(softmax(m(x_batch)), y_batch)
        end
        
        # 检查梯度和损失
        if grads[1] === nothing
            println("警告：梯度为Nothing，Batch $i")
            continue
        elseif isnan(l)
            println("警告：损失为NaN，Batch $i")
            continue
        end
        
        # 更新模型参数
        state, model = Optimisers.update(state, model, grads[1])
        
        total_loss += l
        num_batches += 1
    end
    
    # 计算训练集和测试集的平均损失
    train_loss = total_loss / num_batches
    test_loss = loss(X_test, y_test)
    println("Epoch $epoch: 训练损失 = $train_loss, 测试损失 = $test_loss")
end

# 6. 测试模型
println("测试模型...")
y_pred = softmax(model(X_test))
y_pred_labels = [argmax(y_pred[:, i]) - 1 for i in 1:10000]
y_true_labels = test_data.targets
accuracy = mean(y_pred_labels .== y_true_labels)
println("测试集准确率: $accuracy")

加载MNIST数据集...
验证原始训练数据：
X_train shape: (784, 60000), type: Float32
X_train min: 0.0, max: 1.0
验证原始测试数据：
X_test shape: (784, 10000), type: Float32
X_test min: 0.0, max: 1.0
预处理后训练数据：
X_train shape: (784, 60000), type: Float32
X_train min: 0.0, max: 1.0
预处理后测试数据：
X_test shape: (784, 10000), type: Float32
X_test min: 0.0, max: 1.0
开始训练模型...
Batch 1: x_batch shape: (784, 128), y_batch shape: (10, 128)
x_batch min: 0.0, max: 1.0
Model output shape: (10, 128), min: -1.038578, max: 1.2625971
Epoch 1: 训练损失 = 0.23097359501063697, 测试损失 = 0.12836263
Batch 1: x_batch shape: (784, 128), y_batch shape: (10, 128)
x_batch min: 0.0, max: 1.0
Model output shape: (10, 128), min: -21.276485, max: 20.874569
Epoch 2: 训练损失 = 0.11736340255641353, 测试损失 = 0.13810232
Batch 1: x_batch shape: (784, 128), y_batch shape: (10, 128)
x_batch min: 0.0, max: 1.0
Model output shape: (10, 128), min: -32.37522, max: 30.033476
Epoch 3: 训练损失 = 0.09900698495178875, 测试损失 = 0.11113643
Batch 1: x_batch shape: (784, 128), y_batch 