# Saving and Loading in PyTorch

We have mentioned previously that we can load models that were pretrained on some other task and utilize those models for transfer learning. But there are many more reasons to save your PyTorch parameters. You might for example need to interrupt the training process and would like continue from your current checkpoint without restarting the whole training process at a later point. Whatever your reasons, it is essential to know what you need to save/load and how you can accomplish that.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
import torchvision.transforms as T
from torchvision.datasets import MNIST

## torch.save() and torch.load()

We will ease in and save a couple of simple tensors, before we discuss how we can save whole dictionaries of parameters. 

Leat's assume we have a simple tensor of ones, that we would like to save.

In [2]:
t = torch.ones(3, 3)
print(t)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


PyTorch `torch.save` function to deal with such a task. The function takes two arguments: `obj` which is the object we would like to save and `f`, which in our case is going to be a path.

In [3]:
torch.save(t, '../temp/t.pt')

Essentially the save function uses the python pickle module to serialize an object. We save the object using the `pt` extension. This is a common convention that PyTorch practicioners use and we will use this convention throughout this block.

Similarly we can use the `torch.load()` function to load objects.

In [4]:
new_t = torch.load('../temp/t.pt')
print(new_t)

tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]])


## Saving weights and dictionaries

When we deal with neural networks, the amount of information that we need to save increases. To demonstrate that we will pretrain a convolutional neural network and save all the parameters we need to continue training at a later time.

We use the basic convolutional neural network from the last section as our basis, therefore there is nothing new in terms of the model architecture.

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

import torchvision.transforms as T
from torchvision.datasets import MNIST

We set some the usual required parameters.

In [6]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE=32

And create the necessary datasets and dataloaders.

In [7]:
train_val_dataset = MNIST(root="../datasets/", train=True, download=True, transform=T.ToTensor())
test_dataset = MNIST(root="../datasets/", train=False, download=False, transform=T.ToTensor())

In [8]:
train_idxs, val_idxs = train_test_split(
                                range(len(train_val_dataset)),
                                test_size=0.1)

In [9]:
labels = train_val_dataset.targets.numpy()
indices = list(range(len(train_val_dataset)))
train_idxs, val_idxs = train_test_split(indices,
                                              test_size=0.1,
                                              stratify=labels)

In [10]:
train_dataset = Subset(train_val_dataset, train_idxs)
val_dataset = Subset(train_val_dataset, val_idxs)

In [11]:
train_dataloader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4, drop_last=True)
val_dataloader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=False)
test_dataloader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4, drop_last=False)

In [12]:
class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=2, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=2, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2, padding=0),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        
        shape = self.conv_layers(torch.zeros(1, 1, 28, 28)).shape
        num_features = torch.prod(torch.tensor(shape)).item()
        
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(num_features, 100),
            nn.ReLU(),
            nn.Linear(100, 10)
        )
        
    def get_features(self, input_features, layer=0):
        with torch.inference_mode():
            x = input_features
            for i in range(layer+1):
                x = self.conv_layers[i](x)
            return x
        
    def forward(self, features):
        features = self.conv_layers(features)
        features = self.fc_layers(features)
        return features

In [13]:
def track_performance(dataloader, model, criterion):
    # switch to evaluation mode
    model.eval()
    num_samples = 0
    num_correct = 0
    loss_sum = 0
    
    # no need to calculate gradients
    with torch.inference_mode():
        for batch_idx, (features, labels) in enumerate(dataloader):
            features = features.to(DEVICE)
            labels = labels.to(DEVICE)
            logits = model(features)
            
            predictions = logits.max(dim=1)[1]
            num_correct += (predictions == labels).sum().item()
            
            loss = criterion(logits, labels)
            loss_sum += loss.cpu().item()
            num_samples += len(features)
    
    # we return the average loss and the accuracy
    return loss_sum/num_samples, num_correct/num_samples

