In [1]:
''' 3 DIFFERENT METHODS TO REMEMBER:
 - torch.save(arg, PATH) # can be model, tensor, or dictionary  
 - torch.load(PATH)
 - torch.load_state_dict(arg)
'''

''' 2 DIFFERENT WAYS OF SAVING

# 1) lazy way: save whole model

torch.save(model, PATH)

# model class must be defined somewhere
model = torch.load(PATH)
model.eval()

# 2) recommended way: save only the state_dict

torch.save(model.state_dict(), PATH) # this saves the parameters of the model

# model must be created again with parameters
model = Model(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
'''

' 2 DIFFERENT WAYS OF SAVING\n\n# 1) lazy way: save whole model\n\ntorch.save(model, PATH)\n\n# model class must be defined somewhere\nmodel = torch.load(PATH)\nmodel.eval()\n\n# 2) recommended way: save only the state_dict\n\ntorch.save(model.state_dict(), PATH) # this saves the parameters of the model\n\n# model must be created again with parameters\nmodel = Model(*args, **kwargs)\nmodel.load_state_dict(torch.load(PATH))\nmodel.eval()\n'

In [2]:
import torch
import torch.nn as nn

# -------Lazy Option---------- 

In [3]:
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

model = Model(n_input_features=6)
# train your model...

In [5]:
FILE = 'model.pth'

In [6]:
torch.save(model,FILE)

In [7]:
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred

In [8]:
FILE ='model.pth'
model = torch.load(FILE)
model.eval()

Model(
  (linear): Linear(in_features=6, out_features=1, bias=True)
)

In [9]:
for param in model.parameters():
    print(param)

Parameter containing:
tensor([[-0.0966,  0.0871,  0.1954,  0.2351,  0.3141, -0.2997]],
       requires_grad=True)
Parameter containing:
tensor([0.3980], requires_grad=True)


# ----------Prefered Option -----------

In [11]:
class Model(nn.Module):
    def __init__(self, n_input_features):
        super(Model, self).__init__()
        self.linear = nn.Linear(n_input_features, 1)

    def forward(self, x):
        y_pred = torch.sigmoid(self.linear(x))
        return y_pred
    
model = Model(n_input_features=6)
# train your model...

In [12]:
FILE = 'model.pth'

In [13]:
torch.save(model.state_dict(),FILE)

In [14]:
loaded_model = Model(n_input_features=6)
loaded_model.load_state_dict(torch.load(FILE))
loaded_model.eval()

Model(
  (linear): Linear(in_features=6, out_features=1, bias=True)
)

In [16]:
for param in loaded_model.parameters():
    print(param)

Parameter containing:
tensor([[-0.0609, -0.1224,  0.3977, -0.4039, -0.2630,  0.4061]],
       requires_grad=True)
Parameter containing:
tensor([-0.1798], requires_grad=True)


In [17]:
print(model.state_dict())

OrderedDict([('linear.weight', tensor([[-0.0609, -0.1224,  0.3977, -0.4039, -0.2630,  0.4061]])), ('linear.bias', tensor([-0.1798]))])


In [18]:
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate)
print(optimizer.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}


In [19]:
checkpoint = {
    "epoch" : 90,
    "model_state": model.state_dict(),
    "optim_state" : optimizer.state_dict()
}

torch.save(checkpoint,"checkpoint.pth")

In [20]:
loaded_checkpoint = torch.load("checkpoint.pth")
epoch = loaded_checkpoint["epoch"]

model = Model(n_input_features=6)
optimizer = torch.optim.SGD(model.parameters(),lr=0)

model.load_state_dict(checkpoint["model_state"]) # load all params into the model
optimizer.load_state_dict(checkpoint["optim_state"])
print(optimizer.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'foreach': None, 'differentiable': False, 'fused': None, 'params': [0, 1]}]}
