# (Optional) Spatial Batch Normalization

<div class="alert alert-danger">
    <strong>Note:</strong> This exercise is optional and can be done for a better understanding of batch normalization. Also, when using batch normalization with PyTorch, you should be paying attention to the number of dimensions in the input (see <a href="https://pytorch.org/docs/stable/nn.html#batchnorm1d">BatchNorm1d</a>, <a href="https://pytorch.org/docs/stable/nn.html#batchnorm2d">BatchNorm2d</a> etc.)
</div>

We already saw that batch normalization is a very useful technique for training deep fully-connected networks. Batch normalization can also be used for convolution networks, but we need to tweak it a bit; the modification will be called "spatial batch normalization". 

Since this part is strongly based on batch normalization, a good understanding of batch normalization in general is helpful. If you are not too familiar with the concept and implementation, take a look at the optional notebook `Optional-BatchNormalization&Dropout.ipynb` from exercise 08 first.

# 1. Extension from Batch Normalization

Normally batch-normalization accepts inputs of shape $(N, D)$ and produces outputs of shape $(N, D)$, where we normalize across the mini-batch dimension $N$. For data coming from convolution layers, batch normalization needs to accept inputs of shape $(N, C, H, W)$ and produce outputs of shape $(N, C, H, W)$ where the $N$ dimension denotes the mini-batch size, the $C$ dimension denotes number of channels, and the $(H, W)$ dimensions denote the spatial size of the feature map.

If the feature map was produced using convolutions, we apply the same filter to different locations of feature maps from last layer and to the whole batch of data to get a single feature channel. Then we expect the statistics of each feature channel to be relatively consistent both between different images and different locations within the same image. Therefore spatial batch normalization computes a mean and variance for each of the $C$ feature channels by computing statistics over both the mini-batch dimension $N$ and the spatial dimensions $H$ and $W$.

For a better understanding of relationship and difference between batch normalization and spatial batch normalization, the picture taken from [CS231n Note](http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture07.pdf) gives us a comparison.

<img src='https://i2dl.vc.in.tum.de/static/images/exercise_09/img3.jpeg' width=70% height=70%/>

Basically they share the same computation rules, i.e. normalize over some dimensions and transform to new output based on $y = \gamma (x - \mu) / \delta + \beta$. But they operate in different dimensions, since images are stored in a higher dimension tensor.

## (Optional) Mount folder in Colab

Uncomment thefollowing cell to mount your gdrive if you are using the notebook in google colab:

In [None]:
# Use the following lines if you want to use Google Colab
# We presume you created a folder "i2dl" within your main drive folder, and put the exercise there.
# NOTE: terminate all other colab sessions that use GPU!
# NOTE 2: Make sure the correct exercise folder (e.g exercise_09) is given.

"""
from google.colab import drive
import os

gdrive_path='/content/gdrive/MyDrive/i2dl/exercise_09'

# This will mount your google drive under 'MyDrive'
drive.mount('/content/gdrive', force_remount=True)

# In order to access the files in this notebook we have to navigate to the correct folder
os.chdir(gdrive_path)

# Check manually if all files are present
print(sorted(os.listdir()))
"""


## Set up PyTorch environment in colab
- (OPTIONAL) Enable GPU via Runtime --> Change runtime type --> GPU
- Uncomment the following cell if you are using the notebook in google colab:

In [None]:
############ COLAB ############
# You may simply used the preinstalled versions of colab. No need to reinstall pytorch.

# 2. Implementation

## 2.1 Imports

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import os
import shutil
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from exercise_code.layers import (
    spatial_batchnorm_forward, 
    spatial_batchnorm_backward,
)
from exercise_code.tests.gradient_check import (
    eval_numerical_gradient_array,
    eval_numerical_gradient,
    rel_error,
)
from exercise_code.tests.spatial_batchnorm_tests import (
    test_spatial_batchnorm_forward,
    test_spatial_batchnorm_backward,
)

