# Welcome to RAD-MMM inference tutorial

Before we begin, please download the following dependencies:

1. Language dictionaries from [here](https://drive.google.com/drive/folders/1woNCODwXh9aHu7Fd6b4Jo42aL7f5RFZg) and place them in `RAD-MMM/assets` folder.
2. Download RAD-MMM checkpoint(s) and its config - [decoder.ckpt](https://drive.google.com/file/d/1ZLFHY5iSMdK852UwF1RqFr7cY2ejzeOw/view), [attribute_model.ckpt](https://drive.google.com/file/d/1EduYNwgtRlezJt6RiXMLBBSSOpIbp2CT/view?usp=sharing) and [config.yaml](https://drive.google.com/file/d/1c_dGA82k2Ow65P0vXwYwRTEipNdzTsTa/view?usp=sharing).
3. Download HiFi-GAN vocoder checkpoint and its config - [g_00072000](https://drive.google.com/file/d/1VaH5_MhAjAjHlihi2k-lcOOoy4NqtRV4/view) and [config_16khz.json](https://drive.google.com/file/d/1-eBTNfIh-LSstNirQawHW4jsI-t01jTU/view?usp=sharing).

In [None]:
# imports
import pytorch_lightning as pl
import sys
import yaml
sys.path.append('/akshit/scratch/RAD-MMM/vocoders')
sys.path.append('/akshit/scratch/RAD-MMM')
from pytorch_lightning.cli import LightningCLI
from tts_lightning_modules import TTSModel
from data_modules import BaseAudioDataModule
from jsonargparse import lazy_instance
from decoders import RADMMMFlow
from loss import RADTTSLoss
import inspect
from pytorch_lightning.callbacks import ModelCheckpoint
from training_callbacks import LogDecoderSamplesCallback, \
    LogAttributeSamplesCallback
from utils import get_class_args
from tts_text_processing.text_processing import TextProcessing
from common import Encoder
import torch
import IPython.display as ipd

In [None]:
# Set paths for downloaded files
attribute_model_path = "../generator_ckpt/radmmm_public/attribute_model.ckpt"
gen_config_path = "../generator_ckpt/radmmm_public/config.yaml"
decoder_model_path = "../generator_ckpt/radmmm_public/decoder.ckpt"
voc_model_path = "../generator_ckpt/hfg_public/g_00072000"
voc_config_path = "../generator_ckpt/hfg_public/config_16khz.json"
phonemizer_cfg='{"en_US": "assets/en_US_word_ipa_map.txt","es_MX": "assets/es_MX_word_ipa_map.txt","de_DE": "assets/de_DE_word_ipa_map.txt","en_UK": "assets/en_UK_word_ipa_map.txt","es_CO": "assets/es_CO_word_ipa_map.txt","es_ES": "assets/es_ES_word_ipa_map.txt","fr_FR": "assets/fr_FR_word_ipa_map.txt","hi_HI": "assets/hi_HI_word_ipa_map.txt","pt_BR": "assets/pt_BR_word_ipa_map.txt","te_TE": "assets/te_TE_word_ipa_map.txt"}'

## Load the model

In [None]:
# load the config

with open(gen_config_path, "r") as f:
    gen_config = yaml.safe_load(f)

In [None]:
def instantiate_class(init):
    """Instantiates a class with the given args and init.

    Args:
        args: Positional arguments required for instantiation.
        init: Dict of the form {"class_path":...,"init_args":...}.

    Returns:
        The instantiated class object.
    """
    kwargs = init.get("init_args", {})
    class_module, class_name = init["class_path"].rsplit(".", 1)
    module = __import__(class_module, fromlist=[class_name])
    args_class = getattr(module, class_name)
    return args_class(**kwargs)

In [None]:
# instantiate submodules

gen_config["model"]["add_bos_eos_to_text"] = False
gen_config["model"]["append_space_to_text"] = True
gen_config["model"]["decoder_path"] = decoder_model_path
gen_config["model"]["encoders_path"] = decoder_model_path
gen_config["model"]["handle_phoneme"] = "word"
gen_config["model"]["handle_phoneme_ambiguous"] = "ignore"
gen_config["model"]["heteronyms_path"] = "tts_text_processing/heteronyms"
gen_config["model"]["output_directory"] = "tutorials/run1"
gen_config["model"]["p_phoneme"] = 1
gen_config["model"]["phoneme_dict_path"] = "tts_text_processing/cmudict-0.7b"
gen_config["model"]["phonemizer_cfg"] = phonemizer_cfg
gen_config["model"]["prediction_output_dir"] = "tutorials/out1"
gen_config["model"]["prepend_space_to_text"] = True
gen_config["model"]["sampling_rate"] = 16000
gen_config["model"]["symbol_set"] = "radmmm_phonemizer_marker_segregated"
gen_config["model"]["vocoder_checkpoint_path"] = voc_model_path
gen_config["model"]["vocoder_config_path"] = voc_config_path

hparams = gen_config["model"]
ttsmodel_kwargs={}
for k,v in hparams.items():
    if type(v) == dict and 'class_path' in v:
        print(k)
        ttsmodel_kwargs[k] = instantiate_class(v)
    elif k != "_instantiator":
        ttsmodel_kwargs[k] = v

In [None]:
# load the model from checkpoint
model2 = TTSModel.load_from_checkpoint(checkpoint_path=attribute_model_path,\
                                      **ttsmodel_kwargs)


## Initialize the datamodule

In [None]:
# initialize the datamodule

gen_config["data"]["inference_transcript"]="model_inputs/resynthesis_prompts.json" #ToDo
gen_config["data"]["batch_size"]=1
gen_config["data"]["phonemizer_cfg"]=phonemizer_cfg
data_module = BaseAudioDataModule(**gen_config['data'])
data_module.setup(stage = "predict")

## Run Inference

In [None]:
# run the input through the model
def run_inference(text, speaker_id, input_language_id, target_accent_id, script=None):
    if script == None:
        script = data_module.tp.convert_to_phoneme(text=text, phoneme_dict=data_module.tp.phonemizer_backend_dict[input_language_id])
    print("Converted the text to phonemes: ", script)
    inferData = [{
      "script": script,
      "spk_id": speaker_id,
      "decoder_spk_id": speaker_id,
      "duration_spk_id": speaker_id,
      "energy_spk_id": speaker_id,
      "f0_spk_id": speaker_id,
      "language": target_accent_id,
      "emotion": "other"
    }]
    
    ## set predictset
    data_module.predictset.data = inferData
    
    ## initialize and get the dataloader
    dl = data_module.predict_dataloader()
    
    ## get the first input
    inp1 = next(iter(dl))
    
    ## move the input tensors to GPU
    for k in inp1.keys():
        if type(inp1[k]) == torch.Tensor:
            inp1[k] = inp1[k].to(device="cuda")

    return model2.forward(inp1) 

In [None]:
# first example - ljs (native english speaker) speaking en_US

text = "Hope you are enjoying our session so far!"
speaker_id = "ljs"
input_language_id = "en_US"
target_accent_id = input_language_id
output_file_path = run_inference(text, speaker_id, input_language_id, target_accent_id)


In [None]:
ipd.Audio(output_file_path)

In [None]:
# first example with user-provided phonemes for fine-grained control over speech

text = "Hope you are enjoying our session so far!"
speaker_id = "ljs"
input_language_id = "en_US"
target_accent_id = input_language_id
script="{h ˈoʊ p} {j uː} {ɑː ɹ} {ɛ n dʒ ˈɔɪ ɪ ŋ} {ˌaʊ ɚ} {s ˈɛ ʃ ə n} {s ˈoʊ} {f ˌɑːɹ!}"
output_file_path = run_inference(text, speaker_id, input_language_id, target_accent_id)


In [None]:
ipd.Audio(output_file_path)

In [None]:
# second example - native english speaker (ljs) speaking hindi

text = "आशा है कि आप अब तक हमारे सत्र का आनंद ले रहे हैं!"
speaker_id = "ljs"
input_language_id = "hi_HI"
target_accent_id = input_language_id
output_file_path = run_inference(text, speaker_id, input_language_id, target_accent_id)


In [None]:
ipd.Audio(output_file_path)

In [None]:
# second example - with user-provided phonemes

text = "आशा है कि आप अब तक हमारे सत्र का आनंद ले रहे हैं!"
speaker_id = "ljs"
input_language_id = "hi_HI"
target_accent_id = input_language_id
script="{ˈaː ʃ aː} {h ɛː} {k ˈɪ} {ˌaː p} {ˈʌ b} {t ˌə k} {h ə m ˌaː ɾ eː} {s ˈʌ t ɾ ə} {k aː} {aː n ˈʌ n d} {l ˈeː} {ɾ ˌə h eː} {h ɛ̃!}"
output_file_path = run_inference(text, speaker_id, input_language_id, target_accent_id)

In [None]:
ipd.Audio(output_file_path)

In [None]:
# third example - saying hindi in english accent

text = "आशा है कि आप अब तक हमारे सत्र का आनंद ले रहे हैं!"
speaker_id = "ljs"
input_language_id = "hi_HI"
target_accent_id = "en_US"

output_file_path = run_inference(text, speaker_id, input_language_id, target_accent_id)


In [None]:
ipd.Audio(output_file_path)

In [None]:
# visualize the output

%matplotlib inline
import matplotlib.pyplot as plt
import librosa.display
x, sr = librosa.load(output_file_path)
plt.figure(figsize=(14, 5))
librosa.display.waveshow(x, sr=sr)

In [None]:
X = librosa.stft(x)
Xdb = librosa.amplitude_to_db(abs(X))
plt.figure(figsize=(14, 5))
librosa.display.specshow(Xdb, sr=sr, x_axis='time', y_axis='hz')

## Cleanup

In [None]:
# teardown datamodule
data_module.teardown(stage="predict")

In [None]:
# free up GPU memory
del model2