In [5]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import torch.nn as nn
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                      download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                     download=True, transform=transform)

train_data, val_data = train_test_split(trainset, test_size=0.5, random_state=42)  # 修正划分比例为更合理的0.2作为验证集

train_loader = torch.utils.data.DataLoader(train_data, batch_size=64,
                                           shuffle=True, )
val_loader = torch.utils.data.DataLoader(val_data, batch_size=64,
                                         shuffle=False, )
test_loader = torch.utils.data.DataLoader(testset, batch_size=64,
                                          shuffle=False, )

model = resnet18(pretrained=True)
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(model, train_loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    return epoch_loss, epoch_acc


def validate(model, val_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(val_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    val_loss = running_loss / len(val_loader)
    val_acc = 100. * correct / total
    precision = precision_score(all_targets, all_preds, average='macro')
    recall = recall_score(all_targets, all_preds, average='macro')
    f1 = f1_score(all_targets, all_preds, average='macro')
    conf_matrix = confusion_matrix(all_targets, all_preds)
    return val_loss, val_acc, precision, recall, f1, conf_matrix


def est(model, test_loader, device):  # 新增测试函数
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    test_acc = 100. * correct / total
    precision = precision_score(all_targets, all_preds, average='macro')
    recall = recall_score(all_targets, all_preds, average='macro')
    f1 = f1_score(all_targets, all_preds, average='macro')
    conf_matrix = confusion_matrix(all_targets, all_preds)
    return test_acc, precision, recall, f1, conf_matrix


epochs = 5
for epoch in range(epochs):
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc, precision, recall, f1, conf_matrix = validate(model, val_loader, criterion, device)
    print(f'Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, '
          f'Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}')
    print("Confusion Matrix:")
    print(conf_matrix)

test_acc, precision, recall, f1, conf_matrix = est(model, test_loader, device)  # 在训练和验证循环结束后调用测试函数
print(f"Test Accuracy: {test_acc:.2f}%")
print(f"Test Precision: {precision:.4f}")
print(f"Test Recall: {recall:.4f}")
print(f"Test F1-score: {f1:.4f}")
print("Test Confusion Matrix:")
print(conf_matrix)



Epoch 1: Train Loss: 0.2568, Train Acc: 92.45%, Val Loss: 0.1030, Val Acc: 97.23%, Precision: 0.9731, Recall: 0.9718, F1-score: 0.9720
Confusion Matrix:
[[2949    0    0    0    4    0    1    1    1    0]
 [   1 3289   22    0   11    0    3    3    0    0]
 [  14    3 2860    2    9    0    3   25    3    1]
 [   0   10   41 2979    0    6    0   28    6    7]
 [   2    6    1    0 2903    0    9    3    3    3]
 [  15    4    0   16    2 2679   26    4   10   11]
 [  16    0    1    0    6    1 2992    0    6    0]
 [   1    5   16    1    5    0    0 3129    0    1]
 [  24   57   64   39    8    9    8   47 2592    8]
 [  11    4    0    3   89    3    0   73    5 2797]]
Epoch 2: Train Loss: 0.0953, Train Acc: 97.35%, Val Loss: 0.0866, Val Acc: 97.72%, Precision: 0.9771, Recall: 0.9773, F1-score: 0.9771
Confusion Matrix:
[[2908    0    5    2    1   26    7    0    2    5]
 [   0 3265    3    6    0    1    4   10   39    1]
 [   3   14 2884    4    0    0    0    3   11    1]
 [  