In [None]:
import matplotlib.pyplot as plt

from data import sequence_to_text
from preprocess import get_dataset, DataLoader, collate_fn_transformer
from module import TextPrenet, SpeechPrenet, SpeechPostnet

### Check network input
Get the inputs to the network as described in `train.py` inputs are as expected.

In [None]:
dataset = get_dataset()
dataloader = DataLoader(dataset, batch_size=32,
                        shuffle=True, collate_fn=collate_fn_transformer,
                        drop_last=True, num_workers=16)

Dataloader's first row:

In [None]:
tensor = next(iter(dataloader))
text, mel, mel_input, pos_text, pos_mel, text_length = tensor

In [None]:
print("Text:", text[0])
print("Mel:", mel[0])
print("Mel Input:", mel_input[0])
print("Pos Text:", pos_text[0])
print("Pos Mel:", pos_mel[0])
print("Text Length:", text_length[0])

As expected, the input to the text encoders are the phoneme sequences:

In [None]:
sequence_to_text(text[0].numpy().tolist())

The input to the speech encoders are the mel spectrograms (80 mel filters per window):

In [None]:
fig = plt.figure(figsize=(30, 100))
_ = plt.imshow(mel_input[0].numpy().T, origin="lower")
_ = plt.xlabel("Time")
_ = plt.ylabel("Mel Filters")

### Test Text Prenet

In [None]:
text_prenet = TextPrenet(embedding_size=512,  # from Transformer-TTS
                         num_hidden=256)  # from Transformer-TTS

In [None]:
text_prenet_output = text_prenet(text).detach()

In [None]:
print("Input shape:", text.shape)

In [None]:
print("Output shape:", text_prenet_output.shape)

As expected, output an embedding of 256 for each phoneme.

In [None]:
fig = plt.figure(figsize=(3, 20))
_ = plt.imshow(text_prenet_output[0].T)
_ = plt.ylabel("Embedding")
_ = plt.xlabel("Phonemes")

### Test Speech Prenet

In [None]:
speech_prenet = SpeechPrenet(num_mels=80,  # 80 mel filters
                             hidden_size=256,  # as indicated in Ren's paper
                             output_size=256)  # this depends on decoder

In [None]:
speech_prenet_output = speech_prenet(mel_input).detach()

In [None]:
print("Input shape:", mel_input.shape)

In [None]:
print("Output shape:", speech_prenet_output.shape)

### Test Speech Postnet

In [None]:
speech_postnet = SpeechPostnet(num_mels=80,  # 80 mel filters
                               num_hidden=256)  # as indicated in Ren's paper

The `mel_input` used here should be the mel-spectrogram output from the speech decoder. The Postnet is used to refine the mel-spectrogram further.

In [None]:
decoder_output = mel_input.transpose(1, 2)

In [None]:
speech_postnet_output = speech_postnet(decoder_output).detach()

In [None]:
print("Input shape:", decoder_output.shape)

In [None]:
print("Output shape:", speech_postnet_output.shape)

In [None]:
fig = plt.figure(figsize=(30, 100))
_ = plt.imshow(speech_postnet_output[0], origin="lower")
_ = plt.ylabel("(Supposedly) Mel Filters")
_ = plt.xlabel("Time")