In [2]:
import torch
import torch.nn as nn

## Layer Normalization
- The mean and standard-deviation are calculated separately over the last
certain number dimensions

In [12]:
## Example 1: Layer normalization on the last dimension
## input
x = torch.randn(2,3,4)
layerNorm = torch.nn.LayerNorm(4, elementwise_affine = False)
## output from layer normalization
y1 = layerNorm(x)

## compute mean and standard error on given dimension for layer normalization
## here dim = 2
mean = x.mean((2), keepdim = True) 
print(f"shape of mean in normalization:{mean.shape}")

s2 = (x-mean)**2
var = s2.mean((2), keepdim = True)
y2 = (x-mean)/torch.sqrt(var+layerNorm.eps)
print("y1 should equal to y2")
print(f"y1 from pytorch:{y1[0,:,:]}")
print(f"y2 from pytorch:{y2[0,:,:]}")
print(f"check one of the mean of output from layernorm: {y1[0,0,:].mean()}")

shape of mean in normalization:torch.Size([2, 3, 1])
y1 should equal to y2
y1 from pytorch:tensor([[-0.7852,  0.6866,  1.2550, -1.1563],
        [ 0.3378, -0.3639,  1.3829, -1.3568],
        [ 1.1503, -0.3477, -1.4580,  0.6554]])
y2 from pytorch:tensor([[-0.7852,  0.6866,  1.2550, -1.1563],
        [ 0.3378, -0.3639,  1.3829, -1.3568],
        [ 1.1503, -0.3477, -1.4580,  0.6554]])
check one of the mean of output from layernorm: 0.0


In [13]:
## Example 2: Layer normalization on the last two dimension
## input
x = torch.randn(2,3,4)
layerNorm = torch.nn.LayerNorm((3,4), elementwise_affine = False)
## output from layer normalization
y1 = layerNorm(x)

## compute mean and standard error on given dimension for layer normalization
## here dim = 2
mean = x.mean((1,2), keepdim = True) 
print(f"shape of the mean for layer normalization:{mean.shape}")

s2 = (x-mean)**2
var = s2.mean((1,2), keepdim = True)
y2 = (x-mean)/torch.sqrt(var+layerNorm.eps)
print("y1 should equal to y2")
print(f"y1 from pytorch:{y1[0,:,:]}")
print(f"y2 from pytorch:{y2[0,:,:]}")
print(f"check one of the mean of output from layernorm: {y1[0,:,:].mean()}")

shape of the mean for layer normalization:torch.Size([2, 1, 1])
y1 should equal to y2
y1 from pytorch:tensor([[ 2.3093, -0.4518, -0.8980,  0.3362],
        [-2.0318, -0.2471,  0.1082,  0.2216],
        [ 0.3212,  0.3306,  0.7360, -0.7342]])
y2 from pytorch:tensor([[ 2.3093, -0.4518, -0.8980,  0.3362],
        [-2.0318, -0.2471,  0.1082,  0.2216],
        [ 0.3212,  0.3306,  0.7360, -0.7342]])
check one of the mean of output from layernorm: 1.9868215517249155e-08


In [14]:
## Example 3: Layer normalization on the last two dimension
## input
x = torch.randn(2,3,4,5)
layerNorm = torch.nn.LayerNorm((3,4,5), elementwise_affine = False)
## output from layer normalization
y1 = layerNorm(x)

## compute mean and standard error on given dimension for layer normalization
## here dim = 2
mean = x.mean((1,2,3), keepdim = True) 
print(f"shape of the mean for layer normalization:{mean.shape}")

s2 = (x-mean)**2
var = s2.mean((1,2,3), keepdim = True)
y2 = (x-mean)/torch.sqrt(var+layerNorm.eps)
print("y1 should equal to y2")
print(f"y1 from pytorch:{y1[0,:,:]}")
print(f"y2 from pytorch:{y2[0,:,:]}")
print(f"check one of the mean of output from layernorm: {y1[0,:,:,:].mean()}")

shape of the mean for layer normalization:torch.Size([2, 1, 1, 1])
y1 should equal to y2
y1 from pytorch:tensor([[[-2.6620e+00, -1.4094e+00,  1.9217e+00, -2.1725e-01, -6.4769e-01],
         [-1.1906e-01,  1.5994e-03, -9.5264e-01, -6.4971e-01,  3.7357e-01],
         [ 1.8343e-01,  1.9132e+00,  1.2004e-01, -3.2016e-01, -9.1583e-01],
         [ 4.9316e-01,  8.7327e-01, -3.7027e-01,  1.3901e+00, -1.1653e+00]],

        [[ 2.2722e+00, -2.2620e-01, -9.3603e-02,  9.2616e-02, -2.6175e-01],
         [ 5.0905e-01, -6.1658e-01,  2.1296e+00, -1.4511e+00, -7.3320e-01],
         [-1.2508e+00, -1.6316e-03,  9.8874e-01,  6.4759e-01, -5.0525e-01],
         [ 1.7273e+00, -1.4001e+00, -3.9541e-01,  1.0999e-01,  1.9819e+00]],

        [[-3.1121e-01,  3.5838e-01, -5.3852e-01, -3.8746e-01,  1.9319e+00],
         [ 4.9711e-01,  3.4969e-01, -5.8687e-01,  5.7491e-01,  5.3429e-01],
         [-3.0983e-01, -1.0575e+00,  1.0758e-01,  3.1456e-01,  2.9268e-01],
         [-7.8982e-02, -8.4685e-01, -1.2334e+00, -1.183