# Testing how PyTorch Group Normalization layers compute mean & variance (whether they collapse image dimensions or not)

In [1]:
import numpy as np 
import torch 

In [2]:
# define custom function for group normalization
def group_norm2(input, num_groups, weight=None, bias=None, eps=1e-05, reduce=True,verbose=True):

    batch_size, num_channels, height, width = input.size() 
    if verbose:
        print(f'input size (BCHW): [{input.shape}]')

    input = input.view(batch_size, num_groups, num_channels//num_groups, height, width)

    if verbose:
        print(f'after grouping: [{input.shape}]')

    if reduce:
        mean = torch.mean(input, dim=(2,3,4), keepdim=True)
        var = torch.var(input, dim=(2,3,4), unbiased=False, keepdim=True)
        if verbose:
            print('mean/var shape: ',mean.shape)
    else:
        mean = torch.mean(input, dim=2, keepdim=True)
        var = torch.var(input, dim=2, unbiased=False, keepdim=True)
        if verbose:
            print('mean/var shape: ',mean.shape)
    
    input = (input - mean) / torch.sqrt(var+eps)

    input = input.view(batch_size, num_channels, height, width)

    if weight is None:
        weight = torch.tensor([1]).repeat(num_channels)
    if bias is None:
        bias = torch.tensor([0]).repeat(num_channels)

    # transform weight & bias [C] -> [N, C, H, W]
    weight = weight.unsqueeze(0).unsqueeze(2).unsqueeze(3) # [1, C, 1, 1]
    weight = weight.expand(batch_size, num_channels, height, width) # [N, C, H, W]
    bias = bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) # [1, C, 1, 1]
    bias = bias.expand(batch_size, num_channels, height, width) # [N, C, H, W]

    input = input*weight + bias 

    return input

In [3]:
N, C, H, W = 1, 6, 3, 3 # batch size, channel, height, width

n_groups = 3

input = torch.randn(N,C,H,W)

# random weights
gamma = torch.randn(C)
beta = torch.randn(C)

out =group_norm = torch.nn.functional.group_norm(
    input,
    num_groups = n_groups,
    weight = gamma,
    bias= beta,
    eps = 1e-05
)

print('Group norm reducing H & W dimensions')
out2= group_norm2(input, num_groups= n_groups, weight = gamma, bias = beta, eps=1e-05)

print('\nGroup norm without reducing H & W dimensions')
out3= group_norm2(input, num_groups= n_groups, weight = gamma, bias = beta, eps=1e-05, reduce=False)

Group norm reducing H & W dimensions
input size (BCHW): [torch.Size([1, 6, 3, 3])]
after grouping: [torch.Size([1, 3, 2, 3, 3])]
mean/var shape:  torch.Size([1, 3, 1, 1, 1])

Group norm without reducing H & W dimensions
input size (BCHW): [torch.Size([1, 6, 3, 3])]
after grouping: [torch.Size([1, 3, 2, 3, 3])]
mean/var shape:  torch.Size([1, 3, 1, 3, 3])


In [4]:
print('pytorch group norm')
display(out.reshape((-1)))

print('group norm (computes mean/var by collapsing H & W dimensions)')
display(out2.reshape((-1)))

print('group norm (computes mean/var only along C dimensions (H & W dimensions are not collapsed))')
display(out3.reshape((-1)))

print('Output differences: ')
print(f'Between pytorch group norm and the custom group norm with H&W collapsing: {(out-out2).sum()}')
print(f'Between pytorch group norm and the custom group norm without H&W collapsing: {(out-out3).sum()}')

pytorch group norm


tensor([-0.9086,  0.5706,  0.3820,  0.9110, -1.0005,  1.7645, -0.3582, -0.3094,
        -1.4918,  1.7138,  1.6224,  1.4765,  0.7816,  1.8332,  1.9592,  1.9472,
         0.7319,  1.8505,  0.1094,  0.1917,  0.3087,  0.3940,  0.2060,  0.1634,
         0.3426,  0.2738,  0.3738, -1.4001, -1.2208, -0.4686, -1.0652, -0.6814,
        -0.5616, -1.3275, -0.5260, -0.0893, -0.5445, -0.4432, -0.5275, -0.5387,
        -0.6562, -0.4527, -0.5573, -0.5553, -0.5849, -1.0185, -0.8639, -0.6323,
        -0.9397, -0.9468, -1.1695, -0.9483, -0.6847, -0.9567])

group norm (computes mean/var by collapsing H & W dimensions)


tensor([-0.9086,  0.5706,  0.3820,  0.9110, -1.0005,  1.7645, -0.3582, -0.3094,
        -1.4918,  1.7138,  1.6224,  1.4765,  0.7816,  1.8332,  1.9592,  1.9472,
         0.7319,  1.8505,  0.1094,  0.1917,  0.3087,  0.3940,  0.2060,  0.1634,
         0.3426,  0.2738,  0.3738, -1.4001, -1.2208, -0.4686, -1.0652, -0.6814,
        -0.5616, -1.3275, -0.5260, -0.0893, -0.5445, -0.4432, -0.5275, -0.5387,
        -0.6562, -0.4527, -0.5573, -0.5553, -0.5849, -1.0185, -0.8639, -0.6323,
        -0.9397, -0.9468, -1.1695, -0.9483, -0.6847, -0.9567])

group norm (computes mean/var only along C dimensions (H & W dimensions are not collapsed))


tensor([-1.0115,  1.1384,  1.1107, -1.0115, -1.0114,  1.1384,  1.1383, -1.0115,
        -1.0115,  1.1727,  2.0070,  1.9963,  1.1727,  1.1727,  2.0071,  2.0070,
         1.1727,  1.1727,  0.1224,  0.1224,  0.3235,  0.3235,  0.3235,  0.3235,
         0.3235,  0.3235,  0.3235, -1.4775, -1.4774, -0.5281, -0.5281, -0.5282,
        -0.5282, -0.5282, -0.5281, -0.5281, -0.5867, -0.4514, -0.4514, -0.5867,
        -0.5867, -0.5867, -0.5867, -0.4514, -0.5867, -1.0133, -0.7042, -0.7042,
        -1.0133, -1.0133, -1.0133, -1.0133, -0.7042, -1.0133])

Output differences: 
Between pytorch group norm and the custom group norm with H&W collapsing: -1.1175870895385742e-07
Between pytorch group norm and the custom group norm without H&W collapsing: -0.6618859767913818
