<a href="https://colab.research.google.com/github/151ali/lr-pytorch/blob/main/4_weight_initialisation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[doc](https://pytorch.org/docs/stable/nn.init.html)

In [1]:
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# simple CNN net
class CNN(nn.Module):
  def __init__(self, in_channels, num_classes):
    super(CNN, self).__init__()

    self.conv1 = nn.Conv2d(
        in_channels=in_channels,
        out_channels=6,
        kernel_size=3,
        stride=1,
        padding=1
    )
    self.pool = nn.MaxPool2d(
        kernel_size=(2, 2),
        stride=(2, 2)
    )
    self.conv2 = nn.Conv2d(
        in_channels=6,
        out_channels=16,
        kernel_size=3,
        stride=1,
        padding=1
    )
    self.fc1 = nn.Linear(16*7*7, num_classes)

    # more initializations
    self.initialize_weights()

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.pool(x)
    x = F.relu(self.conv2(x))
    x = self.pool(x)
    x.reshape(x.shape[0], -1)
    x = self.fc1(x)
    return x

  def initialize_weights(self):
    for module in self.modules():
      if isinstance(module, nn.Conv2d):
        nn.init.kaiming_uniform_(module.weight)
        if module.bias is not None:
          nn.init.constant_(module.bias, 0)

      elif isinstance(module, nn.BatchNorm2d):
        nn.init.constant_(module.weight, 1)
        nn.init.constant_(module.bias, 0)

      elif isinstance(module, nn.Linear):
        nn.init.kaiming_uniform_(module.weight)
        nn.init.constant_(module.bias, 0)


In [3]:
model = CNN(
    in_channels=3,
    num_classes=10
    )