In [55]:
import typing as tp
from collections import OrderedDict

import numpy as np
import torch
from matplotlib import pyplot as plt
from torchvision.datasets import MNIST


N = 1024 * 16
VAL_SIZE = 512
TRAIN_SIZE = N - VAL_SIZE


class Flatten(torch.nn.Module):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.view(x.shape[0], -1)


class CustomBatchNorm(torch.nn.Module):
    def __init__(
            self,
            num_features: int,
            eps: float = 1e-05,
            momentum: tp.Optional[float] = 0.1,
            affine: bool = True,
            track_running_stats: bool = True,
    ):
        super().__init__()
        self._num_features: int = num_features
        self._eps: float = eps
        self._momentum: float = momentum
        self._params: tp.Dict[str, torch.nn.Parameter] = torch.nn.ParameterDict(
            {
                "weights": torch.nn.Parameter(
                    torch.ones(num_features),
                    requires_grad=affine,
                ),
                "bias": torch.nn.Parameter(
                    torch.zeros(num_features),
                    requires_grad=affine
                ),
            }
        )
        self._track_running_stats: bool = track_running_stats
        self._running_stats: tp.Optional[
            tp.Dict[
                str,
                tp.Union[torch.Tensor, int]
            ]
        ] = {
            "mean": torch.zeros(num_features),
            "var": torch.zeros(num_features),
            "count": 0,
        } if track_running_stats else None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x_dim_order: tp.Tuple[int] = tuple(range(x.ndim))
        dims_to_be_reduced: tp.Tuple[int] = x_dim_order[:1] + x_dim_order[2:]
        batch_mean: torch.Tensor = x.mean(
            dim=dims_to_be_reduced,
            keepdim=True,
        )
        batch_var: torch.Tensor = x.var(
            dim=dims_to_be_reduced,
            correction=0,
            keepdim=True,
        )
        if self.training and self._track_running_stats:
            correction: float = x.shape[0] / max(x.shape[0] - 1, 1)
            beta: float
            if self._momentum is None:
                beta = x.shape[0] / (self._running_stats["count"] + x.shape[0])
            else:
                beta = self._momentum
            self._running_stats["mean"] += beta * (batch_mean.view(-1) - self._running_stats["mean"])
            self._running_stats["var"] += beta * (correction * batch_var.view(-1) - self._running_stats["var"])
            self._running_stats["count"] += x.shape[0]

        dims_to_be_unreduced = (1, -1) + (1,) * (x.ndim - 2)
        shift: torch.Tensor
        scale: torch.Tensor
        if self.training or (not self._track_running_stats):
            shift, scale = batch_mean, (batch_var + self._eps) ** 0.5
        else:
            shift = self._running_stats["mean"].view(dims_to_be_unreduced)
            scale = (self._running_stats["var"].view(dims_to_be_unreduced) + self._eps) ** 0.5

        x_normed: torch.Tensor = (x - shift) / scale
        weights = self._params["weights"].view(dims_to_be_unreduced)
        bias = self._params["bias"].view(dims_to_be_unreduced)
        return x_normed * weights + bias


mnist = MNIST("./", download=True)
X: torch.Tensor = mnist.data.view(-1, 1, 28, 28)[:N] / 255.
y: torch.Tensor = mnist.targets[:N]

X_train, y_train = X[:TRAIN_SIZE], y[:TRAIN_SIZE]
X_val, y_val = X[TRAIN_SIZE:], y[TRAIN_SIZE:]

In [56]:
device: torch.device = torch.device("cpu")

net: torch.nn.Module = torch.nn.Sequential(
    OrderedDict(
        [
            (
                'bn_test_0',
                # torch.nn.BatchNorm2d(
                CustomBatchNorm(
                    num_features=1,
                    eps=0.00001,
                    momentum=0.1,
                    affine=True,
                    track_running_stats=True,
                )
            ),
            (
                'conv_0',
                torch.nn.Conv2d(
                    in_channels=1,
                    out_channels=3,
                    kernel_size=(3, 3),
                    stride=(1, 1),
                    padding=(0, 0),
                    dilation=(1, 1),
                    groups=1,
                    bias=True,
                )
            ),
            ('bn_test_1', CustomBatchNorm(3)),
            # ('bn_test_1', torch.nn.BatchNorm2d(3)),
            (
                'act_1',
                torch.nn.ReLU(inplace=True)
            ),
            (
                'pooling_2',
                torch.nn.MaxPool2d(
                    kernel_size=(4, 4),
                    stride=(4, 4),
                    padding=(0, 0),
                )
            ),
            (
                'flatten_3',
                Flatten()
            ),
            (
                'fc_4',
                torch.nn.Linear(
                    in_features=3*6*6,
                    out_features=10,
                    bias=True,
                )
            ),
            ('bn_test_2', CustomBatchNorm(10)),
            # ('bn_test_2', torch.nn.BatchNorm1d(10)),
            (
                'act_5',
                torch.nn.ReLU(inplace=True)
            ),
            (
                'fc_6',
                torch.nn.Linear(10, 10)
            )
        ]
    )
).to(device)

loss = torch.nn.CrossEntropyLoss()
oprimizer = torch.optim.SGD(net.parameters(), lr=0.1)
batch_size: int = 1024

In [57]:
number_of_epochs: int = 10

for epoch in range(number_of_epochs):
    indexes: np.ndarray = np.random.permutation(TRAIN_SIZE)
    X_train, y_train = X_train[indexes], y_train[indexes]

    net.train()
    for i in range(0, TRAIN_SIZE, batch_size):
        oprimizer.zero_grad()
        X_batch = X_train[i:i + batch_size].to(device)
        y_batch = y_train[i:i + batch_size].to(device)
        y_pred = net(X_batch)
        loss_value = loss(y_pred, y_batch)
        loss_value.backward()
        oprimizer.step()
        oprimizer.zero_grad()
        # print(f"train loss value: {loss_value.cpu().item()}")

    # if epoch % 10 == 0:
    net.eval()
    y_pred = net(X_val)
    loss_value = loss(y_pred, y_val)
    print(f"val loss value: {loss_value.cpu().item()}")

val loss value: 1.7054740190505981
val loss value: 1.3781808614730835
val loss value: 1.1435519456863403
val loss value: 0.944825291633606
val loss value: 0.7943838834762573
val loss value: 0.657880961894989
val loss value: 0.5529645681381226
val loss value: 0.4810158312320709
val loss value: 0.4336111843585968
val loss value: 0.3929685652256012


In [58]:
# Checking the exact values:
x = torch.rand(3, 2, 2)

In [59]:
CustomBatchNorm(2)(x)

tensor([[[-0.3376,  0.3614],
         [-1.4641,  0.3514]],

        [[ 1.9223, -0.1592],
         [ 0.5904,  1.5894]],

        [[-0.4256, -1.3613],
         [-0.1554, -0.9117]]], grad_fn=<AddBackward0>)

In [60]:
torch.nn.BatchNorm1d(2)(x)

tensor([[[-0.3376,  0.3614],
         [-1.4641,  0.3514]],

        [[ 1.9223, -0.1592],
         [ 0.5904,  1.5894]],

        [[-0.4256, -1.3613],
         [-0.1554, -0.9117]]], grad_fn=<NativeBatchNormBackward0>)