In [8]:
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader

In [15]:
class LeNet5(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.model = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
            nn.Flatten(),
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10)
        )

        # t = torch.rand(2, 3, 32, 32)
        # out = self.model(t)
        # print(out.shape)  # torch.Size([2, 16, 5, 5])

    def forward(self, x):
        return self.model(x)


In [9]:
tf_train = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])
cifar_train = datasets.CIFAR10('cifar', True, transform=tf_train, download=True)
cifar_train = DataLoader(cifar_train, batch_size=32, shuffle=True)

tf_test = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor()
])
cifar_test = datasets.CIFAR10('cifar', False, transform=tf_test, download=True)
cifar_test = DataLoader(cifar_test, batch_size=32, shuffle=True)

Files already downloaded and verified
Files already downloaded and verified


In [21]:
def train_cifar(net: LeNet5, data: DataLoader):
    net.train()
    lossfn = nn.CrossEntropyLoss() # 该方法包含softmax
    adam = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(10):
        for batch, (x, y) in enumerate(data):
            out = net(x)
            # print(out.shape)
            # print(y.shape)
            loss = lossfn(out, y)

            adam.zero_grad()
            loss.backward()
            adam.step()
            if batch % 10 == 0:
                print(f"\repoch {epoch} ===========> batch {batch} loss {loss.item()}", end='\r')


In [16]:
net = LeNet5()

In [23]:
train_cifar(net, cifar_train)

