导入必要的库

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader


定义模型

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc = nn.Linear(64 * 7 * 7, num_classes)

    def forward(self, x):
        x = self.pool(self.relu(self.conv1(x)))
        x = self.pool(self.relu(self.conv2(x)))
        x = x.view(-1, 64 * 7 * 7)
        x = self.fc(x)
        return x


数据加载和预处理

In [None]:
def get_dataloader(dataset_name):
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    if dataset_name == 'MNIST':
        dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    elif dataset_name == 'CIFAR10':
        transform.transforms.insert(0, transforms.Grayscale(num_output_channels=1))
        dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    elif dataset_name == 'FashionMNIST':
        dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    elif dataset_name == 'SVHN':
        transform.transforms.insert(0, transforms.Grayscale(num_output_channels=1))
        dataset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform)

    dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
    return dataloader


训练和微调函数

def train(model, dataloader, epochs, optimizer, criterion):
    model.train()
    for epoch in range(epochs):
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

def finetune(model, dataloader, epochs, optimizer, criterion):
    model.train()
    for epoch in range(epochs):
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()


执行预训练和微调

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 预训练
for dataset_name in ['MNIST', 'CIFAR10', 'FashionMNIST', 'SVHN']:
    dataloader = get_dataloader(dataset_name)
    train(model, dataloader, epochs=5, optimizer, criterion)

# 微调
target_dataloader = get_dataloader('FashionMNIST')
finetune(model, target_dataloader, epochs=3, optimizer, criterion)


测试模型

In [None]:
def test(model, dataloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f'Accuracy of the network on the test images: {100 * correct / total}%')

test_dataloader = get_dataloader('FashionMNIST')
test(model, test_dataloader)
