## Build VAE model
**In this notebook we will build a Bayesian Autoencoder for sinusoid signal reconstruction.**

<img src="misc/SpeechReconstruction.png" alt="Drawing" style="width: 600px;"/>
<img src="misc/vae_vis2.png" alt="Drawing" style="width: 900px;"/>

In [None]:
from typing import NamedTuple

import mxnet as mx
import numpy as np
from mxnet import gluon
from mxnet.gluon import rnn, nn
from mxnet import ndarray as nd

### Encoder implementation. 
We will implement encoder as a simple 1 layer [GRU](https://mxnet.incubator.apache.org/api/python/gluon/rnn.html#mxnet.gluon.rnn.GRU) with the number of features in the hidden state h equal to 10. We will treat last GRU state as a sinusoid signall encoding. So Encoder lattent space will  10 dimensions.
<img src="misc/GRU_image.png" alt="Drawing" style="width: 600px;"/>
<img src="misc/encoder_vis.png" alt="Drawing" style="width: 800px;"/>

In [None]:
class SinEncoder(gluon.Block):
    """
    Sinus GRU encoder. Last GRU state is used as an encoded representation of the sinus signal.
    """
    def __init__(self, **kwargs):
        super(SinEncoder, self).__init__(**kwargs)

        with self.name_scope():
            self.rnn = rnn.GRU(hidden_size=10)

    def forward(self, signal: nd.ndarray):
        """
        Args:
            signal: Sin signal (m, signal_lengths), m - num of signals (batch_size)

        Returns: (last_gru_layer_output:Array(m,s) where m - batch size, s - dimensionality of the GRU latent space

        """
        output = nd.transpose(signal, axes=(1, 0))  # (length, m)
        output = output.expand_dims(axis=2)  # (length, m, 1)

        initial_hidden_state = self.rnn.begin_state(func=nd.zeros, batch_size=signal.shape[0])
        output, latent_space = self.rnn(output, initial_hidden_state)

        output = output[-1, :]  # (m,hidden_size)

        return output

We can now visualize our Encoder structure. It should encode any sequence to 10 dimensional space. Lets check if it is true

In [None]:
SinEncoder()

You might be confused with `TNC`. It is the format of input and output tensors. T, N and C stand for sequence length, batch size, and feature dimensions respectively.

### Lets feedforward some data through the encoder

In [None]:
encoder = SinEncoder()
encoder.initialize()
encoder(nd.ones((4, 50)))

### Decoder implementation
Decoder provide functionality to reconstruct signal with respect to parametrization sample drawn from isotropic Gaussian latent space.

We will implement decoder as a single layer [GRU](https://mxnet.incubator.apache.org/api/python/gluon/rnn.html#mxnet.gluon.rnn.GRU), followed with projection layer implemented as Fully Connected DNN with linear activation function. Fixed dimensional space will be decoded into sequence of signal frames.

We will do that in the loop. During inference in each step we will predict 5 frames of the reconstructed signal. This signal frames will be concatenated with sample drawn from Gaussian latent space and feed to the GRU as an input in the following step.

During the training we will use teacher-forcing, it will help our decoder converge quicker. It means that sometimes instead of feeding predicted frames of the reconstructed signal, we will select reference signal frames from oracle signal. Whether to take predicted or original frames will be decided with constant probability.

Decoder will implement following flow:
<img src="misc/decoder_vis2.png" alt="Drawing" style="height: 400px;"/>

In [None]:
class StatefulGRUCell:
    """
    Wrapper over GRUCell caching hidden states.
    """
    def __init__(self, gru_cell: rnn.GRUCell):
        """
        Args:
            gru_cell: GRUCell which will be used to run forward pass.
            **kwargs:
        """
        self._gru_cell = gru_cell
        self._state = None

    def __call__(self, input_: nd.NDArray):
        """
        Run forward pass on GRUCell.
        First forward call will be initialized with GRUCell.begin_state, all further calls
        will use the latest state returned by GRUCell.
        Args:
            input_: input tensor with shape (batch_size, input_size).
        """
        if not self._state:
            self._state = self._gru_cell.begin_state(input_.shape[0])
        output, new_state = self._gru_cell(input_, self._state)
        self._state = new_state

        return output

In [None]:
class SinDecoder(gluon.Block):

    def __init__(self, samples_per_step: int = 5, **kwargs):
        """
        Args:
            samples_per_step: How many samples to generate at each decoding step.
            **kwargs:
        """
        super(SinDecoder, self).__init__(**kwargs)

        self._samples_per_step = samples_per_step

        with self.name_scope():
            self._decoder_rnn = rnn.GRUCell(hidden_size=20)
            self._projection_layer = nn.Dense(samples_per_step, flatten=False)

    def forward(self, gaussian_latent_space_sample: nd.NDArray,
                length: int, target_signal: nd.NDArray,
                teacher_forcing_prob: float):
        """
        Args:
            gaussian_latent_space_sample: (batch_size, 1)
            length: The total number of samples to generate
            target_signal: True sinus signal that is used during teacher-forcing training.
            teacher_forcing_prob: The probability of using the teacher forcing.

        Returns:
            Decoder sinus signal: (batch_size, length)
        """

        batch_size = gaussian_latent_space_sample.shape[0]
        stateful_rnn = StatefulGRUCell(self._decoder_rnn)
        predicted_frame = mx.nd.zeros(shape=(batch_size, self._samples_per_step))
        reconstructed_signal = [] 
        for i in range(length // self._samples_per_step):
            if target_signal is not None:
                predicted_frame = self._select_signal(predicted_frame, target_signal,
                                                      teacher_forcing_prob, i)  # teacher forcing
            rnn_input = nd.concat(gaussian_latent_space_sample,
                                  predicted_frame, dim=1)  # (batch_size, samples_per_step + 1)
            rnn_output = stateful_rnn(rnn_input)  # (batch_size, rnn_hidden_size)
            predicted_frame = self._projection_layer(rnn_output)  # (batch_size, samples_per_step)
            reconstructed_signal.append(predicted_frame)

        return nd.concat(*reconstructed_signal, dim=1)  # (batch_size, floor(length/samples_per_step) * samples_per_step)

    def _select_signal(self,
                       predicted_signal: nd.NDArray,
                       oracle_signal: nd.NDArray,
                       teacher_forcing_prob: float, 
                       step: int):
        """
        Base on teacher_forcing_prob select either predicted_signal or aligned signal from
        Oracle signal.
        Args:
            predicted_signal: (batch_size, _samples_per_step)
            oracle_signal: (batch_size, #total number of samples)
            teacher_forcing_prob: probability that oracle signal will be selected
            step: step used to align between oracle and predicted signal
        Returns:
            signal for forward pass
        """
        if step > 0 and np.random.rand() < teacher_forcing_prob:
            start_sample = (step - 1) * self._samples_per_step
            stop_sample = start_sample + self._samples_per_step
            teacher_forcing_signal = oracle_signal[:, start_sample:stop_sample]
            return teacher_forcing_signal
        else:
            return predicted_signal

Time to visualize the decoder we have created. It should consist of two steps, one is decoder rnn which takes lattent variable and previously predicted sequence frames. And the other one is projection_layer which will predict next 5 frames from decoder_rnn

In [None]:
SinDecoder()

In [None]:
decoder = SinDecoder()
decoder.initialize()
decoder(nd.ones((2, 1)), 50, None, None)

###  Put all the pices toogether into Gaussian Variational Autoencoder

Now when both encoder and decoder are already defined, lets put all the pices toogether adding final layer responsible for Gaussian lattent space embedding

To compute Gaussian lattent space we will use Fully Connected layer with 2 outputs, representing variance and mean of isotropic Gaussian. These values are predicted from Encoder embeddings.

Additionally we will implement functionality to override this predicted value during inference.

Apart of this we will also add to our model `calc_loss` function which is responsible for computation of L2 and negative KLD loss.

In [None]:
class SinBayesianAutoEncoder(gluon.Block):

    def __init__(self, n_latent=2, **kwargs):
        super(SinBayesianAutoEncoder, self).__init__(**kwargs)

        with self.name_scope():
            self.encoder = SinEncoder()
            self.latent_space = nn.Dense(n_latent * 2)  # mean and log variance of the latent space
            self.decoder = SinDecoder(samples_per_step=5)  # predict 5 samples at once
            self.l2loss = gluon.loss.L2Loss()

    def forward(self, signal: nd.NDArray, teacher_forcing_prob: float,
                latent_space_override: nd.NDArray = None):
        """

        Args:
            signal: Sin signal (m, signal_length), m - num of signals (batch_size)
            teacher_forcing_prob: The probability of activating the teacher forcing
            latent_space_override: The override value for the latent space.

        Returns:

        """
        sig_embedding = self.encoder(signal)  # (m,s), s - dim of the encoder embedding

        # Posterior of the latent space
        # Gaussian variance must be positive, therefore using log variance parametrization
        ls_mean, ls_log_var = self.latent_space(sig_embedding).split(axis=1, num_outputs=2)
        ls_std = nd.exp(ls_log_var * 0.5, axis=0)

        # Sampling from the unit gaussian instead of sampling from the latent space posterior
        # allow for gradient flow via latent_space_mean / latent_space_log_var parameters
        # z = (x-mu)/std, thus: x = mu + z*std
        normal_sample = nd.random_normal(0, 1, shape=ls_mean.shape)
        ls_val = ls_mean + ls_std * normal_sample

        if isinstance(latent_space_override, nd.NDArray):
            ls_val = latent_space_override

        length = signal.shape[1]
        reconstructed_sig = self.decoder(ls_val, length, signal, teacher_forcing_prob)  # (m,length)

        return SinBAEOutput(ls_mean, ls_log_var, ls_val, reconstructed_sig)

    def calc_loss(self, signal: nd.NDArray, teacher_forcing_prob: float) -> (float, float):
        """
        Compute gradients of the loss function with respect of the model parameters.

        Args:
            signal: Sin signal: (m, signal_length), m - num of signals (batch_size)
            
        Returns: L2 Loss between input and decoded signals, KLD loss

        """

        decoded_signal_output = self(signal, teacher_forcing_prob)

        latent_space_mean = decoded_signal_output.latent_space_mean
        latent_space_log_var = decoded_signal_output.latent_space_log_var

        l2_loss = self.l2loss(signal, decoded_signal_output.decoded_signal)
        negative_kld = 0.5 * nd.sum(
            1 + latent_space_log_var - latent_space_mean ** 2 - nd.exp(latent_space_log_var), axis=1)

        return l2_loss, -negative_kld


class SinBAEOutput(NamedTuple):
    """
    Args:
        latent_space_mean: array(m,1), The mean of the posterior of the latent space
        latent_space_log_var: array(m,1), Variance of the posterior of the latent space
        latent_space_val: array(m,1), Value sampled from the posterior of the latent space
        decoded_signal: array(m,signal_length)
    """
    latent_space_mean: nd.NDArray
    latent_space_log_var: nd.NDArray
    latent_space_val: nd.NDArray
    decoded_signal: nd.NDArray

### Create actual model
Now when all building blocks are in place, let's create SinBayesianAutoEncoder instance and visualize it

In [None]:
SinBayesianAutoEncoder()

In [None]:
vae = SinBayesianAutoEncoder()
vae.initialize()
vae(nd.ones((4, 50)), 0.6, None)