# Saving and Loading Models

This document provides solutions to a variety of use cases regarding the saving and loading of PyTorch models. Feel free to read the whole document, or just skip to the code you need for a desired use case.

When it comes to saving and loading models, there are three core functions to be familiar with:

1. *torch.save*: Saves a serialized object to disk. This function uses Python’s *pickle* utility for serialization. Models, tensors, and dictionaries of all kinds of objects can be saved using this function.
2. *torch.load*: Uses pickle's unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into (see Saving & Loading Model Across Devices).
2. *torch.nn.Module.load_state_dict*: Loads a model's parameter dictionary using a deserialized state_dict. For more information on state_dict, see What is a state_dict?

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
# Define model
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## What's a state_dict?

A **state_dict** is simply a Python dictionary object that maps each layer to its parameter tensor in a model. Optimizer objects (torch.optim) also have a **state_dict**, which contains information about the optimizer's state, as well as the hyperparameters used. Note that only layers with learnable parameters(convolutional layers, linear layers, etc.) have entries in the model's state_dict.

Because state_dict objects are Python dictionaries, they can be easily saved, updated, altered, and restored, adding a great deal of modularity to PyTorch models and optimizers.

In [None]:
# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

## Saving & Loading for Inference

When saving a model for inference, it is only necessary to save the trained model's learned parameters. Saving the model's state_dict with the torch.save() function will give you the most flexibility for restoring the model later, which is why it is the recommended method for saving models.

A common PyTorch convention is to save models using either a .pt or .pth file extension.

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

In [None]:
torch.save(model.state_dict(), 'data/infer-model.pth')

In [None]:
loaded_model = TheModelClass()
loaded_model.load_state_dict(torch.load('data/infer-model.pth'))
loaded_model.eval()

## Saving & Loading for Warmstarting

Partially loading a model or loading a partial model are common scenarios when transfer learning or training a new complex model. Leveraging trained parameters, even if only a few are usable, will help to warmstart the training process and hopefully help your model converge much faster than training from scratch.

Whether you are loading from a partial state_dict, which is missing some keys, or loading a state_dict with more keys than the model that you are loading into, you can set the strict argument to False in the load_state_dict() function to ignore non-matching keys.

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

In [None]:
modelB = TheModelClass()
modelB.load_state_dict(torch.load('data/infer-model.pth'), strict=False)

## Saving & Loading for Checkpoint

When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just the model's state_dict. It is important to also save the optimizer's state_dict, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the epoch you left off on, the latest recorded training loss, external torch.nn.Embedding layers, etc.

To save multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the .tar file extension.

To load the items, first initialize the model and optimizer, then load the dictionary locally using torch.load(). From here, you can easily access the saved items by simply querying the dictionary as you would expect.

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results. If you wish to resuming training, call model.train() to ensure these layers are in training mode

In [None]:
epoch = 25
loss = 1.0

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss,
    }, 'data/checkpoint.pth')

In [None]:
checkpoint = torch.load('data/checkpoint.pth')

cp_model = TheModelClass()
cp_model.load_state_dict(checkpoint['model_state_dict'])

cp_optimizer = optim.SGD(cp_model.parameters(), lr=0.001)
cp_optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

epoch = checkpoint['epoch']
loss = checkpoint['loss']

cp_model.eval()

In [None]:
epoch, loss