# Seminar for Lecture 13 "VAE Vocoder"


In the lectures, we studied various approaches to creating vocoders. The problem of sound generation is solved by deep generative models. We've discussed autoregressive models that can be reduced to **MAF**. We've considered the reverse analogue of MAF – **IAF**. We've seen how **normalizing flows** can help us directly optimize likelihood without using autoregression. And alse we've considered a vocoder built with the **GAN** paradigm.

At this seminar we will try to apply another popular generative model: the **variational autoencoder (VAE)**. We will try to build an encoder-decoder architecture with **MAF** as encoder and **IAF** as decoder. We will train this network by maximizing ELBO with a couple of additional losses (in vocoders, you can't do without it yet 🤷‍♂️).

⚠️ In this seminar we call **"MAF"** not the generative model discussed on lecture, but network which architecture is like MAF's one and accepting audio as input. So we won't model data distribution with our **"MAF"**.

In [None]:
# ! pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2  -f https://download.pytorch.org/whl/torch_stable.html
# ! pip install numpy==1.17.5 matplotlib==3.3.3 tqdm==4.54.0

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from typing import Union
from math import log, pi, sqrt
from IPython.display import display, Audio
import numpy as np

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

device = torch.device("cpu")
if torch.cuda.is_available():
    print('GPU found! 🎉')
    device = torch.device("cuda")

Introduce auxiliary modules:
1. causal convolution – simple convolution with `kernel_size` and `dilation` hyper-parameters, but working in causal way (does not look in the future)
2. residual block – main building component of WaveNet architecture

Yes, WaveNet is everywhere. We can build MAF and IAF with any architecture, but WaveNet declared oneself as simple yet powerfull architecture. We will use WaveNet with conditioning on mel spectrograms, because we are building a vocoder.

In [None]:
class CausalConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1):
        super(CausalConv, self).__init__()

        self.padding = dilation * (kernel_size - 1)
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            padding=self.padding,
            dilation=dilation)
        self.conv = nn.utils.weight_norm(self.conv)
        nn.init.kaiming_normal_(self.conv.weight)

    def forward(self, x):
        x = self.conv(x)
        x = x[:, :, :-self.padding]
        return x


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, skip_channels, kernel_size, dilation, cin_channels):
        super(ResBlock, self).__init__()
        self.cin_channels = cin_channels

        self.filter_conv = CausalConv(in_channels, out_channels, kernel_size, dilation)
        self.gate_conv = CausalConv(in_channels, out_channels, kernel_size, dilation)
        self.res_conv = nn.Conv1d(out_channels, in_channels, kernel_size=1)
        self.skip_conv = nn.Conv1d(out_channels, skip_channels, kernel_size=1)
        self.res_conv = nn.utils.weight_norm(self.res_conv)
        self.skip_conv = nn.utils.weight_norm(self.skip_conv)
        nn.init.kaiming_normal_(self.res_conv.weight)
        nn.init.kaiming_normal_(self.skip_conv.weight)

        self.filter_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1)
        self.gate_conv_c = nn.Conv1d(cin_channels, out_channels, kernel_size=1)
        self.filter_conv_c = nn.utils.weight_norm(self.filter_conv_c)
        self.gate_conv_c = nn.utils.weight_norm(self.gate_conv_c)
        nn.init.kaiming_normal_(self.filter_conv_c.weight)
        nn.init.kaiming_normal_(self.gate_conv_c.weight)

    def forward(self, x, c=None):
        h_filter = self.filter_conv(x)
        h_gate = self.gate_conv(x)
        h_filter += self.filter_conv_c(c)
        h_gate += self.gate_conv_c(c)
        out = torch.tanh(h_filter) * torch.sigmoid(h_gate)
        res = self.res_conv(out)
        skip = self.skip_conv(out)
        return (x + res) * sqrt(0.5), skip

For WaveNet it doesn't matter what it is used for: MAF or IAF - it all depends on our interpretation of the input and output variables.

Below is the WaveNet architecture that you are already familiar with from the last seminar. But this time, you will need to implement not inference but forward pass - and it's very simple 😉.

In [None]:
class WaveNet(nn.Module):
    def __init__(self, params):
        super(WaveNet, self). __init__()

        self.front_conv = nn.Sequential(
            CausalConv(1, params.residual_channels, params.front_kernel_size),
            nn.ReLU())

        self.res_blocks = nn.ModuleList()
        for b in range(params.num_blocks):
            for n in range(params.num_layers):
                self.res_blocks.append(ResBlock(
                    in_channels=params.residual_channels,
                    out_channels=params.gate_channels,
                    skip_channels=params.skip_channels,
                    kernel_size=params.kernel_size,
                    dilation=2 ** n,
                    cin_channels=params.mel_channels))

        self.final_conv = nn.Sequential(
            nn.ReLU(),
            nn.Conv1d(params.skip_channels, params.skip_channels, kernel_size=1),
            nn.ReLU(),
            nn.Conv1d(params.skip_channels, params.out_channels, kernel_size=1))

    def forward(self, x, c):
        # x: input tensor with signal or noise [B, 1, T]
        # c: local conditioning [B, C_mel, T]
        out = 0
        ################################################################################
        # YOUR CODE HERE
        ################################################################################
        return out

