# Weight Initialization

In [1]:
import torch
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
layer = nn.Linear(5, 5)

In [3]:
layer.weight.data # Usually Xavier initialization

tensor([[-0.3234,  0.4328,  0.4079,  0.4000,  0.1625],
        [ 0.2432, -0.0117, -0.3947, -0.4365, -0.0971],
        [ 0.0035,  0.3927,  0.1220,  0.1230, -0.2305],
        [ 0.3974,  0.2450,  0.0654,  0.4270,  0.1636],
        [-0.3298,  0.4426, -0.3099, -0.2948,  0.3945]])

In [4]:
# Uniform initialization
nn.init.uniform_(layer.weight, a=0, b=3) # .data can be omitted

Parameter containing:
tensor([[0.6655, 1.9886, 1.4633, 0.9382, 0.1642],
        [2.2031, 2.2106, 0.7725, 1.3087, 1.3745],
        [0.6798, 2.7702, 2.4321, 0.0656, 1.8431],
        [2.7083, 1.7614, 1.5500, 2.2101, 1.3480],
        [0.7889, 0.9701, 2.9232, 2.4584, 2.0841]], requires_grad=True)

In [5]:
# Normal initialization
nn.init.normal_(layer.weight, mean=0, std=1)

Parameter containing:
tensor([[-0.9896, -0.3348,  0.0225,  0.2434, -0.2330],
        [ 1.4850, -0.2462, -0.6582,  0.5431, -0.8584],
        [ 1.0214,  0.3502,  0.8111,  0.3463,  1.1300],
        [ 0.9715, -0.3962,  0.9713,  0.7406, -0.0097],
        [-0.0235,  0.1872, -0.4905,  0.5582, -0.2618]], requires_grad=True)

In [6]:
# Constant initialization
nn.init.constant_(layer.bias, val=0.5) # Can be used for bias

Parameter containing:
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000], requires_grad=True)

In [7]:
# Xavier initialization
nn.init.xavier_uniform_(layer.weight, gain=1) # gain is the multiplicative factor

Parameter containing:
tensor([[ 0.1196,  0.2023, -0.0958, -0.7721,  0.2940],
        [-0.4039, -0.7646,  0.6115, -0.1018,  0.4988],
        [ 0.4709, -0.6105, -0.4641,  0.1008, -0.6438],
        [ 0.7040, -0.7591,  0.3124, -0.4054,  0.6014],
        [-0.1464, -0.6700,  0.6269, -0.5325, -0.2364]], requires_grad=True)

In [8]:
nn.init.xavier_normal_(layer.weight, gain=1)

Parameter containing:
tensor([[-0.2238,  0.0364,  0.1786, -0.1559, -0.7159],
        [-0.1859, -0.2145,  0.0938,  0.1755,  0.3503],
        [ 0.0616, -0.3474,  0.6791,  0.6956,  0.2839],
        [ 1.1012, -0.3483,  0.1877,  0.8971, -0.3752],
        [-0.1214, -0.1176, -0.0534,  0.2296, -0.3338]], requires_grad=True)