In [None]:
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
import torch
import json
from omegaconf.omegaconf import OmegaConf, open_dict
import shutil

from nemo.collections.tts.models.speechllm.megatron_t5_speechllm_model import MegatronT5SpeechLMModel
from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder
from nemo.collections.asr.parts.preprocessing.segment import AudioSegment
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from IPython.display import Audio, display
import torchaudio

# CHANGE THIS TO A LOCAL DIRECTORY
EXP_DIR = "/datap/misc/NotebookInference"

if not os.path.exists(EXP_DIR):
    os.makedirs(EXP_DIR)

## Save a dummy manifest to setup Model Test Step

In [None]:
def write_records(fp, records):
    with open(fp, "w") as f:
        for record in records:
            f.write(json.dumps(record) + "\n")

dummy_codes = torch.ones(8, 300).cpu().type(torch.int16)
dummy_codes_fp = os.path.join(EXP_DIR, "dummy_codes.pt")
torch.save(dummy_codes, dummy_codes_fp)


dummy_record = {
    "question" : "Phoneme TTS Sample Text",
    "answer" : dummy_codes_fp,
    "context" : dummy_codes_fp,
    "context_type" : "REFSPEAKERCODEC",
    "question_type" : "TEXT",
    "answer_type" : "AUDIOCODEC",
    "context_duration" : 5.0,
    "answer_duration" : 5.0,
    "taskname" : "squad"
}

dummy_val_file = os.path.join(EXP_DIR, "dummy_val.json")

write_records(dummy_val_file, [dummy_record])

## Load and setup the model

In [None]:
# CHANGE THESE PATHS TO RELEVANT MOUNTED PATHS IN DOCKER
config_path = "/home/pneekhara/2023/NeMo/examples/tts/speechllm/conf/megatron_t5_speechllm_inference_multiencoder.yaml"
# checkpoint_path = "/datap/misc/temp_checkpoints_new/desta_less_sophia_highLR_step159600.ckpt"
checkpoint_path = "/datap/misc/checkpoints/desta_less_sophia_213850.ckpt"
codecmodel_path = "/datap/misc/checkpoints/SpeechCodec_2402.nemo"
vocab_file = "/datap/misc/checkpoints/9a77f10c2793465e8e8a3fa5fcbef8b0_vocab.txt"

cfg = OmegaConf.load(config_path)

if "gradient_as_bucket_view" not in cfg.model:
    with open_dict(cfg):
        cfg.model.gradient_as_bucket_view=False

trainer = MegatronTrainerBuilder(cfg).create_trainer()
exp_manager(trainer, cfg.exp_manager)

with open_dict(cfg):
    cfg.exp_manager.exp_dir=EXP_DIR
    cfg.checkpoint_path = checkpoint_path
    cfg.model.data.sup_data_path="/datap/misc/speechllm_codecdatasets/"
    cfg.model.global_batch_size=1
    cfg.model.micro_batch_size=1
    cfg.model.data.speech_offset=30128
    cfg.model.lm_vocab_size=30000
    cfg.model.data.add_special_tokens_to_only_first_codebook=True
    cfg.model.data.train_task="all"
    cfg.model.freeze_model=False
    cfg.model.data.max_seq_length=2048
    cfg.model.max_inference_timesteps=2000
    cfg.model.data.context_duration_min=20.0
    cfg.model.data.context_duration_max=20.0
    cfg.model.top_k=80
    cfg.model.temperature=0.85
    cfg.model.data.speech_offset=30128
    cfg.model.lm_vocab_size=30000
    cfg.model.codecmodel_path=codecmodel_path
    cfg.trainer.devices=1
    cfg.trainer.precision="bf16"
    cfg.model.precision = cfg.trainer.precision
    cfg.model.override_tokenizer_vocab_file=vocab_file
    cfg.model.english_only_model=True
    cfg.model.asr_model_name="stt_en_conformer_transducer_large"
    cfg.model.frozen_model.decoder.layer_type=[1,1,1,2,2,2,2,2,2,2,1,1]
    cfg.model.alignment_decoder_layerids=[0,1,2,3,4]
    cfg.model.enc_output_to_layers=[[8,9],[3,4,5,6,7]]
    cfg.model.data.test_ds=[dummy_val_file]
    cfg.model.data.num_workers = 0


checkpoint_path = cfg.get('checkpoint_path', None)
assert checkpoint_path is not None, "checkpoint path needs to be valid"

model = MegatronT5SpeechLMModel.load_from_checkpoint(
        checkpoint_path=checkpoint_path, trainer=trainer, cfg=cfg.model
    )
model.eval()
model = model.cuda()

codec_model = model.additional_models['codec']
trainer.test(model)


## Helper functions

In [None]:
out_dir = os.path.join( model.trainer.logger.save_dir, model.trainer.logger.name, model.trainer.logger.version, "Sample_Audios")
out_path = os.path.join(out_dir, 'predicted_wav_0.wav')


