# Coding: LeNet for MNIST

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader

# Set random seed for reproducibility
torch.manual_seed(42)

# Define the LeNet architecture
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5)
        self.fc1 = nn.Linear(in_features=16 * 4 * 4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=84)
        self.fc3 = nn.Linear(in_features=84, out_features=10)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        # Single-Channel 1-6: First conv block
        # dimension: 28x28x1 -> 24x24x6
        x = self.relu(self.conv1(x))
        # dimension: 24x24x6 -> 12x12x6
        x = self.maxpool(x)

        # Muti-Channel 6-16: Second conv block
        # dimension: 12x12x6 -> 8x8x16
        x = self.relu(self.conv2(x))
        # dimension: 8x8x16 -> 4x4x16
        x = self.maxpool(x)

        # MLP: Flatten and fully connected layers
        x = x.view(x.size(0), -1)                       # 张量x重新整形为二维张量，输入张量的维度为[batch_size, channels, height, width]，x.size(0)表示保持第一个维度（批次大小）不变，-1表示自动推断第二个维度的大小，以保持总元素数量不变
        # dimension: 16x4x4 -> 120
        x = self.relu(self.fc1(x))
        # dimension: 120 -> 84
        x = self.relu(self.fc2(x))
        # dimension: 84 -> 10
        x = self.fc3(x)
        return x

# Data loading and preprocessing
def load_data(batch_size=64):
    # 定义了数据预处理的变换操作, 用于将原始MNIST图像数据转换为适合神经网络训练的格式
    transform = transforms.Compose([
        transforms.ToTensor(),          # 将PIL图像或numpy数组转换为PyTorch张量;将图像像素值从[0, 255]范围缩放到[0, 1]范围;将图像通道顺序从HWC(高度×宽度×通道)转换为CHW(通道×高度×宽度)
        transforms.Normalize((0.1307,), (0.3081,))      # 对图像张量进行标准化处理;第一个参数(0.1307,)是均值,第二个参数(0.3081,)是标准差;这些值是MNIST数据集的统计值，用于将像素值标准化为均值为0，标准差为1的分布
    ])

    train_dataset = torchvision.datasets.MNIST(
        root='./data',
        train=True,
        download=True,
        transform=transform
    )

    test_dataset = torchvision.datasets.MNIST(
        root='./data',
        train=False,
        download=True,
        transform=transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False
    )

    return train_loader, test_loader

# Training
def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        # Do Not Forget to Zero Gradients
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    accuracy = 100. * correct / total
    return running_loss / len(train_loader), accuracy

# Evaluation
def evaluate(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = 100. * correct / total
    return running_loss / len(test_loader), accuracy

def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    # Hyperparameters
    batch_size = 64
    learning_rate = 0.001
    num_epochs = 10

    # Load data
    train_loader, test_loader = load_data(batch_size)

    # Initialize model, loss function, and optimizer
    model = LeNet().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Training loop
    print('Starting training...')
    for epoch in range(num_epochs):
        train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
        test_loss, test_acc = evaluate(model, test_loader, criterion, device)

        print(f'Epoch [{epoch+1}/{num_epochs}]:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
        print('-' * 50)

    # Save the trained model
    torch.save(model.state_dict(), 'lenet_mnist.pth')
    print('Training completed and model saved!')


if __name__ == '__main__':
    main()


Using device: cpu


100.0%
100.0%
100.0%
100.0%


Starting training...
Epoch [1/10]:
Train Loss: 0.2401, Train Acc: 92.56%
Test Loss: 0.0754, Test Acc: 97.57%
--------------------------------------------------
Epoch [2/10]:
Train Loss: 0.0700, Train Acc: 97.76%
Test Loss: 0.0465, Test Acc: 98.47%
--------------------------------------------------
Epoch [3/10]:
Train Loss: 0.0513, Train Acc: 98.40%
Test Loss: 0.0470, Test Acc: 98.57%
--------------------------------------------------
Epoch [4/10]:
Train Loss: 0.0385, Train Acc: 98.78%
Test Loss: 0.0395, Test Acc: 98.72%
--------------------------------------------------
Epoch [5/10]:
Train Loss: 0.0334, Train Acc: 98.94%
Test Loss: 0.0405, Test Acc: 98.69%
--------------------------------------------------
Epoch [6/10]:
Train Loss: 0.0288, Train Acc: 99.08%
Test Loss: 0.0357, Test Acc: 98.86%
--------------------------------------------------
Epoch [7/10]:
Train Loss: 0.0229, Train Acc: 99.26%
Test Loss: 0.0374, Test Acc: 99.06%
--------------------------------------------------
Epoch 