Orthogonal Weight Initialization (with a specific gain) and Zero Bias Initializaton

In [2]:
import torch.nn as nn

In [3]:
def init(module, weight_init, bias_init, gain =1):
    weight_init(module.weight.data, gain = gain)
    if hasattr(module, 'bias') and module.bias is not None:
        bias_init(module.bias.data)
    return module

def debug_init(module):
    print(f"\nInitializing layer: {module.__class__.__name__}")
    print(f"Before init - weight shape: {module.weight.shape}")
    print(f"Before init - weight stats: mean = {module.weight.mean():.4f}, std = {module.weight.std():.4f}")

    init(module,
         nn.init.orthogonal_,
         lambda x: nn.init.constant_(x, 0),
         gain = 0.01)
    print(f"After init - weight stats: mean = {module.weight.mean():.4f}, std = {module.weight.std(): .4f}")
    if hasattr(module, 'bias'):
        print(f"Bias values: {module.bias}")

    return module

layer = nn.Linear(3,3)
debug_init(layer)


Initializing layer: Linear
Before init - weight shape: torch.Size([3, 3])
Before init - weight stats: mean = 0.0917, std = 0.3389
After init - weight stats: mean = -0.0030, std =  0.0052
Bias values: Parameter containing:
tensor([0., 0., 0.], requires_grad=True)


Linear(in_features=3, out_features=3, bias=True)