# Saving and Loading Models

<a href="https://colab.research.google.com/github/jwangjie/gpytorch/blob/master/examples/00_Basic_Usage/Saving_and_Loading_Models.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this bite-sized notebook, we'll go over how to save and load models. In general, the process is the same as for any PyTorch module.

In [1]:
# COMMENT this if not used in colab
!pip install gpytorch



In [0]:
import math
import torch
import gpytorch
from matplotlib import pyplot as plt

## Saving a Simple Model

First, we define a GP Model that we'd like to save. The model used below is the same as the model from our
<a href="../01_Exact_GPs/Simple_GP_Regression.ipynb">Simple GP Regression</a> tutorial.

In [0]:
train_x = torch.linspace(0, 1, 100)
train_y = torch.sin(train_x * (2 * math.pi)) + torch.randn(train_x.size()) * 0.2

In [0]:
# We will use the simplest form of GP model, exact inference
class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = ExactGPModel(train_x, train_y, likelihood)

### Change Model State

To demonstrate model saving, we change the hyperparameters from the default values below. For more information on what is happening here, see our tutorial notebook on <a href="Hyperparameters.ipynb">Initializing Hyperparameters</a>.

In [0]:
model.covar_module.outputscale = 1.2
model.covar_module.base_kernel.lengthscale = 2.2

### Getting Model State

To get the full state of a GPyTorch model, simply call `state_dict` as you would on any PyTorch model. Note that the state dict contains **raw** parameter values. This is because these are the actual `torch.nn.Parameters` that are learned in GPyTorch. Again see our notebook on hyperparamters for more information on this.

In [6]:
model.state_dict()

OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('mean_module.constant', tensor([0.])),
             ('covar_module.raw_outputscale', tensor(0.8416)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]]))])

### Saving Model State

The state dictionary above represents all traininable parameters for the model. Therefore, we can save this to a file as follows:

In [0]:
torch.save(model.state_dict(), 'model_state.pth')

### Loading Model State

Next, we load this state in to a new model and demonstrate that the parameters were updated correctly.

In [8]:
state_dict = torch.load('model_state.pth')
model = ExactGPModel(train_x, train_y, likelihood)  # Create a new GP model

model.load_state_dict(state_dict)

<All keys matched successfully>

In [9]:
model.state_dict()

OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('mean_module.constant', tensor([0.])),
             ('covar_module.raw_outputscale', tensor(0.8416)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[2.0826]]))])

## A More Complex Example

Next we demonstrate this same principle on a more complex exact GP where we have a simple feed forward neural network feature extractor as part of the model.


In [0]:
class GPWithNNFeatureExtractor(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(GPWithNNFeatureExtractor, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        
        self.feature_extractor = torch.nn.Sequential(
            torch.nn.Linear(1, 2),
            torch.nn.BatchNorm1d(2),
            torch.nn.ReLU(),
            torch.nn.Linear(2, 2),
            torch.nn.BatchNorm1d(2),
            torch.nn.ReLU(),
        )
    
    def forward(self, x):
        x = self.feature_extractor(x)
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

# initialize likelihood and model
likelihood = gpytorch.likelihoods.GaussianLikelihood()
model = GPWithNNFeatureExtractor(train_x, train_y, likelihood)

### Getting Model State

In the next cell, we once again print the model state via `model.state_dict()`. As you can see, the state is substantially more complex, as the model now includes our neural network parameters. Nevertheless, saving and loading is straight forward.

In [11]:
model.state_dict()

OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('mean_module.constant', tensor([0.])),
             ('covar_module.raw_outputscale', tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
             ('feature_extractor.0.weight', tensor([[0.4711],
                      [0.6515]])),
             ('feature_extractor.0.bias', tensor([-0.5939,  0.0134])),
             ('feature_extractor.1.weight', tensor([1., 1.])),
             ('feature_extractor.1.bias', tensor([0., 0.])),
             ('feature_extractor.1.running_mean', tensor([0., 0.])),
             ('feature_extractor.1.running_var', tensor([1., 1.])),
             ('feature_extractor.1.num_batches_tracked', tensor(0)),
             ('feature_extractor.3.weight', tensor([[ 0.7055,  0.3596],
                      [-0.3409, -0.4008]])),
             ('feature_extractor.3.bias', tensor([0.0920, 0.3870])),
             ('feature_extractor.4.weight', tensor([1., 1.])),
 

In [12]:
torch.save(model.state_dict(), 'my_gp_with_nn_model.pth')
state_dict = torch.load('my_gp_with_nn_model.pth')
model = GPWithNNFeatureExtractor(train_x, train_y, likelihood)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [13]:
model.state_dict()

OrderedDict([('likelihood.noise_covar.raw_noise', tensor([0.])),
             ('mean_module.constant', tensor([0.])),
             ('covar_module.raw_outputscale', tensor(0.)),
             ('covar_module.base_kernel.raw_lengthscale', tensor([[0.]])),
             ('feature_extractor.0.weight', tensor([[0.4711],
                      [0.6515]])),
             ('feature_extractor.0.bias', tensor([-0.5939,  0.0134])),
             ('feature_extractor.1.weight', tensor([1., 1.])),
             ('feature_extractor.1.bias', tensor([0., 0.])),
             ('feature_extractor.1.running_mean', tensor([0., 0.])),
             ('feature_extractor.1.running_var', tensor([1., 1.])),
             ('feature_extractor.1.num_batches_tracked', tensor(0)),
             ('feature_extractor.3.weight', tensor([[ 0.7055,  0.3596],
                      [-0.3409, -0.4008]])),
             ('feature_extractor.3.bias', tensor([0.0920, 0.3870])),
             ('feature_extractor.4.weight', tensor([1., 1.])),
 