In [None]:
# For tips on running notebooks in Google Colab, see
# https://docs.pytorch.org/tutorials/beginner/colab
%matplotlib inline

[Learn the Basics](intro.html) \|\|
[Quickstart](quickstart_tutorial.html) \|\|
[Tensors](tensorqs_tutorial.html) \|\| [Datasets &
DataLoaders](data_tutorial.html) \|\|
[Transforms](transforms_tutorial.html) \|\| [Build
Model](buildmodel_tutorial.html) \|\|
[Autograd](autogradqs_tutorial.html) \|\|
[Optimization](optimization_tutorial.html) \|\| **Save & Load Model**

Save and Load the Model
=======================

In this section we will look at how to persist model state with saving,
loading and running model predictions.


In [1]:
import torch
import torchvision.models as models

Saving and Loading Model Weights
================================

PyTorch models store the learned parameters in an internal state
dictionary, called `state_dict`. These can be persisted via the
`torch.save` method:


In [2]:
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')
#只保存参数字典（不保存模型类定义/结构代码）。这是官方更推荐、更稳健的方式。

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth


100%|██████████| 528M/528M [00:05<00:00, 92.4MB/s]


To load model weights, you need to create an instance of the same model
first, and then load the parameters using `load_state_dict()` method.

In the code below, we set `weights_only=True` to limit the functions
executed during unpickling to only those necessary for loading weights.
Using `weights_only=True` is considered a best practice when loading
weights.


In [3]:
model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model
#这里先造一个“空壳结构相同”的 VGG16（参数是随机初始化的），随后用文件里的权重覆盖掉。

model.load_state_dict(torch.load('model_weights.pth', weights_only=True))
#torch.load(...) 读出你保存的 state_dict；load_state_dict(...) 把每个参数按名字匹配并拷贝进模型。

model.eval()
#把模型切到评估/推理模式：Dropout 会关闭随机失活，BatchNorm 会使用已学习到的统计量（均值/方差），从而让推理结果稳定可复现。


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>

<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">

<p>be sure to call <code>model.eval()</code> method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.</p>

</div>


Dropout

训练时：会“随机丢弃”一部分神经元输出，用来正则化。
评估/推理时：应该关闭随机丢弃，使用完整网络输出。
如果忘了 model.eval()：每次前向都会随机丢不同的单元，同一输入也可能得到不同输出。


BatchNorm（批归一化）

训练时：用当前 mini-batch 的均值/方差做归一化，并更新“运行中的均值/方差”（running stats）。
评估/推理时：应使用训练阶段累积的 running stats，使输出稳定。
如果忘了 model.eval()：会继续用推理时的 batch 统计量（而推理 batch 常很小甚至为 1），导致输出随 batch 波动，结果更不稳定。




Saving and Loading Models with Shapes
=====================================

When loading model weights, we needed to instantiate the model class
first, because the class defines the structure of a network.
加载模型权重时，我们需要先实例化模型类，因为类定义了网络结构
对应前面“先 models.vgg16() 再 load_state_dict()”。


 We might
want to save the structure of this class together with the model, in
which case we can pass `model` (and not `model.state_dict()`) to the
saving function:
有时我们希望把类的结构也和模型一起保存；这时可以把 model（而不是 model.state_dict()）传给保存函数：

这里讨论“保存整个模型对象”，不仅是权重



In [4]:
torch.save(model, 'model.pth')
#会把对象通过 pickle 序列化；优点是加载时不必手工重新写/实例化同结构（但见后面的风险说明）。


We can then load the model as demonstrated below.

As described in [Saving and loading
torch.nn.Modules](https://pytorch.org/docs/main/notes/serialization.html#saving-and-loading-torch-nn-modules),
saving `state_dict` is considered the best practice. However, below we
use `weights_only=False` because this involves loading the model, which
is a legacy use case for `torch.save`.


In [None]:
#加载整个模型对象需要更“完全”的反序列化能力；因此不能用只读权重模式

In [5]:
model = torch.load('model.pth', weights_only=False)
#从 model.pth 里加载整个模型对象（允许完整反序列化）

<div style="background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px"><strong>NOTE:</strong></div>

<div style="background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px">

<p>This approach uses Python <a href="https://docs.python.org/3/library/pickle.html">pickle</a> module when serializing the model, thus it relies on the actual class definition to be available when loading the model.</p>


这种方法在序列化模型时使用 Python 的 pickle 模块，因此在加载模型时必须能找到并使用原本的类定义。

如果你重命名了类/模块路径、代码不在同一位置、版本差异大，torch.load 可能找不到类而失败；同时反序列化也有安全风险（不要加载不可信来源的文件）

</div>



Related Tutorials
=================

-   [Saving and Loading a General Checkpoint in
    PyTorch](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html)
-   [Tips for loading an nn.Module from a
    checkpoint](https://pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html?highlight=loading%20nn%20module%20from%20checkpoint)
