In [1]:
import torch

def batch_norm_2d(X, gamma, beta, moving_mean, moving_var, train_mode, eps=1e-5, momentum=0.1):
    """eps and momentum values are the same as pytorch defaults"""
    def _bn(x, mean, var):
        # expansion for our 1d mean and var variables
        _expansion = (1, mean.shape[0], 1, 1)
        # according to the batchnorm paper eps is added inside the sqrt instead of after as shown in the lecture
        # There should not be any significant error from that but just to document it
        return torch.div(x - mean.view(_expansion), torch.sqrt(var.view(_expansion) + eps)) * gamma + beta

    if not train_mode:
        return _bn(X, moving_mean, moving_var), moving_mean, moving_var

    # we need to set unbiased=False to remove the bezel correction 
    mean = X.mean(dim=(0, 2, 3)) 
    var = X.var(dim=(0, 2, 3), unbiased=False) 

    with torch.no_grad():
        moving_mean.mul_(1.0 - momentum).add_(momentum * mean)
        moving_var.mul_(1.0 - momentum).add_(momentum * var)
    
    return _bn(X, mean, var), moving_mean, moving_var



In [2]:
class MyBatchNorm2d(torch.nn.Module):
    def __init__(self, in_channels) -> None:
        super().__init__()
        shape = (1, in_channels, 1, 1)
        self.gamma = torch.nn.Parameter(torch.ones(shape))
        self.beta = torch.nn.Parameter(torch.zeros(shape))
        self.moving_mean = torch.nn.Parameter(torch.zeros(in_channels))
        self.moving_var = torch.nn.Parameter(torch.zeros(in_channels))

    def forward(self, x):
        # self.training is set by self.eval() and self.train()
        # we update moving mean and var inplace, since it has to be a parameter 
        # this destroys batch_norm_2d purity ;(
        norm, _, _ = batch_norm_2d(
            X=x, 
            gamma=self.gamma, 
            beta=self.beta, 
            moving_mean=self.moving_mean, 
            moving_var=self.moving_var, 
            train_mode=self.training
        )
        return norm

In [3]:
# Check if batch norms do the same thing 
import torch.nn.functional as F

# mb_size = 2, 3 channels and 4 x 4 image size
to_norm = torch.randn((2, 3, 4, 4))
shape = (1, 3, 1, 1)
gamma  = torch.ones(shape)
beta  = torch.zeros(shape)
moving_mean = torch.zeros(3)
moving_var = torch.zeros(3)
eps = 1e-5

bn2d = MyBatchNorm2d(3)

a, b, c = batch_norm_2d(X=to_norm, gamma=gamma, beta=beta, moving_mean=moving_mean, moving_var=moving_var, train_mode=True)
a_ = bn2d(to_norm)
# we need to reset moving_mean and var because we took away batch_norm_2d's purity 
moving_mean = torch.zeros(3)
moving_var = torch.zeros(3)
bns = F.batch_norm(to_norm, weight=gamma, bias=beta, running_mean=moving_mean, running_var=moving_var, training=True)

# check if the results are the same
print(a[0], a_[0], bns[0]) 

# check if moving_mean is updated the same way
print(b, bn2d.moving_mean, moving_mean)

