# (Optional) Spatial Batch Normalization

<div class="alert alert-danger">
    <strong>Note:</strong> This exercise is optional and can be done for a better understanding of batch normalization. Also, when using batch normalization with PyTorch, you should be paying attention to the number of dimensions in the input (see <a href="https://pytorch.org/docs/stable/nn.html#batchnorm1d">BatchNorm1d</a>, <a href="https://pytorch.org/docs/stable/nn.html#batchnorm2d">BatchNorm2d</a> etc.)
</div>

We already saw that batch normalization is a very useful technique for training deep fully-connected networks. Batch normalization can also be used for convolution networks, but we need to tweak it a bit; the modification will be called "spatial batch normalization". 

Since this part is strongly based on batch normalization, a good understanding of batch normalization in general is helpful. If you are not too familiar with the concept and implementation, take a look at the optional notebook `Optional-BatchNormalization&Dropout.ipynb` from exercise 08 first.

# 1. Extension from Batch Normalization

Normally batch-normalization accepts inputs of shape $(N, D)$ and produces outputs of shape $(N, D)$, where we normalize across the mini-batch dimension $N$. For data coming from convolution layers, batch normalization needs to accept inputs of shape $(N, C, H, W)$ and produce outputs of shape $(N, C, H, W)$ where the $N$ dimension gives the mini-batch size and the $(H, W)$ dimensions give the spatial size of the feature map.

If the feature map was produced using convolutions, we apply the same filter to different locations of feature maps from last layer and to the whole batch of data to get a single feature channel. Then we expect the statistics of each feature channel to be relatively consistent both between different images and different locations within the same image. Therefore spatial batch normalization computes a mean and variance for each of the $C$ feature channels by computing statistics over both the mini-batch dimension $N$ and the spatial dimensions $H$ and $W$.

For a better understanding of relationship and difference between batch normalization and spatial batch normalization, the picture taken from [CS231n Note](http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture07.pdf) gives us a comparison.

<img src='images/SpatialBatchNorm.JPG' width=70% height=70%/>

Basically they share the same computation rules, i.e. normalize over some dimensions and transform to new output based on $y = \gamma (x - \mu) / \delta + \beta$. But they operate in different dimensions, since images are stored in a higher dimension tensor.

# 2. Implementation

## 2.1 Imports

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
import shutil
import pytorch_lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

from exercise_code.layers import (
    spatial_batchnorm_forward, 
    spatial_batchnorm_backward,
)
from exercise_code.tests.gradient_check import (
    eval_numerical_gradient_array,
    eval_numerical_gradient,
    rel_error,
)
from exercise_code.tests.spatial_batchnorm_tests import (
    test_spatial_batchnorm_forward,
    test_spatial_batchnorm_backward,
)

from exercise_code.networks.SpatialBatchNormModel import (
    SimpleNetwork,
    SpatialBatchNormNetwork,
)

%load_ext autoreload
%autoreload 2

# supress cluttering warnings in solutions
import warnings

warnings.filterwarnings("ignore")

## 2.2 Spatial Batch Normalization: Forward

<div class="alert alert-info">
    <h3>Task: Implement</h3>
    <p>In the file <code>exercise_code/layers.py </code>, implement the forward pass for spatial batch normalization in the function <code>spatial_batchnorm_forward</code>. Check your implementation by running the following cell:
 </p>
    <p>
    <b>Hints</b>: you can reuse the batch normalization function defined in exercise 08 optional task <code>Batch Normalization & Dropout</code>. Be careful about the difference of dimensions between batch normalization and spatial batch normalization.
    </p>
</div>

In [2]:
test_spatial_batchnorm_forward()

SpatialBatchnormForwardTest with trivial beta and gamma (train) passed.
SpatialBatchnormForwardTest with nontrivial beta and gamma (train) passed.
SpatialBatchnormForwardTest with trivial beta and gamma (test) passed.
All tests passed for your spatial batchnorm implementation. Tests passed: 3/3


## 2.3 Spatial Batch Normalization: backward

Now that you have successfully implemented the spatial batch normalization forward pass by using the batch normalization functions, it would be easy and straightforward to finish the backward pass.

<div class="alert alert-info">
    <h3>Task: Implement</h3>
    <p>In the file <code>exercise_code/layers.py</code>, implement the backward pass for spatial batch normalization in the function <code>spatial_batchnorm_backward</code>. Run the following to check your implementation using a numeric gradient check:
 </p>
    <p>
    <b>Hints</b>: Again, you can reuse the batch normalization function defined in exercise 08 optional task <code>Batch Normalization & Dropout</code>. Take care of the tensor dimensions.
    </p>
