# PyTorch quickstart

## 1. 导入相关工具

`torch.utils.data.Dataset`：用于存储数据集，包括训练样本以及对应的label

`torch.utils.data.DotaLoader`：用于生成数据集的可迭代对象

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda, Compose
import matplotlib.pyplot as plt

## 2. 载入数据集

PyTorch提供了许多与AI相关领域的工具库，例如[TorchText](https://pytorch.org/text/stable/index.html), [TorchVision](https://pytorch.org/vision/stable/index.html), and [TorchAudio](https://pytorch.org/audio/stable/index.html), 并包含了一些常用的数据集。

在torchvision中，包含了常用的图像任务数据集，例如CIFAR，COCO等，完整列表：https://pytorch.org/docs/stable/torchvision/datasets.html

在quickstart教程中，使用的是FashionMNIST dataset。每个torchvision的数据集都包含以下两个参数：
- `transform`: 改变训练样本
- `target_transform`: 改变label


In [2]:
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
)

# Download test data from open datasets.
test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor(),
)

Downloading http://202.194.1.97/cache/8/02/fashion-mnist.s3-website.eu-central-1.amazonaws.com/268366d35524891c045904ff5a612794/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=26421880.0), HTML(value='')))


Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://202.194.1.97/cache/6/02/fashion-mnist.s3-website.eu-central-1.amazonaws.com/0044051af8c06cfbe7d405eee650335e/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=29515.0), HTML(value='')))


Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://202.194.1.97/cache/8/02/fashion-mnist.s3-website.eu-central-1.amazonaws.com/246c3e6b3ab8ed75eaf255fca08a2807/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=4422102.0), HTML(value='')))


Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=5148.0), HTML(value='')))


Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


`torchvision.datasets`是继承自`torch.utils.data.Dataset`的数据集实现，要实现数据集加载，还需要将`torch.utils.data.Dataset`的数据集实例作为参数传入`torch.utils.data.DataLoader`，`DataLoader`是可用于训练过程的可迭代对象：

In [3]:
batch_size = 64

train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)

for X, y in train_dataloader:
    print(X.shape)
    print(y.shape)
    break

torch.Size([64, 1, 28, 28])
torch.Size([64])


## 3. 模型搭建

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('using {} device.'.format(device))

# 定义一个多层感知机
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu = nn.Sequential(
            nn.Linear(28*28, 512), 
            nn.ReLU(),
            nn.Linear(512, 512), 
            nn.ReLU(),
            nn.Linear(512, 10), 
            nn.ReLU()
        )
    def forward(self, x):
        x = self.flatten(x)
        output = self.linear_relu(x)
        return output

mlp = MLP().to(device)
print(mlp)

using cuda device.
MLP(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
    (5): ReLU()
  )
)


## 4. 定义优化器、损失函数、模型训练、模型测试

在模型的训练过程中，需要循环读取**数据集**，使用**模型**预测训练样本的结果，根据label和**损失函数**计算loss，并使用**优化器**更新模型参数，因此训练函数应传入一下四部分参数：
- dataloader
- model
- loss function
- optimizer

In [5]:
loss_fn= nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(mlp.parameters(), lr = 1e-3)

def train(dataloader, model, loss_fn, optimizer):
    dataset_size = len(dataloader.dataset)
    
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        predict = model(X)
        loss = loss_fn(predict, y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch % 100 == 0:
            loss, current = loss.item(), batch*len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{dataset_size:>5d}]")

在测试函数中，需要对测试集中的测试样本预测输出，同时计算精度，平均误差等指标，评估模型性能。

In [8]:
def test(dataloader, model):
    size = len(dataloader.dataset)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= size
    correct /= size
    print(f"Test error:\n Accuracy: {100*correct:>0.1f}%, Avg loss:{test_loss:>8f} \n")

In [None]:
epochs = 5
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, mlp, loss_fn, optimizer)
    test(test_dataloader, mlp)
print("Done!")

Epoch 1
-------------------------------
loss: 1.574766  [    0/60000]
loss: 1.542595  [ 6400/60000]
loss: 1.444280  [12800/60000]
loss: 1.442885  [19200/60000]
loss: 1.501333  [25600/60000]
loss: 1.623133  [32000/60000]
loss: 1.499517  [38400/60000]
loss: 1.635447  [44800/60000]
loss: 1.387394  [51200/60000]
loss: 1.394289  [57600/60000]
Test error:
 Accuracy: 57.2%, Avg loss:0.022357 

Epoch 2
-------------------------------
loss: 1.490478  [    0/60000]
loss: 1.468123  [ 6400/60000]
loss: 1.353787  [12800/60000]
loss: 1.368516  [19200/60000]
loss: 1.430103  [25600/60000]
loss: 1.550336  [32000/60000]
loss: 1.432137  [38400/60000]
loss: 1.573277  [44800/60000]
loss: 1.324243  [51200/60000]
loss: 1.336770  [57600/60000]
Test error:
 Accuracy: 58.2%, Avg loss:0.021382 

Epoch 3
-------------------------------
loss: 1.420616  [    0/60000]
loss: 1.409062  [ 6400/60000]
loss: 1.280762  [12800/60000]
loss: 1.307961  [19200/60000]
loss: 1.374301  [25600/60000]
loss: 1.491992  [32000/60000]


## 5. 模型保存

In [50]:
torch.save(mlp.state_dict(), 'checkpoints/00_mlp.pth')

## 6. 模型读取与测试

In [58]:
mlp = MLP()
mlp.load_state_dict(torch.load('checkpoints/00_mlp.pth'))

classes = [
    "T-shirt/top",
    "Trouser",
    "Pullover",
    "Dress",
    "Coat",
    "Sandal",
    "Shirt",
    "Sneaker",
    "Bag",
    "Ankle boot",
]

mlp.eval()

x, y = test_data[0][0], test_data[0][1]

pred = mlp(x)

pred, y = classes[pred.argmax(1).item()], classes[y]

print('predict class: {}.'.format(pred), 'Label: {}'.format(y))


predict class: Ankle boot. Label: Ankle boot
