In [1]:
import torch
from torch import nn

In [2]:
class MyDense(nn.Module):
    def __init__(self):
        super(MyDense,self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4,1)))
    def forward(self,x):
        for i in range(len(self.params)):
            x = torch.mm(x,self.params[i])
        return x
net = MyDense()
print(net)

MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)


In [8]:
#parameterdict 接收一个parameter实例的字典作为输入然后得到一个参数字典，然后就可以按照字典的规则使用了
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense,self).__init__()
        self.params = nn.ParameterDict({
            'linear1':nn.Parameter(torch.randn(4,4)),
            'linear2':nn.Parameter(torch.randn(4,1))
        })
        self.params.update({'linear3':nn.Parameter(torch.randn(4,2))})
        self.params.update({'linear4':nn.Parameter(torch.randn(5,6))})
        self.params.update({'测试':nn.Parameter(torch.randn(8,3))})
    def forward(self,x,choice='linear1'):
        return torch.mm(x,self.params[choice])
net = MyDictDense()
print(net)

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
      (linear4): Parameter containing: [torch.FloatTensor of size 5x6]
      (测试): Parameter containing: [torch.FloatTensor of size 8x3]
  )
)


In [11]:
x = torch.ones(1,4)
print(net(x,'linear1'))
print(net(x,'linear2'))
print(net(x,'linear3'))

tensor([[-2.0481, -1.1100, -2.7620,  3.0115]], grad_fn=<MmBackward>)
tensor([[-3.6352]], grad_fn=<MmBackward>)
tensor([[ 1.5910, -1.8178]], grad_fn=<MmBackward>)


In [18]:
net = nn.Sequential(
    MyDictDense(),
)
print(net)
print(net(x))

Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
        (linear4): Parameter containing: [torch.FloatTensor of size 5x6]
        (测试): Parameter containing: [torch.FloatTensor of size 8x3]
    )
  )
)
tensor([[ 0.9115,  0.4142, -3.3413, -0.1003]], grad_fn=<MmBackward>)
