# Homework №5

    This homework will be dedicated to the Text-to-Speech(TTS), specifically the neural vocoder.

In [None]:
### COLAB SETUP 
#!wget https://data.keithito.com/data/speech/LJSpeech-1.1.tar.bz2
#!tar -xf LJSpeech-1.1.tar.bz2
#!pip install torchaudio
#!pip install wandb

In [None]:
### KAGGLE SETUP
!pip uninstall -y torch
!pip uninstall -y torchaudio
!pip install torch==1.7.0+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
!pip install wandb

In [None]:
!wandb login 6aa2251ef1ea5e572e6a7608c0152db29bd9a294

In [None]:
import wandb
wandb.init(project='wavenet-pytorch')
print(1)

# Data

    In this homework we will use only LJSpeech https://keithito.com/LJ-Speech-Dataset/.

    Use the following `featurizer` (his configuration is +- standard for this task):

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
from IPython import display
from dataclasses import dataclass

import torch
from torch import nn

import torchaudio

import librosa
from matplotlib import pyplot as plt


@dataclass
class MelSpectrogramConfig:
    sr: int = 22050
    win_length: int = 1024
    hop_length: int = 256
    n_fft: int = 1024
    f_min: int = 0
    f_max: int = 8000
    n_mels: int = 80
    power: float = 1.0
        
    # value of melspectrograms if we fed a silence into `MelSpectrogram`
    pad_value: float = -11.5129251


class MelSpectrogram(nn.Module):

    def __init__(self, config: MelSpectrogramConfig):
        super(MelSpectrogram, self).__init__()
        
        self.config = config

        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=config.sr,
            win_length=config.win_length,
            hop_length=config.hop_length,
            n_fft=config.n_fft,
            f_min=config.f_min,
            f_max=config.f_max,
            n_mels=config.n_mels
        )

        # The is no way to set power in constructor in 0.5.0 version.
        self.mel_spectrogram.spectrogram.power = config.power

        # Default `torchaudio` mel basis uses HTK formula. In order to be compatible with WaveGlow
        # we decided to use Slaney one instead (as well as `librosa` does by default).
        mel_basis = librosa.filters.mel(
            sr=config.sr,
            n_fft=config.n_fft,
            n_mels=config.n_mels,
            fmin=config.f_min,
            fmax=config.f_max
        ).T
        self.mel_spectrogram.mel_scale.fb.copy_(torch.tensor(mel_basis)).to(device)
    

    def forward(self, audio: torch.Tensor) -> torch.Tensor:
        """
        :param audio: Expected shape is [B, T]
        :return: Shape is [B, n_mels, T']
        """
        
        mel = self.mel_spectrogram(audio) \
            .clamp_(min=1e-5) \
            .log_()

        return mel

In [None]:
featurizer = MelSpectrogram(MelSpectrogramConfig()).to(device)
#wav, sr = torchaudio.load('../dla-ht4/LJSpeech-1.1/wavs/LJ001-0001.wav')
#mels = featurizer(wav)

In [None]:
#_, axes = plt.subplots(2, 1, figsize=(15, 7))
#axes[0].plot(wav.squeeze())
#axes[1].imshow(mels.squeeze())

#plt.show()

# Model

    1) In this homework you need to implement classical version of WaveNet.
        Pay attention on:
            1.1) Causal convs. We recommend to implement it via padding.
            1.2) "Condition Network" which align mel with wav

    2) (Bonus) If you have already implemented WaveNet, you can try to implement [Parallel WaveGAN](https://www.dropbox.com/s/bj25vnmkblr9y8v/PWG.pdf?dl=0).
        This model is based on WaveNet and GAN.

    3) (Bonus) Fast generation of WaveNet. https://arxiv.org/abs/1611.09482.
        Don't forget to compare perfomance.

