In [30]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from torchvision.datasets import MNIST

In [31]:
def load_data(data_root, train_batch_size, test_batch_size):
    mean, std_dev = [0.1307], [0.3081]
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean, std_dev)
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std_dev)
    ])
    trainset = MNIST(
        data_root, train=True, download=True,
        transform=transform_train
    )
    trainset = Subset(trainset, range(20000))
    testset = MNIST(
        data_root, train=False, download=True,
        transform=transform_test
    )
    train_loader = DataLoader(
        trainset, train_batch_size, shuffle=False,
        num_workers=2, pin_memory=True
    )
    test_loader = DataLoader(
        testset, test_batch_size, shuffle=False,
        num_workers=2, pin_memory=True
    )
    return train_loader, test_loader


class AvgMetric(object):
    def __init__(self):
        self.reset()

    def reset(self):
        self.counter = 0
        self.accumulator = 0.0
        self.value = 0.0

    def update(self, value, n):
        self.value = value
        self.counter += n
        self.accumulator += value

    def result(self):
        return self.accumulator / self.counter


class LeNet_300_100(nn.Module):
    def __init__(self):
        super(LeNet_300_100, self).__init__()
        self.linear0 = nn.Linear(28*28, 300)
        self.relu0 = nn.ReLU(inplace=True)
        self.linear1 = nn.Linear(300, 100)
        self.relu1 = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(100, 10)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.relu0(self.linear0(x))
        x = self.relu1(self.linear1(x))
        x = self.linear2(x)
        return x


def train(epochs):
    loss_metric = AvgMetric()
    for epoch in range(epochs):
        for step, (inputs, targets) in enumerate(train_loader):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()  # backprop
            optimizer.step() # update params
            loss_metric.update(loss, inputs.shape[0])

            if (step+1) % 200 == 0:
                print("Epoch[{}/{}], step[{}/{}]:\tLoss={:.4f}".format(
                    epoch+1, epochs, step+1,
                    len(train_loader.dataset) // train_batch_size,
                    loss_metric.result()
                ))


train_batch_size, test_batch_size = 32, 16
data_root = "../../../../data/torchvision"
train_loader, test_loader = load_data(
    data_root, train_batch_size, test_batch_size
)
model = LeNet_300_100()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
loss_fn = nn.CrossEntropyLoss()
train(2)

print("==> Model's state_dict:")
for name, param in model.state_dict().items():
    print(name, param.shape, sep="\t")
print("\n\n==> Optimizer's state_dict:")
for name, param in optimizer.state_dict().items():
    print(name, param, sep="\n")

Epoch[1/2], step[200/625]:	Loss=0.0286
Epoch[1/2], step[400/625]:	Loss=0.0220
Epoch[1/2], step[600/625]:	Loss=0.0186
Epoch[2/2], step[200/625]:	Loss=0.0158
Epoch[2/2], step[400/625]:	Loss=0.0144
Epoch[2/2], step[600/625]:	Loss=0.0133
==> Model's state_dict:
linear0.weight	torch.Size([300, 784])
linear0.bias	torch.Size([300])
linear1.weight	torch.Size([100, 300])
linear1.bias	torch.Size([100])
linear2.weight	torch.Size([10, 100])
linear2.bias	torch.Size([10])