In [14]:
def train(num_epochs, train_dataloader, val_dataloader, model, criterion, optimizer, scheduler=None):
    history = {"train_loss": [], "val_loss": [], "train_acc": [], "val_acc": []}
    model.to(DEVICE)
    for epoch in range(num_epochs):
        for batch_idx, (features, labels) in enumerate(train_dataloader):
            model.train()
            features = features.to(DEVICE)
            labels = labels.to(DEVICE)
            
            # Empty the gradients
            optimizer.zero_grad()
            
            # Forward Pass
            logits = model(features)
            
            # Calculate Loss
            loss = criterion(logits, labels)
            
            # Backward Pass
            loss.backward()
            
            # Gradient Descent
            optimizer.step()
            
        train_loss, train_acc = track_performance(train_dataloader, model, criterion)
        val_loss, val_acc = track_performance(val_dataloader, model, criterion)
        
        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)

    
        print(f'Epoch: {epoch+1:>2}/{num_epochs} | Train Loss: {train_loss:.5f} | Val Loss: {val_loss:.5f} | Train Acc: {train_acc:.3f} | Val Acc: {val_acc:.3f}')
        
        if scheduler:
          scheduler.step(val_loss)
    return history            
            

In [15]:
model = Model()
optimizer = optim.Adam(params=model.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       factor=0.1,
                                                       mode='max',
                                                       patience=3,
                                                       verbose=True)

In [16]:
history = train(10, train_dataloader, val_dataloader, model, criterion, optimizer, scheduler)

Epoch:  1/10 | Train Loss: 0.00411 | Val Loss: 0.00421 | Train Acc: 0.959 | Val Acc: 0.954
Epoch:  2/10 | Train Loss: 0.00336 | Val Loss: 0.00369 | Train Acc: 0.967 | Val Acc: 0.963
Epoch:  3/10 | Train Loss: 0.00466 | Val Loss: 0.00598 | Train Acc: 0.957 | Val Acc: 0.945
Epoch:  4/10 | Train Loss: 0.00276 | Val Loss: 0.00364 | Train Acc: 0.974 | Val Acc: 0.967
Epoch:  5/10 | Train Loss: 0.00349 | Val Loss: 0.00421 | Train Acc: 0.966 | Val Acc: 0.960
Epoch:  6/10 | Train Loss: 0.00343 | Val Loss: 0.00434 | Train Acc: 0.966 | Val Acc: 0.959
Epoch:  7/10 | Train Loss: 0.00330 | Val Loss: 0.00429 | Train Acc: 0.969 | Val Acc: 0.961
Epoch 00007: reducing learning rate of group 0 to 1.0000e-03.
Epoch:  8/10 | Train Loss: 0.00145 | Val Loss: 0.00254 | Train Acc: 0.985 | Val Acc: 0.976
Epoch:  9/10 | Train Loss: 0.00125 | Val Loss: 0.00240 | Train Acc: 0.987 | Val Acc: 0.978
Epoch: 10/10 | Train Loss: 0.00120 | Val Loss: 0.00242 | Train Acc: 0.988 | Val Acc: 0.978


The model, optimizer and scheduler implement a dictionary called `state_dict`. Given that dictionary, we can reconstruct our objects. The model state_dict for example has to contain the layers and the corresponding weights and biases.

In [17]:
for key, value in model.state_dict().items():
    print(f'{key:>20}: {value.size()}')

conv_layers.0.weight: torch.Size([16, 1, 2, 2])
  conv_layers.0.bias: torch.Size([16])
conv_layers.3.weight: torch.Size([32, 16, 2, 2])
  conv_layers.3.bias: torch.Size([32])
conv_layers.6.weight: torch.Size([64, 32, 2, 2])
  conv_layers.6.bias: torch.Size([64])
  fc_layers.1.weight: torch.Size([100, 256])
    fc_layers.1.bias: torch.Size([100])
  fc_layers.3.weight: torch.Size([10, 100])
    fc_layers.3.bias: torch.Size([10])


All we have to do is to save those dictionaries and restore them at a later point. The `torch.save` function can save dictionaries without any trouble, so that is what we are going to do.

In [18]:
torch.save({'weights': model.state_dict(),
            'optim': optimizer.state_dict(),
            'scheduler': scheduler.state_dict()},
            f='../temp/testparams.pt')

Now let' assume at a later point we would like to load the old state of our model for inference.

We load the pt file.

In [19]:
params = torch.load('../temp/testparams.pt')

We get the model state from the dictionary.

In [20]:
model_state = params['weights']

We create a brand new model with randomized weights.

In [21]:
new_model = Model().to(DEVICE)

And finally we load the state dictionary into the new model.

In [22]:
new_model.load_state_dict(model_state)

<All keys matched successfully>

The accuracy on the test data is above 97%.

In [23]:
track_performance(test_dataloader, new_model, criterion)

(0.0021452212269392474, 0.9786)

Torchvision provides relatively easy access to pretrained computer vision models, so that you don't have to load the dictionaries yourself, but sooner or later you will have to deal with the process of saving and loading parameters.