# Code

    1) In this homework you are allowed to use pytorch-lighting.

    2) Try to write code more structurally and cleanly!

    3) Good logging of experiments save your nerves and time, so we ask you to use W&B.
       Log loss, generated and real wavs (in pair, i.e. real wav and wav from correspond mel). 
       Do not remove the logs until we have checked your work and given you a grade!

    4) We also ask you to organize your code in github repo with (Bonus) Docker and setup.py. You can use my template https://github.com/markovka17/dl-start-pack.

    5) Your work must be reproducable, so fix seed, save the weights of model, and etc.

    6) In the end of your work write inference utils. Anyone should be able to take your weight, load it into the model and run it on some melspec.

# Report

    Finally, you need to write a report in W&B https://www.wandb.com/reports. Add examples of generated mel and audio, compare with GT.
    Don't forget to add link to your report.

### IMPORTS

In [None]:
import pandas as pd
import numpy as np
import random

import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

In [None]:
BATCH_SIZE = 5

### USEFULL

In [None]:
def set_seed(seed):
    torch.backends.cudnn.deterministic = True
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
set_seed(21)

In [None]:
def count_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum([np.prod(p.size()) for p in model_parameters])

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
# works if size % hop_len == 0

def aud_len_from_mel(melspec, win_length=1024, hop_length=256):
    return (melspec.size(-1) - 1) * hop_length

### DATA

In [None]:
class MelSpecAudioDataset(torch.utils.data.Dataset):
    """Custom dataset containing text and audio."""

    def __init__(self, root='../input/dlaht4dataset/LJSpeech-1.1/', csv_path='metadata.csv', transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root (string): Directory with all the data.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        
        self.root = root
        self.csv = pd.read_csv(root+csv_path, sep='|', header=None)
        self.csv = self.csv.drop(columns=[1]).rename(columns={0:'filename', 2:'norm_text'})  # leave only normilized
        self.csv = self.csv.dropna().reset_index()
        self.transform = transform
        

    def __len__(self):
        return self.csv.shape[0]
    

    def __getitem__(self, idx):
        utt_name = self.root + 'wavs/' + self.csv.loc[idx, 'filename'] + '.wav'
        utt = torchaudio.load(utt_name)[0].squeeze()
        
        if self.transform:
            utt = self.transform(utt)
    
        sample = {'audio': utt}
        return sample

In [None]:
def tr_transform(wav, len_sample=15104):
    
    start = torch.randint(low=0, high=wav.size(0)-len_sample-1, size=(1,)).item()
    
    return wav[start:start+len_sample]

### LOADERS

In [None]:
my_dataset = MelSpecAudioDataset(csv_path='metadata.csv', transform=tr_transform)
my_dataset_size = len(my_dataset)
print('all train+val samples:', my_dataset_size)

In [None]:
train_len = int(my_dataset_size * 0.8)
val_len = my_dataset_size - train_len
train_set, val_set = torch.utils.data.random_split(my_dataset, [train_len, val_len])

In [None]:
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, 
                          shuffle=True,
                          num_workers=1, pin_memory=True)

val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, 
                        shuffle=True, 
                        num_workers=1, pin_memory=True)

In [None]:
def field_size(D, L):
    res = 0
    for i in range(L):
        res += 2**(i%D)
    return res

In [None]:
field_size(10, 30)

### REAL ARCHITECTURE

In [None]:
class CausalConv1d(nn.Module):
    def __init__(self, in_ch, out_ch, kernel_size, dilation=1):
        super(CausalConv1d, self).__init__()
    
        self.pad_size = (dilation * (kernel_size - 1))
        
        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=kernel_size, dilation=dilation, padding=0)
    

    def forward(self, x):
        x = F.pad(x, (self.pad_size, 0), 'constant', 0)
        return self.conv(x)

