In [1]:
import torch 
from torch import nn
from torch.nn import functional as F

class Lenet5(nn.Module):
    
    def __init__(self):
        super(Lenet5, self).__init__()

        # 卷积层单元
        self.conv_unit = nn.Sequential(
            # 第一层
            nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=0),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),

            # 第二层
            nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0),
            nn.AvgPool2d(kernel_size=2, stride=2, padding=0),

            # 输入数据 x: [b, 3, 32, 32] 经过两层卷积后=> [b, 16, 5, 5]

        )
        # flatten
        # fc unit 全连接层单元
        self.fc_unit = nn.Sequential(
            nn.Linear(16*5*5, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, 10)
        )

        # use Cross Entropy Loss（对于分类问题，一般用交叉熵）
        # self.criteon = nn.CrossEntropyLoss()


    def forward(self, x):
        batch_size = x.size(0)
        # [b, 3, 32, 32] => [b, 16, 5, 5]
        x = self.conv_unit(x)
        # flatten（展开成一维）: [b, 16, 5, 5] => [b, 16*5*5]
        x = x.view(batch_size, -1)
        # [b, 16*5*5] => [b, 10]
        logits = self.fc_unit(x)

        return logits




def main():
    net = Lenet5()
    
    # test
    temp = torch.randn(2, 3, 32, 32)
    out = net(temp)
    print(out.shape)

if __name__ == '__main__':
    main()

torch.Size([2, 10])


In [2]:
import torch 
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch import nn, optim

def main():
    batch_size = 32
    
    cifar_train = datasets.CIFAR10('cifar', train=True, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_train = DataLoader(cifar_train, batch_size=batch_size, shuffle=True)

    cifar_test = datasets.CIFAR10('cifar', train=False, transform=transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]), download=True)
    cifar_test = DataLoader(cifar_test, batch_size=batch_size, shuffle=True)

    # for x, label in iter(cifar_train):
    #     print('x:', x.shape, 'label:', label.shape)
    #     break

    device = torch.device('cuda')
    model = Lenet5()
    model.to(device)
    # 打印网络模型的数据结构
    # print(model)

    # 包含了 softmax 操作
    criteon = nn.CrossEntropyLoss()
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for epoch in range(1000):
        # 将模型设置为训练模式
        model.train()
        for batchidx, (x, label) in enumerate(cifar_train):
            # x: [b, 3, 32, 32]
            # label: [b]
            x, label = x.to(device), label.to(device)
            logits = model(x)
            # logits: [b ,10]
            # label: [b]
            # loss: tensor scalar （一维tensor）
            loss = criteon(logits, label)
             # backprop
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(epoch, loss.item())

        # 将模型设置为测试模式
        model.eval()
        # with torch.no_grad() 包住的代码块不会计算梯度，即不会影响 model
        with torch.no_grad():
            # test
            total_correct = 0
            total_num = 0
            for x, label in cifar_test:
                # x: [b, 3, 32, 32]
                # label: [b]
                x, label = x.to(device), label.to(device)

                # [b, 10]
                logits = model(x)

                pred = logits.argmax(dim=1)
                # [b] vs [b] => scalar tensor（一维tensor）
                # 统计这个batch中预测正确的数量
                total_correct += torch.eq(pred, label).float().sum()
                total_num += x.size(0)
            # 正确率
            acc = total_correct / total_num
            print(epoch, acc)
        

if __name__ == '__main__':
    main()

Files already downloaded and verified
Files already downloaded and verified
0 1.4829308986663818
0 tensor(0.4497, device='cuda:0')
1 1.4972277879714966
1 tensor(0.4883, device='cuda:0')
2 0.8046582341194153
2 tensor(0.5100, device='cuda:0')
3 1.375645399093628
3 tensor(0.5183, device='cuda:0')
4 1.1264864206314087
4 tensor(0.5269, device='cuda:0')
5 1.2121236324310303
5 tensor(0.5407, device='cuda:0')
6 0.9914142489433289
6 tensor(0.5314, device='cuda:0')
7 0.946485698223114
7 tensor(0.5523, device='cuda:0')
8 0.964535653591156
8 tensor(0.5456, device='cuda:0')
9 0.8266449570655823
9 tensor(0.5505, device='cuda:0')
10 0.8220797777175903
10 tensor(0.5483, device='cuda:0')
11 0.8345103859901428
11 tensor(0.5459, device='cuda:0')
12 0.9445749521255493
12 tensor(0.5453, device='cuda:0')
13 0.8477845788002014
13 tensor(0.5534, device='cuda:0')
14 0.5807614326477051
14 tensor(0.5540, device='cuda:0')
15 0.8086283802986145
15 tensor(0.5474, device='cuda:0')
16 0.5396457314491272
16 tensor(0.5