## Introduction
#### Variational Autoencoder for Deep Learning of Images, Labels and Captions

This notebook aims to document the process of implementing [this paper](https://proceedings.neurips.cc/paper/2016/file/eb86d510361fc23b59f18c1bc9802cc6-Paper.pdf) from scratch.
I used [this article](https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73) to help me get a better intution of what varitional autoencoders do. I recently learnt about autoencoders from my dimensionality reduction class in CSC311 offered by UofT, however they just gave us a very brief introduction to them. Nevertheless, the intutions of the varitional autoencoder makes sense to me. 

#### Convolution Review

I was quite rusty with how traditional convolutions networks worked, and so I had to do a quick review. [This article](https://towardsdatascience.com/a-comprehensive-guide-to-convolutional-neural-networks-the-eli5-way-3bd2b1164a53) was a great help for getting me to re-understand it much quicker. 

Diving deep into "convolution" and "cross correlation" lead into a lot of signal processing which I am unfamiliar with. I feel that all I need to know is that convolution performs cross correlation with flipped kernels. Pytorch performs cross correlation, so I won't have to flip kernels here. 

In [1]:
import torch.nn as nn 
import torch.nn.functional as F
import torch
import numpy as np
from scipy import ndimage
import itertools
import torch.nn.init as init
import math
from torch.distributions.multivariate_normal import MultivariateNormal

In [2]:
a_test = np.array([[[1, 2, 0, 0],
               [5, 3, 0, 4],
               [0, 0, 0, 7],
               [9, 3, 0, 0]], 
              
              [[7, 2, 2, 0],
               [4, 3, 0, 1],
               [0, 5, 0, 0],
               [0, 2, 1, 1]]])

k_test = np.array([[[1,1,1],[1,1,0],[1,0,0]], [[1,1,1],[1,0,1],[1,1,1]]])

In [3]:
ndimage.correlate(a_test[0, ...], k_test[0, ...], mode='constant', cval=0) + \
ndimage.correlate(a_test[1, ...], k_test[1, ...], mode='constant', cval=0)

array([[10, 24, 11,  3],
       [25, 31, 18,  6],
       [22, 27, 23, 14],
       [16, 18, 18,  8]])

In [4]:
tensor_a = torch.FloatTensor(a_test[np.newaxis, ...])
tensor_k = torch.FloatTensor(k_test[np.newaxis, ...])

In [7]:
F.conv2d(tensor_a, tensor_k, padding=1)

tensor([[[[10., 24., 11.,  3.],
          [25., 31., 18.,  6.],
          [22., 27., 23., 14.],
          [16., 18., 18.,  8.]]]])

### Image decoder: DDGM

First I have to create the operation defined in (1) and (3).  

For the convolution, it wasn't clear to me what the padding and stride were, so I assumed that we stick to CNN conventions and make the output shape the same as the input shape. 

In [8]:
def DDGM_convolve(D, S):
    """
    Perform sum_(kl) D^(kl, l) * S^(n, kl, l) as described in the paper. 
    D: 4D tensor with shape (KL-1, KL, kW, kH)
        KL-1: Number of "slices" in the previous layer. I.e the out channel. 
        KL: Number of "slices" in the current layer. I.e the in channel.
        kW, kH: kernel width and kernel height
    
    S: 3D tensor with the shape (KL, cW, cH)
        KL: Number of 2D slices
        iW, iH: code width and code height
    """
    result = None
    kernel_size = D.shape[2]
    for i in range(D.shape[1]):
        D_kl = D[:, i, ...].unsqueeze(1)
        S_kl = S[i, ...].unsqueeze(0).unsqueeze(0)
        current_sum = F.conv2d(S_kl, D_kl, stride=1, padding=(kernel_size-1)//2)
        
        if result is None:
            result = current_sum
        else:
            result += current_sum
        
    return result.squeeze(0)

In [9]:
D = torch.rand((5, 3, 5, 5), requires_grad=True)
S = torch.rand((3, 28, 28))
DDGM_convolve(D, S).shape

torch.Size([5, 28, 28])

Then, I have to create a unpooling layer as defined in (2). I copied an excerpt of it here: 


For the stochastic unpooling, $S^{(n,k1 ,1)}$ is partitioned into contiguous $px × py$ pooling blocks (analogous to pooling blocks in CNN-based activation maps). Let $ z^{(n,k1 ,1)}_{i, j} \in \{0, 1\}^{px py}$ be a vector
of $pxpy − 1$ zeros, and a single one; $z^{(n,k1 ,1)}_{i, j}$ corresponds to pooling block $(i, j)$ in $S^{(n,k1 ,1)}$. 


The location of the non-zero element of $z^{(n,k1 ,1)}_{i, j} $ identifies the location of the single non-zero element $i,j$
in the corresponding pooling block of $S^{(n,k1 ,1)}$. 


The non-zero element in pooling block $(i,j)$ of $S^{(n,k1 ,1)}$ is set to $\tilde{S}_{i, j}^{(n,k1,2)}$, i.e., element $(i,j)$ in slice k1 of $\tilde{S}^{(n,2)}$. 

Within the prior of the decoder, we impose z(n,k1,1) ∼ Mult(1; 1/(pxpy), . . . , 1/(pxpy)). 

Both $\tilde{S}^{(n,2)}$ and $S^{(n,2)}$ are 3D tensors with K1 2D slices; as a result of the unpooling, the 2D slices in the sparse $S^{(n,2)}$ have $pxpy$ times more elements than the corresponding slices in the dense $\tilde{S}^{(n,2)}$.

I tried to illustrate what is I believe is happening below:

<img src="./images/figure1.jpg" width="600">


In [10]:
def unpool(S, prob_vecs):
    """
    Performs stochastic unpooling on S where the location of the non zero element in each pooling block (i, j)
    in layer k is defined by z^k (i, j), and z^k, (i, j) 
    is sampled from a multinomial distribution Mult(1, prob_vecs(k, i, j)).
    
    Initially, each prob_vec = (1/(pxpy), ... 1/(pxpy)), i.e the prior distribution. 
    
    The shape of each prob_vec must be pool_size**2. Here we are assuming that our pooling blocks
    will always be square.
    
    S: 3D tensor with the shape (KL, cW, cH)
    
    prob_vecs: 4D tensor with the shape (KL, cW, cH, pool_size**2):
        pool size: px * py
    """
    
    K, w, h = S.shape
    pool_size = int(math.sqrt(prob_vecs.shape[3])) # We are assuming that the pooling blocks are square.
    
    result = torch.zeros(K, w*pool_size, h*pool_size)
    for k in range(K):
        for block in itertools.product(range(w), range(h)):
            i, j = block
            z = torch.zeros(pool_size**2)
            idz = torch.multinomial(prob_vecs[k, i, j], 1).item()
            z[idz] = S[k, i, j]
            result[k, i*pool_size:(i+1)*pool_size, j*pool_size:(j+1)*pool_size] = z.reshape(-1, pool_size)
    
    return result

With these two functions, we can implement the DDGM. For the data generation layer, I used the trick described in 
[this article](https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73). 

$$X^{(n)} \sim \mathcal{N}(\tilde{S}^{(n, 1)}, \alpha_0^{-1} \textbf{I})$$

can be expressed as: 

$$X^{(n)} = \tilde{S}^{(n, 1)} + \alpha_0^{-1} Z $$

Where $Z \sim \mathcal{N}(\textbf{0}, \textbf{I})$ due to the linearity of gaussian distributions. This allows backprop to "reach" the dictionary layers as we seperate the random process. I believe this is called a 'reparameterization trick'

In [11]:
class DDGMDecoder(nn.Module):
    """
    A 2 layered DDGM decoder, with a stochastic unpooling layer between the two layers,
    and a final data generation layer. 
    """
    def __init__(self, K2, K1, Nc, d2_kernel, d1_kernel, iW, iH, pool_size=3):
        """
        d2 is of shape (K1, K2, kW, kH)
            K1: Number of "slices" in layer 1
            K2: Number of "slices" in layer 2
            kW, kH: kernel width, kernel height
            
        d1 is of shape(Nc, K1, kW, kH):
            Nc: Number of channels of the image (1 for grayscale, 3 for rgb)
            K1: Number of "slices" in layer 1
            kW, kH: kernel width, kernel height
            
        distribution is of shape (iW, iH, pool_size**2):
            iW, iW: input width, input height
        """
        
        super().__init__()
        self._d2 = nn.Parameter(torch.Tensor(K1, K2, d2_kernel, d2_kernel))
        self._d1 = nn.Parameter(torch.Tensor(Nc, K1, d1_kernel, d1_kernel))
        self._pool_size = pool_size
        self._distribution = torch.ones(K1, iW, iH, pool_size**2) # Uniformly distributed, equal to 1/pxpy, 1/pxpy ...
        self._precision = nn.Parameter(torch.rand(1))
        self._reset_parameters()
        
    # Initialization method is taken from the pytorch implementation of CNNs. 
    # https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv2d
    
    def _reset_parameters(self):
        init.kaiming_uniform_(self._d2, a=math.sqrt(5))
        init.kaiming_uniform_(self._d1, a=math.sqrt(5))
        
    def forward(self, x):
        """
        X: code generated from encoder
            Shape: (K2, iW, iH)
        """
        x = DDGM_convolve(self._d2, x)
        x = unpool(x, self._distribution)
        x = DDGM_convolve(self._d1, x)
        
        # Data Generation
        Z = MultivariateNormal(torch.zeros(x.shape[1]*x.shape[2]), torch.eye(x.shape[1]*x.shape[2]))
        slices = []
        for k in range(x.shape[0]):
            mean = x[k, ...].reshape(-1)
            covar = (1/self._precision)*(Z.sample())
            slice_ = (mean + covar).reshape(x.shape[1], -1)
            slices.append(slice_)
        return torch.stack(slices)
    
    def set_distribution(self, distribution):
        """
        Sets the distribution for stochastic unpooling.
        """
        self._distribution = distribution

In [12]:
ddgm = DDGMDecoder(K2 = 6, K1 = 4, Nc = 3, d2_kernel=3, d1_kernel=3, iW=10, iH=10)
# ddgm(torch.rand(6, 10, 10))

###  Image Encoder: Deep CNN

For the encoder, we can employ pytorch's conv2d modules as they perform the convolution that we need, and thus we do not need to define our own convolution method as we did above. However, we do need to create our own pooling function. 

In [13]:
def pool(C, prob_vecs):
    """
    Performs stochastic pooling on C, where the location of the element chosen in each pooling block (i, j)
    in layer kis determined by z^k (i, j), and z(i, j) is sampled from a 
    multinomial distribution(1, prob_vecs[k, i, j]), and prob_vecs[k, i, j] is derived from 
    MLP(C^k (i, j)). 
    
    
    C is of size (Kl, iW, iH):
        KL: Number of 2D "slices". 
        iW, iH: image width, image height. 
    
    
    """
    
    K, w, h = C.shape
    pool_size = int(math.sqrt(prob_vecs.shape[-1]))
    w_new, h_new = w//pool_size, h//pool_size
    result = torch.zeros(K, w_new, h_new)
    for k in range(K):
        for block in itertools.product(range(w_new), range(h_new)):
            i, j = block
            
            idz = torch.multinomial(prob_vecs[k, i, j], 1).item()
            C_block = C[k, i*pool_size:(i+1)*pool_size, j*pool_size:(j+1)*pool_size].reshape(-1)
#             prob_vec = F.softmax(mlp(C_block), dim=-1)
#             idz = torch.multinomial(prob_vec, 1).item()
            
            result[k, i, j] = C_block[idz]
    return result
        

In [14]:
class CNNEncoder(nn.Module):
    """
    A 2 layered CNN encoder, with a stochasting pooling layer between the two layers,
    and a final code generation layer.
    """
    
    def __init__(self, Nc, K1, K2, f1_kernel, f2_kernel, iW, iH, pool_size=3):
        """
        Parameters:
        K1: Number of "slices" in layer 1
        K2: Number of "slices" in layer 2
        Nc: Number of channels of the image (1 for grayscale, 3 for rgb)
        iW, iW: input width, input height
        f1_kernel and f2_kernel: Size of the kernels for each of the filter banks.
        """
        
        super().__init__()
        self.layer1 = nn.Conv2d(Nc, K1, kernel_size=f1_kernel, padding=(f1_kernel-1)//2, bias=False)
        self.layer2 = nn.Conv2d(K1, K2, kernel_size=f2_kernel, padding=(f2_kernel-1)//2, bias=False)
        
        # MLP for distribution of stochastic pooling
        self.mlp_pool = nn.Sequential(
            nn.Linear(pool_size**2, 20), # 20 is randomly chosen here.
            nn.Tanh(), 
            nn.Linear(20, pool_size**2)
        )
        
        # MLP for code generation mean 
        self.mlp_mean = nn.Sequential(
            nn.Linear((iW//pool_size) * (iH//pool_size), 20),
            nn.Tanh(), 
            nn.Linear(20, (iW//pool_size) * (iH//pool_size))
        )
        
        # MLP for code generation covariance 
        self.mlp_covar = nn.Sequential(
            nn.Linear((iW//pool_size) * (iH//pool_size), 20),
            nn.Tanh(), 
            nn.Linear(20, (iW//pool_size) * (iH//pool_size))
        )
        
        self._pool_size = pool_size
        
    def get_distribution(self, C):
        """
        Gets all the distributions for pooling, i.e the vector z. 
        This distribution will be reused as the posterior for the decoder's unpooling as well.
        """
        K, w, h = C.shape
        pool_size = self._pool_size
        w_new, h_new = w // pool_size, h // pool_size
      
        result = torch.zeros(K, w_new, h_new, self._pool_size**2)
        for k in range(K):
            for block in itertools.product(range(w_new), range(h_new)):
                i, j = block
                C_block = C[k, i*pool_size:(i+1)*pool_size, j*pool_size:(j+1)*pool_size].reshape(-1)
                prob_vec = F.softmax(self.mlp_pool(C_block), dim=-1)
                result[k, i, j, :] = prob_vec
                
        return result

        
    def forward(self, x):
        """
        x: Input image
            Shape: (Nc, iW, iH)
        """
        
        C1 = self.layer1(x.unsqueeze(0))
        distribution = self.get_distribution(C1.squeeze(0))
        
        C1 = pool(C1.squeeze(0), distribution)
        C2 = self.layer2(C1.unsqueeze(0))
        C2 = C2.squeeze(0)
        
        # Final Code generation
        slices = []
        Z = MultivariateNormal(torch.zeros(C2.shape[1]*C2.shape[2]), torch.eye(C2.shape[1]*C2.shape[2]))
        for k in range(C2.shape[0]):
            mean = self.mlp_mean(C2[k, ...].reshape(-1))
            covar = torch.diag(self.mlp_covar(C2[k, ...].reshape(-1))) @ Z.sample()
            slice_ = (mean + covar).reshape(C2.shape[1], -1)
            slices.append(slice_)
    
        return torch.stack(slices), distribution # (s, z)

        

In [15]:
encoder = CNNEncoder(3, 4, 6, 3, 3, 32, 32)
s, z = encoder(torch.rand(3, 32, 32))
s.shape, z.shape

(torch.Size([6, 10, 10]), torch.Size([4, 10, 10, 9]))

### Putting the VAE together

Finally, we can piece the encoder and decoder together to form the VAE.

In [16]:
class VAE(nn.Module):
    
    def __init__(self, Nc, K1, K2, f1, f2, d2, d1, iW, iH, pool_size=3):
        """
        Nc: Number of channels of the image. 
        K1: Number of 2D "slices" in layer 1. This is shared for the encoder and decoder.
        K2: Number of 2D "slices" in layer 2. This is shared for the encoder and decoder. 
        f1, f2: Kernel size for the first and second filter bank respectively
        d2, d1: Kernel size for the second and first dictionary respectively. 
        iW, iH: width and height of the image. 
        pool_size: The size of the pooling block. We are assuming that pooling blocks are square.
        """
        
        super().__init__()
        self.encoder = CNNEncoder(Nc, K1, K2, f1, f2, iW, iH, pool_size)
        cW, cH = iW // pool_size, iH // pool_size # code width, code height
        self.decoder = DDGMDecoder(K2, K1, Nc, d2, d1, cW, cH, pool_size)
        
    def forward(self, x):
        """
        x: Input image
            Shape: (Nc, iW, iH)
        """
        
        s, z = self.encoder(x)
        self.decoder.set_distribution(z)
        return self.decoder(s)

### Caption Generator: RNN

I created a simple figure to translate section 3.2 into a visual model architecture to help implement the model.

<img src="./images/model_arch.jpg" width=600>

Therefore, we need 3 layers to perform caption generation:

- 2 layered MLP with tanh and softmax activation
    - Generates first word from $s^{(n)}$
    - Last layer converts hidden state $h_t^{(n)}$ into one hot word vector $y_t^{(n)}$
    
- Embedding layer
    - Converts one hot word vector $y_t^{(n)}$ into word representation $w_t^{(n)}$
    
- RNN (Can be LTSU or GRU)
    - Recursively generates other words until the stop symbol is generated.
    


For the implementation below, I opted to use the indices instead of a one hot vector since the operations support indices better. 
    

In [17]:
class CaptionRNN(nn.Module):
    
    CAPTION_LIMIT = 30
    
    def __init__(self, V, M, H, C, stop_index):
        """
        V: Size of the vocabulary
        M: Size of embedded vector
        H: Number of features for hidden state h
        C: Flattened size of input code s
        stop_index: The index that represents the end of a sentence.
        """
        
        super().__init__()
        
        self.mlp_l1 = nn.Linear(in_features=C, out_features=H)
        self.mlp_l2 = nn.Linear(in_features=H, out_features=V)
        
        self.gru = nn.GRU(input_size=M, hidden_size=H, batch_first=True)
        self.embedding = nn.Embedding(V, M)
        self.stop_index = stop_index
        
        self.V = V
    
    def forward(self, x, limit=None):
        """
        x: Unflattened code s.
        Limit denotes the maximum number of words we should generate. 
            Default value can be set by changing CaptionRNN.CAPTION_LIMIT. 
        """
        
        h1 = F.tanh(self.mlp_l1(x.reshape(-1)))
        y1 = torch.multinomial(F.softmax(self.mlp_l2(h1), dim=-1), 1)
        
        words = [y1]
        ht = h1
        wt = self.embedding(y1)
        yt = -1
        
        # Prevents the RNN from endlessly creating words
        if not limit: 
            limit = CaptionRNN.CAPTION_LIMIT
            
        while len(words) < limit and not yt == self.stop_index:
            # Output and hidden are the same in this case, so we just get the output. 
            
            ht = self.gru(wt.unsqueeze(0), ht.unsqueeze(0).unsqueeze(0))[0]
            ht = ht.squeeze(0).squeeze(0)
            yt = torch.multinomial(F.softmax(self.mlp_l2(ht), dim=-1), 1)
            words.append(yt)
            wt = self.embedding(yt)
    
        return words


In [18]:
test = CaptionRNN(V=10, M=20, H=20, C=600, stop_index=None)
# test(torch.rand((6, 10, 10)))

## Complete model

Finally, we put the VAE and CaptionRNN together to create the complete model.

In [19]:
class VAECaption(nn.Module):
    
    def __init__(self, Nc, K1, K2, f1, f2, d2, d1, iW, iH, V, M, H, stop_index, pool_size=3):
        """
        Nc: Number of channels of the image. 
        K1: Number of 2D "slices" in layer 1. This is shared for the encoder and decoder.
        K2: Number of 2D "slices" in layer 2. This is shared for the encoder and decoder. 
        f1, f2: Kernel size for the first and second filter bank respectively
        d2, d1: Kernel size for the second and first dictionary respectively. 
        iW, iH: width and height of the image. 
        V: Size of the vocabulary
        M: Size of embedded vector
        H: Number of features for hidden state h
        stop_index: The index that indicates end of sentence.
        
        pool_size: The size of the pooling block. We are assuming that pooling blocks are square.
        """
        super().__init__()
        
        self.encoder = CNNEncoder(Nc, K1, K2, f1, f2, iW, iH, pool_size)
        cW, cH = iW // pool_size, iH // pool_size # code width, code height
        self.decoder = DDGMDecoder(K2, K1, Nc, d2, d1, cW, cH, pool_size)
        self.captioner = CaptionRNN(V, M, H, K2*cW*cH, stop_index)
        
    def forward(self, x, limit=None):
        """
        x: Input image
            Shape: (Nc, iW, iH)
        Limit denotes the maximum number of words we should generate. 
            Default value can be set by changing CaptionRNN.CAPTION_LIMIT. 
        """
        
        s, z = self.encoder(x)
        self.decoder.set_distribution(z)
        x_reconstructed = self.decoder(s)
        caption = self.captioner(s)
        return x_reconstructed, caption

In [21]:
model = VAECaption(Nc=3, K1 = 4, K2 = 6, f1=3, f2=3, d2=3, d1=3, iW=32, iH=32, V=20, M=30, H=40, stop_index=None)
model(torch.rand((3, 32, 32)))



(tensor([[[ 0.4807,  0.0745, -0.3269,  ...,  2.0071, -1.4856,  1.0216],
          [-2.9287,  0.2027, -1.5701,  ...,  1.6463,  2.6577, -0.3641],
          [-2.4491, -1.9397,  0.2091,  ..., -0.4314,  1.2842,  1.8869],
          ...,
          [ 2.2065, -1.4675, -1.0623,  ...,  0.2752, -0.0412,  2.9423],
          [ 0.2044,  1.6219, -0.3174,  ...,  0.1843,  2.2494, -0.5284],
          [ 1.8016,  0.6920,  2.4196,  ..., -1.5634,  0.5712, -0.3255]],
 
         [[ 3.3454, -0.8931, -1.9007,  ..., -1.3881,  1.9869,  0.4065],
          [-0.5608, -1.1998,  1.6078,  ..., -1.5067, -0.4759, -1.3706],
          [ 1.1602,  1.8228, -2.3619,  ..., -0.0963,  0.8125,  0.9036],
          ...,
          [ 0.3968,  0.5434, -1.6605,  ..., -0.6974,  1.6825,  2.9489],
          [-1.1734,  1.4167,  2.9717,  ...,  1.1532, -0.3852, -0.8545],
          [-1.6451,  0.0143, -2.3994,  ...,  1.0019, -1.4454, -0.6245]],
 
         [[ 0.4202, -0.7991,  1.7699,  ...,  2.0978,  0.6428,  0.0891],
          [ 0.7188, -0.1137,