In [1]:
from utils.utils import TacotronPreprocessor, TTSDataset, collate_fn, reconstruct_audio
import pandas as pd
import numpy as np
import re
import torch
import torchaudio
import torchaudio.functional as F
from torch.utils.data import Dataset, DataLoader
from torchaudio import transforms
from torchaudio.functional import preemphasis
import hyperparams as hps

In [2]:
import torch 
from torch import nn

In [3]:
dataset = TTSDataset()
dataloader = DataLoader(dataset, 32, collate_fn=collate_fn, shuffle=False)

In [4]:
data = next(iter(dataloader))

In [5]:
data[1].shape

torch.Size([32, 80, 558])

In [6]:
### Model

In [22]:
class EncoderConvLayer(nn.Module):
    def __init__(self, input_channels, output_channels, kernel_size) -> None:
        super().__init__()
        self.module = nn.Sequential(
            nn.Conv1d(in_channels=input_channels, out_channels=output_channels, kernel_size=kernel_size, bias=False),
            nn.BatchNorm1d(output_channels),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
    def forward(self, x):
        return self.module(x)

In [23]:
class Encoder(nn.Module):
    def __init__(self, characters_num, embedding_size, lstm_hidden_size) -> None:
        super().__init__()
        self.char_embedding = nn.Embedding(characters_num, embedding_size)
        self.conv_layers = nn.Sequential(
            EncoderConvLayer(embedding_size, embedding_size, 5),
            EncoderConvLayer(embedding_size, embedding_size, 5),
            EncoderConvLayer(embedding_size, embedding_size, 5),
        )
        self.rnn = nn.LSTM(input_size=embedding_size,
                           hidden_size=lstm_hidden_size,
                           bidirectional=True, batch_first=True)
        self.rnn_dropout = nn.Dropout(0.1)

    
    def forward(self, x: torch.tensor):
        """
        На вход подается последовательность символов. Размерность [BATCH_SIZE, NUM_CHARACTERS]
        """
        x = self.char_embedding(x)  #[BATCH_SIZE, NUM_CHARACTERS, EMB_SIZE]
        x = x.transpose(1,2) #[BATCH_SIZE, EMB_SIZE, NUM_CHARACTERS]
        x = self.conv_layers(x)
        x = x.transpose(1,2) #[BATCH_SIZE, NUM_CHARACTERS, CONV_EMB]
        x = self.rnn(x)[0]
        x = self.rnn_dropout(x)

        return x

In [24]:
class PreNet(nn.Module):
    def __init__(self, num_mels, prenet_hidden_dim) -> None:
        super().__init__()
        self.module = nn.Sequential(
            nn.Linear(num_mels, prenet_hidden_dim),
            nn.ReLU(),
            nn.Linear(prenet_hidden_dim, prenet_hidden_dim),
            nn.ReLU()
        )
    def forward(self, x):
        x = self.module(x)
        return x
        

In [47]:
# class Tacotron2Attention(nn.Module):
#     def __init__(self) -> None:
#         super().__init__()
        



In [49]:
class Tacotron2(nn.Module):
    def __init__(self, characters_num: int = 0) -> None:
        super().__init__()
        self.characters_num = characters_num
        self.encoder = Encoder(characters_num, hps.CHARACTER_EMB_SIZE, hps.LSTM_HIDDEN_SIZE)

    def forward(self, x):
        x = self.encoder(x)
        return x

In [50]:
vocab_size = dataset.preprocessor.vocab.shape[0]+1
model = Tacotron2(vocab_size)

decoder = Decoder()

In [51]:
temp_input_encoder = data[0]
temp_input_decoder = data[1]
encoder_output = model(temp_input_encoder)

In [52]:
ttt = decoder(temp_input_decoder, encoder_output)

In [53]:
ttt[0].shape

torch.Size([32, 558, 128])

In [46]:
ttt[1].shape

torch.Size([32, 84, 128])

In [35]:
encoder_output.shape

torch.Size([32, 84, 512])

In [20]:
decoder(data[1]).shape

torch.Size([32, 558, 256])

In [14]:
data[1].shape

torch.Size([32, 80, 558])

In [23]:
data[1].shape

torch.Size([32, 80, 558])

In [10]:
temp_input = data[0]

In [11]:
vocab_size = dataset.preprocessor.vocab.shape[0]+1
model = Tacotron2(vocab_size)

In [12]:
model(temp_input).shape

torch.Size([32, 84, 512])

In [14]:
data[1].shape

torch.Size([32, 80, 558])

In [16]:
data[2].shape

torch.Size([32, 558])

In [41]:
model.characters_num

39

In [13]:
temp_input.unique().shape

torch.Size([39])