In [1]:
'''
Modified from https://github.com/ptrblck/pytorch_misc/blob/master/batch_norm_manual.py

@author: Xin Dong
'''

import torch
import torch.nn as nn

class MyBatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MyBatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            print('mean size:', mean.size())
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            print('variance size:', var.size())
            n = input.numel() / input.size(1)
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                print('running_mean size:', self.running_mean.size())
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
                print('running_var size:', self.running_var.size())
        else:
            mean = self.running_mean
            var = self.running_var

        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input
    
    
    
def compare_bn(bn1, bn2):
    err = False
    if not torch.allclose(bn1.running_mean, bn2.running_mean):
        print('Diff in running_mean: {} vs {}'.format(
            bn1.running_mean, bn2.running_mean))
        err = True

    if not torch.allclose(bn1.running_var, bn2.running_var):
        print('Diff in running_var: {} vs {}'.format(
            bn1.running_var, bn2.running_var))
        err = True

    if bn1.affine and bn2.affine:
        if not torch.allclose(bn1.weight, bn2.weight):
            print('Diff in weight: {} vs {}'.format(
                bn1.weight, bn2.weight))
            err = True

        if not torch.allclose(bn1.bias, bn2.bias):
            print('Diff in bias: {} vs {}'.format(
                bn1.bias, bn2.bias))
            err = True

    if not err:
        print('All parameters are equal!')
        
        
my_bn = MyBatchNorm2d(3, affine=True)
bn = nn.BatchNorm2d(3, affine=True)

my_bn.train()
bn.train()
for i in range(10):
    x = torch.randn(10, 3, 100, 100)
    out1 = my_bn(x)
    out2 = bn(x)
    print('train:', torch.allclose(out1, out2))
    
compare_bn(my_bn, bn)

my_bn.eval()
bn.eval()
x = torch.randn(10, 3, 100, 100)
out1 = my_bn(x)
out2 = bn(x)
print(torch.allclose(out1, out2))
print('Max diff: ', (out1 - out2).abs().max())

mean size: torch.Size([3])
variance size: torch.Size([3])
running_mean size: torch.Size([3])
running_var size: torch.Size([3])
train: True
mean size: torch.Size([3])
variance size: torch.Size([3])
running_mean size: torch.Size([3])
running_var size: torch.Size([3])
train: True
mean size: torch.Size([3])
variance size: torch.Size([3])
running_mean size: torch.Size([3])
running_var size: torch.Size([3])
train: True
mean size: torch.Size([3])
variance size: torch.Size([3])
running_mean size: torch.Size([3])
running_var size: torch.Size([3])
train: True
mean size: torch.Size([3])
variance size: torch.Size([3])
running_mean size: torch.Size([3])
running_var size: torch.Size([3])
train: True
mean size: torch.Size([3])
variance size: torch.Size([3])
running_mean size: torch.Size([3])
running_var size: torch.Size([3])
train: True
mean size: torch.Size([3])
variance size: torch.Size([3])
running_mean size: torch.Size([3])
running_var size: torch.Size([3])
train: True
mean size: torch.Size([3])


In [10]:
# Init BatchNorm layers
my_bn = MyBatchNorm2d(3, affine=True)
bn = nn.BatchNorm2d(3, affine=True)

compare_bn(my_bn, bn)  # weight and bias should be different
# Load weight and bias
my_bn.load_state_dict(bn.state_dict())
compare_bn(my_bn, bn)

# Run train
for _ in range(10):
    scale = torch.randint(1, 10, (1,)).float()
    bias = torch.randint(-10, 10, (1,)).float()
    x = torch.randn(10, 3, 100, 100) * scale + bias
    out1 = my_bn(x)
    out2 = bn(x)
    compare_bn(my_bn, bn)

    torch.allclose(out1, out2)
    print('Max diff: ', (out1 - out2).abs().max())

# Run eval
my_bn.eval()
bn.eval()
for _ in range(10):
    scale = torch.randint(1, 10, (1,)).float()
    bias = torch.randint(-10, 10, (1,)).float()
    x = torch.randn(10, 3, 100, 100) * scale + bias
    print('**:', x.min(), x.max())
    out1 = my_bn(x)
    out2 = bn(x)
    compare_bn(my_bn, bn)

    torch.allclose(out1, out2)
    print('Max diff: ', (out1 - out2).abs().max())

