In [1]:
%matplotlib inline


What is a state_dict in PyTorch
===============================
In PyTorch, the learnable parameters (i.e. weights and biases) of a
``torch.nn.Module`` model are contained in the model’s parameters
(accessed with ``model.parameters()``). A ``state_dict`` is simply a
Python dictionary object that maps each layer to its parameter tensor.

Introduction
------------

Note that only layers with learnable parameters (convolutional layers,
linear layers, etc.) and registered buffers (batchnorm’s running_mean)
have entries in the model’s ``state_dict``. Optimizer objects
(``torch.optim``) also have a ``state_dict``, which contains information
about the optimizer’s state, as well as the hyperparameters used.

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = Net()
print(net)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)


In [4]:
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)

In [5]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in net.state_dict():
    print(param_tensor, "\t", net.state_dict()[param_tensor].size())

print()

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])

Optimizer's state_dict:
state 	 {}
param_groups 	 [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]}]


# Saving and loading model for inference
## save and load state_dict (recommended)

A common PyTorch convention is to save models using either a .pt or .pth file extension.

Remember that you must call model.eval() to set dropout and batch normalization layers to evaluation mode before running inference. Failing to do this will yield inconsistent inference results.

In [11]:
PATH = 'save.pth'

In [12]:
# load 
torch.save(net.state_dict(), PATH)

# save 
net.load_state_dict(torch.load(PATH))
# net.eval()

<All keys matched successfully>

## save and load the entire model

Saving a model in this way will save the entire module using Python’s pickle module. The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. The reason for this is because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code can break in various ways when used in other projects or after refactors.

In [13]:
# save
torch.save(net, PATH)

# load : Model class must be defined somewhere
net = torch.load(PATH)
# net.eval()

## saving and loading a general checkpoint for inference and resuming training

When saving a general checkpoint, to be used for either inference or resuming training, you must save more than just the model’s state_dict. It is important to also save the optimizer’s state_dict, as this contains buffers and parameters that are updated as the model trains. Other items that you may want to save are the epoch you left off on, the latest recorded training loss, external torch.nn.Embedding layers, etc. As a result, such a checkpoint is often 2~3 times larger than the model alone.

To save multiple components, organize them in a dictionary and use torch.save() to serialize the dictionary. A common PyTorch convention is to save these checkpoints using the .tar file extension.

If you wish to resuming training, call model.train() to ensure these layers are in training mode.

In [21]:
# save
epoch = 1
loss = 0.0
torch.save({
            'epoch': epoch,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            }, PATH)

#load 
net = Net()
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)

checkpoint = torch.load(PATH)
net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

# model.eval()
# # - or -
# model.train()

## saving multiple models in one file

Used when saving a model comprised of multiple torch.nn.Modules, such as a GAN, a sequence-to-sequence model, or an ensemble of models,.

In [27]:
# save
netA = Net()
optimizerA = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)
netB = Net()
optimizerB = optim.SGD(net.parameters(), lr = 0.001, momentum = 0.9)
torch.save({
            'modelA_state_dict': netA.state_dict(),
            'modelB_state_dict': netB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            }, PATH)

# load
checkpoint = torch.load(PATH)
netA.load_state_dict(checkpoint['modelA_state_dict'])
netB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

# modelA.eval()
# modelB.eval()
# # - or -
# modelA.train()
# modelB.train()

## Warmstarting Model Using Parameters from a Different Model

Partially loading a model or loading a partial model are common scenarios when transfer learning or training a new complex model. Leveraging trained parameters, even if only a few are usable, will help to warmstart the training process and hopefully help your model converge much faster than training from scratch.

Whether you are loading from a partial state_dict, which is missing some keys, or loading a state_dict with more keys than the model that you are loading into, you can set the strict argument to False in the load_state_dict() function to ignore non-matching keys.

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

In [28]:
# save
torch.save(netA.state_dict(), PATH)

# load
netB.load_state_dict(torch.load(PATH), strict = False)

<All keys matched successfully>

## saving and loading models across devices

## Save on GPU, Load on CPU
When loading a model on a CPU that was trained with a GPU, pass torch.device('cpu') to the map_location argument in the torch.load() function. In this case, the storages underlying the tensors are dynamically remapped to the CPU device using the map_location argument.

In [32]:
# save
torch.save(netA.state_dict(), PATH)

# load
device = torch.device('cpu')
netA.load_state_dict(torch.load(PATH, map_location = device))

<All keys matched successfully>

## Save on GPU, Load on GPU

When loading a model on a GPU that was trained and saved on GPU, simply convert the initialized model to a CUDA optimized model using model.to(torch.device('cuda')). Also, be sure to use the .to(torch.device('cuda')) function on all model inputs to prepare the data for the model. Note that calling my_tensor.to(device) returns a new copy of my_tensor on GPU. It does NOT overwrite my_tensor. Therefore, remember to manually overwrite tensors: my_tensor = my_tensor.to(torch.device('cuda')).

In [30]:
# save 
torch.save(net.state_dict(), PATH)

# load 
device = torch.device("cuda")
net.load_state_dict(torch.load(PATH))
net.to(device)

## Save on CPU, Load on GPU

When loading a model on a GPU that was trained and saved on CPU, set the map_location argument in the torch.load() function to cuda:device_id. This loads the model to a given GPU device. Next, be sure to call model.to(torch.device('cuda')) to convert the model’s parameter tensors to CUDA tensors. Finally, be sure to use the .to(torch.device('cuda')) function on all model inputs to prepare the data for the CUDA optimized model.

In [31]:
# save 
torch.save(net.state_dict(), PATH)

# load 
device = torch.device("cuda")
net.load_state_dict(torch.load(PATH, map_location = "cuda:0"))# Choose whatever GPU device number you want
net.to(device)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

## Saving torch.nn.DataParallel Models

torch.nn.DataParallel is a model wrapper that enables parallel GPU utilization. To save a DataParallel model generically, save the model.module.state_dict(). This way, you have the flexibility to load the model any way you want to any device you want.

In [35]:
# save
# torch.save(net.module.state_dict(), PATH)

# load
# Load to whatever device you want