==> Optimizer's state_dict:
state
{0: {'momentum_buffer': tensor([[-0.0038, -0.0038, -0.0038,  ..., -0.0038, -0.0038, -0.0038],
        [ 0.0028,  0.0028,  0.0028,  ...,  0.0028,  0.0028,  0.0028],
        [-0.0012, -0.0012, -0.0012,  ..., -0.0012, -0.0012, -0.0012],
        ...,
        [-0.0041, -0.0041, -0.0041,  ..., -0.0041, -0.0041, -0.0041],
        [ 0.0001,  0.0001,  0.0001,  ...,  0.0001,  0.0001,  0.0001],
        [-0.0074, -0.0074, -0.0074,  ..., -0.0074, -0.0074, -0.0074]])}, 1: {'momentum_buffer': tensor([ 9.0397e-0

In [41]:
isinstance(torch.tensor([1, 2, 3]), torch.Tensor)

True

## Saving and Loading Models
PyTorch 中有两种方法可以对模型进行保存和加载；一种是`torch.save()`函数保存和`torch.load()`函数加载`state_dict`；另一种方式则是借助 Python 的 pickle 库来保存和加载整个模型；




### Saving and Loading the Entire Model
利用`torch.save()`和`torch.load()`函数便可实现对模型的加载和保存；这种方式简单且易于实现，然而缺点在于，由于 pickle 库并不保存模型类本身，而是保存定义了模型的类所在的文件的路径，即序列化的数据是与保存模型时特定的类和特定的目录结构，进而如果在其他项目中使用该模型或在重构代码或路径之后再进行模型导入时，代码可能会以多种方式中断；保存和加载的示例如下：

```python
model = LeNet_300_100()
ckpt_path = ".../entire_model.pt"
torch.save(net, ckpt_path)
model = torch.load(ckpt_path)
```
PyTorch 中约定使用`.pt`或`.pth`作为文件扩展名来保存模型；





### The `state_dict` in PyTorch
在 PyTorch 中，一个模型的`state_dict`是其内部每个层的名称映射至相应参数的字典；只有一个层的可学习参数或注册缓冲区 (如 BN 的 `running_mean`) 会记录在`state_dict`中；此外，优化器也含有`state_dict`，其包含了优化器的状态信息以及所使用的超参数；

```python
class LeNet_300_100(nn.Module):
    def __init__(self):
        super(LeNet_300_100, self).__init__()
        self.linear0 = nn.Linear(28*28, 300)
        self.relu0 = nn.ReLU(inplace=True)
        self.linear1 = nn.Linear(300, 100)
        self.relu1 = nn.ReLU(inplace=True)
        self.linear2 = nn.Linear(100, 10)

    def forward(self, x):
        x = torch.flatten(x, start_dim=1)
        x = self.relu0(self.linear0(x))
        x = self.relu1(self.linear1(x))
        x = self.linear2(x)
        return x

model = LeNet_300_100()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
for name, param in model.state_dict().items():
    print(name, "\t", param.shape)
""" ==>
linear0.weight	torch.Size([300, 784])
linear0.bias	torch.Size([300])
linear1.weight	torch.Size([100, 300])
linear1.bias	torch.Size([100])
linear2.weight	torch.Size([10, 100])
linear2.bias	torch.Size([10])
"""
```
利用`state_dict`加载和保存模型的优点在于的灵活性更高；而缺点在于所加载的`state_dict`为字典类型，进而只能适用于原代码所定义的模型，或参数名称与参数形状完全与原模型相同的情况，否则便无法将`state_dict`导入模型；




### Saving and Loading a model `state_dict`
对`state_dict`的保存依旧是使用`torch.save()`函数，加载则利用`torch.load()`方法加载`state_dict`，再利用模型的`load_state_dict`方法将`state_dict`导入至模型中；需要说明的是，由于仅对模型的`state_dict`进行了保存，进而加载得到的模型不含有任何与训练参数有关的信息，此时只能将模型用于推断、重新训练或微调；
```python
model = LeNet_300_100()
ckpt_path = ".../state_dict_model.pt"
torch.save(net.state_dict(), ckpt_path)  # save the model
model.load_state_dict(state_dict)        # load the model
```




### Saving and Loading for Resuming Training
如果需要加载模型并恢复之前的训练，应在保存检查点时将优化器的`state_dict`以及其他需要保存的参数连同模型的`state_dict`一同进行保存；具体而言，只需将这些要保存的对象组织成一个字典，再将其传递给`torch.save()`函数即可，示例如下：
```python
# save
model = LeNet_300_100()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
state_dict = {
    'epoch': 5,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': 0.4,
}
ckpt_path = "model.pt"
torch.save(state_dict, ckpt_path)

# load
model = LeNet_300_100()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
ckpt = torch.load(PATH)
model.load_state_dict(ckpt['model_state_dict'])
optimizer.load_state_dict(ckpt['optimizer_state_dict'])
epoch = ckpt['epoch']
loss = ckpt['loss']
```




### Saving and Loading Multiple Models in One File
对于 GAN、seq2seq 模型、集成模型等在保存时需要保存多个`nn.Module`的模型，只需将需要保存的模型的`state_dict`、相应优化器的`state_dict`、以及其他需要保存的参数编写入字典，再利用`torch.save()`进行保存即可；加载时只需利用键值进行索引即可；示例如下
```python
# ==> save
modelA = LeNet_300_100()
modelB = LeNet_300_100()
optimizerA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimizerB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)
ckpt_path = ".../multimodel.pt"
state_dict = {
    'modelA_state_dict': modelA.state_dict(),
    'modelB_state_dict': modelB.state_dict(),
    'optimizerA_state_dict': optimizerA.state_dict(),
    'optimizerB_state_dict': optimizerB.state_dict()
}
torch.save(state_dict, ckpt_path)

# ==> Load multiple models
modelA = LeNet_300_100()
modelB = LeNet_300_100()
optimModelA = optim.SGD(modelA.parameters(), lr=0.001, momentum=0.9)
optimModelB = optim.SGD(modelB.parameters(), lr=0.001, momentum=0.9)
ckpt = torch.load(ckpt_path)
modelA.load_state_dict(ckpt['modelA_state_dict'])
modelB.load_state_dict(ckpt['modelB_state_dict'])
optimizerA.load_state_dict(ckpt['optimizerA_state_dict'])
optimizerB.load_state_dict(ckpt['optimizerB_state_dict'])
```