# 读写文件

In [2]:
import torch
from torch import nn
from torch.nn import functional as F

## 保存、加载张量

使用 `torch.save` 和 `torch.load` 分别保存和加载：

In [5]:
x = torch.randn([4, 5])
torch.save(x, "x-file")
print(x)

tensor([[-1.3624, -0.6662, -2.3517,  0.0261, -0.7324],
        [ 1.8455,  0.4608, -0.3855,  1.1429, -0.3526],
        [ 0.2578,  0.3612,  0.4851, -0.3225,  0.7457],
        [-1.8319, -0.1821, -0.8958,  0.1574,  1.2175]])


In [6]:
y = torch.load("x-file")
print(y)

tensor([[-1.3624, -0.6662, -2.3517,  0.0261, -0.7324],
        [ 1.8455,  0.4608, -0.3855,  1.1429, -0.3526],
        [ 0.2578,  0.3612,  0.4851, -0.3225,  0.7457],
        [-1.8319, -0.1821, -0.8958,  0.1574,  1.2175]])


  y = torch.load("x-file")


也可以存储一个 tensor table，甚至字典也可以。

In [9]:
torch.save([x, y], "x-y-file")
x1, y1 = torch.load("x-y-file")
print(x1)
print(y1)
xy_dict = { "x": x, "y": y }
torch.save(xy_dict, "xy-dict")
xy_dict1 = torch.load("xy-dict")
print(xy_dict1)

tensor([[-1.3624, -0.6662, -2.3517,  0.0261, -0.7324],
        [ 1.8455,  0.4608, -0.3855,  1.1429, -0.3526],
        [ 0.2578,  0.3612,  0.4851, -0.3225,  0.7457],
        [-1.8319, -0.1821, -0.8958,  0.1574,  1.2175]])
tensor([[-1.3624, -0.6662, -2.3517,  0.0261, -0.7324],
        [ 1.8455,  0.4608, -0.3855,  1.1429, -0.3526],
        [ 0.2578,  0.3612,  0.4851, -0.3225,  0.7457],
        [-1.8319, -0.1821, -0.8958,  0.1574,  1.2175]])
{'x': tensor([[-1.3624, -0.6662, -2.3517,  0.0261, -0.7324],
        [ 1.8455,  0.4608, -0.3855,  1.1429, -0.3526],
        [ 0.2578,  0.3612,  0.4851, -0.3225,  0.7457],
        [-1.8319, -0.1821, -0.8958,  0.1574,  1.2175]]), 'y': tensor([[-1.3624, -0.6662, -2.3517,  0.0261, -0.7324],
        [ 1.8455,  0.4608, -0.3855,  1.1429, -0.3526],
        [ 0.2578,  0.3612,  0.4851, -0.3225,  0.7457],
        [-1.8319, -0.1821, -0.8958,  0.1574,  1.2175]])}


  x1, y1 = torch.load("x-y-file")
  xy_dict1 = torch.load("xy-dict")


## 加载、保存模型参数

In [10]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
Y = net(X)

使用 `torch.save` 和 `state_dict` 方法保存模型参数：

In [11]:
torch.save(net.state_dict(), "mlp-params")

使用 `load_state_dict` 方法加载参数：

In [12]:
clone = MLP()
clone.load_state_dict(torch.load('mlp-params'))
clone.eval()

  clone.load_state_dict(torch.load('mlp-params'))


MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)

## 同时加载、保存模型参数和架构

使用 `torch.save` 直接保存模型：

In [14]:
torch.save(net, "mlp.pt")
clone2 = torch.load("mlp.pt")
clone2.eval()

  clone2 = torch.load("mlp.pt")


MLP(
  (hidden): Linear(in_features=20, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)