In [None]:
import torch
import torch.nn as nn

import torchaudio
import torchaudio.functional as taf

import librosa
import numpy as np

from tqdm import tqdm
from IPython.display import Audio

In [None]:
# 데이터셋 다운로드
dataset = torchaudio.datasets.LJSPEECH('.', download=True)

# 텍스트를 원핫인코딩해주는 전처리기 가져오기
text_preprocess = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor()

샘플 데이터셋 확인

In [None]:
# 첫 번째 데이터의 텍스트, 전처리된 텍스트
sample_text = dataset[0][3]
sample_phoneme, _ = text_preprocess(sample_text)

# 데이터 확인
print("Text:", sample_text)
print("Character sequence:", sample_phoneme)

# 음성 확인
Audio(dataset[0][0], rate=24000)

## Tacotron dataset

In [None]:
class HParams:
    # Audio preprocessing
    preemphasis = 0.97
    n_fft = 2048
    window = 'hann'
    frame_length_ms = 50
    frame_shift_ms = 12.5
    sample_rate = 24000
    ref_level_db = 20
    max_level_db = 100

    # min_level_db = -100
    # max_level_db = 100
    
    n_mels = 80
    win_length = int(round(frame_length_ms * sample_rate / 1000))
    hop_length = int(round(frame_shift_ms * sample_rate / 1000))
    n_phonemes = 70

    # model
    reduction_factor = 2
    character_embedding_dim = 256
    encoder_cbhg_k = 16
    encoder_cbhg_dim = 128
    encoder_conv1d_projection = [128, 128]
    encoder_highway = [128, 128, 128, 128]
    encoder_bidirectional_gru = 128
    encoder_prenet = [256, 128]
    decoder_prenet = [256, 128]
    decoder_cbhg_k = 8
    decoder_cbhg_dim = 128
    decoder_conv1d_projection = [256, 80]
    decoder_highway = [128, 128, 128, 128]
    decoder_bidirectional_gru = 128
    
    attention_rnn_dim = 256
    decoder_rnn_dim = 256

    # training
    batch_size = 16 # all sequences are padded to max length
    max_decoding_timestep = 200
    learning_rate = 0.001

hp = HParams()

In [None]:
import torch
import torchaudio
import torchaudio.functional as taf
import librosa
import numpy as np


class TacotronDataset(torch.utils.data.Dataset):
    """
    https://github.com/Kyubyong/tacotron/blob/master/utils.py#L21
    https://github.com/ttaoREtw/Tacotron-pytorch/blob/master/src/utils.py
    """

    def __init__(self):
        self.dataset = torchaudio.datasets.LJSPEECH('.', download=True)
        self.text_preprocess = torchaudio.pipelines.TACOTRON2_WAVERNN_CHAR_LJSPEECH.get_text_processor()

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):

        # get data
        y, sr, _, text = self.dataset[idx]

        # text preprocess
        text, _ = self.text_preprocess(text)

        # audio preprocess
        lin, mel = self.audio_preprocess(y, sr)

        # (text_len), (2048, len), (n_mel, mel_len)
        return text.squeeze(), lin, mel
    
    # Preemphasis
    def preemphasis(self, wav):
        return taf.preemphasis(wav, hp.preemphasis)
    
    # Short Time Frourier Transform
    def _stft(self, x):
        return librosa.stft(x.numpy().squeeze(), n_fft=hp.n_fft, hop_length=hp.hop_length, 
                            win_length=hp.win_length, window=hp.window, pad_mode='constant')

    # Spectrogram
    def spectrogram(self, wav):
        D = self._stft(self.preemphasis(wav))
        S = self._amp_to_db(np.abs(D)) - hp.ref_level_db
        return self._normalize(S)
    
    # Melspectrogram
    def melspectrogram(self, wav):
        D = self._stft(self.preemphasis(wav))
        S = self._amp_to_db(self._linear_to_mel(np.abs(D))) # - hp.ref_level_db
        return self._normalize(S)
    
    # amplitude to decibel
    def _amp_to_db(self, x):
        return 20 * np.log10(np.maximum(1e-5, x))
    
    # spectrogram to melspectrogram
    def _linear_to_mel(self, mag):
        mel_basis = librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.n_mels)
        return np.dot(mel_basis, mag)
    
    # normalization
    def _normalize(self, x):
        return np.clip((x - hp.ref_level_db + hp.max_level_db) / hp.max_level_db, 0, 1)
        # return np.clip((x - hp.min_level_db) / -hp.min_level_db, 0, 1)
        # return np.clip((2 * hp.max_abs_value) * ((x - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value, 
        #                -hp.max_abs_value, hp.max_abs_value)

    def audio_preprocess(self, y, sr):

        # trimming
        y, _ = librosa.effects.trim(y)

        # spectrogram, melspectrogram
        spectrogram = self.spectrogram(y)
        melspectrogram = self.melspectrogram(y)
        
        return torch.FloatTensor(spectrogram), torch.FloatTensor(melspectrogram)


