In [None]:
from nemo.collections.tts.models import MagpieTTSDecoderModel
from nemo.collections.tts.data.text_to_speech_dataset import MagpieTTSDataset, DatasetSample
from omegaconf.omegaconf import OmegaConf, open_dict
import torch
import os
import soundfile as sf
from IPython.display import display, Audio
import numpy as np
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"

### Checkpoint Paths

In [None]:
hparams_file = "/datap/misc/experiment_checkpoints/magpie_decoder/Debug2_hparams.yaml"
checkpoint_file = "/datap/misc/experiment_checkpoints/magpie_decoder/Debug2_epoch13.ckpt"
# codecmodel_path = "/datap/misc/checkpoints/AudioCodec_21Hz_no_eliz.nemo"
codecmodel_path = "/datap/misc/checkpoints/21fps_causal_codecmodel.nemo"


# Temp out dir for saving audios
out_dir = "/datap/misc/t5tts_inference_notebook_samples"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

### Load Model

In [None]:
model_cfg = OmegaConf.load(hparams_file).cfg

with open_dict(model_cfg):
    model_cfg.codecmodel_path = codecmodel_path
    model_cfg.train_ds = None
    model_cfg.validation_ds = None


model = MagpieTTSDecoderModel(cfg=model_cfg)
print("Loading weights from checkpoint")
ckpt = torch.load(checkpoint_file, weights_only=False)
model.load_state_dict(ckpt['state_dict'])
print("Loaded weights.")

model.cuda()
model.eval()

### Initialize Dataset class and helper functions

In [None]:
test_dataset = MagpieTTSDataset(
    dataset_meta={},
    sample_rate=model.sample_rate,
    min_duration=0.5,
    max_duration=20,
    codec_model_samples_per_frame=model.codec_model_samples_per_frame,
    bos_id=None,
    eos_id=model.eos_id,
    context_audio_bos_id=model.context_audio_bos_id,
    context_audio_eos_id=model.context_audio_eos_id,
    audio_bos_id=model.audio_bos_id,
    audio_eos_id=model.audio_eos_id,
    num_audio_codebooks=model.num_audio_codebooks,
    prior_scaling_factor=None,
    load_cached_codes_if_available=True,
    dataset_type='test',
    tokenizer_config=None,
    load_16khz_audio=False,
    use_text_conditioning_tokenizer=True,
    pad_context_text_to_max_duration=model.pad_context_text_to_max_duration,
    context_duration_min=model.cfg.get('context_duration_min', 5.0),
    context_duration_max=model.cfg.get('context_duration_max', 5.0),
)
test_dataset.text_tokenizer, test_dataset.text_conditioning_tokenizer = model.tokenizer, model.tokenizer.first_tokenizer



def get_audio_duration(file_path):
    with sf.SoundFile(file_path) as audio_file:
        # Calculate the duration
        duration = len(audio_file) / audio_file.samplerate
        return duration

def create_record(text, context_audio_filepath=None, context_text=None, ):
    dummy_audio_fp = os.path.join(out_dir, "dummy_audio.wav")
    dummy_audio = sf.write(dummy_audio_fp, np.zeros(22050 * 3), 22050)  # 3 seconds of silence
    record = {
        'audio_filepath' : "/datap/misc/LibriTTSfromNemo/LibriTTS/dev-clean/2428/83705/2428_83705_000002_000000.wav",
        'duration': 6.23,
        'text': text,
        'speaker': "dummy",
    }
    if context_text is not None:
        assert context_audio_filepath is None
        record['context_text'] = context_text
    else:
        assert context_audio_filepath is not None
        record['context_audio_filepath'] = context_audio_filepath
        record['context_audio_duration'] = get_audio_duration(context_audio_filepath)
    
    return record

### Set transcript and context pairs to test

In [None]:
import pprint
# Change sample text and prompt audio/text here
audio_base_dir = "/"
test_entries = [
    create_record(
        text="Alrighty, this thing seems to be all good.",
        context_audio_filepath="/datap/misc/LibriTTSfromNemo/LibriTTS/dev-clean/2428/83705/2428_83705_000002_000000.wav",
    ),
]

data_samples = []
for entry in test_entries:
    dataset_sample = DatasetSample(
        dataset_name="sample",
        manifest_entry=entry,
        audio_dir=audio_base_dir,
        feature_dir=audio_base_dir,
        text=entry['text'],
        speaker=None,
        speaker_index=0,
        tokenizer_names=["qwen"], # Change this for multilingual: "english_phoneme", "spanish_phoneme", "english_chartokenizer", "german_chartokenizer".. 
    )
    data_samples.append(dataset_sample)
    
test_dataset.data_samples = data_samples

test_data_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1,
    collate_fn=test_dataset.collate_fn,
    num_workers=0,
    shuffle=False
)

### Generation Loop

In [None]:
import matplotlib.pyplot as plt

item_idx = 0
for bidx, batch in enumerate(test_data_loader):
    batch_cuda ={}
    for key in batch:
        if isinstance(batch[key], torch.Tensor):
            batch_cuda[key] = batch[key].cuda()
        else:
            batch_cuda[key] = batch[key]
    # process_batch_out = model.process_batch(batch_cuda)
    predicted_audio, predicted_audio_lens,_,_ = model.infer_batch(batch_cuda,temperature=0.6,max_decoder_steps=430, use_local_transformer_for_inference=True)
    for idx in range(predicted_audio.size(0)):
        predicted_audio_np = predicted_audio[idx].float().detach().cpu().numpy()
        predicted_audio_np = predicted_audio_np[:predicted_audio_lens[idx]]
        audio_path = os.path.join(out_dir, f"predicted_audio_{item_idx}.wav")
        sf.write(audio_path, predicted_audio_np, model.cfg.sample_rate)
        print(test_entries[bidx]['text'])
        display(Audio(audio_path))
    
