In [1]:
import torch as th
from torch import nn
from torch.nn import functional as th_f

In [2]:
kernel_size = 3

In [3]:
x = th.randn(2, 3, 16, 16)
b, c, w, h = x.size()

In [7]:
out = th_f.unfold(x, 3, stride=1, padding=1)

In [8]:
out.size()

torch.Size([2, 27, 256])

In [9]:
16*16

256

In [19]:
out_channel = out.view(b, c, kernel_size**2, -1)
print(out_channel.size())

torch.Size([2, 3, 9, 196])


In [11]:
x_recon = th_f.fold(out, (w, h), kernel_size)

In [12]:
x_recon.size()

torch.Size([2, 3, 16, 16])

In [26]:
from typing import Optional


class LocalGroupNorm(nn.Module):
    def __init__(self, channels: int, num_groups: int, kernel_size: int, epsilon: float = 1e-5, affine: bool = True) -> None:
        super().__init__()
        assert channels % num_groups == 0
        
        self.__channels = channels
        self.__num_groups = num_groups
        self.__kernel_size = kernel_size
        self.__epsilon = epsilon
        self.__affine = affine
        
        self.weight: Optional[nn.Parameter] = nn.Parameter(th.ones(1, channels, 1, 1)) if self.__affine else None
        self.bias: Optional[nn.Parameter] = nn.Parameter(th.zeros(1, channels, 1, 1)) if self.__affine else None
    
    def forward(self, x: th.Tensor) -> th.Tensor:
        b, c, h, w = x.size()
        
        out = th_f.unfold(x, self.__kernel_size, padding=1)
        out = out.view(b, self.__num_groups, -1, kernel_size**2, h * w)
        
        mean = th.mean(out, dim=[2, 3]).view(b, self.__num_groups, 1, h, w)
        var = th.var(out, dim=[2, 3], unbiased=True).view(b, self.__num_groups, 1, h, w)
        
        out = (x.view(b, self.__num_groups, -1, h, w) - mean) / th.sqrt(var + self.__epsilon)
        
        out = out.view(b, c, h, w)
        
        if self.__affine:
            out = out * self.weight + self.bias
        
        return out

In [27]:
x = th.randn(2, 4, 16, 16)

In [28]:
lgn = LocalGroupNorm(4, 2, 3)

In [29]:
out = lgn(x)

In [30]:
out.size()

torch.Size([2, 4, 16, 16])