In [None]:
class WaveNetLayer(nn.Module):
    def __init__(self, input_ch, skip_ch, layer_num):
        super(WaveNetLayer, self).__init__()
        
        self.dil_now = 2**(layer_num % 10)   # 10
        
        self.W_f = CausalConv1d(input_ch, input_ch, kernel_size=2, dilation=self.dil_now)
        self.W_g = CausalConv1d(input_ch, input_ch, kernel_size=2, dilation=self.dil_now)
        self.V_f = nn.Conv1d(80, input_ch, kernel_size=1)
        self.V_g = nn.Conv1d(80, input_ch, kernel_size=1)
        
        self.skip_conv = nn.Conv1d(input_ch, skip_ch, kernel_size=1)
        self.resid_conv = nn.Conv1d(input_ch, input_ch, kernel_size=1)
    
        
    def forward(self, melspec, wav):
        #wav1, wav2 = wav, wav
        #mel1, mel2 = melspec, melspec
        
        z = torch.tanh(self.W_f(wav) + self.V_f(melspec)) \
            * \
            torch.sigmoid(self.W_g(wav) + self.V_g(melspec))
        
        skip_res = self.skip_conv(z)
        
        resid_res = self.resid_conv(z)
        resid_res = resid_res + wav
        
        return skip_res, resid_res

In [None]:
class WaveNet(nn.Module):
    def __init__(self, hidden_ch, skip_ch, num_layers, mu):
        super(WaveNet, self).__init__()
        
        self.skip_ch = skip_ch
        self.mu = mu
        #self.convtr = nn.ConvTranspose1d(in_channels=80, out_channels=80,
        #           kernel_size=512,   # 2 * 256 = 2 * hop_len 
        #           stride=256,        # hop_len 
        #           padding=256)       # ks // 2)   #
        self.embedding = CausalConv1d(1, hidden_ch, kernel_size=512)
        
        self.layers = nn.ModuleList()
        for i in range(num_layers):
            self.layers.append(WaveNetLayer(hidden_ch, skip_ch, layer_num=i))
        
        self.out_conv = nn.Conv1d(skip_ch, mu, kernel_size=1)
        self.end_conv = nn.Conv1d(mu, mu, kernel_size=1)
        
        
    def forward(self, melspec, wav):
        
        melspec = torch.nn.functional.interpolate(melspec, aud_len_from_mel(melspec))[:, :, 1:]   #self.convtr(melspec)[:, :, 1:]  
        wav = self.embedding(wav)
        
        skip_conn_res = torch.zeros((wav.size(0), self.skip_ch, wav.size(-1))).to(wav.device)
        for i in range(len(self.layers)):
            skip_one, wav = self.layers[i](melspec, wav)
            skip_conn_res = skip_conn_res + skip_one
            
        result_wav = self.end_conv(F.relu(
                                          self.out_conv(F.relu(skip_conn_res))
                                         ))
        
        return result_wav


     
    def inference(self, melspec):
        # bs=1
        
        new_wav_len = aud_len_from_mel(melspec)
        melspec = torch.nn.functional.interpolate(melspec, new_wav_len)[:, :, 1:] #self.convtr(melspec)[:, :, 1:]  
        
        # melspec[:i], wav[:i-1] = сначала 0
        whole_melspec = melspec
        melspec = melspec[:, :, :1]
        wav = torch.zeros((1, 1, 1)).to(melspec.device)
        for j in tqdm(range(2, new_wav_len+1)):
            # генерим i wav, смотрим на [:i-1] wav (это 0, но это [:1], 
            # [:i] mel это 1, но это [:2]

            new_wav = self.embedding(wav)
        
            skip_conn_res = torch.zeros((new_wav.size(0), self.skip_ch, new_wav.size(-1))).to(new_wav.device)
            for i in range(len(self.layers)):
                skip_one, new_wav = self.layers[i](melspec, new_wav)
                skip_conn_res = skip_conn_res + skip_one

            result_wav = self.end_conv(F.relu(
                                              self.out_conv(F.relu(skip_conn_res))
                                             ))
            result_wav = torch.argmax(result_wav, dim=1)
            # обновление
            # далее оба -1 дим =2 неверно, поэтому : : -1: но верно ли это?
            wav = torch.cat((wav, result_wav.unsqueeze(1)[:, : , -1:]), dim=-1)  #? ? ? ? 
            melspec = whole_melspec[:, :, :j]
        
        return wav[:, :, 1:]

