In [None]:
## tensorflow-gpu==2.3.0rc1 bug to load_weight after call inference
!pip install tensorflow==2.2.0

In [None]:
import yaml

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from ipywidgets import Audio
from tensorflow_tts.inference import AutoConfig
from tensorflow_tts.inference import TFAutoModel
from tensorflow_tts.processor.ljspeech import LJSpeechProcessor
from tensorflow_tts.processor.ljspeech import symbols, _symbol_to_id
from tensorflow_tts.utils import TFGriffinLim

In [None]:
dataset_config_path = "../preprocess/ljspeech_preprocess.yaml"
ds_config = yaml.load(open(dataset_config_path), Loader=yaml.Loader)
stats_path = "../dump/stats.npy"

griffin_lim_tf = TFGriffinLim(dataset_config_path, stats_path)
processor = LJSpeechProcessor(None, "english_cleaners")

input_text = "i love you so much."
input_ids = processor.text_to_sequence(input_text)
input_ids = np.concatenate([input_ids, [len(symbols) - 1]], -1)

In [None]:
config = AutoConfig.from_pretrained("../examples/tacotron2/conf/tacotron2.v1.yaml")
tacotron2 = TFAutoModel.from_pretrained(
    config=config, 
    pretrained_path=None,
    is_build=False, # don't build model if you want to save it to pb. (TF related bug)
    name="tacotron2"
)

tacotron2.setup_window(win_front=6, win_back=6)
tacotron2.setup_maximum_iterations(3000)

# Save to serialized protocol buffer

In [None]:
(decoder_output, mel_outputs, stop_token_prediction, alignments) = tacotron2.inference(
    input_ids=tf.expand_dims(tf.convert_to_tensor(input_ids, dtype=tf.int32), 0),
    input_lengths=tf.convert_to_tensor([len(input_ids)], tf.int32),
    speaker_ids=tf.convert_to_tensor([0], dtype=tf.int32),
)

In [None]:
tacotron2.load_weights("../examples/tacotron2/checkpoints/model-120000.h5")

In [None]:
# save model into pb and do inference. Note that signatures should be a tf.function with input_signatures.
tf.saved_model.save(tacotron2, "./test_saved", signatures=tacotron2.inference)

# Load and inference

In [None]:
tacotron2 = tf.saved_model.load("./test_saved")

In [None]:
input_text = "Unless you work on a ship, it's unlikely that you use the word boatswain in everyday conversation, so it's understandably a tricky one. The word - which refers to a petty officer in charge of hull maintenance is not pronounced boats-wain Rather, it's bo-sun to reflect the salty pronunciation of sailors, as The Free Dictionary explains."
input_ids = processor.text_to_sequence(input_text)
input_ids = np.concatenate([input_ids, [len(symbols) - 1]], -1)

In [None]:
(decoder_output, mel_outputs, stop_token_prediction, alignments) = tacotron2.inference(
    tf.expand_dims(tf.convert_to_tensor(input_ids, dtype=tf.int32), 0),
    tf.convert_to_tensor([len(input_ids)], tf.int32),
    tf.convert_to_tensor([0], dtype=tf.int32),
)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(alignments[0], aspect="auto", interpolation="none", origin="lower")
fig.colorbar(im, pad=0.02, aspect=15, orientation="vertical", ax=ax)
ax.set_xlabel("Decoder timestep")
ax.set_ylabel("Encoder timestep")
plt.tight_layout()

In [None]:
mel_outputs = tf.reshape(mel_outputs, [-1, config.n_mels])
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(np.rot90(mel_outputs), aspect="auto", interpolation="none")
fig.colorbar(im, pad=0.02, aspect=15, orientation="vertical", ax=ax)
ax.set_title("Predicted mel spectrogram")
plt.tight_layout()

In [None]:
gl_output = griffin_lim_tf(mel_outputs[tf.newaxis, :])
tf_wav = tf.audio.encode_wav(gl_output[0, :, tf.newaxis], ds_config["sampling_rate"])
Audio(value=tf_wav.numpy(), autoplay=False, loop=False)

# Inference with input of different shapes

In [None]:
input_text = "The Commission further recommends that the Secret Service coordinate its planning as closely as possible with all of the Federal agencies from which it receives information."
input_ids = processor.text_to_sequence(input_text)
input_ids = np.concatenate([input_ids, [len(symbols) - 1]], -1)  # eos.

In [None]:
(decoder_output, mel_outputs, stop_token_prediction, alignments) = tacotron2.inference(
    tf.expand_dims(tf.convert_to_tensor(input_ids, dtype=tf.int32), 0),
    tf.convert_to_tensor([len(input_ids)], tf.int32),
    tf.convert_to_tensor([0], dtype=tf.int32),
)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(alignments[0], aspect="auto", interpolation="none", origin="lower")
fig.colorbar(im, pad=0.02, aspect=15, orientation="vertical", ax=ax)
ax.set_xlabel("Decoder timestep")
ax.set_ylabel("Encoder timestep")
plt.tight_layout()

In [None]:
mel_outputs = tf.reshape(mel_outputs, [-1, config.n_mels])
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(np.rot90(mel_outputs), aspect="auto", interpolation="none")
fig.colorbar(im, pad=0.02, aspect=15, orientation="vertical", ax=ax)
ax.set_title("Predicted mel spectrogram")
plt.tight_layout()

In [None]:
gl_output = griffin_lim_tf(mel_outputs[tf.newaxis, :])
tf_wav = tf.audio.encode_wav(gl_output[0, :, tf.newaxis], ds_config["sampling_rate"])
Audio(value=tf_wav.numpy(), autoplay=False, loop=False)