In [2]:
import torch
from torch.utils.data import DataLoader, Dataset

#### dataset

In [189]:
features = 68
dataset_size = 160

class DummyDataset(Dataset):
    def __init__(self):
        self.data = torch.randn(dataset_size, features)
        self.target = torch.randn(dataset_size, features)  # torch.randint(features, (dataset_size,))

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, idx):
        return self.data[idx], self.target[idx]

dataset = DummyDataset()
dataloader = DataLoader(dataset, 16)

### models

##### basic

In [16]:
import torch.nn as nn

eps = 1e-05
bn = nn.BatchNorm1d(features, affine=False, eps=eps)

for xb, yb  in dataloader:
    break

In [17]:
y = bn(xb)

In [25]:
mean = xb.mean(dim=0, keepdim=True)
std = torch.sqrt(xb.var(dim=0, unbiased=False) + eps)

y_h = (xb - mean)/ std

In [34]:
miss = ~ torch.isclose(y, y_h)

In [37]:
y[miss]

tensor([0.0003])

In [38]:
y_h[miss]

tensor([0.0003])

In [39]:
torch.allclose(y, y_h, rtol=1e-05, atol=1e-07)

True

##### running mean and std

In [159]:
import torch.nn as nn

class BatchNorm1d(nn.Module):

    def __init__(self, features, eps=1e-05):
        super().__init__()
        self.features = features
        self.eps = eps
        self.momentum = 0.1
        self.istrain= True

        self.register_buffer("running_mean", torch.zeros(features))
        self.register_buffer("running_var", torch.ones(features))

    def eval(self):
        self.istrain = False

    def train(self):
        self.istrain = True
        
    def forward(self, x):

        if self.istrain:
            mean = x.mean(dim=0)
            var = x.var(dim=0, correction=0)

            out = (x - mean)/ torch.sqrt(var + self.eps)

            var = x.var(dim=0, correction=1)

            self.running_mean = self.running_mean * (1-self.momentum) + mean * self.momentum 
            self.running_var = self.running_var * (1 - self.momentum )+ var * self.momentum

            return out

        return (x - self.running_mean)/ torch.sqrt(self.running_var + self.eps)

        
        
        
        

In [160]:
eps = 1e-05

bn_original = nn.BatchNorm1d(features, affine=False, eps=eps)
bn_replicate = BatchNorm1d(features)

In [161]:
y = bn_original(xb)
y_h = bn_replicate(xb)

In [163]:
for xb, yb in dataloader:
    y = bn_original(xb)
    y_h = bn_replicate(xb)
    if not torch.allclose(y, y_h, atol=1e05):
        print("Output is not matching")
        break
running_mean_check = torch.allclose(
    bn_original.running_mean,
    bn_replicate.running_mean,
    # atol=1e-05
)
running_var_check = torch.allclose(
    bn_original.running_var,
    bn_replicate.running_var,
    # atol=1e-05
) 

bn_original.eval()
bn_replicate.eval()

y = bn_original(xb)
y_h = bn_replicate(xb)
eval_test = torch.allclose(
    y,
    y_h,
    # atol=1e-07
)

for xb, yb in dataloader:
    y = bn_original(xb)
    y_h = bn_replicate(xb)
    if not torch.allclose(
        y, 
        y_h, 
        # atol=1e-05
    ):
        print("Output is not matching")
        break


print(f"comparing \n  {running_mean_check=}\n  {running_var_check=}\n  {eval_test=}")

comparing 
  running_mean_check=True
  running_var_check=True
  eval_test=True


##### full model 

In [206]:
import torch.nn as nn

class BatchNorm1d(nn.Module):

    def __init__(self, features, eps=1e-05):
        super().__init__()
        self.features = features
        self.eps = eps
        self.momentum = 0.1
        # self.register_buffer("momentum", 0./1)
        self.istrain= True
    
        self.register_buffer("running_mean", torch.zeros(features))
        self.register_buffer("running_var", torch.ones(features))

        self.weight = torch.ones(features, requires_grad=True)
        self.bias = torch.zeros(features,  requires_grad=True)

    def eval(self):
        self.istrain = False

    def train(self):
        self.istrain = True
        
    def forward(self, x):

        if self.istrain:
            mean = x.mean(dim=0)
            var = x.var(dim=0, correction=0)

            out = (x - mean)/ torch.sqrt(var + self.eps)

            var = x.var(dim=0, correction=1)

            self.running_mean = self.running_mean * (1-self.momentum) + mean * self.momentum 
            self.running_var = self.running_var * (1 - self.momentum )+ var * self.momentum

        else:
            out = (x - self.running_mean)/ torch.sqrt(self.running_var + self.eps)

        return out * self.weight + self.bias

    def parameters(self):
        for i in [self.weight, self.bias]:
            yield i
        

In [214]:
eps = 1e-05

bn_original = nn.BatchNorm1d(features, affine=True, eps=eps)
bn_replicate = BatchNorm1d(features)

In [215]:
import itertools

criterion = nn.MSELoss()

parameters = itertools.chain(bn_original.parameters(), bn_replicate.parameters())
optimizer = torch.optim.SGD(parameters, lr=0.01)

In [223]:
for xb, yb in dataloader:

    optimizer.zero_grad()
    
    y = bn_original(xb)
    y_h = bn_replicate(xb)
    

    loss = criterion(y, yb) 
    loss_h = criterion(y_h, yb)
    
    loss.backward()
    loss_h.backward()
    
    optimizer.step()
    

In [222]:
weight_check = torch.allclose(
    bn_original.weight,
    bn_replicate.weight,
    
)
bias_check = torch.allclose(
    bn_original.bias,
    bn_replicate.bias,    
) 
print(f"comparing \n  {weight_check=}\n  {bias_check=}")

comparing 
  weight_check=True
  bias_check=True