</div>


In [3]:
test_spatial_batchnorm_backward()

SpatialBatchnormBackwardTest passed.


## 2.4 Spatial Batch Normalization in Pytorch Lightning

Similar as the batch normalization task from previous exercise, here we would also like to do some experiments using Pytorch Lightning to see the effect of spatial batch normalization.

### 2.4.1 Setup TensorBoard

After some experience with TensorBoard so far, TensorBoard should be your friend in tuning your network and monitoring the training process. Throughout this notebook, feel free to add further logs or visualizations to your TensorBoard!

In [4]:
# Few Hyperparameters before we start things off
batch_size = 50

logdir = "./spatial_batch_norm_logs"
if os.path.exists(logdir):
    # We delete the logs on the first run
    shutil.rmtree(logdir)
os.mkdir(logdir)

epochs = 5
learning_rate = 0.0005

In [5]:
%load_ext tensorboard
%tensorboard --logdir spatial_batch_norm_logs

ERROR: Timed out waiting for TensorBoard to start. It may still be running as pid 6060.

### 2.4.2 Train a model without Spatial Batch Normalization

<div class="alert alert-success">
    <h3>Task: Check Code</h3>
    <p>We have already implemented a <code>SimpleNetwork</code> without spatial batch normalization in <code>exercise_code/networks/SpatialBatchNormModel.py</code>. Feel free to check it out and play around with the parameters. The cell below is setting up a short training process for this network.
 </p>
</div>

In [6]:
# train
model = SimpleNetwork(batch_size=batch_size, learning_rate=learning_rate)
# Creating a logging object
simple_network_logger = TensorBoardLogger(
    save_dir=logdir,
    name="simple_network"
)
trainer = pl.Trainer(max_epochs=epochs, logger=simple_network_logger)

trainer.fit(model)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 4.8 K 
1 | fc    | Linear     | 15.7 K


Validation sanity check: 0it [00:00, ?it/s]

Val-Acc=0.00125


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Val-Acc=0.6651666666666667


Validating: 0it [00:00, ?it/s]

Val-Acc=0.73625


Validating: 0it [00:00, ?it/s]

Val-Acc=0.7673333333333333


Validating: 0it [00:00, ?it/s]

Val-Acc=0.78475


Validating: 0it [00:00, ?it/s]

Val-Acc=0.7961666666666667


1

### 2.4.3 Train a model with Spatial Batch Normalization

<div class="alert alert-success">
    <h3>Task: Check Code</h3>
    <p> Now that we have already seen how our simple network should work, let us look at a model that is actually using spatial batch normalization. Again, we provide you with such a model <code>SpatialBatchNormNetwork</code> in <code>exercise_code/netowkrs/SpatialBatchNormModel.py</code>. Same as before: Feel free to check it out and play around with the parameters. The cell below is setting up a short training process for this model. 
 </p>
</div>

In [7]:
model_bn = SpatialBatchNormNetwork(batch_size=batch_size, learning_rate=learning_rate)
spatial_bn_network_logger = TensorBoardLogger(
    save_dir=logdir,
    name="spatial_bn_network"
)
trainer = pl.Trainer(max_epochs=epochs, logger=spatial_bn_network_logger)
trainer.fit(model_bn)

GPU available: True, used: False
TPU available: False, using: 0 TPU cores

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 4.9 K 
1 | fc    | Linear     | 15.7 K


Validation sanity check: 0it [00:00, ?it/s]

Val-Acc=0.0010833333333333333


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Val-Acc=0.7924166666666667


Validating: 0it [00:00, ?it/s]

Val-Acc=0.81925


Validating: 0it [00:00, ?it/s]

Val-Acc=0.8306666666666667


Validating: 0it [00:00, ?it/s]

Val-Acc=0.8419166666666666


Validating: 0it [00:00, ?it/s]

Val-Acc=0.844


1

### 2.4.4 Observations

Take a look at TensorBoard to compare the performance of both networks:

In [8]:
%load_ext tensorboard
%tensorboard --logdir spatial_batch_norm_logs

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Reusing TensorBoard on port 6006 (pid 19388), started 0:14:10 ago. (Use '!kill 19388' to kill it.)

Remember the comparison result with respect to batch normalization from last exercise, the difference here is very similar as before, i.e. we could have lower validation loss and higher validation accuracy using spatial batch normalization. The simple experiment shows that spatial batch normalization is helpful when we use convolution networks.