<h1>Save and Load PyTorch Models</h1>

<h3 style="color: yellow;">There are two ways to save the model: (a) saving the entire model with <span style="color: red;">torch.save(model, 'model.pth')</span> or (b) saving only the model parameters, which is the preferred method, using <span style="color: red;">torch.save(model.state_dict(), 'model_weights.pth')</span>.</h3>

<h3 style="color: yellow;">Beyond just the model parameters, it's possible to save attributes such as the number of epochs, loss values, optimizer states, and more.</h3>

<h3 style="color: yellow;">There are three primary methods to keep in mind: (a) <span style="color: red;">torch.save(arg, PATH)</span>, (b) <span style="color: red;">torch.load(PATH)</span>, and (c) <span style="color: red;">model.load_state_dict(arg)</span>.</h3>

<h3 style="color: yellow;"><span style="color: red;">torch.save(arg, PATH)</span> can use tensors, models, or dictionaries as parameters for saving. It employs the Python Pickle module to serialize and save objects.</h3>


In [30]:
import torch
import torch.nn as nn

class Model(nn.Module):
    def __init__(self,n_feat, *args, **kwargs):
        super(Model,self).__init__(*args, **kwargs)
        self.n_feat=n_feat
        self.linear=nn.Linear(n_feat,1)
    def forward(self,x):
        x=torch.sigmoid(self.linear(x))
        return x


# Lazy method for saving and loading model
import torch
import torch.nn as nn

class DummyModel(nn.Module):
    def __init__(self, n_input_features, *args, **kwargs):
        super(DummyModel, self).__init__()
        self.linear=nn.Linear(n_input_features,1)

    def forward(self, x):
        x=torch.sigmoid(self.linear(x))
        return x

model = DummyModel(3)

torch.save(model,'lazy_model.pth')
model=torch.load('lazy_model.pth')
model.eval()

DummyModel(
  (linear): Linear(in_features=3, out_features=1, bias=True)
)

In [31]:
for parm in model.parameters():
    print(parm)


Parameter containing:
tensor([[-0.4445,  0.0294, -0.0880]], requires_grad=True)
Parameter containing:
tensor([0.2566], requires_grad=True)


In [32]:
# The recommended method for saving and loading model by using "state_dict"
import torch
import torch.nn as nn

class DummyModel1(nn.Module):
    def __init__(self,n_input_features):
        super(DummyModel1, self).__init__()
        self.liner=nn.Linear(n_input_features,1)
    def forward(self, x):
        x=torch.sigmoid(self.linear(x))
        return x
    
# Saving the model
model = DummyModel1(3)
for param in model.parameters():
    print(param)
    
torch.save(model.state_dict(), 'modle.pth')        
torch.save(model.state_dict(), 'model_state_dict.pth')

# Loading the model
model.load_state_dict(torch.load('model_state_dict.pth'))
model.eval()  # It's a good practice to call eval() after loading to set dropout and batch normalization layers to evaluation mode.
# During training, BatchNorm uses the batch's mean and variance. However, during inference, it uses the moving average of the mean and variance, which was computed during training
# During inference (or evaluation), we want to utilize the full network without dropping any units, so dropout is turned off.


Parameter containing:
tensor([[0.0770, 0.5616, 0.5434]], requires_grad=True)
Parameter containing:
tensor([-0.5496], requires_grad=True)


DummyModel1(
  (liner): Linear(in_features=3, out_features=1, bias=True)
)

In [33]:
# We loaded the parameters of the same model architecture into a new model instance.
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[0.0770, 0.5616, 0.5434]], requires_grad=True)
Parameter containing:
tensor([-0.5496], requires_grad=True)


In [36]:
# Print state_dict
print(model.state_dict())

OrderedDict([('liner.weight', tensor([[ 0.3238, -0.1527, -0.4669]])), ('liner.bias', tensor([-0.3910]))])


In [37]:
# Saving the checkpoint during training

import torch
import torch.nn as nn

