# BatchNorm2D

In [None]:
import torch

In [None]:
class BatchNorm2D(torch.nn.Module):
  def __init__(self, n_feat, eps=1e-5, momentum=0.1):
    super(BatchNorm2D, self).__init__()
    self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    self.eps = eps
    self.m = momentum
    self.beta = torch.nn.Parameter(torch.zeros(n_feat)).to(self.device)
    self.gamma = torch.nn.Parameter(torch.ones(n_feat)).to(self.device)
    self.register_buffer('mu', torch.zeros(n_feat).to(self.device))
    self.register_buffer('sigma2', torch.ones(n_feat).to(self.device))

  def forward(self, x):
    if self.training:
      n = x.numel() / x.size(1)
      mean = x.mean(dim=[0,2,3])
      var = x.var(dim=[0,2,3, unbiased=False)
      with torch.no_grad():
        self.mu = (1-self.m) * self.mu * n / (n-1) + self.m * mean
        self.sigma2 = (1-self.m) * self.sigma2 + self.m * var
    else:
      mean = self.mu
      var = self.sigma2

    z = (x - mean[None, :, None, None]) / torch.sqrt(var[None, :, None, None] + self.eps)
    return self.gamma[None, :, None, None] * z + self.beta[None, :, None, None]
    