In [1]:
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt


# 检查CUDA设备是否可用，然后选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

class CNN(torch.nn.Module):
    # 定义网络结构
    def __init__(self):
        super(CNN, self).__init__()
        # 图片是灰度图片，只有一个通道
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, 
                               kernel_size=5, stride=1, padding=2)
        self.bn1 = nn.BatchNorm2d(16)  # 批归一化
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, 
                               kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm2d(32)  # 批归一化
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.conv3 = nn.Conv2d(in_channels=32, out_channels=64, 
                               kernel_size=3, stride=1, padding=1)  # 第三层卷积
        self.bn3 = nn.BatchNorm2d(64)  # 批归一化
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.fc1 = nn.Linear(in_features=64 * 3 * 3, out_features=256)  # 修改后的输入尺寸
        self.dropout = nn.Dropout(0.5)  # Dropout层
        self.fc2 = nn.Linear(in_features=256, out_features=10)  # 输出10个类别对应0-9
    
    # 定义前向传播过程的计算函数
    def forward(self, x):
        # 第一层卷积、批归一化、激活函数和池化
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        
        # 第二层卷积、批归一化、激活函数和池化
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        
        # 第三层卷积、批归一化、激活函数和池化
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)
        
        # 将数据平展成一维
        x = x.view(-1, 64 * 3 * 3)  
        
        # 第一层全连接层和 Dropout
        x = self.fc1(x)
        x = self.dropout(x)
        
        # 第二层全连接层
        x = self.fc2(x)
        return x


def get_data_loader(is_train):
    transform = transforms.Compose([transforms.ToTensor()])
    data_set = MNIST(root='mnist_data/', train=is_train, download=True, transform=transform)
    return DataLoader(data_set, batch_size=15, shuffle=True, pin_memory=True)

def evaluate(test_data, net):
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in test_data:
            x, y = x.to(device), y.to(device)
            outputs = net(x)
            _, predicted = torch.max(outputs.data, 1)  # 直接获取预测类别
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

def main():
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = CNN().to(device)
    criterion = torch.nn.CrossEntropyLoss()  # 使用CrossEntropyLoss
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    print("Initial accuracy:", evaluate(test_data, net))

    for epoch in range(10):
        net.train()
        for x, y in train_data:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = net(x)
            loss = criterion(outputs, y)  # 使用CrossEntropyLoss计算损失
            loss.backward()
            optimizer.step()
        print(f"Epoch: {epoch}, Accuracy: {evaluate(test_data, net)}")

    # 可视化预测结果
    plt.figure(figsize=(10, 5))
    net.eval()
    with torch.no_grad():
        for n, (x, _) in enumerate(test_data):
            if n > 3:
                break
            x = x.to(device)
            outputs = net(x)
            prediction = torch.argmax(outputs, dim=1)[0]  # 获取第一个样本的预测
            plt.subplot(2, 2, n+1)
            plt.imshow(x[0].view(28, 28).cpu().numpy(), cmap='gray')
            plt.title(f"Prediction: {prediction.item()}")
            plt.axis('off')
    plt.show()

if __name__ == '__main__':
    main()


Initial accuracy: 0.1245
Epoch: 0, Accuracy: 0.9793
Epoch: 1, Accuracy: 0.9858
Epoch: 2, Accuracy: 0.9887
Epoch: 3, Accuracy: 0.9906
