### Import package


In [None]:
import torch
import torch.nn as nn

### 3 DIFFERENT METHODS TO REMEMBER:

- torch.save(arg, PATH) # can be model, tensor, or dictionary
- torch.load(PATH)
- torch.load_state_dict(arg)


### Model


In [None]:
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

### 1. Save entire model


In [None]:
model = Model(n_input_features=6)
# train...

FILE = "model.pth"
torch.save(model, FILE)

#### Load model


In [6]:
FILE = "model.pth"
model = torch.load(FILE)
model.eval()

for param in model.parameters():
    print(param)

Parameter containing:
tensor([[ 0.0664, -0.2261, -0.1058,  0.1396, -0.1496, -0.0285]],
       requires_grad=True)
Parameter containing:
tensor([0.3581], requires_grad=True)


### 2. Save only state dict（Prefer way）


In [7]:
model = Model(n_input_features=6)
# train...

FILE = "model.pth"
torch.save(model.state_dict(), FILE)

#### Load model


In [8]:
for param in model.parameters():
    print(param)

loaded_model = Model(n_input_features=6)
loaded_model.load_state_dict(torch.load(FILE))
loaded_model.eval()

for param in model.parameters():
    print(param)

Parameter containing:
tensor([[ 0.1215, -0.2023, -0.3699, -0.3362, -0.3926, -0.0041]],
       requires_grad=True)
Parameter containing:
tensor([0.0792], requires_grad=True)
Parameter containing:
tensor([[ 0.1215, -0.2023, -0.3699, -0.3362, -0.3926, -0.0041]],
       requires_grad=True)
Parameter containing:
tensor([0.0792], requires_grad=True)


### check point way


In [11]:
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
print(optimizer.state_dict())

checkpoint = {
    "epoch": 90,
    "model_state": model.state_dict(),
    "optim_state": optimizer.state_dict()
}
print(optimizer.state_dict())
FILE = "checkpoint.pth"
torch.save(checkpoint, FILE)

model = Model(n_input_features=6)
optimizer = torch.optim.SGD(model.parameters(), lr=0)

checkpoint = torch.load(FILE)
model.load_state_dict(checkpoint['model_state'])
optimizer.load_state_dict(checkpoint['optim_state'])
epoch = checkpoint['epoch']

model.eval()
# - or -
# model.train()

print(optimizer.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]}
{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]}
{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'params': [0, 1]}]}
