In [1]:
import torch
import torch.nn as nn

### 3 DIFFERENT METHODS TO REMEMBER:
 - torch.save(arg, PATH) # can be model, tensor, or dictionary
 - torch.load(PATH)
 - torch.load_state_dict(arg)

### 2 DIFFERENT WAYS OF SAVING
- 1) lazy way: save whole model  
torch.save(model, PATH)  

>model class must be defined somewhere  
model = torch.load(PATH)
model.eval()

- 2) recommended way: save only the state_dict
torch.save(model.state_dict(), PATH)  

>model must be created again with parameters  
model = Model(*args, **kwargs)  
model.load_state_dict(torch.load(PATH))  
model.eval()  

In [2]:
class Model(nn.Module):
    def __init__(self, n_input_features):
        super().__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        x = torch.sigmoid(self.linear(x))
        return x

In [4]:
model = Model(n_input_features=6)
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[-0.1755, -0.2687,  0.1158,  0.3417, -0.3840,  0.2510]],
       requires_grad=True)
Parameter containing:
tensor([-0.2257], requires_grad=True)


In [5]:
FILE = "model.pth"
torch.save(model, FILE)

In [6]:
model = torch.load(FILE)
model.eval()

for param in model.parameters():
    print(param)

Parameter containing:
tensor([[-0.1755, -0.2687,  0.1158,  0.3417, -0.3840,  0.2510]],
       requires_grad=True)
Parameter containing:
tensor([-0.2257], requires_grad=True)


In [30]:
torch.save(model.state_dict(), FILE)

In [31]:
loded_model = Model(n_input_features=6)
loded_model.load_state_dict(torch.load(FILE))
loded_model.eval()

for param in model.parameters():
    print(param)

Parameter containing:
tensor([[-0.1271, -0.3877,  0.3984,  0.1534,  0.1466,  0.0637]],
       requires_grad=True)
Parameter containing:
tensor([-0.3831], requires_grad=True)


In [41]:
model.state_dict()

OrderedDict([('linear.weight',
              tensor([[-0.1271, -0.3877,  0.3984,  0.1534,  0.1466,  0.0637]])),
             ('linear.bias', tensor([-0.3831]))])

In [42]:
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

optimizer.state_dict()

{'state': {},
 'param_groups': [{'lr': 0.01,
   'momentum': 0,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'params': [2690817806128, 2690817805624]}]}

In [44]:
checkPoint = {
    "epoch": 90,
    "model_state": model.state_dict(),
    "optim_state": optimizer.state_dict(),
}
torch.save(checkPoint, "checkPoint.pth")

In [50]:
loded_checkPoint = torch.load("checkPoint.pth")

epoch = loded_checkPoint["epoch"]

model = Model(n_input_features=6)
model.load_state_dict(loded_checkPoint["model_state"])

optimizer = torch.optim.SGD(model.parameters(), lr=0)
optimizer.load_state_dict(loded_checkPoint["optim_state"])

In [46]:
torch.load("checkPoint.pth")

{'epoch': 90,
 'model_state': OrderedDict([('linear.weight',
               tensor([[-0.1271, -0.3877,  0.3984,  0.1534,  0.1466,  0.0637]])),
              ('linear.bias', tensor([-0.3831]))]),
 'optim_state': {'state': {},
  'param_groups': [{'lr': 0.01,
    'momentum': 0,
    'dampening': 0,
    'weight_decay': 0,
    'nesterov': False,
    'params': [2690817806128, 2690817805624]}]}}

In [53]:
optimizer.state_dict()

{'state': {},
 'param_groups': [{'lr': 0.01,
   'momentum': 0,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'params': [2692527542920, 2692530951584]}]}