In [8]:
import torch
import torch.nn as nn

class model(nn.Module):
  def __init__(self):
    super().__init__()
    
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=1, padding=0)
    self.pool1 = nn.MaxPool2d(2,2)
    self.bn = nn.BatchNorm2d(16)
    self.fc1 = nn.Linear(256,10)

  def forward(self, x):
    x=self.conv1(x)
    x=self.pool1(x)
    x=self.bn(x)
    x=self.Linear(x)

    return x



In [12]:
my_model = model()

We can save either the model archeticture + weights or just the weights

##Saving the archeticture + weights:
The model is saved as pickle file

In [14]:
model_path = 'models'
torch.save(my_model , model_path)

In [15]:
#to load the model use torch.load()
#no need to define the model archeticture (object) first
new_model = torch.load(model_path)

##Saving just the model weights (state_dict)

In [21]:
#The state_dict is a python dictionary that contains all the layers paprametars + register buffer (like BN parameters)
for param in my_model.state_dict():
  print(f'{param} shape {my_model.state_dict()[param].shape}')

conv1.weight shape torch.Size([16, 3, 3, 3])
conv1.bias shape torch.Size([16])
bn.weight shape torch.Size([16])
bn.bias shape torch.Size([16])
bn.running_mean shape torch.Size([16])
bn.running_var shape torch.Size([16])
bn.num_batches_tracked shape torch.Size([])
fc1.weight shape torch.Size([10, 256])
fc1.bias shape torch.Size([10])


In [26]:
params_path='model_state_dict'
torch.save(my_model.state_dict() , params_path)

In [29]:
#We can not use torch.load() directly to load the weights!. We need to define the model archeticure first.
#If the model matches the weights it will work well
my_new_model = model()
state_dict = torch.load(params_path)
my_new_model.load_state_dict(state_dict)

<All keys matched successfully>

##Optimizers also have parameters that we can save:


In [64]:
optimizer = torch.optim.Adam(my_model.parameters(), lr=0.003, betas=(0.9, 0.90), eps=1e-08, weight_decay=0, amsgrad=False)
for param in optimizer.state_dict():
  print(f'{param}    shape      {optimizer.state_dict()[param]}')

state    shape      {}
param_groups    shape      [{'lr': 0.003, 'betas': (0.9, 0.9), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'params': [0, 1, 2, 3, 4, 5]}]


**Saving the entire optimizer object**

In [45]:
optimizer_path = 'adam_optimizer'
torch.save(optimizer , optimizer_path)
optimizer = torch.load(optimizer_path)

**Saving the optimizer state_dict:**

In [48]:
optimizer_state_dict_path="adam_optimizer_state_dict"
torch.save(optimizer.state_dict() , optimizer_state_dict_path)


In [41]:
new_optimizer = torch.optim.Adam(my_model.parameters())

In [49]:
torch.load(optimizer_state_dict_path)

{'param_groups': [{'amsgrad': False,
   'betas': (0.9, 0.9),
   'eps': 1e-08,
   'lr': 0.003,
   'params': [0, 1, 2, 3, 4, 5],
   'weight_decay': 0}],
 'state': {}}

In [62]:
#load the optimizer state_dict
new_optimizer.load_state_dict(torch.load(optimizer_state_dict_path))

for param in new_optimizer.state_dict():
  print(f'{param}    shape      {new_optimizer.state_dict()[param]}')

state    shape      {}
param_groups    shape      [{'lr': 0.003, 'betas': (0.9, 0.9), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'params': [0, 1, 2, 3, 4, 5]}]


##Torch.save can save any python object!:
We can save all our work in one ckpt file (model, optimizer, loss, epochs, ... )

In [65]:
path_for_all = 'all'
epoch=8
loss=0.022
torch.save({
            'epoch': epoch,
            'model': my_model,
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            
            }, path_for_all)


In [66]:
checkpoint = torch.load(path_for_all)
model_to_continue_training = checkpoint['model']
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch=checkpoint['epoch']
loss=checkpoint['loss']

#continue training....
#evaluate.............

In [72]:
for param in model_to_continue_training.state_dict():
  print(f'{param}    shape      {model_to_continue_training.state_dict()[param].shape}')

conv1.weight    shape      torch.Size([16, 3, 3, 3])
conv1.bias    shape      torch.Size([16])
bn.weight    shape      torch.Size([16])
bn.bias    shape      torch.Size([16])
bn.running_mean    shape      torch.Size([16])
bn.running_var    shape      torch.Size([16])
bn.num_batches_tracked    shape      torch.Size([])
fc1.weight    shape      torch.Size([10, 256])
fc1.bias    shape      torch.Size([10])


In [67]:
for param in optimizer.state_dict():
  print(f'{param}    shape      {optimizer.state_dict()[param]}')

state    shape      {}
param_groups    shape      [{'lr': 0.003, 'betas': (0.9, 0.9), 'eps': 1e-08, 'weight_decay': 0, 'amsgrad': False, 'params': [0, 1, 2, 3, 4, 5]}]


In [69]:
loss

0.022

In [70]:
epoch

8