class TacotronTTSCollate():

    def __init__(self):
        ...

    def __call__(self, batch):

        # get decreasing order by text length within batch
        text_lengths, ids_sorted_decreasing = torch.sort(
            torch.LongTensor([len(text) for text, _, _ in batch]),
            dim=0, descending=True
        )

        # all zero padded tensor
        max_text_len = text_lengths[0]
        text_padded = torch.LongTensor(len(batch), max_text_len)
        text_padded.zero_()

        # allocate text to zero padded tensor
        for i in range(len(ids_sorted_decreasing)):
            text = batch[ids_sorted_decreasing[i]][0]
            text_padded[i, :text.size(0)] = text

        # get maximum length of sequence within batch
        num_lins = batch[0][1].size(0)
        num_mels = batch[0][2].size(0)
        max_seq_len = max([lin.size(1) for _, lin, _ in batch])
        max_seq_len = max_seq_len + (hp.reduction_factor - max_seq_len % hp.reduction_factor)


        # all zero padded tensor
        mel_padded = torch.FloatTensor(len(batch), num_mels, max_seq_len)
        mel_padded.zero_()
        lin_padded = torch.FloatTensor(len(batch), num_lins, max_seq_len)
        lin_padded.zero_()
        seq_lengths = torch.LongTensor(len(batch))

        
        for i in range(len(ids_sorted_decreasing)):
            _, lin, mel = batch[ids_sorted_decreasing[i]]
            lin_padded[i, :, :lin.size(1)] = lin
            mel_padded[i, :, :mel.size(1)] = mel
            seq_lengths[i] = lin.size(1)


        return (
            text_padded,
            lin_padded.transpose(1, 2),
            mel_padded.transpose(1, 2),
            text_lengths,
            seq_lengths
        )


In [None]:
import matplotlib.pyplot as plt

def show_melspectrogram(mel_pred, mel_targ):
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6))
    im1 = ax1.imshow(mel_pred, aspect="auto", origin="lower", interpolation="none")
    plt.colorbar(im1, ax=ax1)

    im2 = ax2.imshow(mel_targ, aspect="auto", origin="lower", interpolation="none")
    plt.colorbar(im2, ax=ax2)
    plt.show()

    plt.close()

3.2 Encoder

The goal of the encoder is to extract robust sequential representations of text. The input to the encoder is a character sequence, where each character is represented as a one-hot vector and embedded into a continuous vector. We then apply a set of non-linear transformations, collectively called a “pre-net”, to each embedding. We use a bottleneck layer with dropout as the pre-net in this work, which helps convergence and improves generalization. A CBHG module transforms the prenet outputs into the final encoder representation used by the attention module. We found that this CBHG-based encoder not only reduces overfitting, but also makes fewer mispronunciations than a standard multi-layer RNN encoder (see our linked page of audio samples)

