In [1]:
import numpy as np
import torch


N = 100
TRAIN_SIZE = 70
VAL_SIZE = N - TRAIN_SIZE


class CustomBatchNorm(torch.nn.Module):
    def __init__(
            self,
            num_features: int,
            eps: float = 0.00001,
            momentum: float = 0.1,
            affine: bool = True,
            track_running_stats: bool = True,
    ):
        super().__init__()
        # self._num_features: int = num_features
        self._eps: int = eps
        self._weights = torch.nn.Parameter(
            torch.ones(num_features),
            requires_grad=affine
        )
        self._bias = torch.nn.Parameter(
            torch.zeros(num_features),
            requires_grad=affine
        )
        self._momentum: tp.Optional[float] = momentum
        self._track_running_stats: bool = track_running_stats
        self._running_stats: tp.Dict[str, tp.Optional[torch.Tensor]] = {
            "mean": None,
            "var": None,
            "num_batches_tracked": 0 if track_running_stats else None,
        }

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        dims_to_reduce: tp.Tuple[int] = x.dim_order()[:1] + x.dim_order()[2:]
        batch_stats: tp.Dict[str, tp.Union[torch.Tensor, int]] = {
            "mean": x.mean(dim=dims_to_reduce, keepdim=True),
            "var": x.var(dim=dims_to_reduce, keepdim=True),
            "size": x.size()[0]
        }

        if self.training and self._track_running_stats:
            exponential_average_factor: float
            if self._momentum is None:
                exponential_average_factor = 1.0 / (self._running_stats["num_batches_tracked"] + 1)
            else:
                exponential_average_factor = self._momentum

            for stat_type in ["mean", "var"]:
                if self._running_stats[stat_type] is None:
                    self._running_stats[stat_type] = batch_stats[stat_type].clone()
                else:
                    delta = batch_stats[stat_type] - self._running_stats[stat_type]
                    self._running_stats[stat_type] = self._running_stats[stat_type] + delta * exponential_average_factor  # may be it is bettor to use .add_ instead
            self._running_stats["num_batches_tracked"] += 1

            x_normed = (x - batch_stats["mean"]) / ((batch_stats["var"] + self._eps) ** 0.5)
        elif (not self.training) and self._track_running_stats:
            x_normed = (x - self._running_stats["mean"]) / ((self._running_stats["var"] + self._eps) ** 0.5)
        else:  # if not self._track_running_stats:
            x_normed = (x - batch_stats["mean"]) / ((batch_stats["var"] + self._eps) ** 0.5)

        for param_name in ["_weights", "_bias"]:
            view_args = tuple(  # example: (1, -1, 1, 1)
                map(
                    lambda i: 1 if i in dims_to_reduce else -1,
                    range(max(1, dims_to_reduce[-1]) + 1)
                )
            )
            getattr(self, param_name).view(view_args)
        return x_normed * self._weights + self._bias

    def _backward(self):
        pass


X = torch.rand(N, 5)
y = (X.mean(dim=1, keepdim=True) > 0.5).float()

X_train, X_val = X[:TRAIN_SIZE], X[TRAIN_SIZE:]
y_train, y_val = y[:TRAIN_SIZE], y[TRAIN_SIZE:]

l2 = torch.nn.Linear(5, 3)
net = torch.nn.Sequential(
    CustomBatchNorm(num_features=5, track_running_stats=False),
    # torch.nn.BatchNorm1d(num_features=5, track_running_stats=False),
    torch.nn.Linear(5, 3),
    CustomBatchNorm(num_features=3, track_running_stats=False),
    # torch.nn.BatchNorm1d(num_features=3, track_running_stats=False),
    torch.nn.Sigmoid(),
    torch.nn.Linear(3, 1),
    CustomBatchNorm(num_features=1, track_running_stats=False),
    # torch.nn.BatchNorm1d(num_features=1, track_running_stats=False),
    torch.nn.Sigmoid(),
)
net.train()

optimizer = torch.optim.SGD(net.parameters(), lr=0.9)
loss = torch.nn.BCELoss()

In [None]:
k = 100
for i in range(50000):
    if i % k == 0:
        print(f"{i}th epoch")
    net.train()
    optimizer.zero_grad()
    y_pred = net(X_train)
    loss_value = loss(y_pred, y_train)
    loss_value.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % k == 0:
        print("train_loss:", loss_value.cpu().item())
    net.eval()
    y_pred = net(X_val)
    loss_value = loss(y_pred, y_val)
    if i % k == 0:
        print("val_loss:", loss_value.cpu().item())
        print()