In [1]:
from d2l import torch as d2l
import torch
import torch.nn as nn
import torch.nn.functional as F

In [65]:
def batch_norm(x, gamma, beta, moving_mean, moving_var, eps, momentum):
    if not (torch.is_grad_enabled()):
        return (x - moving_mean)/torch.sqrt(moving_var + eps)
    else:
        if (len(x.shape) == 2):
#             print('linear case')
            mean = x.mean()
            var = x.var()
            moving_mean = momentum*moving_mean + (1 - momentum)*mean
            moving_var = momentum*moving_var + (1 - momentum)*var
            y = gamma*(x - mean)/(var + eps) + beta
            return y, moving_mean, moving_var
        else:
#             print('conv case')
            mean = x.mean(dim=(0, 2, 3), keepdim=True)
            var = x.var(dim=(0, 2, 3), keepdim=True)
            moving_mean = momentum*moving_mean + (1 - momentum)*mean
            moving_var = momentum*moving_var + (1 - momentum)*var
            y = gamma*(x - mean)/(var + eps) + beta
            return y, moving_mean, moving_var

In [64]:
# exp = torch.rand((5, 10))
# batch_norm(exp, 1, 1, 1, 1, 1, 1)
# exp.shape in range(2, 4)
# exp - exp.mean(dim=(2, 3)).reshape(5, 1, 1, 1)
# exp.mean(dim=(0, 2, 3), keepdim=True).shape, exp.mean(dim=(0, 2, 3)).shape

In [66]:
class BatchNorm(nn.Module):
    def __init__(self, num_features, num_dims):
        super().__init__()
        
        if (num_dims == 2):
            shape = (1, num_features)
        else:
            shape = (1, num_features, 1, 1)
        
        self.gamma = nn.Parameter(torch.ones(shape))
        self.beta = nn.Parameter(torch.zeros(shape))
        
        self.moving_mean = torch.zeros(shape)
        self.moving_var = torch.zeros(shape)
    def forward(self, X):
        y, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, 
                                                self.moving_mean, self.moving_var, 1e-5, 0.9)
        return y

In [68]:
class LeNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5, padding=2), BatchNorm(6, 4),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5), BatchNorm(6, 4),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Flatten(),
            nn.Linear(16*5*5, 100), BatchNorm(100, 2),
            nn.Linear(100, 84), BatchNorm(84, 2),
            nn.Linear(84, 10),
        )
    def forward(self, X):
        X = X.view(-1, 1, 28, 28)
        Y = self.features(X);
        
        return Y

In [69]:
X = torch.rand((1, 1, 28, 28))
net = LeNet()
for layer in net.features:
    X = layer(X)
    print(f'{X.shape}')

torch.Size([1, 6, 28, 28])
torch.Size([1, 6, 28, 28])
torch.Size([1, 6, 14, 14])
torch.Size([1, 16, 10, 10])


RuntimeError: The size of tensor a (6) must match the size of tensor b (16) at non-singleton dimension 1