In [None]:
# check that works and gives expected output size
# full correctness we will check later, when the whole network will be assembled

class Params:
    mel_channels: int = 80
    num_blocks: int = 4
    num_layers: int = 6
    out_channels: int = 3
    front_kernel_size: int = 2
    residual_channels: int = 64
    gate_channels: int = 64
    skip_channels: int = 128
    kernel_size: int = 2
        
net = WaveNet(Params()).to(device).eval()
with torch.no_grad():
    z = torch.FloatTensor(5, 1, 4096).normal_().to(device)
    c = torch.FloatTensor(5, 80, 4096).zero_().to(device)
    assert list(net(z, c).size()) == [5, 3, 4096]

Excellent 👍! Now we are ready to get started on more complex and interesting things.

Do you remember our talks about vocoders built on IAF (Parallel WaveNet or ClariNet Vocoder)? We casually said that IAF we use not just one WaveNet (predicting mu and sigma), but a stack of WaveNets. Actually, let's implement this stack, but first, a few formulas that will help you.

Consider transformations of random variable $z^{(0)} \sim \mathcal{N}(0, I)$: 
$$z^{(0)} \rightarrow z^{(1)} \rightarrow \dots \rightarrow z^{(n)}.$$

Each transformation has the form: 
$$ z^{(k)} = f^{(k)}(z^{(k-1)}) = z^{(k-1)} \cdot \sigma^{(k)} + \mu^{(k)},$$ 
where $\mu^{(k)}_t = \mu(z_{<t}^{(k-1)}; \theta_k)$ and $\sigma^{(k)}_t = \sigma(z_{<t}^{(k-1)}; \theta_k)$ – are shifting and scaling variables modeled by a Gaussan WaveNet. 

It is easy to deduce that the whole transformation $f^{(k)} \circ \dots \circ f^{(2)} \circ f^{(1)}$ can be represented as $f^{(\mathrm{total})}(z) = z \cdot \sigma^{(\mathrm{total})} + \mu^{(\mathrm{total})}$, where
$$\sigma^{(\mathrm{total})} = \prod_{k=1}^n \sigma^{(k)}, ~ ~ ~ \mu^{(\mathrm{total})} = \sum_{k=1}^n \mu^{(k)} \prod_{j > k}^n \sigma^{(j)} $$

$\mu^{(\mathrm{total})}$ and $\sigma^{(\mathrm{total})}$ we will need in the future for $p(\hat x | z)$ estimation.

You need to **implement** `forward` method of `WaveNetFlows` model.

📝 Notes: 
1. WaveNet outputs tensor `output` of size `[B, 2, T]`, where `output[:, 0, :]` is $\mu$ and `output[:, 1, :]` is $\log \sigma$. We model logarithms of $\sigma$ insead of $\sigma$ for stable gradients. 
2. As we model $\mu(z_{<t}^{(k-1)}; \theta_k)$ and $\sigma(z_{<t}^{(k-1)}; \theta_k)$ – their output we have length `T - 1`. To keep constant length `T` of modelled noise variable we need to pad it on the left side (with zero).
3. $\mu^{(\mathrm{total})}$ and $\sigma^{(\mathrm{total})}$ wil have length `T - 1`, because we do not pad distribution parameters.

In [None]:
class WaveNetFlows(nn.Module):
    def __init__(self, params):
        super(WaveNetFlows, self).__init__()

        self.iafs = nn.ModuleList()
        for i in range(params.num_flows):
            self.iafs.append(WaveNet(params))

    def forward(self, z, c):
        # z: random sample from standart distribution [B, 1, T]
        # c: local conditioning for WaveNet [B, C_mel, T]
        mu_tot, logs_tot = 0., 0.
        ################################################################################
        # YOUR CODE HERE
        ################################################################################
        return z, mu_tot, logs_tot

In [None]:
class Params:
    num_flows: int = 4
    mel_channels: int = 80
    num_blocks: int = 1
    num_layers: int = 5
    out_channels: int = 2
    front_kernel_size: int = 2
    residual_channels: int = 64
    gate_channels: int = 64
    skip_channels: int = 64
    kernel_size: int = 3
        