In [None]:
class Conv1DBN(nn.Module):
    def __init__(self, in_ch, out_ch, k, bias=False, activation=None):
        super(Conv1DBN, self).__init__()
        self.conv_1d = nn.Conv1d(in_channels=in_ch, out_channels=out_ch, kernel_size=k, padding=k//2, bias=bias)
        self.bn = nn.BatchNorm1d(out_ch)
        self.activation = activation

    def forward(self, x):
        # 1d convolution
        x = self.conv_1d(x)
        
        # activation function
        if self.activation is not None:
            x = self.activation(x)

        # batch normalization
        x = self.bn(x)
        return x

In [None]:
class Highway(nn.Module):
    """
    https://github.com/r9y9/tacotron_pytorch/blob/master/tacotron_pytorch/tacotron.py
    """
    
    def __init__(self, in_size, out_size):
        super(Highway, self).__init__()
        self.H = nn.Linear(in_size, out_size)
        self.H.bias.data.zero_()
        self.T = nn.Linear(in_size, out_size)
        self.T.bias.data.fill_(-1)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, inputs):
        H = self.relu(self.H(inputs))
        T = self.sigmoid(self.T(inputs))
        return H * T + inputs * (1.0 - T)

In [None]:
class PreNet(nn.Module):
    def __init__(self, ch_emb_dim, hidden_dims):
        super(PreNet, self).__init__()
        self.fc1 = nn.Linear(ch_emb_dim, hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.dropout = nn.Dropout(p=0.5)
        self.relu = nn.ReLU()

    def forward(self, x):
        # 2 fully connected layers : (B, text_len, 128)
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.dropout(self.relu(self.fc2(x)))
        return x

### CBHG

CBHG consists of a bank of 1-D convolutional filters, followed by highway networks (Srivastava et al., 2015) and a bidirectional gated recurrent unit (GRU) (Chung et al., 2014) recurrent neural net (RNN).  The input sequence is first convolved with $K$ sets of $1\text{-D}$ convolutional filters, where the $k$-th set contains $C_k$ filters of width $k\; (\text{i.e. } k = 1, 2, . . . , K)$. 

The convolution outputs are stacked together and further max pooled along time to increase local invariances. Note that we use a stride of 1 to preserve the original time resolution. We further pass the processed sequence to a few fixed-width 1-D convolutions, whose outputs are added with the original input sequence via residual connections (He et al., 2016). Batch normalization (Ioffe & Szegedy, 2015) is used for all convolutional layers.

The convolution outputs are fed into a multi-layer highway network to extract high-level features. Finally, we stack a bidirectional GRU RNN on top to extract sequential features from both forward and backward context. 

In [None]:
class Conv1DBank(nn.Module):
    def __init__(self, K, cbhg_dim):
        super(Conv1DBank, self).__init__()
        relu = nn.ReLU()
        self.convolutions = nn.ModuleList(
            [Conv1DBN(cbhg_dim, cbhg_dim, k, False, relu) 
             for k in range(1, K + 1)]
        )
        self.maxpool = nn.MaxPool1d(kernel_size=2, stride=1, padding=2//2)

    def forward(self, x):
        text_len = x.size(1)

        # convolution 연산을 위해 channel은 1번째 index에 위치해야함
        x = x.transpose(1, 2)

        # PAPER: The convolution outputs are stacked together
        x = torch.cat([conv(x)[:, :, :text_len] for conv in self.convolutions], dim=1)

        # PAPER: and further max pooled along ...
        x = self.maxpool(x)[:, :, :text_len]
        
        x = x.transpose(1, 2)
        return x

In [None]:
class ConvProjection(nn.Module):
    def __init__(self, in_dim, hidden_dims, out_dim):
        super(ConvProjection, self).__init__()

        # convolution layer dimensions
        in_dims = [in_dim] + hidden_dims[:-1]
        out_dims = hidden_dims[1:]
        activations = [nn.ReLU()] * len(out_dims) + [None]

        self.linear = nn.Linear(hidden_dims[-1], out_dim)
        self.convolutions = nn.ModuleList(
            [Conv1DBN(in_dim, out_dim, k=3, bias=True, activation=act)
             for in_dim, out_dim, act in zip(in_dims, out_dims, activations)]
        )

    def forward(self, x):
        x = x.transpose(1, 2)

        # convolution projections
        for conv in self.convolutions:
            x = conv(x)
            
        x = x.transpose(1, 2)

        # match dimension with highway input
        x = self.linear(x)

        return x

In [None]:
class CBHG(nn.Module):
    def __init__(self, K, cbhg_dim, proj_dims, highway_dims, gru_dim):
        super(CBHG, self).__init__()
        self.convolution_bank = Conv1DBank(K, cbhg_dim)
        self.convolution_proj = ConvProjection(K * cbhg_dim, proj_dims, highway_dims[0])
        self.highways = nn.ModuleList([
            Highway(in_dim, out_dim) 
            for in_dim, out_dim in zip(highway_dims[:-1], highway_dims[1:])])
        self.bidirectional_gru = nn.GRU(highway_dims[-1], gru_dim, bidirectional=True, batch_first=True)

    def forward(self, x):
        # residual
        residual = x

        # Conv 1D bank + stacking + maxpool -> [B, text_len, K * 256]
        x = self.convolution_bank(x)

        # Convolution projection -> [B, text_len, 128]
        x = self.convolution_proj(x)

        # residual connection
        x += residual

        # highway layers -> [B, text_len, 128]
        for highway in self.highways:
            x = highway(x)

        # Bidirectional RNN -> [B, text_len, 256]
        x, _ = self.bidirectional_gru(x)
        
        return x
        

In [None]:
class Encoder(nn.Module):
    def __init__(self, hparams):
        super(Encoder, self).__init__()

        # character embedding
        n_phonemes = hparams.n_phonemes
        ch_emb_dim = hparams.character_embedding_dim
        self.character_embedding = nn.Embedding(n_phonemes, ch_emb_dim)

        # prenet
        prenet_dim = hparams.encoder_prenet
        self.prenet = PreNet(ch_emb_dim, prenet_dim)

        # cbhg
        encoder_K = hparams.encoder_cbhg_k
        encoder_cbhg_dim = hparams.encoder_cbhg_dim
        encoder_proj_dim = hparams.encoder_conv1d_projection
        highway_dims = hparams.encoder_highway
        gru_dim = hparams.encoder_bidirectional_gru
        self.cbhg = CBHG(encoder_K, encoder_cbhg_dim, encoder_proj_dim, highway_dims, gru_dim)

    def forward(self, x):
        # character embedding -> [B, text_len, 128]
        x = self.character_embedding(x)

        # prenet -> [B, text_len, 128]
        x = self.prenet(x)

        # cbhg -> [B, text_len, 256]
        x = self.cbhg(x)

        return x

### Decoder

We use a content-based tanh attention decoder (see e.g. Vinyals et al. (2015)), where a stateful recurrent layer produces the attention query at each decoder time step. We concatenate the context vector and the attention RNN cell output to form the input to the decoder RNNs. We use a stack of GRUs with vertical residual connections (Wu et al., 2016) for the decoder.

While we could directly predict raw spectrogram, it’s a highly redundant representation for the purpose of learning alignment between speech signal and text. We use 80-band mel-scale spectrogram as the target, though fewer bands or more concise targets such as cepstrum could be used. We use a post-processing network (discussed below) to convert from the seq2seq target to waveform.

In [None]:
class AttentionRNN(nn.Module):
    def __init__(self, prenet_dim, attn_rnn_hidden_dim):
        super(AttentionRNN, self).__init__()
        self.gru = nn.GRUCell(prenet_dim + attn_rnn_hidden_dim, attn_rnn_hidden_dim)

    def forward(self, prenet_out, attn_out, attn_rnn_hidden):
        
        # We concatenate the context vector and the attention RNN cell output 
        x = torch.cat([prenet_out, attn_out], dim=1)

        # 1-layer GRU (256 cells)
        x = self.gru(x, attn_rnn_hidden)
        return x

In [None]:
class DecoderRNN(nn.Module):
    """
    2-layer residual GRU (default: 256 cells)
    """

    def __init__(self, input_dim: int, hidden_dim: int):
        super(DecoderRNN, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim) # 이거는 차원 맞춰주기용인데, gru1의 입력 차원을 변경할지는 생각해봐야할 듯
        self.gru1 = nn.GRUCell(hidden_dim, hidden_dim)
        self.gru2 = nn.GRUCell(hidden_dim, hidden_dim)


    def forward(self, attn_out, gru_out, dec_rnn_hiddens):
        # - attnention out  : (B, 256)
        # - gru out         : (B, 256)

        # match input dimension with linear -> [B, 256]
        decoder_in = torch.cat([gru_out, attn_out], dim=-1)
        decoder_in = self.linear1(decoder_in)

        # first layer residual GRU -> [B, 256]
        dec_rnn_hiddens[0] = self.gru1(decoder_in, dec_rnn_hiddens[0])
        decoder_in = decoder_in + dec_rnn_hiddens[0]

        # second layer residual GRU -> [B, 256]
        dec_rnn_hiddens[1] = self.gru2(decoder_in, dec_rnn_hiddens[1])
        decoder_out = decoder_in + dec_rnn_hiddens[1]

        # [B, 256], [2, B, 256]
        return decoder_out, dec_rnn_hiddens

Attention Mechanism introduced in 'Grammar as a Foreign Language'  
Orio Vinyals et al. (2015)

\begin{aligned}

u_i^t &= v^T \text{tanh}(W_1^\prime h_i + W_2^\prime d_t) \\

a_i^t &= \text{softmax}(u_i^t) \\

d_t^\prime &= \sum_{i=1}^{T_A} a_i^t h_i

\end{aligned}

In [None]:
class Attention(nn.Module):
    """
    Content-based tanh attention decoder (Vinyals et al. (2015))
    """

    def __init__(self):
        super(Attention, self).__init__()
        self.W1         = nn.Linear(256, 256)
        self.W2         = nn.Linear(256, 256)
        self.v          = nn.Linear(256, 1, bias=False)
        self.tanh       = nn.Tanh()
        self.softmax    = nn.Softmax(dim=1)


    def forward(self, d, h, mask):
        
        # Projection of attention rnn hidden -> [B, 256]
        d_proj = self.W1(d)

        # Projection of encoder out -> [B, text_len, 256]
        h_proj = self.W2(h)

        # Expand attention rnn hidden dimension -> [B, 1, 256]
        if d_proj.dim() == 2:
            d_proj = d_proj.unsqueeze(1)

        # Add projection results and apply tanh -> [B, text_len, 256]
        o = self.tanh(d_proj + h_proj)
        
        # Calculate attention score -> [B, text_len, 1]
        u = self.v(o)

        # Squeeze output -> [B, text_len]
        u = u.squeeze(2)

        # if using masked attention
        if mask is not None:
            mask = mask.view(d.size(0), -1)
            u.data.masked_fill_(mask, -float("inf"))

        # Convert to probability -> [B, text_len]
        a = self.softmax(u)

        # Matrix multiplication with attention score -> [B, 1, 256]
        h_prime = torch.bmm(a.unsqueeze(1), h)

        # Squeeze output -> [B, 256]
        h_prime = h_prime.squeeze(1)
        
        return h_prime, a

In [None]:
def get_mask_from_lengths(memory, memory_lengths):
    """Get mask tensor from list of length using triu

    Args:
        memory: (batch, max_time, dim)
        memory_lengths: array like
    """
    max_time = memory.size(1)
    # Create a range tensor and expand it to match the batch size
    mask = torch.arange(max_time, device=memory.device).expand(len(memory_lengths), max_time)
    # Compare with lengths using clone().detach() to avoid the warning
    memory_lengths_tensor = memory_lengths.clone().detach()
    mask = mask >= memory_lengths_tensor.unsqueeze(1)
    return mask

In [None]:
class Decoder(nn.Module):
    """
    Attention: content-based tanh attention decoder
    """

    def __init__(self, hparams):
        super(Decoder, self).__init__()

        self.max_decoding_timestep = hparams.max_decoding_timestep

        # prenet
        self.n_mels = hparams.n_mels
        self.r      = hparams.reduction_factor
        decoder_prenet_dims = hparams.decoder_prenet
        self.prenet = PreNet(self.n_mels * self.r, decoder_prenet_dims)

        # attention rnn
        self.attn_rnn_dim = hparams.attention_rnn_dim
        self.attention_rnn = AttentionRNN(decoder_prenet_dims[-1], self.attn_rnn_dim)

        # attention
        self.attention = Attention()

        # decoder rnn
        decoder_dim = hp.decoder_rnn_dim
        self.decoder_rnn = DecoderRNN(decoder_dim * 2, decoder_dim)

        # postnet
        decoder_K = hp.decoder_cbhg_k
        decoder_cbhg_dim = hp.decoder_cbhg_dim
        decoder_proj_dim = hp.decoder_conv1d_projection
        highway_dims = hparams.decoder_highway
        gru_dim = hparams.encoder_bidirectional_gru

        self.postnet = CBHG(decoder_K, decoder_cbhg_dim, decoder_proj_dim, highway_dims, gru_dim)

        # linears
        self.pre_cbhg = nn.Linear(80, 128)
        self.decoder_linear = nn.Linear(256, self.n_mels * self.r)
        self.linear = nn.Linear(256, 1025)


    def forward(self, z, y, len):
        # z : encoder output        -> [B, text_len, 256]
        # y : target melspectrogram -> [B, seq_len, 80]

        # Batch size
        B = z.size(0)

        # initial variables
        input_frame      = z.new_zeros(B, self.n_mels * self.r)
        attn_rnn_hidden  = z.new_zeros(B, 256)
        dec_rnn_hiddens  = [z.new_zeros(B, 256) for _ in range(2)]
        attn_out         = z.new_zeros(B, 256)

        # Store predicted melframes and alignments
        pred_mel_frames, pred_alignments = [], []
        
        # Maximum timestep
        max_T = None

        # If it is training
        if y is not None:
            if y.size(2) == self.n_mels:
                y = y.contiguous()
                y = y.view(B, y.size(1) // self.r, -1)
            assert y.size(2) == self.n_mels * self.r
            max_T = y.size(1)
            y = y.transpose(0, 1)

        if len is not None:
            mask = get_mask_from_lengths(z, len)
        else:
            mask = None

        t = 0

        while True:

            # decoder prenet -> [B, 128]
            o = self.prenet(input_frame)

            # attention rnn -> [B, 256]
            attn_rnn_hidden = self.attention_rnn(o, attn_out, attn_rnn_hidden)

            # attention -> attention context [B, 256], alignment [B, text_len]
            attn_out, alignment = self.attention(attn_rnn_hidden, z, mask)
            pred_alignments.append(alignment)

            # decoder rnn -> decoder rnn out [B, 256], dec_rnn_hiddens [2, B, 256]
            decoder_rnn_out, dec_rnn_hiddens = self.decoder_rnn(attn_out, attn_rnn_hidden, dec_rnn_hiddens)
            
            # Make reduction factor (r) number of mel frames -> [B, 80 * r]
            r_mel_frames = self.decoder_linear(decoder_rnn_out)

            pred_mel_frames.append(r_mel_frames)
            pred_alignments.append(alignment)

            t += 1

            # Inference: Check if it is end
            if self.training is False and max_T is None:
                if self.is_end_of_frame(r_mel_frames):
                    break
                elif t > self.max_decoding_timestep:
                    print("Mel spectrogram does not seem to be converged.")
                    break
            
            # Training: Proceed until max timestep
            else:
                if t == max_T:
                    break

            # 학습: 입력을 ground truth melspectrogram frame을 사용
            # 추론: 이전 decoder step에서 예측한 melspectrogram frame을 사용
            input_frame = y[t - 1] if self.training else r_mel_frames


        # Concat all pred_alignments -> [B, seq_len // r, text_len]
        pred_alignments = torch.stack(pred_alignments).transpose(0, 1)

        # Concat all mel frames -> [B, seq_len // r, 80 * r]
        mel_pred   = torch.stack(pred_mel_frames).transpose(0, 1).contiguous()

        # predicted mel spectrograms -> [B, seq_len, 80]
        mel_pred  = mel_pred.view(B, -1, 80)

        # Post-net 처리를 거친 후 선형 레이어를 통해 최종 스펙트로그램 생성
        lin_pred  = self.pre_cbhg(mel_pred)
        lin_pred  = self.postnet(lin_pred)
        lin_pred  = self.linear(lin_pred)

        return mel_pred, lin_pred, pred_alignments

        
    def is_end_of_frame(self, z):
        return (z < 0.2).all()

In [None]:
class Tacotron(nn.Module):
    def __init__(self, hp):
        super(Tacotron, self).__init__()
        self.encoder = Encoder(hp)
        self.decoder = Decoder(hp)

    def forward(self, x, y=None, text_len=None):
        x = self.encoder(x)
        mel, lin, alignment = self.decoder(x, y, text_len)
        return mel, lin, alignment

In [None]:
class TacotronLoss():
    def __init__(self):
        self.l1_loss = nn.L1Loss()


    def __call__(self, mel_pred, lin_pred, mel_targ, lin_targ, seq_len):
        
        # prepare mask
        mel_mask = self.get_mask(mel_pred, seq_len)
        lin_mask = self.get_mask(lin_pred, seq_len)

        # masking
        mel_pred = mel_pred * mel_mask
        lin_pred = lin_pred * lin_mask

        # calculate l1 loss
        mel_loss = self.l1_loss(mel_pred, mel_targ)
        lin_loss = self.l1_loss(lin_pred, lin_targ)

        # conjugate loss
        return mel_loss * 0.5 + lin_loss * 0.5


    def get_mask(self, pred, seq_len):
        mask = pred.data.new(pred.permute(0, 2, 1).size()).fill_(1)
        for id, len in enumerate(seq_len):
            mask[id, :, len:] = 0
        mask = mask.permute(0, 2, 1)
        return mask

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = Tacotron(hp).to(device)

criterion = TacotronLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, betas=(0.9, 0.999))

dataset = TacotronDataset()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=hp.batch_size, collate_fn=TacotronTTSCollate(), shuffle=True)

epoch = 0
for text, lin, mel, text_len, seq_len in dataloader:

    # move device
    text = text.to(device)
    lin = lin.to(device)
    mel = mel.to(device)
    text_len = text_len.to(device)
    seq_len = seq_len.to(device)

    # model prediction
    mel_pred, lin_pred, alignment = model(text, mel, text_len)

    # loss function
    loss = criterion(mel_pred, lin_pred, mel, lin, seq_len)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if epoch % 10 == 0:
        show_melspectrogram(lin_pred[0].cpu().detach().numpy().transpose(1, 0),
                            lin[0].cpu().detach().numpy().transpose(1, 0))

    epoch += 1
