In [7]:
import torch
class BatchNorm(torch.nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.eps = 1e-5
        self.scale = torch.nn.Parameter(torch.ones(1, emb_size))
        self.bias = torch.nn.Parameter(torch.zeros(1, emb_size))

    def forward(self, x):
        mean = x.mean(dim=0, keepdim=True)
        var = x.var(dim=0, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.bias

In [9]:
import torch
class LayerNorm(torch.nn.Module):
    def __init__(self, emb_size):
        super().__init__()
        self.eps = 1e-5
        self.scale = torch.nn.Parameter(torch.ones(1, emb_size))
        self.bias = torch.nn.Parameter(torch.zeros(1, emb_size))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.bias

In [11]:
inp = torch.randn(2, 5)
bn = BatchNorm(5)
out = bn(inp)
mean = out.mean(dim=0, keepdim=True)
var = out.var(dim=0, unbiased=False, keepdim=True)

print("After BN")
print("Mean:\n", mean)
print("Variance:\n", var)

After BN
Mean:
 tensor([[ 0.0000e+00, -1.7881e-07,  0.0000e+00,  5.9605e-08,  0.0000e+00]],
       grad_fn=<MeanBackward1>)
Variance:
 tensor([[1.0000, 0.9997, 1.0000, 1.0000, 1.0000]], grad_fn=<VarBackward0>)


In [12]:
inp = torch.randn(2, 5)
bn = LayerNorm(5)
out = bn(inp)
mean = out.mean(dim=-1, keepdim=True)
var = out.var(dim=-1, unbiased=False, keepdim=True)

print("After BN")
print("Mean:\n", mean)
print("Variance:\n", var)

After BN
Mean:
 tensor([[ 0.0000e+00],
        [-3.2573e-08]], grad_fn=<MeanBackward1>)
Variance:
 tensor([[1.0000],
        [1.0000]], grad_fn=<VarBackward0>)
