### Saving and Loading Models
(Author: Matthew Inkawhich)

When it comes to saving and loading models, there are three core functions to be familiar with:

- torch.save: Saves a serialized object to disk. This function uses Python's `pickle` utility for serialization. Models,tensors,and dictionaries of kinds of objects can be saved using this function.
- torch.load: Uses `pickle`'s unpickling facilities to deserialize pickled object files to memory. This function also facilitates the device to load the data into.
- torch.nn.Module.load_state_dict: Loads a model's parameter dictionary using a deserialized state_dict.

#### state_dict

`state_dict`是python字典类型。在PyTorch中，网络训练的权重被保存在model的parameters中，state_dict以字典的形式将每一层的layer和他的权重保存起来。只有可学习的参数才会保存在state_dict中。不只是model,　optimizer也有自己的state_dict,其中保存了训练时的超参和优化器的状态。

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim

In [6]:
class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)
        self.pool = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.fc1 = nn.Linear(16*5*5,120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)
    
    def forward(self,x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1,16*5*5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = TheModelClass()
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
print('Model\'s state_dict:')
for param_tensor in model.state_dict():
    print(param_tensor,"\t", model.state_dict()[param_tensor].size())
    
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
conv1.weight 	 torch.Size([6, 3, 5, 5])
conv1.bias 	 torch.Size([6])
conv2.weight 	 torch.Size([16, 6, 5, 5])
conv2.bias 	 torch.Size([16])
fc1.weight 	 torch.Size([120, 400])
fc1.bias 	 torch.Size([120])
fc2.weight 	 torch.Size([84, 120])
fc2.bias 	 torch.Size([84])
fc3.weight 	 torch.Size([10, 84])
fc3.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {}
param_groups 	 [{'nesterov': False, 'dampening': 0, 'weight_decay': 0, 'lr': 0.001, 'params': [139829746445696, 139829746445984, 139829746445120, 139829746445840, 139829746445912, 139829746446056, 139829746446128, 139829746446200, 139829746446272, 139829746151496], 'momentum': 0.9}]


#### 保存和载入模型参数

In [None]:
torch.save(model.state_dict(),PATH)
model = TheModelClass()
model.load_state_dict(torch.load(PATH))

在保存时，一般会使用.pt或者.pth作为文件扩展名

#### 保存和载入整个模型

In [None]:
torch.save(model,PATH)
model = torch.load(PATH)

The disadvantage of this approach is that the serialized data is bound to the specific classes and the exact directory structure used when the model is saved. The reason for this is because pickle does not save the model class itself. Rather, it saves a path to the file containing the class, which is used during load time. Because of this, your code can break in various ways when used in other projects or after refactors.

#### 保存和载入checkpoint

In [None]:
torch.save({
    'epoch':epoch,
    'model_state_dict':model.state_dict(),
    'optimizer_state_dict':optimizer.state_dict(),
    'loss':loss,
    ...
},PATH)
# 一般用.tar后缀的文件保存

In [None]:
model = TheModelClass()
optimizer = TheOptimizerClass()
checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
model.train()

#### 保存多个模型

In [None]:
torch.save({
    'modelA_state_dict': modelA.state_dict(),
    'modelB_state_dict': modelB.state_dict(),
    'optimizerA_state_dict': optimizerA.state_dict(),
    'optimizerB_state_dict': optimizerB.state_dict(),
    ...
},PATH)

In [None]:
modelA = TheModelAClass()
modelB = TheModelBCladd()
optimizerA = TheOptimizerAClass()
optimizerB = TheOptimizerBClass()
checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

#### Warmstarting model using parameters from a deiifrent model

In [None]:
torch.save(modelA.state_dict(), PATH)
modelB = TheModelBClass()
modelB.load_state_dict(torch.load(PATH),strict=False)

对于某个模型，除了载入之前训练的参数外，还可以用其他模型的参数来赋值。`state_dict`本质上就是一个字典，只要关键字对的上，就能进行赋值。读取的模型参数可能会少一些层，也可能会多一些层，在加载时设置`strict = False`即可。另外，如果权重可以赋值，但关键字对不上，可以手动修改使其一致。

#### Saving&Loading Model Across Devices

In [None]:
# Save on GPU, load on CPU
torch.save(model.state_dict(),PATH)
device = torch.device('cpu')
model = TheModelClass()
model.load_state_dict(torch.load(PATH, map_location=device))

In [None]:
# Save on GPU, load on GPU
torch.save(model.state_dict(),PATH)
device = torch.device("cuda")
model = TheModelClass()
model.load_state_dict(torch.load(PATH))
model.to(device)

In [None]:
# Save on CPU, load on GPU
torch.save(model.state_dict(),PATH)
device = torch.device("cuda")
model = TheModelClass()
model.load_state_dict(torch.load(PATH,map_location="cuda"))
model.to(device)

#### Saving torch.nn.DataParallel Model

In [None]:
torch.save(model.module.state_dict(),PATH)
# load to whatever device you want