In [None]:
from nemo.collections.tts.models import T5TTS_Model
from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset, DatasetSample
from omegaconf.omegaconf import OmegaConf, open_dict
import torch
import os
import soundfile as sf
from IPython.display import display, Audio
import numpy as np
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

### Checkpoint Paths

In [None]:
hparams_file = "/datap/misc/ChallengingFinetuneLocalTraining/head1hparams.yaml"
checkpoint_file = "/datap/misc/ChallengingFinetuneLocalTraining/head1_epoch248.ckpt"
codecmodel_path = "/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo"

# Temp out dir for saving audios
out_dir = "/datap/misc/t5tts_inference_notebook_samples"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

### Load Model

In [None]:
model_cfg = OmegaConf.load(hparams_file).cfg

with open_dict(model_cfg):
    model_cfg.codecmodel_path = codecmodel_path
    if hasattr(model_cfg, 'text_tokenizer'):
        # Backward compatibility for models trained with absolute paths in text_tokenizer
        model_cfg.text_tokenizer.g2p.phoneme_dict = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt"
        model_cfg.text_tokenizer.g2p.heteronyms = "scripts/tts_dataset_files/heteronyms-052722"
        model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0
    model_cfg.train_ds = None
    model_cfg.validation_ds = None


model = T5TTS_Model(cfg=model_cfg)
print("Loading weights from checkpoint")
ckpt = torch.load(checkpoint_file)
model.load_state_dict(ckpt['state_dict'])
print("Loaded weights.")

model.use_kv_cache_for_inference = True

model.cuda()
model.eval()

### Initialize Dataset class and helper functions

In [None]:
test_dataset = T5TTSDataset(
    dataset_meta={},
    sample_rate=model_cfg.sample_rate,
    min_duration=0.5,
    max_duration=20,
    codec_model_downsample_factor=model_cfg.codec_model_downsample_factor,
    bos_id=model.bos_id,
    eos_id=model.eos_id,
    context_audio_bos_id=model.context_audio_bos_id,
    context_audio_eos_id=model.context_audio_eos_id,
    audio_bos_id=model.audio_bos_id,
    audio_eos_id=model.audio_eos_id,
    num_audio_codebooks=model_cfg.num_audio_codebooks,
    prior_scaling_factor=None,
    load_cached_codes_if_available=True,
    dataset_type='test',
    tokenizer_config=None,
    load_16khz_audio=model.model_type == 'single_encoder_sv_tts',
    use_text_conditioning_tokenizer=model.use_text_conditioning_encoder,
    pad_context_text_to_max_duration=model.pad_context_text_to_max_duration,
    context_duration_min=model.cfg.get('context_duration_min', 5.0),
    context_duration_max=model.cfg.get('context_duration_max', 5.0),
)
test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test')



def get_audio_duration(file_path):
    with sf.SoundFile(file_path) as audio_file:
        # Calculate the duration
        duration = len(audio_file) / audio_file.samplerate
        return duration

def create_record(text, context_audio_filepath=None, context_text=None):
    dummy_audio_fp = os.path.join(out_dir, "dummy_audio.wav")
    dummy_audio = sf.write(dummy_audio_fp, np.zeros(22050 * 3), 22050)  # 3 seconds of silence
    record = {
        'audio_filepath' : dummy_audio_fp,
        'duration': 3.0,
        'text': text,
        'speaker': "dummy",
    }
    if context_text is not None:
        assert context_audio_filepath is None
        record['context_text'] = context_text
    else:
        assert context_audio_filepath is not None
        record['context_audio_filepath'] = context_audio_filepath
        record['context_audio_duration'] = get_audio_duration(context_audio_filepath)
    
    return record

### Set transcript and context pairs to test

In [None]:
# Change sample text and prompt audio/text here
audio_base_dir = "/"
test_entries = [
    create_record(
        text="This is an example of a regular text without any challlenging texts like repeated words or numbers just to do a sanity check.",
        context_text="Speaker and Emotion: | Language:en Dataset:Riva Speaker:Lindy_WIZWIKI |",
        #context_audio_filepath="/datap/misc/LibriTTSfromNemo/LibriTTS/test-clean/7729/102255/7729_102255_000012_000001.wav", # Supply either context_audio_filepath or context_text, not both
    ),
]

data_samples = []
for entry in test_entries:
    dataset_sample = DatasetSample(
        dataset_name="sample",
        manifest_entry=entry,
        audio_dir=audio_base_dir,
        feature_dir=audio_base_dir,
        text=entry['text'],
        speaker=None,
        speaker_index=0,
        tokenizer_names=["english_phoneme"], # Change this for multilingual: "english_phoneme", "spanish_phoneme", "english_chartokenizer", "german_chartokenizer".. 
    )
    data_samples.append(dataset_sample)
    
test_dataset.data_samples = data_samples

test_data_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    collate_fn=test_dataset.collate_fn,
    num_workers=0,
    shuffle=False
)

### Generate With Prior

In [None]:
import matplotlib.pyplot as plt

item_idx = 0
for bidx, batch in enumerate(test_data_loader):
    print("Processing batch {} out of {}".format(bidx, len(test_data_loader)))
    model.t5_decoder.reset_cache(use_cache=True)
    batch_cuda ={}
    for key in batch:
        if isinstance(batch[key], torch.Tensor):
            batch_cuda[key] = batch[key].cuda()
        else:
            batch_cuda[key] = batch[key]
    import time
    st = time.time()
    
    for _ in range(1):
        for apply_prior in [True, False]:
            predicted_audio, predicted_audio_lens, _, _, cross_attn_np, all_heads_attn_np = model.infer_batch(
                batch_cuda, 
                max_decoder_steps=430, 
                temperature=0.6, 
                topk=80, 
                use_cfg=True,
                cfg_scale=2.5,
                prior_epsilon=0.1,
                lookahead_window_size=5,
                return_cross_attn_probs=True,
                estimate_alignment_from_layers=[4,6,7],
                apply_attention_prior=apply_prior,
                apply_prior_to_layers=[3,4,5,6,7,8,9,10],
                compute_all_heads_attn_maps=True,
                start_prior_after_n_audio_steps=10
            )
            print("generation time", time.time() - st)
            for idx in range(predicted_audio.size(0)):
                predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()
                predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]
                audio_path = os.path.join(out_dir, f"predicted_audio_{item_idx}.wav")
                sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)
                print(test_entries[bidx]['text'])
                print("Prior Used?", apply_prior)
                display(Audio(audio_path))
                item_idx += 1
                plt.imshow(cross_attn_np[idx])
                plt.show()
#                 for hidx, head_cross_attn in enumerate(all_heads_attn_np[idx]):
#                     layer_num = hidx // model.cfg.t5_decoder.xa_n_heads
#                     head_num = hidx % model.cfg.t5_decoder.xa_n_heads
#                     print("item, layer, head", idx, layer_num, head_num)
#                     plt.imshow(all_heads_attn_np[idx][hidx])
#                     plt.show()

            print("------------------------------------")
            print("------------------------------------")