class DummyModel1(nn.Module):
    def __init__(self,n_input_features):
        super(DummyModel1, self).__init__()
        self.liner=nn.Linear(n_input_features,1)
    def forward(self, x):
        x=torch.sigmoid(self.linear(x))
        return x
    
# Saving the model
model = DummyModel1(3)
LR=0.01
optimizer=torch.optim.SGD(model.parameters(),lr=LR)

for param in model.parameters():
    print(param)

print('')
print(optimizer.state_dict())   # We will notice her that the optimizer has its own state_dict


Parameter containing:
tensor([[-0.3572, -0.3858, -0.0631]], requires_grad=True)
Parameter containing:
tensor([-0.4825], requires_grad=True)

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]}


In [20]:
# During the training process we can save the checkpoint at a certain epoch
checkpoint={
    'epoch':60,
    'model_state':model.state_dict(),
    'optim_state':optimizer.state_dict()    
}
torch.save(checkpoint,'checkpoint.pth')

In [38]:
# Load the checkpoint
loaded_checkpoint=torch.load('checkpoint.pth')
epoch=loaded_checkpoint['epoch']
epoch

60

In [40]:
# What if we want to continue the training process from the checkpoint?

model=DummyModel1(n_input_features=3)
optimizer=torch.optim.SGD(model.parameters(),lr=0) # LR is zero because we will load the optimizer state from the checkpoint to continue the training process
model.load_state_dict(loaded_checkpoint['model_state'])
optimizer.load_state_dict(loaded_checkpoint['optim_state'])

print(model.state_dict())
print(optimizer.state_dict())

OrderedDict([('liner.weight', tensor([[-0.0370,  0.1912,  0.5491]])), ('liner.bias', tensor([-0.5027]))])
{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]}


In [41]:
# If the training and saving processes were completed on a CPU, we can use the aforementioned implementation.
# If the training was done on a GPU and we want to load on a CPU, we should proceed as follows:

device=torch.device('cuda')
model.to(device)
torch.save(model.state_dict(),'model.pth')

#loading ...

device=torch.device('cpu')
model=DummyModel1(n_input_features=3)
model.load_state_dict(torch.load('model.pth',map_location=device))

<All keys matched successfully>

In [None]:
# If the training has completed on GPU, and want to load in a GPU, we follow:
device=torch.device('cuda')
model.to(device)
torch.save(model.state_dict(),'model.pth')

# Loading ...

model=DummyModel1(n_input_features=3)
model.load_state_dict(torch.load('model.pth'))
model.to(device)

In [None]:
# If the training has completed on CPU, and want to load in a GPU, we follow:
torch.save(model.state_dict(),'model.pth')

device=torch.device('cuda')
model=DummyModel(n_input_features=3)
model.load(torch.load('model.pth',map_location=device))
model.to(device)


In [46]:
# Save and load model during training
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

