# 6.3 Parameter Initialization

In [1]:
import torch
from torch import nn

net = nn.Sequential(nn.LazyLinear(8), nn.ReLU(), nn.LazyLinear(1))
X = torch.rand(size = (2, 4))
net(X).shape



torch.Size([2, 1])

## 6.3.1 Built-in Initialization

In [2]:
def init_normal(module):
    if type(module) == nn.Linear:
        nn.init.normal_(module.weight, mean = 0, std = 0.01)
        nn.init.zeros_(module.bias)

net.apply(init_normal)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([-0.0032, -0.0133,  0.0060,  0.0006]), tensor(0.))

In [3]:
def init_constant(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 1)
        nn.init.zeros_(module.bias)

net.apply(init_constant)
net[0].weight.data[0], net[0].bias.data[0]

(tensor([1., 1., 1., 1.]), tensor(0.))

In [4]:
def init_xavier(module):
    if type(module) == nn.Linear:
        nn.init.xavier_uniform_(module.weight)
    
def init_42(module):
    if type(module) == nn.Linear:
        nn.init.constant_(module.weight, 42)
    
net[0].apply(init_xavier)
net[2].apply(init_42)
print(net[0].weight.data[0])
print(net[2].weight.data)

tensor([ 0.7041, -0.4528, -0.6375, -0.3405])
tensor([[42., 42., 42., 42., 42., 42., 42., 42.]])


In [8]:
def my_init(module):
    if type(module) == nn.Linear:
        print("Init", *[(name, param.shape) for name, param in module.named_parameters()][0])
        nn.init.uniform_(module.weight, -10, 10)
        module.weight.data *= module.weight.data.abs() >= 5

net.apply(my_init)
net[0].weight[:2]

print(net[0].weight.data)

net[0].weight.data[:] += 1
net[0].weight.data[0, 0] = 42
net[0].weight.data[0]

Init weight torch.Size([8, 4])
Init weight torch.Size([1, 8])
tensor([[ 0.0000, -0.0000,  0.0000, -0.0000],
        [ 0.0000, -0.0000, -0.0000,  7.9620],
        [ 0.0000, -5.4582, -0.0000, -9.6926],
        [-0.0000, -0.0000,  0.0000, -0.0000],
        [-9.1287,  6.2427,  7.6209, -5.9902],
        [ 0.0000, -8.5045,  8.2143,  7.5000],
        [-0.0000, -5.7318,  0.0000, -0.0000],
        [-0.0000,  7.8663, -0.0000,  8.3928]])


tensor([42.,  1.,  1.,  1.])