# Saving and Loading Models

In this notebook, I'll show you how to save and load models with PyTorch. This is important because you'll often want to load previously trained models to use in making predictions or to continue training on new data.

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt

import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms

import helper
import fc_model

In [2]:
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))])
# Download and load the training data
trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

# Download and load the test data
testset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

Here we can see one of the images.

# Train a network

To make things more concise here, I moved the model architecture and training code from the last part to a file called `fc_model`. Importing this, we can easily create a fully-connected network with `fc_model.Network`, and train the network using `fc_model.train`. I'll use this model (once it's trained) to demonstrate how we can save and load models.

In [18]:
# Create the network, define the criterion and optimizer
model = fc_model.Network(784, 10, [512,256,128])
criterion = nn.NLLLoss()
optimizer = optim.Adam( model.parameters(), lr=0.001)

In [19]:
fc_model.train(model, trainloader, testloader, criterion, optimizer, epochs=2)

Epoch: 1/2..  Training Loss: 1.663..  Test Loss: 0.966..  Test Accuracy: 0.657
Epoch: 1/2..  Training Loss: 1.046..  Test Loss: 0.779..  Test Accuracy: 0.706
Epoch: 1/2..  Training Loss: 0.880..  Test Loss: 0.763..  Test Accuracy: 0.710
Epoch: 1/2..  Training Loss: 0.829..  Test Loss: 0.657..  Test Accuracy: 0.752
Epoch: 1/2..  Training Loss: 0.728..  Test Loss: 0.646..  Test Accuracy: 0.762
Epoch: 1/2..  Training Loss: 0.738..  Test Loss: 0.611..  Test Accuracy: 0.767
Epoch: 1/2..  Training Loss: 0.696..  Test Loss: 0.579..  Test Accuracy: 0.778
Epoch: 1/2..  Training Loss: 0.704..  Test Loss: 0.590..  Test Accuracy: 0.778
Epoch: 1/2..  Training Loss: 0.651..  Test Loss: 0.559..  Test Accuracy: 0.787
Epoch: 1/2..  Training Loss: 0.630..  Test Loss: 0.539..  Test Accuracy: 0.802
Epoch: 1/2..  Training Loss: 0.639..  Test Loss: 0.535..  Test Accuracy: 0.806
Epoch: 1/2..  Training Loss: 0.657..  Test Loss: 0.549..  Test Accuracy: 0.793
Epoch: 1/2..  Training Loss: 0.625..  Test Loss: 0.5

## Saving and loading networks

As you can imagine, it's impractical to train a network every time you need to use it. Instead, we can save trained networks then load them later to train more or use them for predictions.

The parameters for PyTorch networks are stored in a model's `state_dict`. We can see the state dict contains the weight and bias matrices for each of our layers.

In [27]:
print(model)
print(model.state_dict().keys())

Network(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
  )
  (output): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'hidden_layers.2.weight', 'hidden_layers.2.bias', 'output.weight', 'output.bias'])


In [28]:
torch.save(model.state_dict(), 'checkpoint.pth')

In [29]:
state_dict = torch.load('checkpoint.pth')

In [30]:
model.load_state_dict(state_dict)

<All keys matched successfully>

In [41]:
checkpoint = {'input_size': 784,
             'output_size': 10,
             'hidden_layers': [each.out_features for each in model.hidden_layers],
             'state_dict': model.state_dict()}
torch.save(checkpoint, 'checkpoint.pth')

In [38]:
def load_checkpoint(filepath):
    checkpoint = torch.load(filepath)
    model = fc_model.Network(checkpoint['input_size'],
                             checkpoint['output_size'],
                             checkpoint['hidden_layers'])
    model.load_state_dict(checkpoint['state_dict'])
    
    return model

In [44]:
model2 = load_checkpoint('checkpoint.pth')
print(model2)
print(model2.state_dict().keys())

Network(
  (hidden_layers): ModuleList(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): Linear(in_features=512, out_features=256, bias=True)
    (2): Linear(in_features=256, out_features=128, bias=True)
  )
  (output): Linear(in_features=128, out_features=10, bias=True)
  (dropout): Dropout(p=0.5, inplace=False)
)
odict_keys(['hidden_layers.0.weight', 'hidden_layers.0.bias', 'hidden_layers.1.weight', 'hidden_layers.1.bias', 'hidden_layers.2.weight', 'hidden_layers.2.bias', 'output.weight', 'output.bias'])
