In [1]:
import torch
from torch import nn

参数绑定

In [2]:
shared = nn.Linear(8,8)

net = nn.Sequential(nn.Linear(4,8), nn.ReLU(), shared, nn.ReLU(), shared, nn.ReLU(), nn.Linear(8,4))
#这样的话第二和第三个隐藏层就可以做到参数绑定

自定义带参数的层

In [3]:
class MyLinear(nn.Module):

#我的理解self就是创建实例时里面的参数
    def __init__(self, input, output): #init里的参数就是创建实例时输入的参数
        super().__init__()
        self.weight = nn.Parameter(torch.rand(input, output))
        self.bias = nn.Parameter(torch.zeros(output))

    def forward(self, x):
        linear = torch.matmul(x, self.weight.data) + self.bias.data #数值计算要加data
        return nn.functional.relu(linear)
#nn.Relu()是一个module,也就是层  nn.functional.relu()才是用于数值计算的函数

In [4]:
def init_normal(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, mean = 0, std = 1)
        nn.init.zeros_(m.bias)
#torch里如果加了_的函数,那就表示参数不是传入的值,而表示要被替换的值

net.apply(init_normal)
#apply表示对这个module执行这个函数,如果这个module包含了module,那么还会对里面的所有module也执行,你也可以手动修改

Sequential(
  (0): Linear(in_features=4, out_features=8, bias=True)
  (1): ReLU()
  (2): Linear(in_features=8, out_features=8, bias=True)
  (3): ReLU()
  (4): Linear(in_features=8, out_features=8, bias=True)
  (5): ReLU()
  (6): Linear(in_features=8, out_features=4, bias=True)
)

In [5]:
def init_42(m):
    if type(m) == nn.Linear:
        nn.init.constant_(m.weight, 42)
        nn.init.zeros_(m.bias)

#42是宇宙的答案！

保存参数,torch可以存储任何tensor的数据

In [6]:
x1 = torch.rand(2,2)
torch.save(x1, "./save/x-file")
y1 = torch.load("./save/x-file")
x1, y1

(tensor([[0.8642, 0.7480],
         [0.4792, 0.8119]]),
 tensor([[0.8642, 0.7480],
         [0.4792, 0.8119]]))

In [7]:
x2 = torch.rand(2,2)
y2 = torch.rand(2,2)
torch.save([x2, y2], "./save/xy-file")
x3, y3 = torch.load("./save/xy-file")
x2, y2, x3, y3

(tensor([[0.4036, 0.6333],
         [0.6070, 0.6539]]),
 tensor([[0.8855, 0.1170],
         [0.8896, 0.2891]]),
 tensor([[0.4036, 0.6333],
         [0.6070, 0.6539]]),
 tensor([[0.8855, 0.1170],
         [0.8896, 0.2891]]))

In [8]:
x4 = torch.rand(2,2)
y4 = torch.rand(2,2)
mydict = {'x':x4, 'y':y4}
torch.save(mydict, "./save/mydict")
mydict2 = torch.load("./save/mydict")
mydict, mydict2

({'x': tensor([[0.8538, 0.0476],
          [0.3852, 0.0888]]),
  'y': tensor([[0.8780, 0.8141],
          [0.7528, 0.0242]])},
 {'x': tensor([[0.8538, 0.0476],
          [0.3852, 0.0888]]),
  'y': tensor([[0.8780, 0.8141],
          [0.7528, 0.0242]])})

In [13]:
torch.save(net.state_dict(),"./save/net_params") #net.state_dict()就是整个网络的所有权重,torch只能存权重,不像tf那样也可以存网络的定义

In [16]:
#网络读取参数有专门的api : load_state_dict()
net.load_state_dict(torch.load("./save/net_params")), net.eval()

(<All keys matched successfully>,
 Sequential(
   (0): Linear(in_features=4, out_features=8, bias=True)
   (1): ReLU()
   (2): Linear(in_features=8, out_features=8, bias=True)
   (3): ReLU()
   (4): Linear(in_features=8, out_features=8, bias=True)
   (5): ReLU()
   (6): Linear(in_features=8, out_features=4, bias=True)
 ))