In [1]:
import torch
from torch import nn

In [2]:
class CenteredLayer(nn.Module):
    def __init__(self,**kwargs):
        super(CenteredLayer,self).__init__(**kwargs)
        
    def forward(self,x):
        return x-x.mean()

In [3]:
layer=CenteredLayer()
layer(torch.tensor([1,2,3,4,5],dtype=torch.float))

tensor([-2., -1.,  0.,  1.,  2.])

In [4]:
net=nn.Sequential(nn.Linear(8,128),CenteredLayer())

In [5]:
y=net(torch.rand(4,8))
y.mean().item()

1.1175870895385742e-08

In [6]:
class MyDense(nn.Module):
    def __init__(self):
        super(MyDense,self).__init__()
        self.params=nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(3)]) 
        self.params.append(nn.Parameter(torch.randn(4,1)))
        
    def forward(self,x):
        for i in range(len(self.params)):
            x=torch.mm(x,self.params[i])
        return x
net=MyDense()
print(net)

MyDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)


In [11]:
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense,self).__init__()
        self.params=nn.ParameterDict({
            'linear1':nn.Parameter(torch.randn(4,4)),
            'linear2':nn.Parameter(torch.randn(4,1))
        })
        self.params.update({'linear3':nn.Parameter(torch.randn(4,2))})
        
    def forward(self,x,choice='linear1'):
        return torch.mm(x,self.params[choice])


net=MyDictDense()
print(net)

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)


In [12]:
x=torch.ones(1,4)
print(net(x,'linear1'))
print(net(x,'linear2'))
print(net(x,'linear3'))

tensor([[-3.0380, -0.7905,  1.2824, -2.4360]], grad_fn=<MmBackward>)
tensor([[-2.5932]], grad_fn=<MmBackward>)
tensor([[-0.7617,  0.8655]], grad_fn=<MmBackward>)


In [15]:
net=nn.Sequential(
    MyDictDense(),
    MyDense(),
)
print(net)
print(net(x))

Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
  (1): MyDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)
tensor([[16.2906]], grad_fn=<MmBackward>)
