# Testing how PyTorch Group Normalization layers compute mean & variance (whether they collapse image dimensions or not)

In [1]:
import numpy as np 
import torch 

In [2]:
# define custom function for group normalization
def group_norm2(input, num_groups, weight=None, bias=None, eps=1e-05, reduce=True,verbose=True):

    batch_size, num_channels, height, width = input.size() 
    if verbose:
        print(f'input size (BCHW): [{input.shape}]')

    input = input.view(batch_size, num_groups, num_channels//num_groups, height, width)

    if verbose:
        print(f'after grouping: [{input.shape}]')

    if reduce:
        mean = torch.mean(input, dim=(2,3,4), keepdim=True)
        var = torch.var(input, dim=(2,3,4), unbiased=False, keepdim=True)
        if verbose:
            print('mean/var shape: ',mean.shape)
    else:
        mean = torch.mean(input, dim=2, keepdim=True)
        var = torch.var(input, dim=2, unbiased=False, keepdim=True)
        if verbose:
            print('mean/var shape: ',mean.shape)
    
    input = (input - mean) / torch.sqrt(var+eps)

    input = input.view(batch_size, num_channels, height, width)

    if weight is None:
        weight = torch.tensor([1]).repeat(num_channels)
    if bias is None:
        bias = torch.tensor([0]).repeat(num_channels)

    # transform weight & bias [C] -> [N, C, H, W]
    weight = weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) # [1, C, 1, 1]
    weight = weight.expand(batch_size, num_channels, height, width) # [N, C, H, W]
    bias = bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) # [1, C, 1, 1]
    bias = bias.expand(batch_size, num_channels, height, width) # [N, C, H, W]

    input = input*weight + bias 

    return input

In [3]:
N, C, H, W = 1, 6, 3, 3 # batch size, channel, height, width

n_groups = 3

input = torch.randn(N,C,H,W)

# random weights
gamma = torch.randn(C)
beta = torch.randn(C)

out =group_norm = torch.nn.functional.group_norm(
    input,
    num_groups = n_groups,
    weight = gamma,
    bias= beta,
    eps = 1e-05
)

print('Group norm reducing H & W dimensions')
out2= group_norm2(input, num_groups= n_groups, weight = gamma, bias = beta, eps=1e-05)

print('\nGroup norm without reducing H & W dimensions')
out3= group_norm2(input, num_groups= n_groups, weight = gamma, bias = beta, eps=1e-05, reduce=False)

Group norm reducing H & W dimensions
input size (BCHW): [torch.Size([1, 6, 3, 3])]
after grouping: [torch.Size([1, 3, 2, 3, 3])]
mean/var shape:  torch.Size([1, 3, 1, 1, 1])

Group norm without reducing H & W dimensions
input size (BCHW): [torch.Size([1, 6, 3, 3])]
after grouping: [torch.Size([1, 3, 2, 3, 3])]
mean/var shape:  torch.Size([1, 3, 1, 3, 3])


In [4]:
print('pytorch group norm')
display(out.reshape((-1)))

print('group norm (computes mean/var by collapsing H & W dimensions)')
display(out2.reshape((-1)))

print('group norm (computes mean/var only along C dimensions (H & W dimensions are not collapsed))')
display(out3.reshape((-1)))

print('Output differences: ')
print(f'Between pytorch group norm and the custom group norm with H&W collapsing: {(out-out2).sum():.5f}')
print(f'Between pytorch group norm and the custom group norm without H&W collapsing: {(out-out3).sum():.5f}')

pytorch group norm


tensor([ 4.8982e-01, -1.0988e+00,  1.5307e+00, -7.2852e-01,  1.9463e+00,
        -1.8436e+00, -2.6760e-01,  2.0959e+00, -1.6022e+00,  2.3373e-01,
        -1.2876e-01, -1.7684e-02,  5.6397e-02, -1.5088e-01, -8.1340e-01,
        -5.0214e-01, -2.1661e+00, -5.2516e-01, -1.0343e+00, -1.3529e-02,
        -2.4171e-01, -4.0385e-01,  3.7268e-02, -5.0453e-01,  2.8989e-01,
         5.3631e-02, -2.9229e-01,  2.5705e+00,  8.6528e-01,  2.4461e+00,
         4.0048e-02, -4.1047e-01, -3.0143e-01,  2.0454e+00,  7.0477e-02,
         1.2665e+00,  1.3012e+00,  3.4624e-01,  3.5497e+00, -5.6892e-01,
         1.3078e+00,  2.7435e-01,  1.2711e-01, -1.0073e+00,  9.0564e-01,
         3.1841e+00, -5.2309e-01, -1.8582e+00, -3.5025e+00,  2.7110e-01,
         1.7705e+00,  3.2191e-03,  1.9183e+00,  6.7090e-01])

group norm (computes mean/var by collapsing H & W dimensions)


tensor([ 4.8982e-01, -1.0988e+00,  1.5307e+00, -7.2852e-01,  1.9463e+00,
        -1.8436e+00, -2.6760e-01,  2.0959e+00, -1.6022e+00,  2.3373e-01,
        -1.2876e-01, -1.7684e-02,  5.6397e-02, -1.5088e-01, -8.1340e-01,
        -5.0214e-01, -2.1661e+00, -5.2516e-01, -1.0343e+00, -1.3529e-02,
        -2.4171e-01, -4.0385e-01,  3.7268e-02, -5.0453e-01,  2.8989e-01,
         5.3631e-02, -2.9229e-01,  2.5705e+00,  8.6528e-01,  2.4461e+00,
         4.0048e-02, -4.1047e-01, -3.0143e-01,  2.0454e+00,  7.0477e-02,
         1.2665e+00,  1.3012e+00,  3.4624e-01,  3.5497e+00, -5.6892e-01,
         1.3078e+00,  2.7435e-01,  1.2711e-01, -1.0073e+00,  9.0564e-01,
         3.1841e+00, -5.2309e-01, -1.8582e+00, -3.5025e+00,  2.7110e-01,
         1.7705e+00,  3.2190e-03,  1.9183e+00,  6.7090e-01])

group norm (computes mean/var only along C dimensions (H & W dimensions are not collapsed))


tensor([ 1.6560,  1.6541,  1.6560,  1.6559,  1.6560, -1.9822, -1.9426, -1.9822,
        -1.9822,  0.0700,  0.0694,  0.0700,  0.0700,  0.0700, -1.1048, -1.0920,
        -1.1048, -1.1048, -0.4418, -0.4417, -0.4418, -0.4418,  0.3042, -0.4418,
        -0.4417,  0.3041, -0.4418,  1.7377,  1.7372,  1.7377,  1.7376, -1.0781,
         1.7376,  1.7373, -1.0781,  1.7377, -1.0334,  1.7502,  1.7502,  1.7502,
         1.7502, -1.0334,  1.7494, -1.0334,  1.7500,  2.5037, -1.1869, -1.1870,
        -1.1870, -1.1870,  2.5037, -1.1859,  2.5038, -1.1868])

Output differences: 
Between pytorch group norm and the custom group norm with H&W collapsing: 0.00000
Between pytorch group norm and the custom group norm without H&W collapsing: -0.48400
