<h1>Save and Load Pytorch Model</h1>

<h3 style="color: yellow;">Saving the model can be achieved in two ways: (a) by saving the entire model (<span style="color: red;">torch.save(model, 'model.pth')</span>), or (b) by saving only the model parameters, which is the recommended method (<span style="color: red;">torch.save(model.state_dict(), 'model_weights.pth')</span>).</h3>

<h3 style="color: yellow;">In addition to the model parameters, we can also save attributes such as epochs, loss, optimizer states, and more.</h3>


In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision
from torchvision.datasets import MNIST
import torchvision.transforms as transforms

  warn(


In [2]:
# Cuda
device= torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [3]:
# Simple CNN
class SCNN(nn.Module):
    def __init__(self,n_channels,n_classes):
        super(SCNN,self).__init__()
        self.n_channels=n_channels
        self.n_classes=n_classes
        self.conv1=nn.Conv2d(in_channels=self.n_channels,out_channels=8,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        self.max_pool=nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))
        self.conv2=nn.Conv2d(in_channels=8,out_channels=16,kernel_size=(3,3),stride=(1,1),padding=(1,1))
        self.fc=nn.Linear(16*7*7, self.n_classes)
        
    def forward(self,x):
        out=self.max_pool(F.relu(self.conv1(x)))
        out=self.max_pool(F.relu(self.conv2(out)))
        out=out.reshape(out.shape[0],-1)
        out=self.fc(out)
        return out

        

In [10]:
# Constants and hyperparameters
IN_CHANNELS=1
NUM_CLASSES=10
LR=0.001
BATCH_SIZE=128
EPOCHS=5

In [11]:

# Datset and dataloader
train_dataset=MNIST(root='dataset/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset=MNIST(root='dataset/',train=False,transform=transforms.ToTensor(),download=True)
train_loader=DataLoader(dataset=train_dataset,batch_size=BATCH_SIZE,shuffle=True)
test_loader=DataLoader(dataset=test_dataset,batch_size=BATCH_SIZE,shuffle=True)

In [12]:
# Model initialization
model=SCNN(IN_CHANNELS,NUM_CLASSES).to(device)
loss=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=LR)

In [13]:
# training loop
def save_checkpoint(state,file_path='my_checkpoint.pth.tar'):
    print('Saving checkpoint')
    torch.save(state,file_path)

for epoch in range(EPOCHS):
    checkpoint={'state_dict':model.state_dict(),'optimizer':optimizer.state_dict()}
    if epoch%2==0:  # Save every two epochs
        save_checkpoint(checkpoint)
        
        
        
    for i, (data,labels) in enumerate(train_loader):
        data=data.to(device)
        labels=labels.to(device)
        prediction=model(data)
        loss_=loss(prediction,labels)
        optimizer.zero_grad()
        loss_.backward()
        optimizer.step()
        
        if i%20==0:
            print(f'{epoch}/{EPOCHS} | step: {i}/ {len(train_loader)} | loss: {loss_.item():.4f}')   
        


Saving checkpoint
0/5 | step: 0/ 469 | loss: 2.3037
0/5 | step: 20/ 469 | loss: 2.0658


0/5 | step: 40/ 469 | loss: 1.4196
0/5 | step: 60/ 469 | loss: 0.7532
0/5 | step: 80/ 469 | loss: 0.5039
0/5 | step: 100/ 469 | loss: 0.5457
0/5 | step: 120/ 469 | loss: 0.3653
0/5 | step: 140/ 469 | loss: 0.4000
0/5 | step: 160/ 469 | loss: 0.3868
0/5 | step: 180/ 469 | loss: 0.2893
0/5 | step: 200/ 469 | loss: 0.2670
0/5 | step: 220/ 469 | loss: 0.2309
0/5 | step: 240/ 469 | loss: 0.1280
0/5 | step: 260/ 469 | loss: 0.2090
0/5 | step: 280/ 469 | loss: 0.2064
0/5 | step: 300/ 469 | loss: 0.2114
0/5 | step: 320/ 469 | loss: 0.2540
0/5 | step: 340/ 469 | loss: 0.1429
0/5 | step: 360/ 469 | loss: 0.2036
0/5 | step: 380/ 469 | loss: 0.1336
0/5 | step: 400/ 469 | loss: 0.1237
0/5 | step: 420/ 469 | loss: 0.0952
0/5 | step: 440/ 469 | loss: 0.2018
0/5 | step: 460/ 469 | loss: 0.1668
1/5 | step: 0/ 469 | loss: 0.2461
1/5 | step: 20/ 469 | loss: 0.1150
1/5 | step: 40/ 469 | loss: 0.0900
1/5 | step: 60/ 469 | loss: 0.1047
1/5 | step: 80/ 469 | loss: 0.0783
1/5 | step: 100/ 469 | loss: 0.1295
1

