# 读写Tensor
save函数使用python的pickle将对象序列化存进disk    
load函数使用unpickle反序列化转为内存

In [5]:
import torch
from torch import nn

x = torch.ones(3)
torch.save(x, 'x.pt')

In [6]:
x2 = torch.load('x.pt')
x2

tensor([1., 1., 1.])

In [7]:
y = torch.zeros(4)
torch.save([x,y], 'xy.pt')
xy_list = torch.load('xy.pt')
xy_list

[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]

In [8]:
torch.save({'x': x, 'y': y}, 'xy_dict.pt')
xy_dict = torch.load('xy_dict.pt')
xy_dict

{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}

# 读写模型
Pytorch中Module的可学习参数包含在model.parameters(）中。  state_dict是一个从参数名映射带参数tensor的字典对象

In [9]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.hidden = nn.Linear(3,2)
        self.activation = nn.ReLU()
        self.output = nn.Linear(2,1)
    def forward(self, x):
        x = self.hidden(x)
        output = self.output(self.activation(x))
        return output
net = MLP()
net.state_dict()

OrderedDict([('hidden.weight',
              tensor([[ 0.2923,  0.0614, -0.3215],
                      [ 0.0194, -0.0606,  0.1464]])),
             ('hidden.bias', tensor([ 0.0050, -0.4540])),
             ('output.weight', tensor([[-0.4543, -0.3464]])),
             ('output.bias', tensor([-0.3170]))])

1. 只有具有可学习参数的层才会有state_dict中你那个的条目（比如上面就没有activation）。    
2. 优化器也有state_dict，包含优化器超参数和状态信息

In [12]:
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer.state_dict()

{'state': {},
 'param_groups': [{'lr': 0.001,
   'momentum': 0.9,
   'dampening': 0,
   'weight_decay': 0,
   'nesterov': False,
   'params': [0, 1, 2, 3]}]}

## 保存和加载模型

### 保存和加载模型参数 （state_dict）推荐方式

In [17]:
# 保存
model_state_path = 'model_state.pt'
torch.save(net.state_dict(), model_state_path)

In [18]:
# 加载
model = MLP()
model.load_state_dict(torch.load(model_state_path))
model

MLP(
  (hidden): Linear(in_features=3, out_features=2, bias=True)
  (activation): ReLU()
  (output): Linear(in_features=2, out_features=1, bias=True)
)

### 保存和加载整个模型 (不推荐)

In [20]:
model_path = 'model.pt'
torch.save(model, model_path)

model = torch.load(model_path)