tensor([[[-1.6069,  1.4866,  0.2035,  0.1400],
         [ 0.8278, -0.6942, -0.4111, -1.8127],
         [ 0.0092,  0.6123, -1.1293, -0.4382],
         [ 0.2795,  0.1995,  1.6047, -0.8003]],

        [[-0.3410, -0.9540,  0.3204,  0.3103],
         [-0.8923, -1.6469, -0.5020,  1.1118],
         [-0.4890,  2.8161,  0.1520,  0.3219],
         [-0.2665,  1.3162,  0.4014,  1.6630]],

        [[ 0.2422, -0.7985, -0.2573,  0.0778],
         [-0.0048,  1.1459, -0.2289,  1.3111],
         [-0.5052, -1.4229,  0.9413,  0.0635],
         [-0.5695,  1.6662,  1.3684, -0.9769]]]) tensor([[[-1.6069,  1.4866,  0.2035,  0.1400],
         [ 0.8278, -0.6942, -0.4111, -1.8127],
         [ 0.0092,  0.6123, -1.1293, -0.4382],
         [ 0.2795,  0.1995,  1.6047, -0.8003]],

        [[-0.3410, -0.9540,  0.3204,  0.3103],
         [-0.8923, -1.6469, -0.5020,  1.1118],
         [-0.4890,  2.8161,  0.1520,  0.3219],
         [-0.2665,  1.3162,  0.4014,  1.6630]],

        [[ 0.2422, -0.7985, -0.2573,  0.0778],
   

In [4]:
# Copy and Paste from last exercise 
import pytorch_lightning as pl
from typing import Any
import torchmetrics
import torch

class CNN(pl.LightningModule):

    def __init__(self, loss: callable, lr: float, conv_layers, classification_head:torch.nn.Module ,num_classes:int =10) -> None:
        super().__init__()
        self.conv_layers = conv_layers
        self.classification_head = classification_head 
        self.num_classes = num_classes
        self.loss = loss
        self.lr = lr
        
        self.test_auroc = torchmetrics.AUROC(num_classes=self.num_classes)
        self.test_acc = torchmetrics.Accuracy(num_classes=self.num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv_layers(x)
        x = self.classification_head(x)
        return x

    def _step(self, batch) -> torch.Tensor:
        x, y = batch
        pred = self.forward(x)
        loss = self.loss(pred, y)
        return pred, loss

    def training_step(self, batch) -> torch.Tensor:
        pred, loss = self._step(batch)
        self.log("train/loss", loss)
        pred = torch.nn.functional.softmax(pred, dim=1)
        acc = torchmetrics.functional.accuracy(pred, batch[-1], num_classes=self.num_classes)
        self.log("train/acc", acc)
        return loss
    
    def _eval_step(self, batch, auroc, acc):
        pred, loss = self._step(batch)
        pred = torch.nn.functional.softmax(pred, dim=1)
        auroc.update(pred, batch[-1])
        acc.update(pred, batch[-1])
        return loss
        
    def test_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def test_epoch_end(self, outputs) -> None:
        print(f"Test AUROC: {self.test_auroc.compute().data}")
        print(f"Test Accuracy: {self.test_acc.compute().data}")

    def validation_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor:
        loss = self._eval_step(batch, self.test_auroc, self.test_acc)
        self.log("test", loss)

    def validation_epoch_end(self, outputs) -> None:
        print(f"Test AUROC: {self.test_auroc.compute().data}")
        print(f"Test Accuracy: {self.test_acc.compute().data}")

        
    def configure_optimizers(self) -> Any:
        optim = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optim

In [5]:
import torchvision
from torchvision.datasets import MNIST

transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),]
)

test_data = MNIST(
    root="data", 
    download=True, 
    train=False, 
    transform=transform)

dl_test = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

In [6]:
loss = torch.nn.CrossEntropyLoss()
lr = 1e-3
num_classes = 10

conv_layers_my_bn = torch.nn.Sequential(
    torch.nn.Conv2d(1, 32, (3, 3)),
    MyBatchNorm2d(32)
)

conv_layers_torch_bn = torch.nn.Sequential(
    torch.nn.Conv2d(1, 32, (3, 3)),
    torch.nn.BatchNorm2d(32)
)

pool_head = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(32, num_classes),
)

pool_head_ = torch.nn.Sequential(
    torch.nn.AdaptiveAvgPool2d((1, 1)),
    torch.nn.Flatten(),
    torch.nn.Linear(32, num_classes),
)

custom_bn_model = CNN(loss, lr, conv_layers_my_bn, pool_head, num_classes)

torch_bn_model = CNN(loss, lr, conv_layers_torch_bn, pool_head_, num_classes)



In [100]:
EPOCHS = 1
# I train and test on the test set, because the goal is to show that the custom batch_norm layer behaves the same 
# Numbers can vary, because initialization of weights and biases can vary

trainer_custom_bn = pl.Trainer(max_epochs=EPOCHS, log_every_n_steps=10)
trainer_custom_bn.fit(custom_bn_model, dl_test)

trainer_torch_bn = pl.Trainer(max_epochs=EPOCHS, log_every_n_steps=10)
trainer_torch_bn.fit(torch_bn_model, dl_test)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name                | Type             | Params
---------------------------------------------------------
0 | conv_layers         | Sequential       | 448   
1 | classification_head | Sequential       | 330   
2 | loss                | CrossEntropyLoss | 0     
3 | test_auroc          | AUROC            | 0     
4 | test_acc            | Accuracy         | 0     
---------------------------------------------------------
778       Trainable params
0         Non-trainable params
778       Total params
0.003     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(

  | Name                | Type             | Params
---------------------------------------------------------
0 | conv_layers         | Sequential       | 384   
1 | classification_head | Sequential       | 330   
2 | loss                | CrossEntropyLoss | 0     
3 | test_auroc          | AUROC            | 0     
4 | test_acc            | Accuracy         | 0     
---------------------------------------------------------
714       Trainable params
0         Non-trainable params
714       Total params
0.003     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [12]:
trainer_custom_bn.test(custom_bn_model, dl_test)
trainer_torch_bn.test(torch_bn_model, dl_test)

  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.622795045375824
Test Accuracy: 0.19509999454021454
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          test               50.35130310058594
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


  rank_zero_warn(


Testing: 0it [00:00, ?it/s]

Test AUROC: 0.666571855545044
Test Accuracy: 0.2069000005722046
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
          test               2.100219488143921
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test': 2.100219488143921}]

# Part c observations:

- there is no significant difference between the torch implementation and my implementation.
- the difference is probably because of different initialization of the Parameters and other random factors