In [1]:
import torch
import torch.nn.functional as F
from torch import nn

In [9]:
class CenterLayer(nn.Module):
    def __init__(self):
        super().__init__()
    def forward(self, X):
        return X - X.mean()

In [10]:
layer = CenterLayer()

In [14]:
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

tensor([-2., -1.,  0.,  1.,  2.])

In [15]:
X = torch.arange(5, dtype=torch.float32)

In [16]:
layer(X)

tensor([-2., -1.,  0.,  1.,  2.])

In [17]:
net = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 1))

In [18]:
def init_ones(m):
    if type(m) == nn.Linear:
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

In [19]:
net.apply(init_ones)

Sequential(
  (0): Linear(in_features=5, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=1, bias=True)
)

In [20]:
net[0].weight

Parameter containing:
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]], requires_grad=True)

In [21]:
net[0].bias

Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], requires_grad=True)

In [22]:
torch.save(net.state_dict(), "test.params")

In [24]:
net2=nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 1))

In [25]:
net2[0].weight

Parameter containing:
tensor([[ 0.2046,  0.1665,  0.0622, -0.4409, -0.3601],
        [ 0.1157, -0.0843, -0.0527,  0.2909, -0.4455],
        [-0.2787, -0.0076,  0.1805, -0.3056,  0.2448],
        [-0.3835, -0.2038, -0.0088,  0.1396, -0.2942],
        [-0.0318, -0.1889, -0.4036, -0.3052, -0.1948],
        [ 0.1329, -0.2048, -0.2307,  0.4171,  0.4152],
        [ 0.1933,  0.0916, -0.2804, -0.0864, -0.1279],
        [ 0.3767, -0.1776, -0.2878,  0.3111, -0.3860],
        [-0.2102, -0.2435,  0.4064,  0.3377,  0.2058],
        [-0.2186, -0.4056, -0.0354,  0.2544,  0.3469]], requires_grad=True)

In [26]:
net2.load_state_dict(torch.load('test.params'))

<All keys matched successfully>

In [27]:
net2[0].weight

Parameter containing:
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]], requires_grad=True)

In [28]:
net2 == net

False

In [29]:
net2

Sequential(
  (0): Linear(in_features=5, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=1, bias=True)
)

In [30]:
net

Sequential(
  (0): Linear(in_features=5, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=1, bias=True)
)

In [33]:
X = torch.rand((2, 5))

In [35]:
net(X) == net2(X)

tensor([[True],
        [True]])

In [36]:
torch.save(X, "x.param")

In [37]:
y = torch.load('x.param')

In [38]:
y==X

tensor([[True, True, True, True, True],
        [True, True, True, True, True]])