## 存储数据

加载和保存张量

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

In [6]:
# 保存张量
x = torch.arange(4)
torch.save(x, 'x-file')

# 加载张量
x2 = torch.load('x-file')
x2

tensor([0, 1, 2, 3])

存储一个张量列表，然后载入内存

In [8]:
y = torch.zeros(4)
torch.save([x,y], 'x-files')
x2, y2 = torch.load('x-files')
x2, y2

(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))

写入字典 (从字符串映射到张量的)

In [10]:
mydict = {'x':x, 'y':y}
torch.save(mydict, 'mydict')
mydict2 = torch.load('mydict')
mydict2

{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}

## 存储模型

加载和保存模型参数<br>
思考一下，整个模型不好存储，存下参数就可以了

In [23]:
class MLP(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn((2, 20))
net(X)
net.state_dict() # 我们这儿并没有反向传播更新参数 所以说此参数和模型初始化的参数一致

OrderedDict([('hidden.weight',
              tensor([[-0.1468,  0.1776,  0.1083,  ..., -0.1544, -0.0365,  0.0215],
                      [ 0.1555,  0.0358, -0.0167,  ...,  0.0078, -0.1334,  0.0618],
                      [-0.1813,  0.0030, -0.0446,  ...,  0.0290,  0.1827,  0.0977],
                      ...,
                      [-0.0624, -0.0120, -0.0701,  ..., -0.0054, -0.1180, -0.1159],
                      [ 0.1780, -0.0015, -0.1902,  ...,  0.0556, -0.0849,  0.0781],
                      [ 0.0743, -0.1456,  0.1431,  ...,  0.0568, -0.0824, -0.2091]])),
             ('hidden.bias',
              tensor([-0.1470, -0.1126, -0.0338,  0.1389, -0.0379, -0.0129, -0.1170,  0.1143,
                      -0.1976,  0.0881,  0.1911, -0.1842, -0.0071, -0.1714,  0.0930, -0.1521,
                      -0.0197,  0.1058, -0.0522, -0.0329, -0.1956, -0.1331,  0.0983,  0.1527,
                       0.1246,  0.1663, -0.1743,  0.0323, -0.1801, -0.1035, -0.1778, -0.0511,
                      -0.1080,

In [24]:
torch.save(net.state_dict(), 'mlp.params')
# state_dict()给出模型的所有参数信息，将其存入文件'mlp.params'中

如何了load进来这个模型呢？<br>
首先你需要将'mlp.params'文件和MLP模型文件存下来

In [26]:
clone = MLP() # 首先创建一个新的初始化了的MLP模型
clone.state_dict() # 可见此处参数和上述不一致，是因为模型参数大概是随机初始化的

OrderedDict([('hidden.weight',
              tensor([[ 0.1829, -0.1702, -0.1242,  ..., -0.2197,  0.1551, -0.1883],
                      [ 0.1727, -0.1204, -0.0834,  ...,  0.1236, -0.1408,  0.1631],
                      [-0.0018, -0.1559,  0.0014,  ..., -0.1992,  0.1859, -0.1096],
                      ...,
                      [ 0.0745,  0.1741, -0.1104,  ..., -0.0006,  0.0380, -0.1719],
                      [ 0.2228, -0.2065,  0.0681,  ..., -0.1548,  0.1072, -0.1154],
                      [ 0.1347,  0.0424, -0.0475,  ..., -0.0508,  0.0799,  0.2043]])),
             ('hidden.bias',
              tensor([ 0.1032,  0.2030,  0.0548,  0.1337, -0.1278,  0.1162, -0.1547, -0.1803,
                      -0.2155,  0.0194,  0.1343,  0.0308, -0.2006,  0.1551, -0.1674, -0.1454,
                       0.1072, -0.1228, -0.1588,  0.1023, -0.0068,  0.1959,  0.0463,  0.1583,
                      -0.2176, -0.1607,  0.0768, -0.1161,  0.0479, -0.0008, -0.0203,  0.1730,
                       0.1072,

In [None]:
torch.load()

In [28]:
clone.load_state_dict(torch.load('mlp.params')) # 内层返回类型为 OrderedDict 
clone.state_dict() # 可见此时将上述模型参数load进来了

OrderedDict([('hidden.weight',
              tensor([[-0.1468,  0.1776,  0.1083,  ..., -0.1544, -0.0365,  0.0215],
                      [ 0.1555,  0.0358, -0.0167,  ...,  0.0078, -0.1334,  0.0618],
                      [-0.1813,  0.0030, -0.0446,  ...,  0.0290,  0.1827,  0.0977],
                      ...,
                      [-0.0624, -0.0120, -0.0701,  ..., -0.0054, -0.1180, -0.1159],
                      [ 0.1780, -0.0015, -0.1902,  ...,  0.0556, -0.0849,  0.0781],
                      [ 0.0743, -0.1456,  0.1431,  ...,  0.0568, -0.0824, -0.2091]])),
             ('hidden.bias',
              tensor([-0.1470, -0.1126, -0.0338,  0.1389, -0.0379, -0.0129, -0.1170,  0.1143,
                      -0.1976,  0.0881,  0.1911, -0.1842, -0.0071, -0.1714,  0.0930, -0.1521,
                      -0.0197,  0.1058, -0.0522, -0.0329, -0.1956, -0.1331,  0.0983,  0.1527,
                       0.1246,  0.1663, -0.1743,  0.0323, -0.1801, -0.1035, -0.1778, -0.0511,
                      -0.1080,

In [31]:
clone.eval()

MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)