<br>

<div align=center><font color=maroon size=6><b>Saving and Loading Models</b></font></div>

<br>

<font size=4><b>References:</b></font>
1. Pytorch official tutorials: <a href="https://pytorch.org/tutorials/index.html" style="text-decoration:none;">WELCOME TO PYTORCH TUTORIALS</a>
    * <a href="https://pytorch.org/tutorials/beginner/saving_loading_models.html" style="text-decoration:none;">Saving and Loading Models</a>

<br>

**Author**: <a href="https://github.com/MatthewInkawhich" style="text-decoration:none;"><b>Matthew Inkawhich</b></a>

<br>
<br>

This document provides solutions to a variety of use cases regarding the saving and loading of PyTorch models. Feel free to read the whole document, or just skip to the code you need for a desired use case.

When it comes to saving and loading models, there are three core functions to be familiar with:

* <a href="https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save" style="text-decoration:none;"><font size=3 color=maroon>torch.save</font></a>: Saves a serialized object to disk. This function uses Python’s <a href="https://docs.python.org/3/library/pickle.html" style="text-decoration:none;">pickle</a> utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.


* <a href="https://pytorch.org/docs/stable/torch.html?highlight=torch%20load#torch.load" style="text-decoration:none;"><font size=3 color=maroon>torch.load</font></a>: Uses <a href="https://docs.python.org/3/library/pickle.html" style="text-decoration:none;">pickle</a>’s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see section `Saving & Loading Model Across Devices`).


* <a href="https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict" style="text-decoration:none;"><font size=3 color=maroon>torch.nn.Module.load_state_dict</font></a>: Loads a model’s parameter dictionary using a deserialized state_dict. For more information on state_dict, see section `What is a state_dict?`.

**Cotents:** (略)

<br>
<br>
<br>

# What is a state_dict?

In PyTorch, the learnable parameters (i.e. weights and biases) of an `torch.nn.Module model` are contained in the model’s parameters (accessed with `model.parameters()`). 

A state_dict is simply a Python dictionary object that maps each layer to its parameter tensor. <font color=maroon>Note that only layers with learnable parameters (convolutional layers, linear layers, etc.) and registered buffers (batchnorm’s running_mean) have entries in the model’s *state_dict*.</font>

<font color=maroon>Optimizer objects (torch.optim) also have a state_dict</font>, which contains information about the <b>optimizer’s state</b>, as well as the <b>hyperparameters</b> used.

Because *state_dict* objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

<br>

## Example:

Let’s take a look at the state_dict from the simple model used in the <a href="https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py" style="text-decoration:none;">Training a classifier</a> tutorial.

In [1]:
import torch.nn as nn
import torch.optim as optim

In [3]:
# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

print()
    
# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]


<br>
<br>
<br>

# Saving & Loading Model <font color=maroon><b>for Inference</b></font>

<br>
<br>

## Save/Load <font style="color:red;font-size:110%">`state_dict` (Recommended)</font>

<br>

### Save

<br>

### Load

<br>

<div class="alert alert-block alert-info">

<font size=3 color=red><b>Note: </b></font>

The 1.6 release of PyTorch switched `torch.save` to use a new zipfile-based file format. `torch.load` still retains the ability to load files in the old format. If for any reason you want `torch.save` to use the old format, pass the `kwarg _use_new_zipfile_serialization=False`.

</div>

When saving a model for inference, it is only necessary to save the trained model’s learned parameters. Saving the model’s `state_dict` with the `torch.save()` function will <font color=maroon>give you the most flexibility for restoring the model later, which is why it is the recommended method for saving models.</font>

<font size=3 color=maroon>A common PyTorch convention is to save models using either a **`.pt`** or **`.pth`** file extension.</font>

<font color=maroon size=3>Remember that</font> you must call `model.eval()` to set dropout and batch normalization layers to `evaluation mode` before running inference. Failing to do this will yield inconsistent inference results.

<div class="alert alert-block alert-info">

<font size=3 color=red><b>Note: </b></font>

