In [None]:
from nemo.collections.tts.models import T5TTS_Model
from nemo.collections.tts.data.text_to_speech_dataset import T5TTSDataset, DatasetSample
from omegaconf.omegaconf import OmegaConf, open_dict
import torch
import os
import soundfile as sf
from IPython.display import display, Audio
import os
import numpy as np
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"

## Set checkpoint and other file paths on your machine

In [None]:
hparams_file = "/datap/misc/duplexcheckpoints/blackwell/duplex_blackwell_medium_decoder_noexpresso_onlyphonemeFT_hparams.yaml"
checkpoint_file = "/datap/misc/duplexcheckpoints/blackwell/duplex_blackwell_medium_decoder_withTC_fromroycheckpoint_lowsestvalloss.ckpt"
codecmodel_path = "/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo"
out_dir = "/datap/misc/t5tts_inference_notebook_samples"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

dummy_audio_filepath = os.path.join(out_dir, "dummy_audio.wav")
sf.write(dummy_audio_filepath, np.zeros(22050 * 3), 22050)

## Load Model

In [None]:
model_cfg = OmegaConf.load(hparams_file).cfg

with open_dict(model_cfg):
    model_cfg.codecmodel_path = codecmodel_path
    if hasattr(model_cfg, 'text_tokenizer'):
        # Backward compatibility for models trained with absolute paths in text_tokenizer
        model_cfg.text_tokenizer.g2p.phoneme_dict = "scripts/tts_dataset_files/ipa_cmudict-0.7b_nv23.01.txt"
        model_cfg.text_tokenizer.g2p.heteronyms = "scripts/tts_dataset_files/heteronyms-052722"
        model_cfg.text_tokenizer.g2p.phoneme_probability = 1.0
    model_cfg.train_ds = None
    model_cfg.validation_ds = None


model = T5TTS_Model(cfg=model_cfg)
# Load weights from checkpoint file
print("Loading weights from checkpoint")
ckpt = torch.load(checkpoint_file)
model.load_state_dict(ckpt['state_dict'])
print("Loaded weights.")

if model_cfg.t5_decoder.pos_emb == "learnable":
    if (model_cfg.t5_decoder.use_flash_self_attention) is False and (model_cfg.t5_decoder.use_flash_self_attention is False):
        print("Using kv cache for inference.")
        model.use_kv_cache_for_inference = True

model.cuda()
model.eval()

In [None]:
test_dataset = T5TTSDataset(
    dataset_meta={},
    sample_rate=model_cfg.sample_rate,
    min_duration=0.5,
    max_duration=20,
    codec_model_downsample_factor=model_cfg.codec_model_downsample_factor,
    bos_id=model.bos_id,
    eos_id=model.eos_id,
    context_audio_bos_id=model.context_audio_bos_id,
    context_audio_eos_id=model.context_audio_eos_id,
    audio_bos_id=model.audio_bos_id,
    audio_eos_id=model.audio_eos_id,
    num_audio_codebooks=model_cfg.num_audio_codebooks,
    prior_scaling_factor=None,
    load_cached_codes_if_available=True,
    dataset_type='test',
    tokenizer_config=None,
    load_16khz_audio=model.model_type == 'single_encoder_sv_tts',
    use_text_conditioning_tokenizer=model.use_text_conditioning_encoder,
    pad_context_text_to_max_duration=model.pad_context_text_to_max_duration,
    context_duration_min=model.cfg.get('context_duration_min', 5.0),
    context_duration_max=model.cfg.get('context_duration_max', 5.0),
)
test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model._setup_tokenizers(model.cfg, mode='test')

### Set dialogues to generate

Each item in the list can be a single-turn or a multi-turn dialogue.

[SPK-BWL-B-F] is the speaker tag for female speaker and [SPK-BWL-B-M] is the speaker tag for male speaker.

ChatGPT prompt that can generate something like this:

```
Generate dialogues for a 30 second podcast about <TOPIC>.
The conversation should be between a male and female speaker formatted as follows:
[SPK-BWL-B-F] Sentence by a female speaker
[SPK-BWL-B-M] Sentence by a male speaker
where [SPK-BWL-B-F] and [SPK-BWL-B-M] indicate speaker tags. Dont have any quotation marks in the text, and if there are any numbers spell them out. Basically, keep the text normalized suitable for a TTS model. Keep the conversation fun and engaging with the speakers talking and responding to each other. 
```


