# Weight Initilization

## Option 1. After instanciating a network

In [None]:
import torch
import torch.nn as nn

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
        nn.init.xavier_uniform_(m.weight, gain=1.0)
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, nn.Linear):
        torch.nn.init.normal_(m.weight,mean=0.0, std=1.0)
        if m.bias is not None:
            m.bias.data.zero_()
    elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.InstanceNorm2d):
        if m.weight is not None:
            torch.nn.init.constant_(m.weight, 1)
            torch.nn.init.constant_(m.bias, 0)

In [None]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=3//2, bias=True)
        self.bn = nn.BatchNorm2d(1)
        self.linear = nn.Linear(in_features=1, out_features=1)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.linear(x.view(x.size(0), -1))
        
        return x

net = Net()

print('before init')
for name, params in dict(net.named_parameters()).items():
    print('{}: {}'.format(name, params))
    
net.apply(weights_init)

print('\n\nAfter init')
for name, params in dict(net.named_parameters()).items():
    print('{}: {}'.format(name, params))

## Option 2. In constructor of network (nn.Module)

In [None]:
import torch
import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        ## layers ##
        self.conv = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, stride=1, padding=3//2, bias=True)
        self.bn = nn.BatchNorm2d(1)
        self.linear = nn.Linear(in_features=1, out_features=1)
        
        ## initialization ##
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
                nn.init.xavier_uniform_(m.weight, gain=1.0)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                torch.nn.init.normal_(m.weight, mean=0.0, std=1.0)
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.InstanceNorm2d):
                if m.weight is not None:
                    torch.nn.init.constant_(m.weight, 1)
                    torch.nn.init.constant_(m.bias, 0)
            
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.linear(x.view(x.size(0), -1))
        
        return x

net = Net()