In [1]:
from torchvision import datasets, transforms
# from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
# %matplotlib inline

In [2]:
# Prepare datasets
data_path = '../Chapter 7/data-unversioned/'
transformed_cifar10 = datasets.CIFAR10(
    data_path, train=True, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                             std=(0.2470, 0.2435, 0.2616))
    ]))
transformed_cifar10_val = datasets.CIFAR10(
    data_path, train=False, download=False,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.4914, 0.4822, 0.4465),
                             std=(0.2470, 0.2435, 0.2616))
    ]))
label_map = {0: 0, 2: 1}  # dictionary
class_names = ['airplane', 'bird']
tr_cifar2 = [(img, label_map[label])
             for img, label in transformed_cifar10 if label in [0, 2]]
tr_cifar2_val = [(img, label_map[label])
                 for img, label in transformed_cifar10_val if label in [0, 2]]

In [3]:
device = (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
print(f'Training on device: {device}')

Training on device: cuda


# 测试: 卷积, 池化

In [7]:
conv = nn.Conv2d(   # accepts a (batches, channels, height, width) shape tensor
    3,              # input features (in this case, RGB channels)
    16,             # output features (hidden layer)
    kernel_size=3,  # conv core size (3x3, gotta be an odd number)
    padding=1       # keeps image size, processes edges
)
img, _ = tr_cifar2[0]
with torch.no_grad():
    conv.bias.zero_()
    conv.weight.fill_(1.0/9.0)
output = conv(img.unsqueeze(0))
img.unsqueeze(0).shape, output.shape

(torch.Size([1, 3, 32, 32]), torch.Size([1, 16, 32, 32]))

In [6]:
pool = nn.MaxPool2d(2)  # (devider size 2x2)
pool_output = pool(output)
output.shape, pool_output.shape

(torch.Size([1, 16, 32, 32]), torch.Size([1, 16, 16, 16]))

# 训练我们的 convnet

#### 构建自己的 nn.Module 子类

(当你需要使用预设模块没有的操作时)

```
nn.MaxPool2d 得到矩阵, 无法直接进行 nn.Linear, 需要一步reshape. 最新版本的PyTorch提供了 nn.Flatten, 但下面我们假设正在使用老版本...
```

In [4]:
class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1, dtype=(torch.half if device==torch.device('cuda') else torch.float))
        self.conv2 = nn.Conv2d(16, 8, kernel_size=3, padding=1, dtype=(torch.half if device==torch.device('cuda') else torch.float))
        self.lin1 = nn.Linear(8*8*8, 32, dtype=(torch.half if device==torch.device('cuda') else torch.float))
        self.lin2 = nn.Linear(32, 2, dtype=(torch.half if device==torch.device('cuda') else torch.float))

    def forward(self, input):
        out = F.max_pool2d(torch.tanh(self.conv1(input)), 2)
        out = F.max_pool2d(torch.tanh(self.conv2(out)), 2)
        out = out.reshape(-1, 8*8*8)
        out = torch.tanh(self.lin1(out))
        out = self.lin2(out)
        return out
model = MyNet()

#### 编写训练循环, 并在 GPU 上训练

In [5]:
import datetime

def training_loop(n_epochs, optimizer, model, loss_fn, train_loader):
    
    for epoch in range(1, n_epochs+1):
        loss_train = 0.0
        
        for imgs, labels in train_loader:
            imgs = imgs.to(device=device, dtype=(torch.half if device==torch.device('cuda') else torch.float))
            labels = labels.to(device=device)

            outputs = model(imgs)
            loss = loss_fn(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            loss_train += loss.item()
        if epoch <= 3 or epoch % 10 == 0:
            print('{} Epoch {}, Training loss {}'.format(datetime.datetime.now(), epoch, loss_train/len(train_loader)))

In [6]:
train_loader = torch.utils.data.DataLoader(tr_cifar2, batch_size=64, shuffle=True)

model = MyNet().to(device=device)       # model should also be transported to GPU along with parameters!

optimizer = optim.SGD(model.parameters(), lr=1e-2)
loss_fn = nn.CrossEntropyLoss()

training_loop(
    n_epochs = 200,
    optimizer = optimizer,
    model = model,
    loss_fn = loss_fn,
    train_loader = train_loader
)

2022-08-11 13:40:52.178821 Epoch 1, Training loss 0.580199417794586
2022-08-11 13:40:52.493490 Epoch 2, Training loss 0.4934875099522293
2022-08-11 13:40:52.813486 Epoch 3, Training loss 0.4729688122014331
2022-08-11 13:40:54.979368 Epoch 10, Training loss 0.3388360867834395
2022-08-11 13:40:58.081469 Epoch 20, Training loss 0.2944872424860669
2022-08-11 13:41:01.167114 Epoch 30, Training loss 0.26366954244625795
2022-08-11 13:41:04.263241 Epoch 40, Training loss 0.24216378570362262
2022-08-11 13:41:07.355047 Epoch 50, Training loss 0.22185190771795382
2022-08-11 13:41:10.364079 Epoch 60, Training loss 0.20581404570561307
2022-08-11 13:41:13.290759 Epoch 70, Training loss 0.1924928312848328
2022-08-11 13:41:16.163686 Epoch 80, Training loss 0.18051866665007962
2022-08-11 13:41:19.053011 Epoch 90, Training loss 0.16270592561952626
2022-08-11 13:41:21.866783 Epoch 100, Training loss 0.1502137396745621
2022-08-11 13:41:24.697841 Epoch 110, Training loss 0.13771591672472133
2022-08-11 13:4

#### 测量精度

In [7]:
train_loader = torch.utils.data.DataLoader(tr_cifar2, batch_size=64, shuffle=False)
val_loader = torch.utils.data.DataLoader(tr_cifar2_val, batch_size=64, shuffle=False)

def validate(model, train_loader, val_loader):
    for name, loader in [('train', train_loader), ('val', val_loader)]:
        correct, total = 0, 0
        with torch.no_grad():
            for imgs, labels in loader:
                imgs = imgs.to(device=device, dtype=(torch.half if device==torch.device('cuda') else torch.float))
                labels = labels.to(device=device)
                
                outputs = model(imgs)
                _, predicted = torch.max(outputs, dim=1)
                total += labels.shape[0]
                correct += int((predicted == labels).sum())
        print(f'Accuracy {name}: {(correct/total):.2f}')

validate(model, train_loader, val_loader)

Accuracy train: 0.99
Accuracy val: 0.89


# 保存/加载模型参数

In [9]:
# SAVE
torch.save(model.state_dict(), './pretrained/birds_airplanes.pt')

In [None]:
# LOAD
# loaded_model = MyNet()
# loaded_model.load_state_dict(torch.load('./pretrained/birds_airplanes.pt'))