In [14]:
# check accuracy on training and test to see how good our model is
def check_accuracy(loader,model):
    if loader.dataset.train:
        print('Checking accuracy on training data')
    else:
        print('Checking accuracy on test data')
    num_correct=0
    num_samples=0
    model.eval()
    
    with torch.no_grad():
        for data,label in loader:
            data=data.to(device)
            label=label.to(device)
            prediction=model(data)
            _,pred=prediction.max(1)
            num_correct+=(pred==label).sum()
            num_samples+=pred.size(0)
        print(f'Obtained {num_correct}/{num_samples} with accuracy: '
            f'{float(num_correct)/float(num_samples)*100:.2f}')
    model.train()
    
    
check_accuracy(train_loader,model)
check_accuracy(test_loader,model)
            

Checking accuracy on training data


Obtained 59008/60000 with accuracy: 98.35
Checking accuracy on test data
Obtained 9819/10000 with accuracy: 98.19


<h1>Loading the saved checkpoint .....</h1>


In [15]:
# Model initialization
model=SCNN(IN_CHANNELS,NUM_CLASSES).to(device)
loss=nn.CrossEntropyLoss()
optimizer=optim.Adam(model.parameters(),lr=LR)

In [16]:
# Training loop

load_model=True

def load_checkpoint(checkpoint):
    print('Loading checkpoint..')
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    
if load_model:
    load_checkpoint(torch.load('my_checkpoint.pth.tar'))    


for epoch in range(EPOCHS):
    checkpoint={'state_dict':model.state_dict(),'optimizer':optimizer.state_dict()}
    if epoch%2==0:  # Save every two epochs
        save_checkpoint(checkpoint)
        
        
        
    for i, (data,labels) in enumerate(train_loader):
        data=data.to(device)
        labels=labels.to(device)
        prediction=model(data)
        loss_=loss(prediction,labels)
        optimizer.zero_grad()
        loss_.backward()
        optimizer.step()
        
        if i%20==0:
            print(f'{epoch}/{EPOCHS} | step: {i}/ {len(train_loader)} | loss: {loss_.item():.4f}')   
        


Loading checkpoint..
Saving checkpoint
0/5 | step: 0/ 469 | loss: 0.0264
0/5 | step: 20/ 469 | loss: 0.1308
0/5 | step: 40/ 469 | loss: 0.0552
0/5 | step: 60/ 469 | loss: 0.0406
0/5 | step: 80/ 469 | loss: 0.1265
0/5 | step: 100/ 469 | loss: 0.0795
0/5 | step: 120/ 469 | loss: 0.0361
0/5 | step: 140/ 469 | loss: 0.0317
0/5 | step: 160/ 469 | loss: 0.1040
0/5 | step: 180/ 469 | loss: 0.0540
0/5 | step: 200/ 469 | loss: 0.0936
0/5 | step: 220/ 469 | loss: 0.0170
0/5 | step: 240/ 469 | loss: 0.0578
0/5 | step: 260/ 469 | loss: 0.0617
0/5 | step: 280/ 469 | loss: 0.0465
0/5 | step: 300/ 469 | loss: 0.0837
0/5 | step: 320/ 469 | loss: 0.0685
0/5 | step: 340/ 469 | loss: 0.0447
0/5 | step: 360/ 469 | loss: 0.1036
0/5 | step: 380/ 469 | loss: 0.0657
0/5 | step: 400/ 469 | loss: 0.0842
0/5 | step: 420/ 469 | loss: 0.0689
0/5 | step: 440/ 469 | loss: 0.0327
0/5 | step: 460/ 469 | loss: 0.1045
1/5 | step: 0/ 469 | loss: 0.0757
1/5 | step: 20/ 469 | loss: 0.0607
1/5 | step: 40/ 469 | loss: 0.1234