In [1]:
import torch
from torch import nn
from torch.nn import functional as F

class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, X):
        return X-X.mean()

In [2]:
layer = CenteredLayer()
layer(torch.FloatTensor([1,2,3,4,5]))

tensor([-2., -1.,  0.,  1.,  2.])

In [3]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

In [4]:
Y = net(torch.rand(4, 8))
Y.mean()

tensor(4.6566e-09, grad_fn=<MeanBackward0>)

In [5]:
class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.myweight = nn.Parameter(torch.rand(in_units, units))
        self.mybias = nn.Parameter(torch.rand(units, ))

    def forward(self, X):
        linear = torch.matmul(X, self.myweight) + self.mybias
        return F.relu(linear)

In [6]:
dense = MyLinear(5, 3)
dense.myweight

Parameter containing:
tensor([[0.2011, 0.2613, 0.4144],
        [0.4295, 0.8542, 0.6101],
        [0.3472, 0.9560, 0.4946],
        [0.9982, 0.5486, 0.5743],
        [0.6757, 0.8661, 0.6392]], requires_grad=True)

In [7]:
dense(torch.rand(2, 5))

tensor([[1.9120, 2.7579, 1.9180],
        [1.9961, 2.6571, 2.0515]], grad_fn=<ReluBackward0>)

In [8]:
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

tensor([[46.8640],
        [47.3184]], grad_fn=<ReluBackward0>)

In [9]:
print(net)

Sequential(
  (0): MyLinear()
  (1): MyLinear()
)


In [10]:
print(net[0].state_dict())

OrderedDict([('myweight', tensor([[0.6763, 0.9285, 0.6650, 0.4379, 0.8160, 0.9833, 0.8908, 0.1153],
        [0.4434, 0.2526, 0.8612, 0.8973, 0.9209, 0.1105, 0.0112, 0.1107],
        [0.8394, 0.3208, 0.9308, 0.4441, 0.5919, 0.5947, 0.3682, 0.5327],
        [0.9144, 0.4309, 0.3892, 0.6737, 0.8484, 0.9908, 0.9117, 0.8884],
        [0.4068, 0.6723, 0.1192, 0.1207, 0.9574, 0.4091, 0.1627, 0.6301],
        [0.0984, 0.5851, 0.6725, 0.2058, 0.3099, 0.9555, 0.6498, 0.4833],
        [0.8902, 0.0149, 0.4872, 0.9073, 0.4683, 0.0861, 0.2122, 0.1248],
        [0.7613, 0.7016, 0.1263, 0.7739, 0.9175, 0.5732, 0.0743, 0.2472],
        [0.6960, 0.8611, 0.8251, 0.1119, 0.7616, 0.8704, 0.7637, 0.8231],
        [0.4605, 0.7099, 0.9542, 0.2153, 0.9965, 0.0847, 0.6976, 0.4207],
        [0.2057, 0.8186, 0.9140, 0.8425, 0.7280, 0.3020, 0.4794, 0.2164],
        [0.9116, 0.1172, 0.6713, 0.3868, 0.3698, 0.7831, 0.4887, 0.3897],
        [0.0601, 0.0545, 0.9092, 0.9795, 0.5273, 0.2835, 0.1712, 0.8860],
        [0.3