In [1]:
import torch
from torch import nn
from torch.nn import functional as F

# GroupNorm

https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html#torch.nn.GroupNorm

$$
y = \frac {x - E[x]} {\sqrt {Var[x] + \epsilon}} * \gamma + \beta
$$

In [2]:
norm = nn.GroupNorm(num_groups=2, num_channels=4)
norm

GroupNorm(2, 4, eps=1e-05, affine=True)

In [5]:
# gamma
norm.weight.data

tensor([1., 1., 1., 1.])

In [6]:
# beta
norm.bias.data

tensor([0., 0., 0., 0.])

In [7]:
# eps
norm.eps

1e-05

# 运算

In [8]:
x = torch.arange(100.).reshape(1, 4, 5, 5)
x

tensor([[[[ 0.,  1.,  2.,  3.,  4.],
          [ 5.,  6.,  7.,  8.,  9.],
          [10., 11., 12., 13., 14.],
          [15., 16., 17., 18., 19.],
          [20., 21., 22., 23., 24.]],

         [[25., 26., 27., 28., 29.],
          [30., 31., 32., 33., 34.],
          [35., 36., 37., 38., 39.],
          [40., 41., 42., 43., 44.],
          [45., 46., 47., 48., 49.]],

         [[50., 51., 52., 53., 54.],
          [55., 56., 57., 58., 59.],
          [60., 61., 62., 63., 64.],
          [65., 66., 67., 68., 69.],
          [70., 71., 72., 73., 74.]],

         [[75., 76., 77., 78., 79.],
          [80., 81., 82., 83., 84.],
          [85., 86., 87., 88., 89.],
          [90., 91., 92., 93., 94.],
          [95., 96., 97., 98., 99.]]]])

In [9]:
x.shape

torch.Size([1, 4, 5, 5])

In [10]:
norm(x)

tensor([[[[-1.6977, -1.6285, -1.5592, -1.4899, -1.4206],
          [-1.3513, -1.2820, -1.2127, -1.1434, -1.0741],
          [-1.0048, -0.9355, -0.8662, -0.7969, -0.7276],
          [-0.6583, -0.5890, -0.5197, -0.4504, -0.3811],
          [-0.3118, -0.2425, -0.1732, -0.1039, -0.0346]],

         [[ 0.0346,  0.1039,  0.1732,  0.2425,  0.3118],
          [ 0.3811,  0.4504,  0.5197,  0.5890,  0.6583],
          [ 0.7276,  0.7969,  0.8662,  0.9355,  1.0048],
          [ 1.0741,  1.1434,  1.2127,  1.2820,  1.3513],
          [ 1.4206,  1.4899,  1.5592,  1.6285,  1.6977]],

         [[-1.6977, -1.6285, -1.5592, -1.4899, -1.4206],
          [-1.3513, -1.2820, -1.2127, -1.1434, -1.0741],
          [-1.0048, -0.9355, -0.8662, -0.7969, -0.7276],
          [-0.6583, -0.5890, -0.5197, -0.4504, -0.3811],
          [-0.3118, -0.2425, -0.1732, -0.1039, -0.0346]],

         [[ 0.0346,  0.1039,  0.1732,  0.2425,  0.3118],
          [ 0.3811,  0.4504,  0.5197,  0.5890,  0.6583],
          [ 0.7276,  0.79