In [None]:
model = WaveNet(hidden_ch=120, skip_ch=240, num_layers=30, mu=256)
model = model.to(device)

In [None]:
wandb.watch(model)

In [None]:
count_parameters(model)

In [None]:
model

In [None]:
#checkpoint = torch.load('../input/epoch3/epoch_3', map_location=device)
#model.load_state_dict(checkpoint['model_state_dict'])

### TRAINING

In [None]:
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR

In [None]:
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
#scheduler = StepLR(opt, step_size=500, gamma=0.7)

In [None]:
mu_law_encoder = torchaudio.transforms.MuLawEncoding(quantization_channels=256).to(device)
mu_law_decoder = torchaudio.transforms.MuLawDecoding(quantization_channels=256).to(device)

In [None]:
NUM_EPOCHS=7

In [None]:
@torch.no_grad()
def validate(model, loader, featurizer, mu_law_encoder):
    total_loss = 0
    for el in loader:
        wav = el['audio'].to(device)
        melspec = featurizer(wav)
        wav = mu_law_encoder(wav).unsqueeze(1).type(torch.float)  # to device?
            
        new_wav = model(melspec, wav[:, :, :-1])
        new_wav = new_wav.transpose(-1, -2)

        ans = wav.type(torch.long)[:, 0, 1:]
        loss = F.cross_entropy(new_wav.reshape(-1, 256), ans.reshape(-1))
        wandb.log({'val_item_loss':loss.item()})
        total_loss = total_loss + loss.item()
            
        
    wandb.log({'val_loss':total_loss})

In [None]:
for i in tqdm(range(NUM_EPOCHS)):
    for el in train_loader:
        wav = el['audio'].to(device)
        melspec = featurizer(wav)
        wav = mu_law_encoder(wav).unsqueeze(1).type(torch.float)  # to device?

        opt.zero_grad()

        new_wav = model(melspec, wav[:, :, :-1])

        #print(new_wav.size(), new_wav.transpose(-1, -2).size(),
        #      F.log_softmax(new_wav.transpose(-1, -2)).size(), 
        #      F.log_softmax(new_wav.transpose(-1, -2)).view(-1, 256).size())

        new_wav = new_wav.transpose(-1, -2)

        ans = wav.type(torch.long)[:, 0, 1:]
        loss = F.cross_entropy(new_wav.reshape(-1, 256), ans.reshape(-1))


        #print('AFTER' , new_wav.detach().unique().sort())
        #print(new_wav.size(), wav.type(torch.long).squeeze().view(-1).size())            
        
        '''new_wav = F.log_softmax(new_wav, dim=-1).view(-1, 256)
        ans = wav.type(torch.long).squeeze()[:, 1:]
        loss = F.nll_loss(new_wav, ans.reshape(-1))'''

        #print(new_wav, ans)

        #print(loss.item())
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 20)
            
        opt.step()
        #scheduler.step()

        wandb.log({'train_loss':loss.item()})

    torch.save({'model_state_dict': model.state_dict()}, 'epoch_'+str(i))
    validate(model, val_loader, featurizer, mu_law_encoder)
    #print(scheduler.get_last_lr())

### INFERENCE

In [None]:
@torch.no_grad()
def inference(model, loader, featurizer, mu_law_encoder):
    for el in loader:
        wav = el['audio'][:, :4096].to(device)
        melspec = featurizer(wav)
        wav = mu_law_encoder(wav).unsqueeze(1).type(torch.float)  # to device?

        new_wav = model.inference(melspec)

        plt.plot(mu_law_decoder(wav.squeeze().detach().cpu()))
        plt.show()
        plt.plot(mu_law_decoder(new_wav.squeeze().detach().cpu()))
        plt.show

        break

In [None]:
#plt.plot(mu_law_decoder(wav.squeeze().detach().cpu()))
#plt.plot(mu_law_decoder(new_wav.squeeze().detach().cpu()))
#plt.show