from exercise_code.networks.SpatialBatchNormModel import (
    SimpleNetwork,
    SpatialBatchNormNetwork,
)

%load_ext autoreload
%autoreload 2

# supress cluttering warnings in solutions
import warnings
warnings.filterwarnings('ignore')

os.environ['KMP_DUPLICATE_LIB_OK']='True' # To prevent the kernel from dying.

<div class="alert alert-warning">
    <h3>Note: Google Colab</h3>
    <p>
In case you don't have a GPU, you can run this notebook on Google Colab where you can access a GPU for free, but, of course, you can also run this notebook on your CPU.
         </p>
</div>

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('You are using the following device: ', device)

## 2.2 Spatial Batch Normalization: Forward

<div class="alert alert-info">
    <h3>Task: Implement</h3>
    <p>In the file <code>exercise_code/layers.py </code>, implement the forward pass for spatial batch normalization in the function <code>spatial_batchnorm_forward</code>. Check your implementation by running the following cell:
 </p>
    <p>
    <b>Hints</b>: you can reuse the batch normalization function defined in exercise 08 optional task <code>Batch Normalization & Dropout</code>. Be careful about the difference of dimensions between batch normalization and spatial batch normalization.
    </p>
</div>

In [None]:
test_spatial_batchnorm_forward()

## 2.3 Spatial Batch Normalization: backward

Now that you have successfully implemented the spatial batch normalization forward pass by using the batch normalization functions, it would be easy and straightforward to finish the backward pass.

<div class="alert alert-info">
    <h3>Task: Implement</h3>
    <p>In the file <code>exercise_code/layers.py</code>, implement the backward pass for spatial batch normalization in the function <code>spatial_batchnorm_backward</code>. Run the following to check your implementation using a numeric gradient check:
 </p>
    <p>
    <b>Hints</b>: Again, you can reuse the batch normalization function defined in exercise 08 optional task <code>Batch Normalization & Dropout</code>. Take care of the tensor dimensions.
    </p>
</div>


In [None]:
test_spatial_batchnorm_backward()

## 2.4 Spatial Batch Normalization in Pytorch

Similar as the batch normalization task from previous exercise, here we would also like to do some experiments using Pytorch to see the effect of spatial batch normalization.

### 2.4.1 Setup TensorBoard

After some experience with TensorBoard so far, TensorBoard should be your friend in tuning your network and monitoring the training process. Throughout this notebook, feel free to add further logs or visualizations to your TensorBoard!

In [None]:
# define directory to save logs
logdir = './logs'

if os.path.exists(logdir):
    # We delete the logs on the first run
    shutil.rmtree(logdir)
    
os.mkdir(logdir)

### 2.4.2 Setup the training pipeline for the network

In [None]:
def create_tqdm_bar(iterable, desc):
    return tqdm(enumerate(iterable),total=len(iterable), ncols=150, desc=desc)

def train_model(model, train_loader, val_loader, loss_func, tb_logger, epochs=10, name='Autoencoder'):
    
    optimizer = model.configure_optimizers()
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=epochs * len(train_loader) / 5, gamma=0.7)
    validation_loss = 0
    model = model.to(device)
    for epoch in range(epochs):
        
        # Train
        training_loop = create_tqdm_bar(train_loader, desc=f'Training Epoch [{epoch + 1}/{epochs}]')
        training_loss = 0
        for train_iteration, batch in training_loop:
            optimizer.zero_grad()
            loss = model.training_step(batch, loss_func)
            loss.backward()
            optimizer.step()
            scheduler.step()
            
            training_loss += loss.item()

            # Update the progress bar.
            training_loop.set_postfix(train_loss = "{:.8f}".format(training_loss / (train_iteration + 1)), val_loss = "{:.8f}".format(validation_loss))

            # Update the tensorboard logger.
            tb_logger.add_scalar(f'{name}/train_loss', loss.item(), epoch * len(train_loader) + train_iteration)

        # Validation
        val_loop = create_tqdm_bar(val_loader, desc=f'Validation Epoch [{epoch + 1}/{epochs}]')
        validation_loss = 0
        with torch.no_grad():
            for val_iteration, batch in val_loop:
                loss = model.validation_step(batch, loss_func) # You need to implement this function.
                validation_loss += loss.item()

                # Update the progress bar.
                val_loop.set_postfix(val_loss = "{:.8f}".format(validation_loss / (val_iteration + 1)))

                # Update the tensorboard logger.
                tb_logger.add_scalar(f'{name}/val_loss', validation_loss / (val_iteration + 1), epoch * len(val_loader) + val_iteration)
        # This value is for the progress bar of the training loop.
        validation_loss /= len(val_loader)

