In [56]:
from torch import nn, Tensor
import torch

# just for learning already available at nn.BatchNorm1d
class MiniBatchNormalization(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.mean = 0
        self.var = 0
        self.exp_rate = 0.99
        self.gamma = nn.Parameter(torch.zeros(input_dim))
        self.beta = nn.Parameter(torch.zeros(input_dim))
        self.is_training = True

    def forward(self, inp: Tensor):
        if self.is_training:
            _mean = torch.mean(inp, dim=1, keepdim=True)
            _var = torch.var(inp, dim=1, keepdim=True)
            self.mean = self.exp_rate * self.mean + (1 - self.exp_rate) * _mean
            self.var = self.exp_rate * self.var + (1 - self.exp_rate) * _var
        else:
            _mean = self.mean
            _var = self.var

        normed = (inp - _mean) / (torch.sqrt(_var) + (10 ** -8))
        return self.gamma * normed + self.beta