<br>

<div align=center><font color=maroon size=6><b>Serialization semantics</b></font></div>

<font size=4><b>References:</b></font>
* <a href="https://pytorch.org/docs/stable/index.html" style="text-decoration:none;">Docs > PyTorch documentation</a>

    * **Notes**
        * Docs > 16 <a href="https://pytorch.org/docs/stable/notes/serialization.html" style="text-decoration:none;">Serialization semantics</a>


<br>
<br>
<br>

<font size=3>This note describes how you can save and load PyTorch tensors and module states in Python, and how to serialize Python modules so they can be loaded in C++.</font>

In [2]:
import torch

<br>

# Saving and loading tensors

<br>

<font size=3>
<a href="https://pytorch.org/docs/master/generated/torch.save.html#torch.save" style="text-decoration:none;font-size:120%">torch.save()</a> and 
<a href="https://pytorch.org/docs/master/generated/torch.load.html#torch.load" style="text-decoration:none;font-size:120%">torch.load()</a> let you easily save and load tensors:</font>

In [3]:
t = torch.tensor([1., 2.])
torch.save(t, './model_weights/tensor.pt')
torch.load('./model_weights/tensor.pt')

tensor([1., 2.])

<br>

<font size=3 color=maroon>By convention, PyTorch files are typically written with a **`.pt`** or **`.pth`** extension.</font>

<font size=3>`torch.save()` and `torch.load()` <font color=maroon>use Python’s pickle by default</font>, so you can also save multiple tensors as part of Python objects like tuples, lists, and dicts:</font>

In [4]:
d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
torch.save(d, './model_weights/tensor_dict.pt')
torch.load('./model_weights/tensor_dict.pt')