In [None]:
dialogues = ["[SPK-BWL-B-F] Have you been keeping up with multi-turn TTS models? They are so fascinating [SPK-BWL-B-M] Absolutely. The way they handle conversations feels almost human.",
"[SPK-BWL-B-F] I know. It is amazing how they can remember context across multiple exchanges [SPK-BWL-B-M] Exactly. It is not just about saying words anymore. They actually make the responses flow naturally.",
"[SPK-BWL-B-F] And it is not just for assistants or chatbots. I heard they are being used for things like audiobooks and interactive storytelling.",
"[SPK-BWL-B-M] That is true. Plus, they can even adjust tone to match emotions. Imagine hearing a story where the narrator sounds genuinely excited during a twist.",
"[SPK-BWL-B-F] Or even a bit sarcastic during a funny moment. That makes everything feel more real.",
"[SPK-BWL-B-M] For sure. It is like conversations with a machine are finally catching up to the way we actually talk.",
"[SPK-BWL-B-F] The future of tech keeps surprising me. Every day, there is something new to explore.",]
dialogues = [d for d in dialogues if len(d) > 0]

### Generation

Below code generates 4 samples for each item in the dialogues list. Then asks, which one you like the best (index 0,1,2 or 3) and adds that to the already generated dialogue. You may modify the below code to automate and just select the first generation if you dont want to do this manually. After every dialogue it also plays the combined dialogues until that item.

In [None]:
from pydub import AudioSegment

def play_combined_audio(audio_files):
    combined_audio = AudioSegment.empty()
    for file in audio_files:
        audio = AudioSegment.from_file(file)
        combined_audio += audio
    output_path = os.path.join(out_dir, "combined_audio.wav")
    combined_audio.export(output_path, format="wav")
    display(Audio(output_path))
    return output_path
    


context_path = None
generated_audios = []
for didx, dialogue in enumerate(dialogues):
    audio_dir = "/"
    entry = {
        "audio_filepath": dummy_audio_filepath,
        "duration": 3.0,
        "text": dialogue,
        "speaker": "dummy",
    }
    if didx == 0:
        entry["context_text"] = "MIXED SPEECH TTS"
    else:
        entry['context_audio_filepath'] = context_path
        entry['context_audio_duration'] = 5.0
        
        
        
    data_sample = DatasetSample(
        dataset_name="sample",
        manifest_entry=entry,
        audio_dir=audio_dir,
        feature_dir=audio_dir,
        text=entry['text'],
        speaker=None,
        speaker_index=0,
        tokenizer_names=["english_phoneme"]
    )
    test_dataset.data_samples = [data_sample]

    test_data_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=1,
        collate_fn=test_dataset.collate_fn,
        num_workers=0,
        shuffle=False
    )
    
    
    item_idx = 0
    for bidx, batch in enumerate(test_data_loader):
        print("Processing batch {} out of {}".format(bidx, len(test_data_loader)))
        model.t5_decoder.reset_cache(use_cache=True)
        batch_cuda ={}
        for key in batch:
            if isinstance(batch[key], torch.Tensor):
                batch_cuda[key] = batch[key].cuda()
            else:
                batch_cuda[key] = batch[key]
        import time
        
        candidates = []
        for try_idx in range(4):
            st = time.time()
            predicted_audio, predicted_audio_lens, _, _ = model.infer_batch(
                batch_cuda, 
                max_decoder_steps=500, 
                temperature=0.6, 
                topk=80, 
                use_cfg=True, 
                cfg_scale=1.6
            )
            print("generation time", time.time() - st)
            for idx in range(predicted_audio.size(0)):
                predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()
                predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]
                audio_path = os.path.join(out_dir, f"predicted_audio_{try_idx}_{didx}_{item_idx}.wav")
                sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)
                print("Dialogue:", item_idx, "Candidate: ", try_idx)
                display(Audio(audio_path))
                candidates.append(audio_path)
        
        user_input = input("Enter Candidate number that sounds the best:").strip().lower()
        selected_audio_idx = int(user_input)
        audio_path = candidates[selected_audio_idx]
        item_idx += 1
        generated_audios.append(audio_path)
        print("Podcast generated until now:")
        combined_audio_path = play_combined_audio(generated_audios)
        last_gen_audio = AudioSegment.from_file(combined_audio_path)
        last_5_seconds = last_gen_audio[-5000:]  # Duration is in milliseconds
        context_path = os.path.join(out_dir, "context.wav")
        last_5_seconds.export(context_path)