In [None]:
from IPython.display import Image
from IPython.core.display import HTML

In [None]:
import torch
import helper
import pywt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

## Introduction

The purpose of the notebook is to implement the architecture shown in the following article.

*Wavelet Convolutional Neural Networks*, Fujieda et al. 2018, (link [here](https://arxiv.org/abs/1805.08620))

Basically the wavelet CNN is a convolutional neural network with two main characteristics:
- The first one is that the data entering the network are not images, but **their wavelet decomposition**. A first step is to always decompose the input images by passing them into a low pass filter and high pass filter.
    - Everytime, before being passed to the convolution blocks, both low-pass and high-pass filters are concatenated channel-wise.
    - The low-pass filter is recursively decomposed before each level.
    - After the first level, the data is again concatenated but with another component the shortcut projection.
- The second one, is the use of shortcuts projections.
    - These shorcuts are added to the channel-wise concatenation before each level of convolution
    - The more forward we go in the network, the more convolutions the filter have to pass through.
   

From the architecture described above, it is possible to decompose the network into basic building blocks. Following an Object-Oriented Programming, it will be possible to build classes for each basic block.
The building blocks would be as follows:

- **The decomposition block**, data decomposition into low-pass filters and high-pass filters
- **The shortcut block**, shortcuts used in the process
- **The concatenation block**, channel-wise concatenation that occurs before the convolutions
- **The convolution block**, a two-step convolution that occurs at each level

The purpose of the notebook is to work on each of these blocks and put them together to create the wavelet-CNN.

# First and foremost, reading the data

Here is the link to download the data: https://www.csc.kth.se/cvap/databases/kth-tips/download.html

Even though it is not very important...

The data used in the example is the same that has been used in the article. k-th-tips-2. The authors just read the data in the regular way.
- They use regular images of size 224x224
- A first step is to scale the training images to 256x256
- then conducting random crops to 224x224
- Then flipping

It is done like below:

In [None]:
data_path = "../wavelets/kth-tips2-b_col_200x200/"
transform = transforms.Compose([transforms.Resize(255), 
                               transforms.CenterCrop(224),
                               transforms.ToTensor()])
dataset = datasets.ImageFolder(data_path, transform=transform)

In [None]:
# DataLoaders
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, 
                                         shuffle=True)

In [None]:
images, labels = next(iter(dataloader))

The function `imshow` below has been taken from a stackoverflow discussion because it allowed to show the images (link [here](https://stackoverflow.com/questions/53570181/error-in-importing-libraries-of-helper-though-helper-is-installed))

In [None]:
def imshow(image, ax=None, title=None, normalize=True):
    """Imshow for Tensor."""
    if ax is None:
        fig, ax = plt.subplots()
    image = image.numpy().transpose((1, 2, 0))

    if normalize:
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = std * image + mean
        image = np.clip(image, 0, 1)

    ax.imshow(image)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.spines['left'].set_visible(False)
    ax.spines['bottom'].set_visible(False)
    ax.tick_params(axis='both', length=0)
    ax.set_xticklabels('')
    ax.set_yticklabels('')

    return ax

In [None]:
imshow(images[0].view(3, 224, 224));

## Decomposition block

### Convolution and convolutional networks in a nutshell

A brief explanation of convolutions is relevant in order to justify why the use of wavelets can be relevant in image analysis. The authors justify the use of wavelets based on the limits of the regular convolution operation in a CNN. In the CNN litterature a convolution operation is just like this (GoodFellow et al, 2016, link [here](https://www.deeplearningbook.org/contents/convnets.html)):

$$
s(t) = (x * w)(t) = \sum_{a=-\infty}^{+\infty}x(a)w(t -a)
$$

The first argument $x$ is called the **input** and the second one $w$ is the **kernel**. The output $s(t)$ is called the **feature map**. In general in a CNN each convolution layer contains several operations added to the one described above. It is very often associated with for example a **pooling operation**, a normalization, and an activation function.

### A word on multiresolution analysis

Before talking about the technical aspect of it, let us briefly explain how multiresolution signal works. The multiresolution analysis (MRA) has been introduced by Stéphane MALLAT in 1989 in the paper *A Theory of Multiresolution Signal Decomposition* (link [here](https://www.di.ens.fr/~mallat/papiers/MallatTheory89.pdf)). The Multiresolution approah consists in decomposing images into approximation and details coefficients at resolution based on $2^j$ (the standard one, and the one used in Fujieda et al). In Mallat's paper, the signal of an image can be decomposed using a *scaling function* $\phi(x)$ and a *wavelet function* $\psi(x)$. The scaling function is used to compute the approximation signal whereas the wavelet function computes the detail signal. It is worth mentionning that the approximation can be considered like a low-pass filter of the image signal, whereas the detail signal corresponds to a high-pass filter.


### Wavelets in a CNN
Fujieda et al, go from the basic convolution layer that occurs in a CNN. But they note that CNNs can be seen as a limited form of multiresolution analysis because it does not take into account the high-pass filtered information. The convolution for them is noted as such:
$$
x_{l, t+1} = (x_{l, t} * k) \downarrow 2
$$

Where $x_{l, t}$ denotes the value of $x$ at the layer $l$ at time $t$, and $\downarrow 2$ implies a downsizing of the output to half the original size.

In the equation above the authors considered that a CNN would a limited form of multiresolution analysis. Using a a wavelet decomposition of an image would allow to take into account the whole spectrum of the multiresolution analysis and the hierarchical decomposition of an image. So now the convolution becomes:

$$
x_l = (x_{l, t} * k_{l, t}) \downarrow 2
$$
$$
x_h = (x_{l, t} * k_{h, t}) \downarrow 2
$$

In Mallat's approach, $k_{l, t}$ corresponds to the scaling function $\phi(x)$ (the approximation coefficients) and $k_{h, t}$ is the wavelet function $\psi(x)$ (details coefficients)

### Implementation with pywt
This is considered to be the most important block, since it is the main innovation brought by the architecture. In parallel to the convolution filters, a mutli-level decomposition of the image takes place. In figure 1 of the paper, a 4-level decomposition of the image takes place. Actually as mentionned earlier even convolution layers are trained on the wavelet transform of the image.
- At each step, only the low-pass filters are further decomposed, in multi-resolution analysis, it corresponds to the approximation coefficients that are to be recursively decomposed at each level (with a decrease of the size by half at each level $2^{-j}$).

- It is possible to use the `pywt` package to perform this decomposition. There is a method that makes this decomposition and maybe even returns the concatenated version of the filters. Useful methods would be is `pywt.coeffs_to_array` and `wavedec2`.

## The shortcut block

The shortcut block is taken directly from the paper about *Deep Residual Learning*, He et al, 2016 (link [here](https://arxiv.org/pdf/1512.03385.pdf)). Basically it involves an identity mapping of the inputs that are then added to the output layer. Let us admit a mapping function $\mathcal{F}(x)$ that corresponds to the feature map of a convolution layer for example. Before entering the next layer, the whole input $x$ is added to the mapping $\mathcal{F}(x)$ which gives: 
$$\mathcal{F}(x) + x$$

This is what He et al, have called the *residual mapping* and note (interestingly) that the identity shorcut connections add neither extra parameters nor computational complexity. Except that in our case, identity shortcuts are not adapted because dimensions differe between input and output layers. So this is where **projection shortcuts** come into play. Let us write the way *He et al*, present the projection shortcut:
$$
\mathbf{y} = \mathcal{F}(x, \{W_i\}) + W_s \mathbf{x}
$$

Where $\mathcal{F}(x, \{W_i\})$ is the residual mapping to be learned and $W_s$ is a linear projection for $\mathbf{x}$ used in order to match the dimensions of the residual mapping. Fujieda et al, indicate that the projection shortcuts are performed using a 1x1 convolutional kernel in order to increase dimensions of the input. All shortcuts are done using projection shortcuts. But it differs in between the first shortcut and the rest.

- The first shortcut projection is made with the highpass filter of the first level decomposition. This filter is concatenated to other data especially before entering the second layer. *(still a doubt that)*
- For all the other levels the whole decomposition is projected and concatenated.

Below a `ProjectionShortcut` class has been developped in order to use it later for the wavelet CNN. It might be a bit *overkill* to create a class just for a projection shortcut but due to a lack of PyTorch mastery, it seems to be the best solution...

In [None]:
class ProjectionShortcut(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        
        self.conv = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, 
                              kernel_size=1, stride=self.stride, padding=1)
    def forward(self, x):
        return self.conv(x)

In [None]:
x = images[0]

In [None]:
# adding a dimension to x before convolution because Conv2d accepts batches of data
x= x.unsqueeze(0)

In [None]:
ps = ProjectionShortcut(3, 64)
ps(x)
print(ps)

## Concatenation block

Despite the use of shortcut connections, it appears that a vanishing gradient problem can still appear in the process. The authors have relied on a method taken from *DenseNet* an architecture which has been presented in *Densely Connected Convolutional Networks*, Huang et al, 2018 (link [here](https://arxiv.org/pdf/1608.06993.pdf)). The operation here is very basic, it relies on a **channel-wise concatenation** of several tensors. In the very particular case of the Wavelet-CNN, it is the concatenation of the output of the previous layer with  a feature map of the wavelet decomposition of the image at corresponding level.

### Concatenating pytorch tensors

There are functions that can do the concatenation in Pytorch (e.g. `torch.cat` or `torch.stack`).

## Convolution block
- The decomposition block which corresponds to the convolution filter that are used. The example is a four level decomposition of the input image.
- But it is always the same block with different *in_channels* and *out_channels* size
- So as a first step, a basic `ConvolutionBlock` class is built before using it in the development of the network.

In [None]:
class ConvolutionBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels):
        super().__init__()
        
        ### Convolution layer
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.block = nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels, 
                      out_channels=self.out_channels, 
                      kernel_size=3, padding=1),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=self.out_channels, 
                      out_channels=self.out_channels, 
                      kernel_size=3, stride=2, padding=1)
        )
        
    def forward(self, x):
        return self.block(x)

In [None]:
dummy = torch.ones((1, 3, 224, 224))
block = ConvolutionBlock(3, 64)
block(dummy)
print(block)

# Bringing things together

Below is *(an attempt)* at creating a version for the Wavelet-CNN, for this a class `WaveletCnn` is created in which the different components of the network will be taken into consideration.

As it is just a basic example aiming to reproduce the network as it has been shown in Fujieda et al. Their model in Fig. 1 in the paper consists of a CNN with a 4-level decomposition of the input image, hence four corresponding convolution layers with a global average pooling layer and a fully connected layer at the end.

- There is a method `img_decomposition` within the class that should allow to make the image decomposition and creates an iterable containing the different levels of the multiresolution analysis.

**REMINDER**: The code below is unfinished (just like the whole project) its only purpose is to show how the architecture the authors suggested would look like. There are still many problems to fix including (but not only):
- The handling of the image decomposition within the class
- How the projection shortcuts work and where do we perform the addition (if there is one)
- How to handle global average pooling
- Perhaps a special class should be created to the handle the wavelet decomposition which would allow us to extract the relvant arrays at the corresponding level
- ...

In [None]:
class WaveletCnn(nn.Module):
    
    def __init__(self, in_channels):
        
        super().__init__()
        
        # Higher part of the projection shortcuts, with stride 2
        self.ps1 = ProjectionShortcut(3, 64, 2)
        self.ps2 = ProjectionShortcut(64, 128, 2)
        self.ps3 = ProjectionShortcut(128, 256, 2)
        self.ps4 = ProjectionShortcut(256, 512, 2)
        
        # Creating four convolution blocks
        
        self.block1 = ConvolutionBlock(3, 64)
        self.block2 = ConvolutionBlock(64, 128)
        self.block3 = ConvolutionBlock(128, 256)
        self.block4 = ConvolutionBlock(256, 512)
        
        # Transformations for decomposition
        self.convdec2 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3)
        
        self.convdec3 = nn.Sequential(
            nn.Conv2d(in_channels=3, 
                      out_channels=64, 
                      kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, 
                      out_channels=128, 
                      kernel_size=3, padding=1))
        
        self.convdec4 = nn.Sequential(
            nn.Conv2d(in_channels=3, 
                      out_channels=64, 
                      kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, 
                      out_channels=128, 
                      kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, 
                      out_channels=256, 
                      kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU())
        
        # Projection shortcuts for the wavelet decomposition
        self.psdec2 = ProjectionShortcut(3, 64)
        self.psdec3 = ProjectionShortcut(3, 128)
        self.psdec4 = ProjectionShortcut(3, 256)
        
        # fully connected layer
        self.fc2 = nn.Linear(512, 11)

        
          
    def img_decomposition(self, img, n_level):
        """Function performing the multiresolution decomposition at n_level 
        and returning a dictionary with the corresponding arrays and slices"""
        
        def decomposition(img, level):
            c = pywt.wavedec2(x, 'Haar', mode='periodization', level=level)
            arr, slices = pywt.coeffs_to_array(c, axes=[1, 2])
            return arr, slices
        
        decomposition_dict = {"level"+str(l+1): decomposition(img, l) for l in range(n_level)}
        
        return decomposition_dict

    
    def forward(self, x):
        
        dec_dict = self.img_decomposition(x, 4)
        # fetching the detail coefficient of the first level decomposition
        d1 = dec_dict["level1"][0][dec_dict["level1"][1][1]["dd"]]
        ps1 = self.ps1(d1)
        arr = dec_dict["level1"][0]
        
        # First level k1
        x = self.block1(arr1)
        psdec2 = self.psdec2(dec_dict["level2"][0])
        convdec2 = self.convdec2(dec_dict["level2"][0])
        x = torch.cat([x, ps1, psdec2, convdec1], 1)
        ps2 = self.ps2(x)
        
        # Second layer k2
        x = self.block2(x)
        psdec3 = self.psdec3(dec_dict["level3"][0])
        convdec3 = self.convdec3(dec_dict["level3"][0])
        x = torch.cat([x, ps2, psdec3, convdec3], 1)
        ps3 = self.ps3(x)
        
        # Third layer k3
        x = self.block3(x)
        psdec4 = self.psdec4(dec_dict["level4"][0])
        convdec4 = self.convdec4(dec_dict["level4"][0])
        x = torch.cat([x, ps3, psdec4, convdec4], 1)
        ps4 = self.ps4(x)
        
        # Fourth layer k4 (final one before average pooling)
        x = self.block4(x)
        x = torch.cat([x, ps4], 1)
        
        # Global average pooling
        x = F.adaptive_avg_pool2d(x, (1, 1))
        
        x = self.fc(x)
        
        # returning the prediction
        output = F.log_softmax(x, dim=1)
        return output