In [1]:
import torch

def batch_norm_2d(X, gamma, beta, moving_mean, moving_var, train_mode, eps=1e-5, momentum=0.1):
    """eps and momentum values are the same as pytorch defaults"""

    if not train_mode:
        X_hat = torch.div(X - moving_mean[None, ::, None, None], torch.sqrt(moving_var[None, ::, None, None] + eps))
    else:
        # we need to set unbiased=False to remove the bezel correction 
        mean = X.mean(dim=(0, 2, 3)) 
        var = X.var(dim=(0, 2, 3), unbiased=False) 
        X_hat = torch.div(X - mean[None, ::, None, None], torch.sqrt(var[None, ::, None, None] + eps))

        # update moving_mean and var
        moving_mean = (1.0 - momentum) * moving_mean + momentum * mean
        moving_var = (1.0 - momentum) * moving_var + momentum * var
    
    norm_batch = X_hat * gamma + beta
    return norm_batch, moving_mean, moving_var



In [2]:
class MyBatchNorm2d(torch.nn.Module):
    def __init__(self, in_channels) -> None:
        super().__init__()
        shape = (1, in_channels, 1, 1)
        self.gamma = torch.nn.Parameter(torch.ones(shape))
        self.beta = torch.nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.zeros(in_channels)
        self.moving_var = torch.zeros(in_channels)

    def forward(self, x):
        # self.training is set by self.eval() and self.train()
        norm, mm, mv = batch_norm_2d(x, self.gamma, self.beta, self.moving_mean, self.moving_var, self.training)
        self.moving_mean = mm
        self.mv = mv
        return norm

In [4]:
# mb_size = 16, 3 channels and 12 x 12 image size
to_norm = torch.randn((2, 3, 4, 4))
shape = (1, 3, 1, 1)
gamma  = torch.ones(shape)
beta  = torch.zeros(shape)
moving_mean = torch.zeros(3)
moving_var = torch.zeros(3)
eps = 1e-5

bn2d = MyBatchNorm2d(3)

In [5]:
# Check if batch norms to the same by hand

import torch.nn.functional as F
a, b, c = batch_norm_2d(X=to_norm, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, train_mode=True)
a_ = bn2d(to_norm)
bns = F.batch_norm(to_norm, weight=gamma, bias=beta, running_mean=moving_mean, running_var=moving_var, training=True)

print(a[0], a_[0], bns[0]) 


