# 17 Saving and Loading Models

In [1]:
import torch
import torch.nn as nn

```java
torch.save(model.state_dict(), PATH)

model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
```

In [2]:
class Model(nn.Module):
    def __init__(self, n_input):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input, 1)
        
    def forward(self, x):
        return torch.sigmoid(self.linear(x))

In [3]:
model = Model(6)

In [5]:
FILE = 'models/saving.pth'
torch.save(model, FILE)

In [6]:
model = torch.load(FILE)
model.eval()

for param in model.parameters():
    print(param)

Parameter containing:
tensor([[-0.3440,  0.1509, -0.0664, -0.3039,  0.3444, -0.3902]],
       requires_grad=True)
Parameter containing:
tensor([-0.1436], requires_grad=True)


In [7]:
torch.save(model.state_dict(), FILE)

In [8]:
loaded_model = Model(6)
loaded_model.load_state_dict(torch.load(FILE))

for param in model.parameters():
    print(param)

Parameter containing:
tensor([[-0.3440,  0.1509, -0.0664, -0.3039,  0.3444, -0.3902]],
       requires_grad=True)
Parameter containing:
tensor([-0.1436], requires_grad=True)


In [9]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer.state_dict()

{'state': {},
 'param_groups': [{'lr': 0.01,
   'momentum': 0,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'maximize': False,
   'foreach': None,
   'params': [0, 1]}]}

In [10]:
checkpoint = {
    'epoch': 90,
    'model_state': model.state_dict(),
    'optim_state': optimizer.state_dict()
}
torch.save(checkpoint, 'models/checkpoint90.pth')

In [11]:
loaded_checkpoint = torch.load('models/checkpoint90.pth')
epoch = loaded_checkpoint['epoch']
epoch

90

In [12]:
model = Model(6)
model.load_state_dict(loaded_checkpoint['model_state'])

<All keys matched successfully>

In [13]:
optimizer = torch.optim.SGD(model.parameters(), lr=0)
optimizer.load_state_dict(loaded_checkpoint['optim_state'])
optimizer.state_dict()

{'state': {},
 'param_groups': [{'lr': 0.01,
   'momentum': 0,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'maximize': False,
   'foreach': None,
   'params': [0, 1]}]}