net = WaveNetFlows(Params()).to(device)

with torch.no_grad():
    z = torch.FloatTensor(3, 1, 4096).normal_().to(device)
    c = torch.FloatTensor(3, 80, 4096).zero_().to(device)
    z_hat, mu, log_sigma = net(z, c)
    assert list(z_hat.size()) == [3, 1, 4096]         # same length as input
    assert list(mu.size()) == [3, 1, 4096 - 1]        # shorter by one sample
    assert list(log_sigma.size()) == [3, 1, 4096 - 1] # shorted by one sample

If you are not familiar with VAE framework, please try to figure it out. For example, please familiarize with this [blog post](https://wiseodd.github.io/techblog/2016/12/10/variational-autoencoder/).


In short, VAE – is just "modification" of AutoEncoder, which consists of encoder and decoder. VAE allows you to sample from data distribution $p(x)$ as $p(x|z)$ via its decoder, where $p(z)$ is simple and known, e.g. $\mathcal{N}(0, I)$. The interesting part is that $p(x | z)$ cannot be optimized with Maximum Likelihood Estimation, because $p(x | z)$ is not tractable. 

But we can maximize Evidence Lower Bound (ELBO) which has a form:

$$\max_{\phi, \theta} \mathbb{E}_{q_{\phi}(z | x)} \log p_{\theta}(x | z) - \mathbb{D}_{KL}(q_{\phi}(z | x) || p(z))$$

where $p_{\theta}(x | z)$ is VAE decoder and $q_{\phi}(z | x)$ is VAE encoder. For more details please read mentioned blog post or any other materials on this theme.

In our case $q_{\phi}(z | x)$ is represented by MAF WaveNet, and $p_{\theta}(x | z)$ – by IAF build with WaveNet stack. To be more precise our decoder $p_{\theta}(x | z)$ is parametrised by the **one-step-ahead prediction** from an IAF.

🧑‍💻 **let's practice..**

We will start from easy part: generation (or sampling). 

**Implement** `generate` method, which accepts mel spectrogram as conditioning tensor. Inside this method random tensor from standart distribution N(0, I) is sampled. This tensor than transformed to tensor from audio distribution via `decoder`. In the cell bellow you will see code for loading pretrained model and mel spectrogram. Listen to result – it should sound passable, but MOS 5.0 is not expected. 😄

In [None]:
class WaveNetVAE(nn.Module):
    def __init__(self, encoder_params, decoder_params):
        super(WaveNetVAE, self).__init__()

        self.encoder = WaveNet(encoder_params)
        self.decoder = WaveNetFlows(decoder_params)
        self.log_eps = nn.Parameter(torch.zeros(1))

        self.upsample_conv = nn.ModuleList()
        for s in [16, 16]:
            conv = nn.ConvTranspose2d(1, 1, (3, 2 * s), padding=(1, s // 2), stride=(1, s))
            conv = nn.utils.weight_norm(conv)
            nn.init.kaiming_normal_(conv.weight)
            self.upsample_conv.append(conv)
            self.upsample_conv.append(nn.LeakyReLU(0.4))

    def forward(self, x, c):
        # x: audio signal [B, 1, T]
        # c: mel spectrogram [B, 1, T / HOP_SIZE]
        loss_rec = 0
        loss_kl = 0
        loss_frame_rec = 0
        loss_frame_prior = 0
        ################################################################################
        # YOUR CODE HERE
        ################################################################################
        alpha = 1e-9  # for annealing during training
        return  loss_rec + alpha * loss_kl + loss_frame_rec + loss_frame_prior

    def generate(self, c):
        # c: mel spectrogram [B, 80, L] where L - number of mel frames
        # outputs: audio [B, 1, L * HOP_SIZE]
        ################################################################################
        # YOUR CODE HERE
        ################################################################################
        return x_sample

    def upsample(self, c):
        c = c.unsqueeze(1) # [B, 1, C, L]
        for f in self.upsample_conv:
            c = f(c)
        c = c.squeeze(1) # [B, C, T], where T = L * HOP_SIZE
        return c

In [None]:
# saved checkpoint model has following architecture parameters

class ParamsMAF:
    mel_channels: int = 80
    num_blocks: int = 2
    num_layers: int = 10
    out_channels: int = 2
    front_kernel_size: int = 32
    residual_channels: int = 128
    gate_channels: int = 256
    skip_channels: int = 128
    kernel_size: int = 2


class ParamsIAF:
    num_flows: int = 6
    mel_channels: int = 80
    num_blocks: int = 1
    num_layers: int = 10
    out_channels: int = 2
    front_kernel_size: int = 32
    residual_channels: int = 64
    gate_channels: int = 128
    skip_channels: int = 64
    kernel_size: int = 3
        
# load checkpoint
ckpt_path = 'data/checkpoint.pth'
net = WaveNetVAE(ParamsMAF(), ParamsIAF()).eval().to(device)
ckpt = torch.load(ckpt_path, map_location='cpu')
net.load_state_dict(ckpt['state_dict'])

# load original audio and it's mel
x = torch.load('data/x.pth').to(device)
c = torch.load('data/c.pth').to(device)

# generate audio from 
with torch.no_grad():
    x_prior = net.generate(c.unsqueeze(0)).squeeze()

display(Audio(x_prior.cpu(), rate=22050))

If it sounds plausible **5 points** 🥉 are already yours 🎉! And here the most interesting and difficult part comes: loss function implementation. The `forward` method will return the loss. But lets talk more precisly about our architecture and how it was trained.

The encoder of our model $q_{\phi}(z|x)$ is parametrerized by a Gaussian autoregressive WaveNet, which maps the audio $x$ into the sample length latent representation $z$. Specifically, the Gaussian WaveNet (if we talk about **real MAF**) models $x_t$ given the previous samples $x_{<t}$ with $x_t ∼ \mathcal{N}(\mu(x_{<t}; \phi), \sigma(x_{<t}; \phi))$, where the mean $\mu(x_{<t}; \phi)$ and log-scale $\log \sigma(x_{<t}; \phi)$ are predicted by WaveNet, respectively.

Our **encoder** posterior is constructed as

$$q_{\phi}(z | x) = \prod_{t} q_{\phi}(z_t | x_{\leq t})$$

where

$$q_{\phi}(z_t | x_{\leq t}) = \mathcal{N}(\frac{x_t - \mu(x_{<t}; \phi)}{\sigma(x_{<t}; \phi)}, \varepsilon)$$

We apply the mean $\mu(x_{<t}; \phi)$ and scale $\sigma(x_{<t})$ for "whitening" the posterior distribution. Also we introduce a trainable scalar $\varepsilon > 0$ to decouple the global variation, which will make optimization process easier.

Substitution of our model formulas in $\mathbb{D}_{KL}$ formula gives:

$$\mathbb{D}_{KL}(q_{\phi}(z | x) || p(z)) = \sum_t \log\frac{1}{\varepsilon} + \frac{1}{2}(\varepsilon^2 - 1 + (\frac{x_t - \mu(x_{<t})}{\sigma(x_{<t})})^2)$$

**Implement** calculation of `loss_kl` in `forward` method as KL divergence.

---

The other term in ELBO formula can be interpreted as reconstruction loss. It can be evaluated by sampling from $p_{\theta}(x | z)$, where $z$ is from $q_{\phi}(z | \hat x)$, $\hat x$ is our ground truth audio. But sampling is not differential operation! 🤔 We can apply reparametrization trick!

**Implement** calculation of `loss_rec` in `forward` method as recontruction loss – which is just log likelihood of ground truth sample $x$ in predicted by IAF distribution $p_{\theta}(x | \hat z)$ where $\hat z \sim q_{\phi}(z | \hat x)$.

--- 

Vocoders without MLE are still not able to train without auxilary losses. We studied many of them, but STFT-loss is our favourite!

**Implement** calculation of `loss_frame_rec` which stands for MSE loss in STFT domain between original audio and its reconstruction.

--- 

We can go even further and calculate STFT loss with random sample from $p_\theta(x | z)$. Conditioning on mel spectrogram allows us to do so.

**Implement** calculation of `loss_frame_prior` which stands for MSE loss in STFT domain between original audio and sample from prior.

In [None]:
net = WaveNetVAE(ParamsMAF(), ParamsIAF()).to(device).train()

x = x[:64 * 256]
c = c[:, :64]

net.zero_grad()
loss = net.forward(x.unsqueeze(0).unsqueeze(0), c.unsqueeze(0))
loss.backward()
print(f"Initial loss: {loss.item():.2f}")

ckpt = torch.load(ckpt_path, map_location='cpu')
net.load_state_dict(ckpt['state_dict'])

net.zero_grad()
loss = net.forward(x.unsqueeze(0).unsqueeze(0), c.unsqueeze(0))
loss.backward()
print(f"Optimized loss: {loss.item():.2f}")

If you correctly implemented losses and the backward pass works smoothly, **8 more points**🥈 are yours 🎉!

For **2 additional points** 🥇 please write a short essay (in russian) about your thoughts on vocoders. Try to avoid obvious statements as "vocoder is very important part of TTS pipeline". We are interested in insights you've got from studying vocoders. 

`YOUR TEXT HERE`