All parameters are equal!
All parameters are equal!
All parameters are equal!
Max diff:  tensor(7.1526e-07, grad_fn=<MaxBackward1>)
All parameters are equal!
Max diff:  tensor(4.7684e-07, grad_fn=<MaxBackward1>)
All parameters are equal!
Max diff:  tensor(4.7684e-07, grad_fn=<MaxBackward1>)
All parameters are equal!
Max diff:  tensor(4.7684e-07, grad_fn=<MaxBackward1>)
All parameters are equal!
Max diff:  tensor(7.1526e-07, grad_fn=<MaxBackward1>)
All parameters are equal!
Max diff:  tensor(1.4305e-06, grad_fn=<MaxBackward1>)
All parameters are equal!
Max diff:  tensor(4.7684e-07, grad_fn=<MaxBackward1>)
All parameters are equal!
Max diff:  tensor(4.7684e-07, grad_fn=<MaxBackward1>)
All parameters are equal!
Max diff:  tensor(9.5367e-07, grad_fn=<MaxBackward1>)
All parameters are equal!
Max diff:  tensor(7.1526e-07, grad_fn=<MaxBackward1>)
**: tensor(-27.1325) tensor(8.6804)
All parameters are equal!
Max diff:  tensor(4.7684e-07, grad_fn=<MaxBackward1>)
**: tensor(-43.7317) tensor(39.4

In [2]:
class SM_BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(SM_BatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([2, 3])
            print('mean size:', mean.size())
            # use biased var in train
            var = input.var([0, 2, 3], unbiased=False)
            print('variance size:', var.size())
            n = input.numel() / input.size(1)
            with torch.no_grad():
                for i in range(mean.size(0)):
                    self.running_mean = exponential_average_factor * mean[i]\
                        + (1-exponential_average_factor) * self.running_mean
                print('running_mean size:', self.running_mean.size())
#                 self.running_mean = exponential_average_factor * mean\
#                     + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
                print('running_var size:', self.running_var.size())
        else:
            mean = self.running_mean
            var = self.running_var

#         input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input
    
    
sm_bn = SM_BatchNorm2d(3, affine=True)
x = torch.randn(10, 3, 100, 100)
out_sm = sm_bn(x)

mean size: torch.Size([10, 3])
variance size: torch.Size([3])
running_mean size: torch.Size([3])
running_var size: torch.Size([3])


RuntimeError: The size of tensor a (100) must match the size of tensor b (3) at non-singleton dimension 4

In [10]:
class SV_BatchNorm2d(nn.BatchNorm2d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(SV_BatchNorm2d, self).__init__(
            num_features, eps, momentum, affine, track_running_stats)

    def forward(self, input):
        self._check_input_dim(input)

        exponential_average_factor = 0.0

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  # use cumulative moving average
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  # use exponential moving average
                    exponential_average_factor = self.momentum

        # calculate running estimates
        if self.training:
            mean = input.mean([0, 2, 3])
            print('mean size:', mean.size())
            # use biased var in train
            var = input.var([2, 3], unbiased=False)
            print('variance size:', var.size())
            n = input.numel() / (input.size(1) * input.size(0))
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                # update running_var with unbiased var
                for i in range(var.size(0)):
                    self.running_var = exponential_average_factor * var[i] * n / (n - 1)\
                        + (1 - exponential_average_factor) * self.running_var
#                 self.running_var = exponential_average_factor * var * n / (n - 1)\
#                     + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

#         input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
        input = (input - mean[None, :, None, None]) / (torch.sqrt(var[:, :, None, None] + self.eps))
        if self.affine:
            input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]

        return input
    
    
sv_bn = SV_BatchNorm2d(3, affine=True)
x = torch.randn(10, 3, 100, 100)
out_sv = sv_bn(x)

mean size: torch.Size([3])
variance size: torch.Size([10, 3])


In [None]:
class BatchNorm2d(nn.BatchNorm2d):
    def forward(self, x):
        self._check_input_dim(x)
        y = x.transpose(0,1)
        return_shape = y.shape
        y = y.contiguous().view(x.size(1), -1)
        mu = y.mean(dim=1)
        sigma2 = y.var(dim=1)
        if self.training is not True:
            y = y - self.running_mean.view(-1, 1)
            y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
        else:
            if self.track_running_stats is True:
                with torch.no_grad():
                    self.running_mean = (1-self.momentum)*self.running_mean + self.momentum*mu
                    self.running_var = (1-self.momentum)*self.running_var + self.momentum*sigma2
            y = y - mu.view(-1,1)
            y = y / (sigma2.view(-1,1)**.5 + self.eps)

        y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)
        return y.view(return_shape).transpose(0,1)
