In [1]:
import torch
from torch import nn

In [12]:
net = nn.Sequential(
    nn.LazyLinear(8),
    nn.ReLU(),
    nn.LazyLinear(1)
)
X = torch.rand(size=(2,4))
net(X).shape

torch.Size([2, 1])

In [13]:
def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight, mean=0, std=0.01)
        nn.init.zeros_(module.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([ 0.0208,  0.0104, -0.0115, -0.0033]), tensor(0.))

In [15]:
def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([1., 1., 1., 1.]), tensor(0.))

In [16]:
def init_Xavier(module):
    if type(module) == nn.Linear:
        nn.init.xavier_normal_(module.weight)

def init_42(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 42)

net[0].apply(init_Xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)

tensor([ 0.3888, -0.2313, -0.0433, -0.2116])
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])


In [17]:
def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param.shape)
                        for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]

Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])


tensor([[ 9.6928, -0.0000,  0.0000,  0.0000],
        [-0.0000,  6.2407,  8.9003, -6.4578]], grad_fn=<SliceBackward0>)