{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

<br>

<font size=3>Custom data structures that include PyTorch tensors can also be saved if the data structure is pickle-able.</font>

<br>
<br>
<br>

# Saving and loading tensors preserves views

<font size=3>Saving tensors preserves their view relationships:</font>

In [5]:
numbers = torch.arange(1, 10)
evens = numbers[1::2]
evens

tensor([2, 4, 6, 8])

In [7]:
torch.save([numbers, evens], './model_weights/tensors.pt')
loaded_numbers, loaded_evens = torch.load('./model_weights/tensors.pt')
l
oaded_numbers, loaded_evens

(tensor([1, 2, 3, 4, 5, 6, 7, 8, 9]), tensor([2, 4, 6, 8]))

In [8]:
loaded_evens *= 2    # 这将会改变 loaded_numbers
loaded_evens

tensor([ 4,  8, 12, 16])

In [9]:
loaded_numbers

tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

<br>

<font size=3>Behind the scenes, these tensors share the same “storage.” See <a href="https://pytorch.org/docs/master/tensor_view.html" style="text-decoration:none;font-size:120%">Tensor Views</a> for more on views and storage.</font>

<font size=3>When PyTorch saves tensors it saves their **storage objects** and **tensor metadata** separately. <font color=maroon>This is an implementation detail that may change in the future, but it typically saves space and lets PyTorch easily reconstruct the view relationships between the loaded tensors.</font> In the above snippet, for example, only a single storage is written to `tensors.pt`.</font>

<font size=3>In some cases, however, <font color=maroon>saving the current **storage objects** may be unnecessary and create prohibitively large files</font>. In the following snippet a storage much larger than the saved tensor is written to a file:</font>

In [10]:
large = torch.arange(1, 1000)
small = large[0:5]
torch.save(small, './model_weights/small.pt')

loaded_small = torch.load('./model_weights/small.pt')
loaded_small.storage().size()

999

<font size=3>Instead of saving only the five values in the small tensor to `small.pt`, the 999 values in the storage it shares with *large* were saved and loaded.

When saving tensors with fewer elements than their storage objects, <font color=maroon>the size of the saved file can be reduced by first **cloning** the tensors</font>. **Cloning a tensor** produces a new tensor with a new storage object containing only the values in the tensor:
</font>

In [11]:
large = torch.arange(1, 1000)
small = large[0:5]
torch.save(small.clone(), './model_weights/small.pt')   # saves a clone of small

loaded_small = torch.load('./model_weights/small.pt')
loaded_small.storage().size()

5

<font size=3 color=maroon>Since the cloned tensors are independent of each other, however, they have none of the view relationships the original tensors did. If both file size and view relationships are important when saving tensors smaller than their storage objects, then care must be taken to construct new tensors that minimize the size of their storage objects but still have the desired view relationships before saving.</font>

<br>
<br>
<br>

# Saving and loading torch.nn.Modules

See also: Tutorial: <a href="https://pytorch.org/tutorials/beginner/saving_loading_models.html" style="text-decoration:none;color:maroon;font-size:120%;">Saving and loading modules</a>

<br>

<font size=3>In PyTorch, a module’s state is frequently serialized using a ‘state dict.’ A module’s state dict contains all of its `parameters` and `persistent buffers`:</font>

In [12]:
bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
list(bn.named_parameters())

[('weight',
  Parameter containing:
  tensor([1., 1., 1.], requires_grad=True)),
 ('bias',
  Parameter containing:
  tensor([0., 0., 0.], requires_grad=True))]

In [13]:
list(bn.named_buffers())

[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

In [14]:
bn.state_dict()

OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

<br>

<font size=3>Instead of saving a module directly, for compatibility reasons it is recommended to instead save only its state dict. Python modules even have a function, <a href="https://pytorch.org/docs/master/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict" style="text-decoration:none;font-size:120%">load_state_dict()</a>, to restore their states from a state dict:</font>

In [15]:
torch.save(bn.state_dict(), './model_weights/bn.pt')
bn_state_dict = torch.load('./model_weights/bn.pt')
new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
new_bn.load_state_dict(bn_state_dict)

<All keys matched successfully>

<font size=3>Note that the state dict is first loaded from its file with `torch.load()` and the state then restored with `load_state_dict()`.</font>

<br>

<font size=3 color=maroon>Even custom modules and modules containing other modules have state dicts and can use this pattern:</font>

In [16]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

    def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)


m = MyModule()
m.state_dict()

OrderedDict([('l0.weight',
              tensor([[ 0.4406,  0.0225,  0.2327,  0.2363],
                      [-0.0538, -0.0514, -0.4410, -0.4944]])),
             ('l0.bias', tensor([ 0.1049, -0.0507])),
             ('l1.weight', tensor([[ 0.2251, -0.6717]])),
             ('l1.bias', tensor([-0.3721]))])

In [17]:
torch.save(m.state_dict(), './model_weights/mymodule.pt')
m_state_dict = torch.load('./model_weights/mymodule.pt')
new_m = MyModule()
new_m.load_state_dict(m_state_dict)

<All keys matched successfully>

<br>
<br>
<br>

# <font color=maroon>Serializing torch.nn.Modules and loading them in C++</font>

See also: <a href="https://pytorch.org/tutorials/advanced/cpp_export.html" style="text-decoration:none;color:maroon;font-size:120%;">Tutorial: Loading a TorchScript Model in C++</a>

<font size=3>ScriptModules can be serialized as a TorchScript program and loaded using <a href="https://pytorch.org/docs/master/generated/torch.jit.load.html#torch.jit.load" style="text-decoration:none;font-size:120%">torch.jit.load()</a>. This serialization encodes all the modules’ methods, submodules, parameters, and attributes, and it allows the serialized program to be loaded in C++ (i.e. without Python).</font>

<font size=3 color=maroon>The distinction between `torch.jit.save()` and `torch.save()` may not be immediately clear.</font> 

* <a href="https://pytorch.org/docs/master/generated/torch.save.html#torch.save" style="text-decoration:none;font-size:140%">torch.save()</a> saves Python objects with pickle. This is especially useful for prototyping, researching, and training. 


* <a href="https://pytorch.org/docs/master/generated/torch.jit.save.html#torch.jit.save" style="text-decoration:none;font-size:140%">torch.jit.save()</a>, on the other hand, serializes ScriptModules to a format that can be loaded in Python or C++. This is useful when saving and loading C++ modules or for running modules trained in Python with C++, a common practice when deploying PyTorch models.

<br>

<font size=4 color=maroon>To **script**, **serialize** and **load** a module in Python:</font>

In [18]:
scripted_module = torch.jit.script(MyModule())   # script
torch.jit.save(scripted_module, './model_weights/mymodule.pt')   # serialize
torch.jit.load('./model_weights/mymodule.pt')                    # load

RecursiveScriptModule(
  original_name=MyModule
  (l0): RecursiveScriptModule(original_name=Linear)
  (l1): RecursiveScriptModule(original_name=Linear)
)

<br>

<font size=3><font color=maroon>**Traced modules** can also be saved with `torch.jit.save()`, with the caveat that only the traced code path is serialized.</font> The following example demonstrates this:</font>

In [19]:
# A module with control flow
class ControlFlowModule(torch.nn.Module):
    
    def __init__(self):
        super(ControlFlowModule, self).__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

    def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

In [22]:
traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
torch.jit.save(traced_module, './model_weights/controlflowmodule_traced.pt')
loaded = torch.jit.load('./model_weights/controlflowmodule_traced.pt')
loaded(torch.randn(2, 4))

tensor([[0.1782],
        [0.4974]], grad_fn=<AddmmBackward0>)

In [24]:
scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
torch.jit.save(scripted_module, './model_weights/controlflowmodule_scripted.pt')
loaded = torch.jit.load('./model_weights/controlflowmodule_scripted.pt')
loaded(torch.randn(2, 4))

tensor(0)

<font size=3 color=maroon><br>
* The above module has an if statement that is **not triggered by the traced inputs**, and so is not part of the **traced module** and not serialized with it. 

    
* The **scripted module**, however, contains the if statement and is serialized with it. </font>

See the <a href="https://pytorch.org/docs/stable/jit.html" style="text-decoration:none;color:maroon;font-size:120%;">TorchScript documentation</a> for more on scripting and tracing.

<br>

<font size=3>Finally, to load the module in C++:</font>

<font size=5>
    
```c++
torch::jit::script::Module module;
```
```c++
module = torch::jit::load('controlflowmodule_scripted.pt');
```
    
</font>

See the <a href="https://pytorch.org/cppdocs/" style="text-decoration:none;color:maroon;font-size:120%;">PyTorch C++ API documentation</a> for details about how to use PyTorch modules in C++.

<br>
<br>
<br>

# Saving and loading ScriptModules across PyTorch versions

<font size=3><font color=maroon>The PyTorch Team recommends saving and loading modules with the same version of PyTorch.</font> Older versions of PyTorch may not support newer modules, and newer versions may have removed or modified older behavior. These changes are explicitly described in PyTorch’s release notes, and modules relying on functionality that has changed may need to be updated to continue working properly. </font>


<font color=maroon size=3>In limited cases, detailed below, PyTorch will preserve the historic behavior of serialized ScriptModules so they do not require an update.</font>

## torch.div performing integer division

详见 <a href="https://pytorch.org/docs/master/notes/serialization.html#torch-div-performing-integer-division" style="text-decoration:none;color:maroon;font-size:120%;">这里</a>

<br>

## torch.full always inferring a float dtype

详见 <a href="https://pytorch.org/docs/master/notes/serialization.html#torch-full-always-inferring-a-float-dtype" style="text-decoration:none;color:maroon;font-size:120%;">这里</a>

<br>
<br>
<br>