# 自定义层
这里的意思是计算上的自定义, 而不是用现有的进行搭建, 不是搭积木, 而是创建新的积木

In [1]:
import torch
from torch import nn
from torch.nn import functional as F

`CenteredLayer`: 它从输入中减去均值, 不接受任何参数

In [3]:
class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, X):
        return X-X.mean()

net = nn.Sequential(nn.Linear(3, 6), CenteredLayer())
iX = torch.randn(4, 3)
io = net(iX)
net, io.mean()

(Sequential(
   (0): Linear(in_features=3, out_features=6, bias=True)
   (1): CenteredLayer()
 ),
 tensor(0., grad_fn=<MeanBackward0>))

---
一个带参数的层, 例如实现一个MyLinear层, 它接受n_in, n_out, bias作为参数, 设有权重weight和bias项(if bias==true ), 具有与nn.Linear相同的行为.

In [9]:
class MyLinearA(nn.Module):
    def __init__(self, n_in, n_out, bias=True):
        super().__init__()
        self.hav_bias = bias
        self.weight = torch.randn(n_out, n_in, requires_grad=True)
        if self.hav_bias:
            self.bias = torch.zeros(n_out, requires_grad=True)
    def forward(self, X):
        Z = torch.matmul(X, self.weight.T)
        if self.hav_bias:
            Z = Z + self.bias
        return Z
net = nn.Sequential(MyLinearA(4, 32), MyLinearA(32, 1))
net(torch.randn(3, 4)), net, net.state_dict()
# 可以看出应该是有些问题的, state_dict获取不到参数

(tensor([[12.8676],
         [-3.5359],
         [-4.9836]], grad_fn=<AddBackward0>),
 Sequential(
   (0): MyLinearA()
   (1): MyLinearA()
 ),
 OrderedDict())

以下来自参考代码, 可见需要转换类型

In [11]:
class MyLinear(nn.Module):
    def __init__(self, i_, o_):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(i_, o_))
        self.bias = nn.Parameter(torch.randn(o_,))
    def forward(self, X):
        return F.relu(torch.matmul(X, self.weight)+self.bias)
net = MyLinear(8, 4)
net(torch.randn(3, 8)), net, net.state_dict()

(tensor([[0.0000, 2.4367, 0.0000, 2.7845],
         [0.0000, 2.2957, 1.3112, 0.0000],
         [0.4345, 3.5540, 0.0000, 1.0051]], grad_fn=<ReluBackward0>),
 MyLinear(),
 OrderedDict([('weight',
               tensor([[-2.4772, -0.9873, -0.1902, -0.8379],
                       [ 1.2272, -1.5311,  0.6536, -1.3510],
                       [ 0.3239, -0.5081, -0.8020,  0.7834],
                       [-0.0946,  0.0701, -0.8087,  0.9963],
                       [-1.6395,  0.6347, -0.8426,  0.7625],
                       [ 0.5313,  0.7384,  0.8902,  1.1225],
                       [-0.1281,  0.8938, -0.2132, -0.5563],
                       [ 0.1858,  1.1956, -0.2290, -2.4159]])),
              ('bias', tensor([-0.5148,  1.0647, -0.2149,  0.5073]))]))