
# 模型保存以及加载等一系列问题

参考 [save, load](https://zhuanlan.zhihu.com/p/107203828), [module/optimizer.state_dict](https://zhuanlan.zhihu.com/p/84797438)
[nn.Module](https://zhuanlan.zhihu.com/p/340453841)


## 1. torch.nn.Module.state_dict

Module.state_dict()通过调用self._save_to_state_dict()将模型的self._parameters, self._buffers保存
进一个Orderdict中.

Optimizer.state_dict()返回 `{
            'state': packed_state,
            'param_groups': param_groups,
        }`

In [3]:
import torch
import torch.nn as nn

import torch.nn.functional as F
import torch.optim as optim
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
print(optimizer)
# Print model's state_dict
print("Model's state_dict:")
print(model.state_dict().keys())
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])
    

SGD (
Parameter Group 0
    dampening: 0
    lr: 0.001
    momentum: 0.9
    nesterov: False
    weight_decay: 0
)
Model's state_dict:
odict_keys(['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [3176176701872, 3176176702016, 3176174379944, 3176176482632, 3176176482128, 3176176481696, 3176176482848, 3176176481552, 3176176482920, 3176176481048]}]



## 2. torch.save

将字典数据序列化保存为二进制文件


### resource code

在 torch.serialization.save()里源码如下:

```
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):
    # 其中obj一般是字典格式数据, f为`.pth, .pt`格式的二进制压缩文件, pickle_module默认为pickle
    # 
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
```


其中`_with_file_like`源码如下:

```
def _with_file_like(f, mode, body):
    """
    Executes a body function with a file object for f, opening
    it in 'mode' if it is a string filename.
    """
    new_fd = False
    if isinstance(f, str) or \
            (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
            (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
        new_fd = True
        f = open(f, mode)
    try:
        return body(f)
    finally:
        if new_fd:
            f.close()
```

```
def save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL):

    pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
    pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
    pickle_module.dump(sys_info, f, protocol=pickle_protocol)
    pickler = pickle_module.Pickler(f, protocol=pickle_protocol)
    pickler.persistent_id = persistent_id
    pickler.dump(obj)

    serialized_storage_keys = sorted(serialized_storages.keys())
    pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol)
    f.flush()

```

**可以看到其实整个保存过程就是利用的pickle序列化工具来保存字典数据**

#### example

```
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),  # dict(['weight': torch.tensor, 'bias': torch.tensor])
            'optimizer_state_dict': optimizer.state_dict(), # 一般也会把optimizer的有关优化器状态以及所用超参数的信息保存
            'loss': loss,
            ...
            }, PATH)
```



## 3. torch.load

反序列化 `torch.load('.pth')`

`load(f, map_location=None, pickle_module=pickle, **pickle_load_args)` 加载从`torch.save`
保存的二进制文件. 利用python的pickle反序列化pickle.load来对序列化的文件进行加载以及反序列化为字典数据.

**关于其中的map_location参数**
    * `map_location`参数接受两个参数: `storage`&`location` 
    ```
    # If map_location is a callable, it will be called once for each serialized storage with two arguments: storage and location. 
    # The storage argument will be the initial deserialization of the storage, residing on the CPU. 
        Each serialized storage has a location tag associated with it which identifies the device it was saved from, and this tag is the second argument passed to map_location. 
        The builtin location tags are 'cpu' for CPU tensors and 'cuda:device_id' (e.g. 'cuda:2') for CUDA tensors. map_location should return either None or a storage. 
        If map_location returns a storage, it will be used as the final deserialized object, already moved to the right device. 
        Otherwise, torch.load() will fall back to the default behavior, as if map_location wasn’t specified.

    # If map_location is a torch.device object or a string containing a device tag, it indicates the location where all tensors should be loaded.

    # Otherwise, if map_location is a dict, it will be used to remap location tags appearing in the file (keys), to ones that specify where to put the storages (values).```
    
    # When you call torch.load() on a file which contains GPU tensors, those tensors will be loaded to GPU by default. 
        You can call torch.load(.., map_location='cpu') and then load_state_dict() to avoid GPU RAM surge when loading a model checkpoint.
#### Example: 

        >>> torch.load('tensors.pt')
        # Load all tensors onto the CPU
        >>> torch.load('tensors.pt', map_location=torch.device('cpu'))
        # Load all tensors onto the CPU, using a function
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage)
        # Load all tensors onto GPU 1
        >>> torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1))
        # Map tensors from GPU 1 to GPU 0
        >>> torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
        # Load tensor from io.BytesIO object
        >>> with open('tensor.pt', 'rb') as f:
                buffer = io.BytesIO(f.read())
        >>> torch.load(buffer)
        # Load a module with 'ascii' encoding for unpickling
        >>> torch.load('module.pt', encoding='ascii')


## 4. nn.Module.load_state_dict
][=-
`model.load_state_dict(torch.load('.pth'))` 会通过调用**每个子模块**的_load_from_state_dict 函数来加载他们所需的权重.
而 _load_from_state_dict 才是真正负责加载 parameter 和 buffer 的函数.
这也说明了每个模块可以自行定义他们的 _load_from_state_dict 函数来满足特殊需求.

```
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                              missing_keys, unexpected_keys, error_msgs):
    # 获取模型的结构
    local_name_params = itertools.chain(self._parameters.items(), self._buffers.items()) 
    local_state = {k: v.data for k, v in local_name_params if v is not None}
    
    # 对每一个结构参数进行加载
    for name, param in local_state.items():
        if isinstance(input_param, Parameter):  
         # backwards compatibility for serialized parameters
         
         # 获取参数数据
         input_param = input_param.data
         
         # 加载数据tensor
         try:
            param.copy_(input_param)  

```


##### 利用 load_from_state_dict来无痛加载迁移模型 [_load_from_state_dict](https://zhuanlan.zhihu.com/p/340453841)

In [None]:
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
                          missing_keys, unexpected_keys, error_msgs):
    # override the _load_from_state_dict function
    # convert the backbone weights pre-trained in Mask R-CNN
    # use list(state_dict.keys()) to avoid
    # RuntimeError: OrderedDict mutated during iteration
    for key_name in list(state_dict.keys()):
        key_changed = True
        if key_name.startswith('backbone.'):
            new_key_name = f'img_backbone{key_name[8:]}'
        elif key_name.startswith('neck.'):
            new_key_name = f'img_neck{key_name[4:]}'
        elif key_name.startswith('rpn_head.'):
            new_key_name = f'img_rpn_head{key_name[8:]}'
        elif key_name.startswith('roi_head.'):
            new_key_name = f'img_roi_head{key_name[8:]}'
        else:
            key_changed = False
        if key_changed:
            logger = get_root_logger()
            print_log(
                f'{key_name} renamed to be {new_key_name}', logger=logger)
            state_dict[new_key_name] = state_dict.pop(key_name)
    super()._load_from_state_dict(state_dict, prefix, local_metadata,
                                  strict, missing_keys, unexpected_keys,
                                  error_msgs)