In [47]:
# Cuda
device= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [48]:
# Simple CNN
class SCNN(nn.Module):
    def __init__(self,n_channels,n_classes):
        super(SCNN,self).__init__()
        self.n_channels=n_channels
        self.n_classes=n_classes
        self.conv1=nn.Conv2d(in_channels=self.n_channels,out_channels=8,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        self.max_pool=nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
        self.conv2=nn.Conv2d(in_channels=8,out_channels=16,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        self.fc=nn.Linear(16*7*7, self.n_classes)
        
    def forward(self,x):
        out=self.max_pool(F.relu(self.conv1(x)))
        out=self.max_pool(F.relu(self.conv2(out)))
        out=out.reshape(out.shape[0],-1)
        out=self.fc(out)
        return out   

In [51]:
# Constants and hyperparameters
IN_CHANNELS=1
NUM_CLASSES=10
LR=0.001
BATCH_SIZE=128
EPOCHS=5


In [52]:

# Datset and dataloader
train_dataset=MNIST(root='dataset/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset=MNIST(root='dataset/',train=False,transform=transforms.ToTensor(),download=True)
train_loader=DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader=DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=True)

In [56]:
iter_=iter(train_loader)
images,labels=next(iter_)
images  [0].shape

torch.Size([1, 28, 28])

In [12]:
# Model initialization
model=SCNN(IN_CHANNELS,NUM_CLASSES).to(device)
loss=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=LR)

In [13]:
# training loop
def save_checkpoint(state,file_path='my_checkpoint.pth.tar'):
    print('Saving checkpoint')
    torch.save(state,file_path)


checkpoint={'model_state': model.state_dict(),'optimizer_state':optimizer.state_dict()}



for epoch in range(EPOCHS):
    checkpoint={'state_dict':model.state_dict(),'optimizer':optimizer.state_dict()}
    if epoch%2==0:  # Save every two epochs
        save_checkpoint(checkpoint)
        
    for i, (data,labels) in enumerate(train_loader):
        data=data.to(device)
        labels=labels.to(device)
        prediction=model(data)
        loss_=loss(prediction,labels)
        optimizer.zero_grad()
        loss_.backward()
        optimizer.step()
        
        if i%20==0:
            print(f'{epoch}/{EPOCHS} | step: {i}/ {len(train_loader)} | loss: {loss_.item():.4f}')   
        


Saving checkpoint
0/5 | step: 0/ 469 | loss: 2.3037
0/5 | step: 20/ 469 | loss: 2.0658


0/5 | step: 40/ 469 | loss: 1.4196
0/5 | step: 60/ 469 | loss: 0.7532
0/5 | step: 80/ 469 | loss: 0.5039
0/5 | step: 100/ 469 | loss: 0.5457
0/5 | step: 120/ 469 | loss: 0.3653
0/5 | step: 140/ 469 | loss: 0.4000
0/5 | step: 160/ 469 | loss: 0.3868
0/5 | step: 180/ 469 | loss: 0.2893
0/5 | step: 200/ 469 | loss: 0.2670
0/5 | step: 220/ 469 | loss: 0.2309
0/5 | step: 240/ 469 | loss: 0.1280
0/5 | step: 260/ 469 | loss: 0.2090
0/5 | step: 280/ 469 | loss: 0.2064
0/5 | step: 300/ 469 | loss: 0.2114
0/5 | step: 320/ 469 | loss: 0.2540
0/5 | step: 340/ 469 | loss: 0.1429
0/5 | step: 360/ 469 | loss: 0.2036
0/5 | step: 380/ 469 | loss: 0.1336
0/5 | step: 400/ 469 | loss: 0.1237
0/5 | step: 420/ 469 | loss: 0.0952
0/5 | step: 440/ 469 | loss: 0.2018
0/5 | step: 460/ 469 | loss: 0.1668
1/5 | step: 0/ 469 | loss: 0.2461
1/5 | step: 20/ 469 | loss: 0.1150
1/5 | step: 40/ 469 | loss: 0.0900
1/5 | step: 60/ 469 | loss: 0.1047
1/5 | step: 80/ 469 | loss: 0.0783
1/5 | step: 100/ 469 | loss: 0.1295
1

In [14]:
# check accuracy on training and test to see how good our model is
def check_accuracy(loader,model):
    if loader.dataset.train:
        print('Checking accuracy on training data')
    else:
        print('Checking accuracy on test data')
    num_correct=0
    num_samples=0
    model.eval()
    
    with torch.no_grad():
        for data,label in loader:
            data=data.to(device)
            label=label.to(device)
            prediction=model(data)
            _,pred=prediction.max(1)
            num_correct+=(pred==label).sum()
            num_samples+=pred.size(0)
        print(f'Obtained {num_correct}/{num_samples} with accuracy: '
            f'{float(num_correct)/float(num_samples)*100:.2f}')
    model.train()
    
    
check_accuracy(train_loader,model)
check_accuracy(test_loader,model)
            

Checking accuracy on training data


Obtained 59008/60000 with accuracy: 98.35
Checking accuracy on test data
Obtained 9819/10000 with accuracy: 98.19


In [None]:
# Return the checkpoint with the minimal loss

def save_checkpoint(state, file_path='my_checkpoint.pth.tar'):
    print('Saving checkpoint')
    torch.save(state, file_path)

min_loss = float('inf')  # Initialize with a large value. When you pass 'inf' to float(), it gives you the floating-point representation of positive infinity.

for epoch in range(EPOCHS):
    losses = []  # List to store batch-wise losses for an epoch
    
    for i, (data, labels) in enumerate(train_loader):
        data = data.to(device)
        labels = labels.to(device)
        predictions = model(data)
        
        loss_ = loss(predictions, labels)
        losses.append(loss_.item())  # Store the batch-wise loss
    
        
        optimizer.zero_grad()
        loss_.backward()
        optimizer.step()
        
        if i % 20 == 0:
            print(f'{epoch}/{EPOCHS} | step: {i}/{len(train_loader)} | loss: {loss_.item():.4f}')
    
    epoch_loss = sum(losses) / len(losses)  # Average loss for the epoch
    
    
    # Save the model with the minimal loss every two epochs
    if epoch % 2 == 0 and epoch_loss < min_loss:
        min_loss = epoch_loss
        checkpoint = {
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }
        save_checkpoint(checkpoint)


<h1> Saving and Loading the saved checkpoint .....</h1>


In [15]:
# Model initialization
model=SCNN(IN_CHANNELS,NUM_CLASSES).to(device)
loss=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=LR)