def encode(wav_path):
    # Convert an audio file to nemo codec codes
    features = AudioSegment.segment_from_file(
                    wav_path, target_sr=codec_model.sample_rate, n_segments=-1, trim=False,
                )
    audio_samples = features.samples
    audio = torch.tensor(audio_samples).cuda()
    audio_length = torch.tensor(audio.size(0)).long().cuda()
    print(f"audio {audio.size()} audio_length {audio_length}")
    print(f"audio {audio.device} audio_length {audio_length.device} codec_model {codec_model.device}")

    original_codec_codes, _ = codec_model.encode(audio=audio.unsqueeze(0), audio_len=audio_length.unsqueeze(0))
    original_codec_codes = original_codec_codes[0]
    print(f"original_codec_codes {original_codec_codes.size()} audio {audio.size()} audio_length {audio_length}")
    duration = original_codec_codes.size()[1] / 86
    
    target_codec_filepath = wav_path[:-4] + "_codes.pt"
    torch.save(original_codec_codes.cpu().type(torch.int16), target_codec_filepath)
    return original_codec_codes, target_codec_filepath, duration
    
    
    
def play_codec(codec_path):
    # Convert nemo codecs to audio and play it
    codec = torch.load(codec_path)
    codec = codec.to('cuda')
    codec = codec.unsqueeze(0)
    codec_lens = torch.Tensor([codec.shape[2]]).long().cuda()
    codec_decoded_audios, _ = codec_model.decode(tokens=codec.long(), tokens_len=codec_lens)
    codec_decoded_audio = codec_decoded_audios[0]
    temp_wav_path = os.path.join(EXP_DIR, "temp.wav")
    torchaudio.save(temp_wav_path, codec_decoded_audio[None].cpu(), 22050)
    display(Audio(temp_wav_path))

def generate_new_audio(
    text,
    context,
    context_duration=4.0,
    context_type="REFSPEAKERCODEC",
    temperature=0.85,
    top_k=80,
    text_task="Phoneme TTS "
    ):
    # Prepare data in speechllm format
    model.cfg.temperature = temperature
    model.cfg.top_k = top_k
    dummy_answer = dummy_codes_fp
    json_in = {}
    json_in["question"] = text_task + text
    json_in["question_type"] = "TEXT"
    json_in["answer"] = dummy_answer 
    json_in["context"] = context 
    json_in["answer_type"] = "AUDIOCODEC"
    json_in["context_type"] = context_type
    json_in["context_duration"] = context_duration
    json_in["answer_duration"] = 2.0
    json_in["taskname"] = "squad"
    json_in["lang"] = "en"
    json_in = [json_in]
    
    # Prepare dataloader
    model._test_ds.examples = []
    model._test_ds.examples = model._test_ds.load_data(json_in)
    
    sampler = torch.utils.data.distributed.DistributedSampler(
            model._test_ds, num_replicas=1, rank=0, shuffle=False, seed=1
        )

    model._test_dl = torch.utils.data.DataLoader(
        model._test_ds,
        collate_fn=model._test_ds.collate_fn,
        sampler=sampler,
        batch_size=1,
        drop_last=False,
        num_workers=1,
        pin_memory=False,
        persistent_workers=True
    )
    
    # Run inference
    model.cfg.data.test_ds = None
    trainer.test(model, model._test_dl)
    print("Out path:", out_path)
    print("Inference done")

In [None]:
text_contexts = [
    "TEXT CONTEXT: | Language:en Dataset:Riva Speaker:Lindy_CMU_FEARFUL |",
    "TEXT CONTEXT: | Language:en Dataset:Riva Speaker:Lindy_CMU_HAPPY |",
    "TEXT CONTEXT: | Language:en Dataset:Riva Speaker:Rodney_CMU_HAPPY |",
    "TEXT CONTEXT: | Language:en  Dataset:PromptTTS Gender:female SpeakingRate:2. Slow emotion:neutral Pitch:4. High SNR:5. Clean REVERB:5. Very close-sounding |"
]

## Generate audio from a text context

In [None]:
text = "As I closed my laptop for the night, my reflection in the screen continued to smile back at me."
text_task = "Phoneme TTS " # Can be "Text to speech this " (for sentence-piece tokenizer) or "Phoneme TTS " (for phoneme tokenizer)
context = text_contexts[1] # Sample Text Context
context_type = "TEXT" # Can be REFSPEAKERCODEC (for audio context), TEXT (for text context)
generate_new_audio(
    text, 
    context, 
    context_type=context_type, 
    context_duration=5.0, # Does not matter, should just be > 3 so that dataset does not filter it out.
    top_k=80, # Can play around with this to check roubstness
    temperature=0.8, # Can play around with this. temperature < 0.85 can be more robust
    text_task=text_task
)
display(Audio(out_path))

## Listen to some ground-truth context audios

In [None]:
context_paths = [
    "/datap/misc/speechllm_codecdatasets/codecs/RivattsAllLanguagesUpdated_train_nemo_codec_bw_6.0/target_codes_en_Lindy_44khz_CMU_HAPPY_LINDY_CMU_HAPPY_000570.pt",
]

for cidx, context_path in enumerate(context_paths):
    print(cidx, context_path)
    play_codec(context_path)

## Generate audio from an audio context

In [None]:
text = "As I closed my laptop for the night, my reflection in the screen continued to smile back at me."
text_task = "Text to speech this " # Can be "Text to speech this " (for sentence-piece tokenizer) or "Phoneme TTS " (for phoneme tokenizer)
context = context_paths[0] # Sample Text Context
context_type = "REFSPEAKERCODEC" # Can be REFSPEAKERCODEC (for audio context), TEXT (for text context)
generate_new_audio(
    text, 
    context, 
    context_type=context_type, 
    context_duration=5.0, # Does not matter, should just be > 3 so that dataset does not filter it out.
    temperature=0.8, # Can play around with this. temperature < 0.85 can be more robust
    text_task=text_task
)
display(Audio(out_path))