In [None]:
# Few Hyperparameters before we start things off
batch_size = 50

epochs = 5
learning_rate = 0.0005

### 2.4.3 Train a model without Spatial Batch Normalization

<div class="alert alert-success">
    <h3>Task: Check Code</h3>
    <p>We have already implemented a <code>SimpleNetwork</code> without spatial batch normalization in <code>exercise_code/networks/SpatialBatchNormModel.py</code>. Feel free to check it out and play around with the parameters. The cell below is setting up a short training process for this network.
 </p>
</div>

In [None]:
model = SimpleNetwork(batch_size=batch_size, learning_rate=learning_rate).to(device)

path = os.path.join('logs', 'Spatial_BN_model_logs')

if os.path.exists(path):
    shutil.rmtree(path)
path = os.path.join(path, f'simple-model')
tb_logger = SummaryWriter(path)

# Train the classifier.
train_dl, val_dl, _ = model.prepare_data()

loss_func = F.cross_entropy # The loss function we use for regression (Could also be nn.L1Loss()).

train_model(model, train_dl, val_dl, loss_func, tb_logger, epochs=epochs, name='SpatialBatchNorm')

### 2.4.4 Train a model with Spatial Batch Normalization

<div class="alert alert-success">
    <h3>Task: Check Code</h3>
    <p> Now that we have already seen how our simple network should work, let us look at a model that is actually using spatial batch normalization. Again, we provide you with such a model <code>SpatialBatchNormNetwork</code> in <code>exercise_code/netowkrs/SpatialBatchNormModel.py</code>. Same as before: Feel free to check it out and play around with the parameters. The cell below is setting up a short training process for this model. 
 </p>
</div>

In [None]:
model = SpatialBatchNormNetwork(batch_size=batch_size, learning_rate=learning_rate)

path = os.path.join('logs', 'Spatial_BN_model_logs')
path = os.path.join(path, f'SBN-model')
tb_logger = SummaryWriter(path)

# Train the classifier.
train_dl, val_dl, _ = model.prepare_data()

loss_func = F.cross_entropy # The loss function we use for regression (Could also be nn.L1Loss()).
train_model(model, train_dl, val_dl, loss_func, tb_logger, epochs=epochs, name='SpatialBatchNorm')

### 2.4.5 Observations
Finally, you can launch and take a look at the TensorBoard to compare the performance of both networks:

In [None]:
################# COLAB ONLY #################
# %load_ext tensorboard
# %tensorboard --logdir=./ --port 6006

# Use the cmd for less trouble, if you can. From the working directory, run: tensorboard --logdir=./ --port 6006

#### Reminder: Run Tensorboard

<div class="alert alert-success">
    <p> 
    There area a few ways to run tensorboard (Make sure that version <code>2.8.0</code> is installed): <br><br>
    1) (In Jupyter) Use the cell above. <br>
    2) (In VS Code) Press <code> Ctrl + Shift + p </code> and look for <code> Tensorboard </code> <br>
    3) Put <code> tensorboard --logdir=./ </code> in the terminal.

</div>

Remember the comparison result with respect to batch normalization from last exercise, the difference here is very similar as before, i.e. we could have lower validation loss and higher validation accuracy using spatial batch normalization. The simple experiment shows that spatial batch normalization is helpful when we use convolution networks.