In [1]:
import torch
import torch.nn as nn

class GroupNorm2d(nn.Module):
    def __init__(self, gamma=0.1, beta=0.1, g_size=30, eps=1e-6):
        super().__init__()
        self.gamma = gamma
        self.beta = beta
        self.g_size = g_size
        self.eps = eps

    def forward(self, x):
        b, c, h, w = x.shape
        groups = c // self.g_size
        left = c % self.g_size
        if left == 0:
            #如果可以整除
            x = x.view(b, groups, self.g_size, h, w).contiguous()
            mu = x.mean(dim=[2,3,4]).unsqueeze(2).unsqueeze(3).unsqueeze(4)
            var = x.var(dim=[2,3,4]).unsqueeze(2).unsqueeze(3).unsqueeze(4)
            x = (x - mu) / torch.sqrt(var + self.eps)
            x = self.gamma * x + self.beta
            x = x.view(b, groups * self.g_size, h, w).contiguous()

        else:
            #如果不能整除
            x1, x2 = torch.split(x, [c - left, left], dim=1)

            x1 = x1.view(b, groups, self.g_size, h, w).contiguous()
            mu = x1.mean(dim=[2, 3, 4]).unsqueeze(2).unsqueeze(3).unsqueeze(4)
            var = x1.var(dim=[2, 3, 4]).unsqueeze(2).unsqueeze(3).unsqueeze(4)
            x1 = (x1 - mu) / torch.sqrt(var + self.eps)
            x1 = self.gamma * x1 + self.beta
            x1 = x1.view(b, groups * self.g_size, h, w).contiguous()

            x2 = x2.view(b, 1, left, h, w).contiguous()
            mu = x2.mean(dim=[2, 3, 4]).unsqueeze(2).unsqueeze(3).unsqueeze(4)
            var = x2.var(dim=[2, 3, 4]).unsqueeze(2).unsqueeze(3).unsqueeze(4)
            x2 = (x2 - mu) / torch.sqrt(var + self.eps)
            x2 = self.gamma * x2 + self.beta
            x2 = x2.view(b, 1 * left, h, w).contiguous()

            x = torch.cat((x1, x2), dim=1)

        return x

In [2]:
gn = GroupNorm2d()
x = torch.rand(16, 100, 80, 60)
out = gn(x)
print(out)

tensor([[[[ 2.3825e-01,  2.5032e-01, -5.4962e-02,  ...,  1.3481e-01,
            1.0489e-01,  2.2473e-01],
          [ 2.4581e-01,  1.7981e-01,  1.4824e-01,  ...,  2.6627e-01,
           -6.5136e-02,  1.1843e-02],
          [ 1.5665e-01,  9.1661e-02,  3.3077e-02,  ...,  1.0420e-01,
           -4.5235e-02, -5.3352e-02],
          ...,
          [ 1.9743e-01,  2.3774e-01,  1.4452e-01,  ..., -6.6780e-04,
            2.6675e-01,  1.8721e-01],
          [ 6.1020e-02,  1.8107e-01,  7.6758e-02,  ...,  2.7124e-01,
            9.1800e-02,  6.2228e-02],
          [ 9.0488e-02,  9.4766e-02,  1.9284e-01,  ...,  1.1275e-01,
            1.9793e-01,  2.7298e-01]],

         [[ 9.1280e-04,  1.8919e-01,  2.6979e-01,  ...,  4.2882e-02,
            1.9420e-01,  2.1890e-02],
          [ 3.7759e-02,  2.0600e-01,  4.3851e-02,  ..., -2.6925e-02,
           -5.4148e-02, -1.6984e-02],
          [ 2.5967e-01,  9.4841e-02, -7.5193e-03,  ..., -1.8854e-02,
            2.3971e-01, -9.5049e-03],
          ...,
     