# Tensor与module的保存与加载

https://pytorch.org/tutorials/beginner/saving_loading_models.html

In [1]:
import torch
from torch import nn
import os

# Tensor的保存与加载

In [3]:
# 序列化到文件
t = torch.randn(5, 5)
print(t)
torch.save(t, "/tmp/t.tensor")
t1 = torch.load("/tmp/t.tensor")
assert torch.sum(t - t1) < 1e-5
os.remove("/tmp/t.tensor")

tensor([[ 0.7387, -1.5761, -1.1319,  0.6973,  0.8781],
        [-1.3609,  0.4777, -0.0969, -0.5937,  0.0431],
        [ 0.8837,  0.4034, -0.7483,  0.3092, -0.8628],
        [-1.3030, -0.8521, -0.1324, -0.4000, -1.0109],
        [ 1.0480,  0.1687,  1.6522, -0.6016, -0.1482]])


In [4]:
from io import BytesIO

t = torch.randn(5, 5)
buffer = BytesIO()
torch.save(t, buffer)
# 重置buffer的读写位置
buffer.seek(0)
t1 = torch.load(buffer)

# Module的保存与加载

可以单独保存模型的参数，也可以把整个模型保存起来

In [6]:
model = nn.Sequential(nn.Linear(25, 100), nn.ReLU(), nn.Linear(100, 10))
torch.save(model.state_dict(), "/tmp/mlp-params.pt")

In [8]:
params = torch.load("/tmp/mlp-params.pt")
# 现在的模型加载一份离线的参数
model.load_state_dict(params)

<All keys matched successfully>

In [14]:
# 直接把一个网络保存下来
torch.save(model, "./data/mlp-model.pt")

In [15]:
print(torch.load("./data/mlp-model.pt"))

Sequential(
  (0): Linear(in_features=25, out_features=100, bias=True)
  (1): ReLU()
  (2): Linear(in_features=100, out_features=10, bias=True)
)


# 使用GPU

In [9]:
gpu_model = model.to(device="cuda:0")

In [10]:
input = torch.randn(1, 25, device="cuda:0")
gpu_model(input)

tensor([[-0.1885,  0.3525,  0.5501,  0.1224, -0.1966, -0.3041,  0.1034, -0.1547,
          0.2408,  0.2632]], device='cuda:0', grad_fn=<AddmmBackward>)

GPU下保存的Tesnor或model，加载回来时，还是在对应的GPU上

In [7]:
t = torch.randn(3, 4, device="cuda")
torch.save(t, "./data/t-cuda.pt")
t = torch.load("./data/t-cuda.pt")
t.device

device(type='cuda', index=0)

# 优化器状态的保存与加载

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

checkpoint = {"state_dict": model.state_dict(), "optimizer": optimizer.state_dict()}

torch.save(checkpoint, "/tmp/tmp.ckpt")

In [10]:
checkpoint = torch.load("/tmp/tmp.ckpt")
print(checkpoint.keys())

dict_keys(['state_dict', 'optimizer'])


In [None]:
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])