<a href="https://colab.research.google.com/github/YonggunJung/Fastcompus/blob/main/5_Batch_normalization_and_its_variations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Batch normalization , Layer Normalization,  Instance Normalization, and Group Normalization

![](./data/group_norm.png)



https://wandb.ai/wandb_fc/GroupNorm/reports/Group-Normalization-in-Pytorch-With-Examples---VmlldzoxMzU0MzMy



https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html

In [None]:
import torch, torchvision
import torch.nn as nn
import torchvision.models as models
import torchvision.datasets as datasets

import matplotlib.pyplot as plt
from PIL import Image


Batch norm, layer norm, instance norm, group norm

https://pytorch.org/docs/stable/nn.html





https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html#torch.nn.LayerNorm

https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html
    

In [None]:
class BatchNorm(nn.Module):
    def __init__(self, in_channel, out_channels):
        super(BatchNorm, self).__init__()
        self.bn = nn.BatchNorm2d(in_channel)

    def forward(self,x):
        out = self.bn(x)  #[N, C, HW] -> [N, C, HW]


        return out


## For different sequences, e.g., RNN.
class LayerNorm(nn.Module):
    def __init__(self, in_shape, out_channels):
        super(LayerNorm, self).__init__()
        self.ln = nn.LayerNorm(in_shape, eps=1e-08)

    def forward(self,x):
        out = self.ln(x)  #[N, C, HW] -> [N, C, HW]


        return out


## For style transfer, domain adaptation.
class InstanceNorm(nn.Module):
    def __init__(self, in_channel, out_channels):
        super(InstanceNorm, self).__init__()
        self.In = nn.InstanceNorm2d(in_channel, eps=1e-08)

    def forward(self,x):
        out = self.In(x)  #[N, C, HW] -> [N, C, HW]
        return out


## stable in small batch size.
class GroupNorm(nn.Module):
    def __init__(self, group_size, in_channel, out_channels):
        super(GroupNorm, self).__init__()
        self.gn = nn.GroupNorm(group_size, in_channel, eps=1e-08)  ## num_group and in_channel

    def forward(self,x):
        out = self.gn(x) #[N, C, HW] -> [N, C, HW]

        return out


In [None]:
in_channel = 64
feature = torch.randn(8, in_channel, 120, 120)  ## temp tensor [B, C, H, W]


BN = BatchNorm(in_channel, out_channels=64)

out_feat = BN(feature)

print(out_feat.shape)

torch.Size([8, 64, 120, 120])


In [None]:
LN = LayerNorm(in_shape=list(feature.shape[1:]), out_channels=64)

out_feat = LN(feature)

print(out_feat.shape)

torch.Size([8, 64, 120, 120])


In [None]:
IN=InstanceNorm(in_channel, out_channels=64)

out_feat = IN(feature)

print(out_feat.shape)

torch.Size([8, 64, 120, 120])


In [None]:
GN=GroupNorm(group_size=2, in_channel=in_channel, out_channels=64)

out_feat = GN(feature)

print(out_feat.shape)  ## 32 / 32

GN=GroupNorm(group_size=4, in_channel=in_channel, out_channels=64)

out_feat = GN(feature)

print(out_feat.shape)  ## 16 / 16 / 16 / 16

torch.Size([8, 64, 120, 120])
torch.Size([8, 64, 120, 120])
