## Traine Varational Auto Encoder Model
In this notebook we will train variational auto encoder model, which we have created in the previous notebook. We will train model on data generated with signal generator from our first notebook

In [None]:
%matplotlib inline  

from typing import List
import time
import logging
import sys

import numpy as np
import mxnet as mx
import mxnet.ndarray as nd
from mxnet import gluon, autograd
from mxnet.gluon.data import DataLoader
import matplotlib.pyplot as plt

from mlss_gdansk2019.SignalGenerator import SignalGenerator
from mlss_gdansk2019.SinBayesianAutoEncoderModel import SinBayesianAutoEncoder

In [None]:
# We want to have repetable results
np.random.seed(54545)

# Logging settings
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)

### Create Trainer class
Trainer will be responsible for running forward and bacward pass over the model we have created. It will use adam optimizer to update weights of our model. Trainer will compute both train and test loss of both KLD and L2 loss.

**DataLoaders** are used to create mini-batches of samples from a Dataset, and provides a convenient iterator interface for looping these batches. It’s typically much more efficient to pass a mini-batch of data through a neural network than a single sample at a time, because the computation can be performed in parallel. A required parameter of DataLoader is the size of the mini-batches you want to create, called batch_size.

In [None]:
class SinBayesianAutoEncoderTrainer:
    """
    Train SinBayesianAutoEncoder model.
    """

    def __init__(self,
                 train_dataloader: DataLoader,
                 test_dataloader: DataLoader,
                 model: SinBayesianAutoEncoder,
                 kld_weight: float = 0.01,
                 optimizer: str = "adam",
                 learning_rate: float = 0.001):
        self._train_dataloader = train_dataloader
        self._test_dataloader = test_dataloader
        self._model = model
        self._kld_weight = kld_weight
        self._trainer = gluon.Trainer(self._model.collect_params(),
                                      optimizer, {'learning_rate': learning_rate})
        self.logger = logging.getLogger()
        self.logger.setLevel(logging.DEBUG)

    def train(self, n_epoch: int, decay_multiplier: int = 0.01) -> (List[float],
                                                                    List[float],
                                                                    List[float]):
        """
        Train the model for given number of epochs updating its weights,
        and exponentially decaying teacher forcing probability
        Args:
            n_epoch: Number of epochs for which model will be trained.
            decay_multiplier: multiplier of epoch in exponential decay of
                              teacher forcing.

        Returns: test_l2_loss_list, test_kld_list, teacher_forcing_prob_list
        """
        start = time.time()

        test_l2_loss_list = []
        test_kld_list = []
        teacher_forcing_prob_list = []
        for epoch in range(n_epoch):
            teacher_forcing_prob = np.exp(-decay_multiplier * epoch)

            epoch_l2_train_loss, epoch_kld_train_loss = self._forward_backward(teacher_forcing_prob)
            epoch_l2_test_loss, epoch_kld_test_loss = self._calc_test_accuracy(teacher_forcing_prob)

            epoch_total_train_loss = epoch_l2_train_loss + epoch_kld_train_loss
            epoch_total_test_loss = epoch_l2_test_loss + epoch_kld_test_loss
            self.logger.info(
                'Epoch %d, teacher_forcing_prob=%.2f, training loss (total/l2/kld): '
                '%.4f, %.4f, %.4f, test loss (total/l2/kld): %.4f, %.4f, %.4f'
                % (epoch,
                   teacher_forcing_prob,
                   epoch_total_train_loss, epoch_l2_train_loss, epoch_kld_train_loss,
                   epoch_total_test_loss, epoch_l2_test_loss, epoch_kld_test_loss))

            test_l2_loss_list.append(epoch_l2_test_loss)
            test_kld_list.append(epoch_kld_test_loss)
            teacher_forcing_prob_list.append(teacher_forcing_prob)

        end = time.time()
        self.logger.info('Time elapsed: {:.2f}s'.format(end - start))

        return test_l2_loss_list, test_kld_list, teacher_forcing_prob_list

    def _forward_backward(self, teacher_forcing_prob: float) -> (float, float):
        """
        Run single epoch forward and backward path with update of parameters on the model.
        Args:
            teacher_forcing_prob: probability of selecting oracle signal
        Returns:
            l2 loss, kld loss
        """
        epoch_l2_loss = 0
        epoch_kld_loss = 0

        for signal_batch in self._train_dataloader:
            with autograd.record():
                l2_loss, kld_loss = self._model.calc_loss(signal_batch, teacher_forcing_prob)
                loss = l2_loss + self._kld_weight * kld_loss  # weighting the kld loss

            loss.backward()
            self._trainer.step(signal_batch.shape[0])

            epoch_l2_loss += nd.mean(l2_loss).asscalar()
            epoch_kld_loss += nd.mean(kld_loss).asscalar()

        epoch_l2_loss /= len(self._train_dataloader)
        epoch_kld_loss /= len(self._train_dataloader)

        return epoch_l2_loss, epoch_kld_loss

    def _calc_test_accuracy(self, teacher_forcing_prob: float) -> (float, float):
        """
        Computes test set single epoch loss.
        Args:
            teacher_forcing_prob: probability of selecting oracle signal
        Returns:
            l2 loss, kld loss
        """
        epoch_l2_loss = 0
        epoch_kld_loss = 0

        for signal_batch in self._test_dataloader:
            l2_loss, kld_loss = self._model.calc_loss(signal_batch, teacher_forcing_prob)
            kld_loss = self._kld_weight * kld_loss

            epoch_l2_loss += nd.mean(l2_loss).asscalar()
            epoch_kld_loss += nd.mean(kld_loss).asscalar()

        epoch_l2_loss /= len(self._test_dataloader)
        epoch_kld_loss /= len(self._test_dataloader)

        return epoch_l2_loss, epoch_kld_loss

###  Almost there. Now we can train the model we have  created earlier

In [None]:
with mx.Context(mx.cpu()):
    # Generate n sin functions (both train and test sets)
    signal_generator = SignalGenerator()
    multi, phase, train_signals = signal_generator.generate_signals(50, 500)
    _, _, test_signals = signal_generator.generate_signals(50, 50)
    signal_generator.plot_signals_with_multipliers(multi, phase, train_signals)
    
    # Create sin auto encoder model
    sin_bae = SinBayesianAutoEncoder()
    sin_bae.initialize(mx.init.Xavier())

    # Train the model and save to disk
    train_data_loader = DataLoader(nd.array(train_signals), batch_size=20, shuffle=True)
    test_data_loader = DataLoader(nd.array(test_signals), batch_size=20, shuffle=False)
    sin_bae_trainer = SinBayesianAutoEncoderTrainer(train_data_loader,
                                                   test_data_loader,
                                                   sin_bae)
    test_l2_loss_list, test_kld_list, teacher_forcing_prob_list = sin_bae_trainer.train(n_epoch=1000)
    sin_bae.save_parameters("model/model_train.params")


### Visualize the training loss
Notice how kld and l2 loss behaves. Do you know what might have caused that spike?

In [None]:
# Visualize the training loss
total_loss = list(np.array(test_kld_list) + np.array(test_l2_loss_list))
ax = plt.subplot()
ax.plot(np.arange(len(test_l2_loss_list)), test_l2_loss_list, label="l2 loss")
ax.plot(np.arange(len(test_kld_list)), test_kld_list, label="kld loss")
ax.plot(np.arange(len(total_loss)), total_loss, label="total loss")
ax.set_xlabel("iter", fontsize=14)
ax.set_ylabel("loss", fontsize=14)
ax.legend()