tensor([[[ 0.6874, -0.2554, -0.2423,  0.1149],
         [ 1.0755, -0.1256,  1.4470, -0.3999],
         [ 0.3599, -0.0605, -1.5511, -0.2335],
         [ 1.3836, -0.4995, -0.9170,  0.5237]],

        [[-1.6499, -2.2293, -0.1241,  0.1928],
         [-0.4826, -1.3876, -0.8286,  1.0646],
         [ 0.6182,  2.1962, -0.3039, -0.7923],
         [ 1.8773, -0.9112,  0.8893,  0.2331]],

        [[-1.7067, -1.9626, -1.0487,  0.8881],
         [ 0.0134,  0.5152,  0.1008,  0.3373],
         [ 0.5981,  0.6143, -0.9635, -1.1860],
         [-0.4608,  2.1036, -0.3587,  0.9629]]]) tensor([[[ 0.6874, -0.2554, -0.2423,  0.1149],
         [ 1.0755, -0.1256,  1.4470, -0.3999],
         [ 0.3599, -0.0605, -1.5511, -0.2335],
         [ 1.3836, -0.4995, -0.9170,  0.5237]],

        [[-1.6499, -2.2293, -0.1241,  0.1928],
         [-0.4826, -1.3876, -0.8286,  1.0646],
         [ 0.6182,  2.1962, -0.3039, -0.7923],
         [ 1.8773, -0.9112,  0.8893,  0.2331]],

        [[-1.7067, -1.9626, -1.0487,  0.8881],
   

In [6]:
# Copy and Paste from last exercise 
import pytorch_lightning as pl
from typing import Any
import torchmetrics
import torch

class CNN(pl.LightningModule):

    def __init__(self, loss: callable, lr: float, conv_layers, classification_head:torch.nn.Module ,num_classes:int =10) -> None:
        super().__init__()
        self.conv_layers = conv_layers
        self.classification_head = classification_head 
        self.num_classes = num_classes
        self.loss = loss
        self.lr = lr
        
        self.test_auroc = torchmetrics.AUROC(num_classes=self.num_classes)
        self.test_acc = torchmetrics.Accuracy(num_classes=self.num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_layers(x)
        x = self.classification_head(x)
        return x

    def _step(self, batch) -> torch.Tensor:
        x, y = batch
        pred = self.forward(x)
        loss = self.loss(pred, y)
        return pred, loss

    def training_step(self, batch) -> torch.Tensor:
        pred, loss = self._step(batch)
        self.log("train/loss", loss)
        pred = torch.nn.functional.softmax(pred, dim=1)
        acc = torchmetrics.functional.accuracy(pred, batch[-1], num_classes=self.num_classes)
        self.log("train/acc", acc)
        return loss
    
    def _eval_step(self, batch, auroc, acc):
        pred, loss = self._step(batch)
        pred = torch.nn.functional.softmax(pred, dim=1)
        auroc.update(pred, batch[-1])
        acc.update(pred, batch[-1])
        return loss
        
    def test_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def test_epoch_end(self, outputs) -> None:
        print(f"Test AUROC: {self.test_auroc.compute().data}")
        print(f"Test Accuracy: {self.test_acc.compute().data}")

    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def validation_epoch_end(self, outputs) -> None:
        print(f"Test AUROC: {self.test_auroc.compute().data}")
        print(f"Test Accuracy: {self.test_acc.compute().data}")

        
    def configure_optimizers(self) -> Any:
        optim = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optim

In [7]:
import torchvision
from torchvision.datasets import MNIST

transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),]
)

test_data = MNIST(
    root="data", 
    download=True, 
    train=False, 
    transform=transform)

dl_test = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

In [13]:
loss = torch.nn.CrossEntropyLoss()
lr = 1e-3
num_classes = 10

conv_layers_my_bn = torch.nn.Sequential(
    torch.nn.Conv2d(1, 32, (3, 3)),
    MyBatchNorm2d(32)
)

conv_layers_torch_bn = torch.nn.Sequential(
    torch.nn.Conv2d(1, 32, (3, 3)),
    torch.nn.BatchNorm2d(32)
)

pool_head = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(32, num_classes),
)

pool_head_ = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(32, num_classes),
)

custom_bn_model = CNN(loss, lr, conv_layers_my_bn, pool_head, num_classes)

torch_bn_model = CNN(loss, lr, conv_layers_torch_bn, pool_head_, num_classes)



In [11]:
EPOCHS = 1
# I train and test on the test set, because the goal is to show that the custom batch_norm layer behaves the same 
# Numbers can vary, because initialization of weights and biases can vary

trainer_custom_bn = pl.Trainer(max_epochs=EPOCHS, log_every_n_steps=10)
trainer_custom_bn.fit(custom_bn_model, dl_test)

trainer_torch_bn = pl.Trainer(max_epochs=EPOCHS, log_every_n_steps=10)
trainer_torch_bn.fit(torch_bn_model, dl_test)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name                | Type             | Params
---------------------------------------------------------
0 | conv_layers         | Sequential       | 384   
1 | classification_head | Sequential       | 330   
2 | loss                | CrossEntropyLoss | 0     
3 | test_auroc          | AUROC            | 0     
4 | test_acc            | Accuracy         | 0     
---------------------------------------------------------
714       Trainable params
0         Non-trainable params
714       Total params
0.003     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name                | Type             | Params
---------------------------------------------------------
0 | conv_layers         | Sequential       | 384   
1 | classification_head | Sequential       | 330   
2 | loss                | CrossEntropyLoss | 0     
3 | test_auroc          | AUROC            | 0     
4 | test_acc            | Accuracy         | 0     
---------------------------------------------------------
714       Trainable params
0         Non-trainable params
714       Total params
0.003     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [12]:
trainer_custom_bn.test(custom_bn_model, dl_test)
trainer_torch_bn.test(torch_bn_model, dl_test)

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.622795045375824
Test Accuracy: 0.19509999454021454
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          test               50.35130310058594
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.666571855545044
Test Accuracy: 0.2069000005722046
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          test               2.100219488143921
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test': 2.100219488143921}]