Notice that the `load_state_dict()` function takes a dictionary object, NOT a path to a saved object. This means that you must deserialize the saved ***state_dict*** before you pass it to the `load_state_dict()` function. For example, you CANNOT load using `model.load_state_dict(PATH)`.

</div>

<div class="alert alert-block alert-info">

<font size=3 color=red><b>Note: </b></font>

If you only plan to keep the best performing model (according to the acquired validation loss), don’t forget that `best_model_state = model.state_dict()` returns a reference to the state and not its copy! You ***`must serialize`*** `best_model_state` or use `best_model_state = deepcopy(model.state_dict())` otherwise your best `best_model_state` will keep getting updated by the subsequent training iterations. As a result, the final model state will be the state of the overfitted model.

</div>

<br>
<br>

## Save/Load <font style="color:red;font-size:110%">Entire Model</font>

<br>

### Save

<br>

### Load

<br>

This save/load process uses the most intuitive syntax and involves the least amount of code. <font color=maroon>Saving a model in this way will save the entire module using Python’s a <href="https://docs.python.org/3/library/pickle.html" style="text-decoration:none;">pickle</a> module</font>. 

The **disadvantage** of this approach is that <font color=maroon>the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. The reason for this is because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code <b>can break in various ways when used in other projects or after refactors</b>.</font>

<font size=3 color=maroon>A common PyTorch convention is to save models using either a **`.pt`** or **`.pth`** file extension.</font>

<font color=maroon size=3>Remember that</font> you must call `model.eval()` to set dropout and batch normalization layers to `evaluation mode` before running inference. Failing to do this will yield inconsistent inference results.

<br>
<br>

## Export/Load Model in <font style="color:red;font-size:110%">TorchScript Format</font>

One common way to do inference with a trained model is to use TorchScript, an intermediate representation of a PyTorch model that can be run in Python as well as in a high performance environment like C++. <font size=4 color=red>TorchScript is actually the recommended model format for <b>scaled inference and deployment</b>.</font>

<div class="alert alert-block alert-info">

<font size=3 color=red><b>Note: </b></font>

Using the TorchScript format, you will be able to load the exported model and run inference without defining the model class.

</div>

<br>

### Export

<br>

### Load

<font color=maroon size=3>Remember that</font> you must call `model.eval()` to set dropout and batch normalization layers to `evaluation mode` before running inference. Failing to do this will yield inconsistent inference results.

For more information on TorchScript, feel free to visit the dedicated <a href="https://pytorch.org/tutorials/beginner/Intro_to_TorchScript_tutorial.html" style="text-decoration:none;">tutorials</a>. You will get familiar with the tracing conversion and learn how to run a TorchScript module in a <a href="https://pytorch.org/tutorials/advanced/cpp_export.html" style="text-decoration:none;">C++ environment</a>.

<br>
<br>
<br>

# Saving & Loading a <font color=maroon><b>General Checkpoint</b></font> for Inference and/or Resuming Training

<br>

### Save

<br>

###  Load

<br>
<br>

When saving a general checkpoint, to be used <font color=maroon size=3>for either inference or resuming training</font>, you must save
* `more than just the `model’s ***`state_dict`***. 

It is important to also save：

* the `optimizer’s `***`state_dict`***, as this contains buffers and parameters that are updated as the model trains. 

Other items that you may want to save are: 
* the `epoch` you left off on, 
* the `latest recorded training loss`, 
* external **`torch.nn.Embedding`**` layers`, etc. 

<font color=maroon size=3>As a result, such a checkpoint is often 2~3 times larger than the model alone.</font>

To save multiple components, organize them in a dictionary and use `torch.save()` to serialize the dictionary. A common PyTorch convention is <font color=maroon size=3>to save these checkpoints using the **`.tar`** file extension.</font>

<br>

<font color=maroon size=3>Remember that</font> you must call `model.eval()` to set dropout and batch normalization layers to `evaluation mode` before running inference. Failing to do this will yield inconsistent inference results.

If you wish to resuming training, call `model.train()` to ensure these layers are in training mode.

<br>
<br>
<br>

# Saving <font color=maroon><b>Multiple Models in One File</b></font>

<br>

### Save

<br>

### Load

<br>

