<a href="https://colab.research.google.com/github/anton-selitskiy/WaveNet/blob/main/WaveNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchaudio

Collecting torchaudio
[?25l  Downloading https://files.pythonhosted.org/packages/aa/55/01ad9244bcd595e39cea5ce30726a7fe02fd963d07daeb136bfe7e23f0a5/torchaudio-0.8.1-cp37-cp37m-manylinux1_x86_64.whl (1.9MB)
[K     |████████████████████████████████| 1.9MB 8.1MB/s 
Installing collected packages: torchaudio
Successfully installed torchaudio-0.8.1


In [None]:
import math
import pathlib
import random
import itertools
from tqdm import tqdm

from IPython import display
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import distributions
from torch import nn
from torch import optim
from torch.utils.data import Dataset, DataLoader

import torchaudio
from torchaudio.transforms import MelSpectrogram

import librosa
import pandas as pd
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt

Causal Convolution Block

We use nn.ConstantPad1d to add zeros from the left. How to calculate the size of the padding? If the kernal size is M and the length of the input L, then output has length L-M+1. If we want to have the same length, we should add M-1 zeros (then L+M-1 -M+1 = L).

For example, create a tensor:

```
batch_size = 1
in_channel = 1
time =10
inp = torch.arange(time).reshape(batch_size, in_channel, time).float()
#inp = torch.rand(batch_size, in_channel, time)
inp
```
tensor([[[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]]])

Create a kernal:
```
kernel_size = 2
padding_f = nn.ConstantPad1d((kernel_size-1,0), value=0.0)
conv = nn.Conv1d(1,1,kernel_size,bias=False,dilation=1)
conv.weight.data = torch.ones(1,1,kernel_size)
conv.weight.data
```
tensor([[[1., 1.]]])
```
print(padding_f(inp))
print(conv(padding_f(inp)))
```
tensor([[[0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]]])

tensor([[[ 0.,  1.,  3.,  5.,  7.,  9., 11., 13., 15., 17.]]],
       grad_fn=\<SqueezeBackward1\>)

As a result, two nearest digits were addad with weights 1 and the length of the output did not change.

If we want to add a dilation (dilatation) D, then we should add (M-1)*d zeros.

In [None]:
batch_size = 1
in_channel = 1
time =10
inp = torch.arange(time).reshape(batch_size, in_channel, time).float()
# inp = torch.rand(batch_size, in_channel, time)
print(inp)
kernel_size = 3
padding_f = nn.ConstantPad1d((kernel_size-1,0), value=0.0)
conv = nn.Conv1d(1,1,kernel_size,bias=False,dilation=2)
conv.weight.data = torch.ones(1,1,kernel_size)
print(conv.weight.data)
print(padding_f(inp))
print(conv(padding_f(inp)))

tensor([[[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]]])
tensor([[[1., 1., 1.]]])
tensor([[[0., 0., 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]]])
tensor([[[ 2.,  4.,  6.,  9., 12., 15., 18., 21.]]],
       grad_fn=<SqueezeBackward1>)


In [None]:
class CausalConv1d(nn.Conv1d):
    """
    Casual Conv1d
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        dilation: int = 1,
        bias: bool = True
    ):
        super().__init__(
            in_channels,
            out_channels,
            kernel_size,
            dilation=dilation,
            bias=bias
        )

        padding_size = (kernel_size - 1) * dilation
        self.zero_padding = nn.ConstantPad1d(
            padding=(padding_size, 0),
            value=0.0
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        padded_input = self.zero_padding(input)
        output = super().forward(padded_input)
        return output

### Gated Activation Unit
$tanh(W_f*x )\cdot \sigma(W_g*x)$

In [None]:
class GatedConv1d(nn.Module):
    """
    Gated Conv1d
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        dilation: int
    ):
        super().__init__()

        self.filter_conv = CausalConv1d(in_channels, out_channels, kernel_size, dilation)
        self.gate_conv = CausalConv1d(in_channels, out_channels, kernel_size, dilation)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        filter_ouput = self.filter_conv(input)
        gate_output = self.gate_conv(input)
        
        output = torch.tanh(filter_ouput) * torch.sigmoid(gate_output)

        return output

Conditioned GAU $tanh(W_f*x + V_f*x)\cdot \sigma(W_g*x+V_g*x)$

In [None]:
class CondGatedConv1d(GatedConv1d):
    """
    Conditioned Gated Conv1d
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        cond_in_channels: int,
        kernel_size: int,
        dilation: int
    ):
        super().__init__(in_channels, out_channels, kernel_size, dilation)

        self.cond_conv = nn.Conv1d(
            in_channels=cond_in_channels,
            out_channels=2 * out_channels,
            kernel_size=1
        )

    def forward(
        self,
        input: torch.Tensor,
        condition: torch.Tensor
    ) -> torch.Tensor:
        assert input.size(-1) == condition.size(-1)
        
        filter_ouput = self.filter_conv(input)
        gate_output = self.gate_conv(input)
        
        c_output = self.cond_conv(condition)
        # We can use  2 channels instead of two filters (function chunk):
        c_filter_output, c_gate_output = torch.chunk(c_output, 2, dim=1)
        
        output = torch.tanh(filter_ouput + c_filter_output) * torch.sigmoid(gate_output + c_gate_output)

        return output

Example of using chunk
```
torch.chunk(inp, 2, dim=-1)
```
(tensor([[[0., 1., 2., 3., 4.]]]), tensor([[[5., 6., 7., 8., 9.]]]))

In [None]:
torch.chunk(inp, 2, dim=-1)

(tensor([[[0., 1., 2., 3., 4.]]]), tensor([[[5., 6., 7., 8., 9.]]]))

### Residual Block

In [None]:
class CondWaveNetBlock(nn.Module):
    """
    Conditioned WaveNet block
    """

    def __init__(
        self,
        gated_in_channels: int,
        gated_out_channels: int,
        cond_in_channels: int,
        skip_out_channels: int,
        kernel_size: int,
        dilation: int
    ):
        super().__init__()
        
        self.gated_cond = CondGatedConv1d(
            in_channels=gated_in_channels,
            out_channels=gated_out_channels,
            cond_in_channels=cond_in_channels,
            kernel_size=kernel_size,
            dilation=dilation
        )

        self.skip_conv = nn.Conv1d(gated_out_channels, skip_out_channels, kernel_size=1)
        self.residual_conv = nn.Conv1d(gated_out_channels, gated_in_channels, kernel_size=1)
    

    def forward(self, input: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        gated_output = self.gated_cond(input, condition)
        
        # y = f(x) + x
        residual_output = self.residual_conv(gated_output) + input
        skip_output = self.skip_conv(gated_output)

        return residual_output, skip_output

### Reduce quantisation size from $2^{16}$ to $2^8$
$f(x) = sign(x) \dfrac{\ln(1+\mu |x|)}{\ln(1+\mu)}$

In [None]:
class MuLaw(nn.Module):

    def __init__(self, mu: float = 256):
        super().__init__()
        self.register_buffer('mu', torch.FloatTensor([mu - 1]))

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return self.encode(input)

    def encode(self, input: torch.Tensor) -> torch.Tensor:
        input = torch.clamp(input, -1 + 1e-5, 1 - 1e-5)

        mu_law_output = torch.sign(input) * torch.log1p(self.mu * torch.abs(input)) / torch.log1p(self.mu)

        # [-1, 1] -> [0, 1]
        quantized_output = (mu_law_output + 1) / 2

        # [0, 1] -> [0, mu - 1]
        quantized_output = torch.floor(quantized_output * self.mu + 0.5).long()
        
        return quantized_output

    def decode(self, input: torch.Tensor) -> torch.Tensor:
        dequantized_output = (input.float() / self.mu) * 2 - 1
        output = (torch.sign(dequantized_output) / self.mu) * \
            ((1 + self.mu) ** torch.abs(dequantized_output) - 1)
        
        return output

```
mu_law_encoder = MuLaw(256)
input = torch.randn(5).mul(0.1).clamp(-1, 1)
print(f'Input: {input}')
print(f'After MuLaw Encoding: {mu_law_encoder(input)}')
print(f'After MuLaw Decoding: {mu_law_encoder.decode(mu_law_encoder(input))}')

Input: tensor([ 0.1020, -0.0206, -0.0404,  0.0036,  0.2439])

After MuLaw Encoding: tensor([203,  85,  72, 142, 223])

After MuLaw Decoding: tensor([ 0.1007, -0.0210, -0.0399,  0.0034,  0.2457])
```

In [None]:
class OneHot(nn.Module):
    """
    Convert quantized 1d samples into n_class one-hot tensor
    """

    def __init__(self, n_class: int = 256):
        super().__init__()

        self.n_class = n_class

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.dim() == 3, "Expected shape of input is [B, C, T], where C == 1"
        return self.encode(input)

    def encode(self, input: torch.Tensor) -> torch.Tensor:
        output = torch.zeros(input.size(0), self.n_class, input.size(-1), device=input.device)
        output.scatter_(1, input, 1)
        return output

    def decode(self, input: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError()

```
upsampling = nn.Upsample(scale_factor=3, mode='nearest')
input = torch.arange(5).view(1, 1, -1).float()
print(f'Input: {input.squeeze()}')
print(f'After Upsampling: {upsampling(input).squeeze()}')

Input: tensor([0., 1., 2., 3., 4.])
After Upsampling: tensor([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4.])
```

In [None]:
class CondNet(nn.Module):
    """
    Processing condition (mel from TTS or something else)
    """

    def __init__(self, input_size: int, hidden_size: int, hop_size: int):
        """
        :param input_size:
        """
        super().__init__()

        self.input_size = input_size
        self.hop_size = hop_size
        self.hidden_size = hidden_size

        self.net = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size // 2,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
        )
        self.upsampler = nn.Upsample(scale_factor=hop_size, mode='nearest')

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        """
        :return: .shape == [B, T', C']
        """
        
        assert input.shape[-1] == self.input_size

        self.net.flatten_parameters()

        output, _ = self.net(input)

        output = output.transpose(-1, -2)
        upsampled_output = self.upsampler(output)

        return upsampled_output

WaveNet

In [None]:
class WaveNet(nn.Module):

    def __init__(
        self,
        # in_channels: int = 256,
        # out_channels: int = 256,
        # gate_channels: int = 512,
        # residual_channels: int = 256,
        # skip_channels: int = 256,
        # head_channels: int = 256,
        # condition_channels: int = 256,
        
        in_channels: int = 64,
        out_channels: int = 64,
        gate_channels: int = 64,
        residual_channels: int = 64,
        skip_channels: int = 64,
        head_channels: int = 64,
        condition_channels: int = 64,
        kernel_size: int = 2,
        dilation_cycles: int = 3,
        dilation_depth: int = 10,
        upsample_factor: int = 480,
    ) -> None:
        super().__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.gate_channels = gate_channels
        self.residual_channels = residual_channels
        self.skip_channels = skip_channels
        self.head_channels = head_channels
        self.condition_channels = condition_channels
        self.kernel_size = kernel_size
        self.dilation_cycles = dilation_cycles
        self.dilation_depth = dilation_depth
        self.upsample_factor = upsample_factor

        # 80 -- number of channels in mels 
        self.cond = CondNet(80, self.condition_channels, upsample_factor)

        self.stem = nn.Sequential(
            nn.Conv1d(in_channels, residual_channels, kernel_size=1)
        )

        self.blocks = nn.ModuleList([
            CondWaveNetBlock(residual_channels, gate_channels, condition_channels, skip_channels,
                             kernel_size, 2 ** (i % dilation_depth))
            for i in range(dilation_cycles * dilation_depth)
        ])

        # To avoid DDP error
        self.blocks[-1].residual_conv.requires_grad_(False)

        self.head = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv1d(skip_channels, head_channels, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(head_channels, out_channels, kernel_size=1),
        )

    def _forward(self, input: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        # already upsampled condition

        stem_output = self.stem(input)

        accumulation = 0
        residual_output = stem_output
        for i, block in enumerate(self.blocks):
            residual_output, skip_output = block(residual_output, condition)
            accumulation = accumulation + skip_output

        output = self.head(accumulation)

        return output

    def forward(self, input: torch.Tensor, condition: torch.Tensor) -> torch.Tensor:
        """
        :param input: samples
        :param condition: mel
        """

        condition = self.cond(condition)
        output = self._forward(input, condition)

        return output

    @property
    def num_parameters(self) -> int:
        return sum([p.numel() for p in self.parameters()])

    @property
    def receptive_field(self) -> int:
        dilations = [2 ** (i % self.dilation_depth)
                     for i in range(self.dilation_cycles * self.dilation_depth)]
        receptive_field = (self.kernel_size - 1) * sum(dilations) + 1

        return receptive_field

    def generate(self, condition: torch.Tensor, inference_type: str = "naive", verbose: bool = True) -> torch.Tensor:
        """
        :param condition: [1, T, C], C from ASR
        :param inference_type:
        :param verbose:
        """

        mu_low = MuLaw().to(condition.device)

        if inference_type == "naive":
            compressed_samples = self._naive_generate(condition, verbose)
        elif inference_type == "fast":
            compressed_samples = self._fast_generate(condition)
        else:
            raise ValueError(f"Invalid type of inference: {inference_type}")

        return mu_low.decode(compressed_samples)

    @torch.no_grad()
    def _naive_generate(self, condition: torch.Tensor, verbose: bool) -> torch.Tensor:
        one_hot = OneHot()

        required_num_samples = condition.shape[1] * self.upsample_factor
        generated_samples = torch.Tensor(1, 1, self.receptive_field + required_num_samples) \
            .fill_(self.in_channels // 2) \
            .to(condition.device)

        condition = self.cond(condition)
        condition = F.pad(condition, (self.receptive_field, 0), 'replicate')

        iterator = range(required_num_samples)
        if verbose:
            iterator = tqdm(iterator)

        for i in iterator:
            current_condition = condition[:, :, i:i + self.receptive_field]
            current_samples = generated_samples[:, :, i:i + self.receptive_field]
            current_one_hot_samples = one_hot(current_samples.long())

            current_output = self._forward(current_one_hot_samples, current_condition)
            last_logits = current_output[:, :, -1].squeeze()

            # sampling new sample
            samples = distributions.Categorical(logits=last_logits)
            new_sample = samples.sample(torch.Size([1]))
            generated_samples[:, :, i + self.receptive_field] = new_sample

        return generated_samples.squeeze()[-required_num_samples:]

NameError: ignored

In [None]:
model = WaveNet()