# Saving, loading and running model predictions in PyTorch

In [4]:
import torch
import torchvision.models as models

We can store the learned model parameters from training in an internal state dictionary. We call this "state_dict".

In [5]:
# save model (can convert into a function)
model = models.vgg16(weights='IMAGENET1K_V1')
torch.save(model.state_dict(), 'model_weights.pth')

We load model weights by creating an instance of the model class first, then load the model parameters. This is because the class defines the structure of the network.

In [6]:
# load model (can convert into a function)
model = models.vgg16()
model.load_state_dict(torch.load('model_weights.pth'))

# make sure we call this before inferencing:
# set dropout and batch normalisation layers to evaluation model to achieve consistent results
model.eval()

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

We can also save the model class structure with the actual trained model itself by passing 'model' into the saving function.

In [7]:
torch.save(model, 'model.pth')
model = torch.load('model.pth')