When saving a model comprised of multiple `torch.nn.Modules`, such as a GAN, a sequence-to-sequence model, or an ensemble of models, you <font color=maroon>follow the same approach as when you are saving a general checkpoint</font>. In other words, save a dictionary of each model’s ***state_dict*** and corresponding optimizer. As mentioned before, you can save any other items that may aid you in resuming training by simply appending them to the dictionary.

A common PyTorch convention is <font color=maroon size=3>to save these checkpoints using the **`.tar`** file extension.</font>

To load the models, first initialize the models and optimizers, then load the dictionary locally using `torch.load()`. From here, you can easily access the saved items by simply querying the dictionary as you would expect.

<font color=maroon size=3>Remember that</font> you must call `model.eval()` to set dropout and batch normalization layers to `evaluation mode` before running inference. Failing to do this will yield inconsistent inference results.

If you wish to resuming training, call `model.train()` to ensure these layers are in training mode.

<br>
<br>
<br>

# <font color=maroon><b>Warmstarting Model</b> Using Parameters from a Different Model</font>

<br>

### Save

<br>

### Load

<br>

<font size=3 color=maroon>Partially loading a model or loading a partial model are common scenarios when <font size=4><b>transfer learning</b></font> or <b>training a new complex model</b></font>.</font>

<font size=3> Leveraging trained parameters, even if only a few are usable, will help to warmstart the training process and hopefully help your model converge much faster than training from scratch.</font>

<font size=3>Whether you are loading from a partial ***state_dict***, which is missing some keys, or loading a ***state_dict*** with more keys than the model that you are loading into, you can set the `strict` argument to **False** in the `load_state_dict()` function <font color=maroon>to ignore non-matching keys</font>.</font>

<font size=3 color=maroon>If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the ***state_dict*** that you are loading to match the keys in the model that you are loading into.</font>

<br>
<br>
<br>

# Saving & Loading Model <font color=maroon><b>Across Devices</b></font>

<br>

## Save on GPU, Load on CPU

### Save

<br>

### Load

<br>

When loading a model on a CPU that was trained with a GPU, pass `torch.device('cpu')` to the `map_location` argument in the `torch.load()` function. In this case, the storages underlying the tensors are dynamically remapped to the CPU device using the `map_location` argument.

<br>
<br>

## Save on GPU, Load on GPU

### Save

<br>

### Load

<br>

When loading a model on a GPU that was trained and saved on GPU, simply convert the initialized `model` to a CUDA optimized model using `model.to(torch.device('cuda'))`. 

Also, <font color=maroon>be sure to use the `.to(torch.device('cuda'))` function on all model inputs to prepare the data for the model.</font> 

<font color=maroon><b>Note that</b></font> calling `my_tensor.to(device)` returns a new copy of my_tensor on GPU. It does NOT overwrite `my_tensor`. Therefore, remember to manually overwrite tensors: `my_tensor = my_tensor.to(torch.device('cuda'))`.

<br>
<br>

## Save on CPU, Load on GPU

### Save

<br>

### Load

<br>

When loading a model on a GPU that was trained and saved on CPU, set the ``map_location`` argument in the `torch.load()` function to ***cuda:device_id***. This loads the model to a given GPU device. 

Next, be sure to call `model.to(torch.device('cuda'))` to convert the model’s parameter tensors to CUDA tensors. 

Finally, be sure to use the `.to(torch.device('cuda'))` function on all model inputs to prepare the data for the CUDA optimized model. 


<font color=maroon><b>Note that</b></font> calling `my_tensor.to(device)` returns a new copy of my_tensor on GPU. It does NOT overwrite `my_tensor`. Therefore, remember to manually overwrite tensors: `my_tensor = my_tensor.to(torch.device('cuda'))`.

<br>
<br>
<br>

# Saving **`torch.nn.DataParallel`** Models

### Save

<br>

### Load

<br>

`torch.nn.DataParallel` is a model wrapper that enables parallel GPU utilization. To save a `DataParallel` model generically, save the `model.module.state_dict()`. This way, you have the flexibility to load the model any way you want to any device you want.

<br>
<br>
<br>