In [16]:
# Training loop

load_model=True

def load_checkpoint(checkpoint):
    print('Loading checkpoint..')
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])    

    
if load_model:
    load_checkpoint(torch.load('my_checkpoint.pth.tar'))    


for epoch in range(EPOCHS):
    checkpoint={'state_dict':model.state_dict(),'optimizer':optimizer.state_dict()}
    if epoch%2==0:  # Save every two epochs
        save_checkpoint(checkpoint)
        
        
        
    for i, (data,labels) in enumerate(train_loader):
        data=data.to(device)
        labels=labels.to(device)
        prediction=model(data)
        loss_=loss(prediction,labels)
        optimizer.zero_grad()
        loss_.backward()
        optimizer.step()
        
        if i%20==0:
            print(f'{epoch}/{EPOCHS} | step: {i}/ {len(train_loader)} | loss: {loss_.item():.4f}')   
        


Loading checkpoint..
Saving checkpoint
0/5 | step: 0/ 469 | loss: 0.0264
0/5 | step: 20/ 469 | loss: 0.1308
0/5 | step: 40/ 469 | loss: 0.0552
0/5 | step: 60/ 469 | loss: 0.0406
0/5 | step: 80/ 469 | loss: 0.1265
0/5 | step: 100/ 469 | loss: 0.0795
0/5 | step: 120/ 469 | loss: 0.0361
0/5 | step: 140/ 469 | loss: 0.0317
0/5 | step: 160/ 469 | loss: 0.1040
0/5 | step: 180/ 469 | loss: 0.0540
0/5 | step: 200/ 469 | loss: 0.0936
0/5 | step: 220/ 469 | loss: 0.0170
0/5 | step: 240/ 469 | loss: 0.0578
0/5 | step: 260/ 469 | loss: 0.0617
0/5 | step: 280/ 469 | loss: 0.0465
0/5 | step: 300/ 469 | loss: 0.0837
0/5 | step: 320/ 469 | loss: 0.0685
0/5 | step: 340/ 469 | loss: 0.0447
0/5 | step: 360/ 469 | loss: 0.1036
0/5 | step: 380/ 469 | loss: 0.0657
0/5 | step: 400/ 469 | loss: 0.0842
0/5 | step: 420/ 469 | loss: 0.0689
0/5 | step: 440/ 469 | loss: 0.0327
0/5 | step: 460/ 469 | loss: 0.1045
1/5 | step: 0/ 469 | loss: 0.0757
1/5 | step: 20/ 469 | loss: 0.0607
1/5 | step: 40/ 469 | loss: 0.1234