
# 模型保存以及加载等一系列问题

参考 [save, load](https://zhuanlan.zhihu.com/p/107203828)



## 1. torch.nn.Module.state_dict

Module.state_dict()通过调用self._save_to_state_dict()将模型的self._parameters, self._buffers保存
进一个Orderdict中.

In [2]:
import torch
import torch.nn as nn

import torch.nn.functional as F
import torch.optim as optim
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print(optimizer)
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
    

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.001
    momentum: 0.9
    nesterov: False
    weight_decay: 0
)
Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [2842111917296, 2842111915424, 2842111916216, 2842111917728, 2842111917152, 2842111917872, 2842111915928, 2842111915784, 2842111916360, 2842111914632]}]



## 2. torch.save


### resource code

在 torch.serialization.save()里源码如下:

```
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
    # 其中obj一般是字典格式数据, f为`.pth, .pt`格式的二进制压缩文件, pickle_module默认为pickle
    # 
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
```


其中`_with_file_like`源码如下:

```
def _with_file_like(f, mode, body):
    """
    Executes a body function with a file object for f, opening
    it in 'mode' if it is a string filename.
    """
    new_fd = False
    if isinstance(f, str) or \
            (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
            (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
        new_fd = True
        f = open(f, mode)
    try:
        return body(f)
    finally:
        if new_fd:
            f.close()
```

```
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):

    pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
    pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
    pickle_module.dump(sys_info, f, protocol=pickle_protocol)
    pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
    pickler.persistent_id = persistent_id
    pickler.dump(obj)

    serialized_storage_keys = sorted(serialized_storages.keys())
    pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
    f.flush()

```

**可以看到其实整个保存过程就是利用的pickle序列化工具来保存字典数据**

example

```
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),  # dict(['weight': torch.tensor, 'bias': torch.tensor])
            'optimizer_state_dict': optimizer.state_dict(), # 一般也会把optimizer的有关优化器状态以及所用超参数的信息保存
            'loss': loss,
            ...
            }, PATH)
```
