In [9]:
import torch
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, pipeline, WhisperForConditionalGeneration, WhisperTokenizer, WhisperTokenizerFast
import soundfile as sf
import evaluate

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_name = "ajd12342/parler-tts-mini-v1-paraspeechcaps" # Replace with "ajd12342/parler-tts-mini-v1-paraspeechcaps-only-base" for the model trained only on the PSC-Base dataset
model = ParlerTTSForConditionalGeneration.from_pretrained(model_name).to(device)
description_tokenizer = AutoTokenizer.from_pretrained(model_name)
transcription_tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")

  WeightNorm.apply(module, name, dim)
Config of the text_encoder: <class 'transformers.models.t5.modeling_t5.T5EncoderModel'> is overwritten by shared text_encoder config: T5Config {
  "_name_or_path": "google/flan-t5-large",
  "architectures": [
    "T5ForConditionalGeneration"
  ],
  "classifier_dropout": 0.0,
  "d_ff": 2816,
  "d_kv": 64,
  "d_model": 1024,
  "decoder_start_token_id": 0,
  "dense_act_fn": "gelu_new",
  "dropout_rate": 0.1,
  "eos_token_id": 1,
  "feed_forward_proj": "gated-gelu",
  "initializer_factor": 1.0,
  "is_encoder_decoder": true,
  "is_gated_act": true,
  "layer_norm_epsilon": 1e-06,
  "model_type": "t5",
  "n_positions": 512,
  "num_decoder_layers": 24,
  "num_heads": 16,
  "num_layers": 24,
  "output_past": true,
  "pad_token_id": 0,
  "relative_attention_max_distance": 128,
  "relative_attention_num_buckets": 32,
  "tie_word_embeddings": false,
  "transformers_version": "4.46.1",
  "use_cache": true,
  "vocab_size": 32128
}

Config of the audio_encoder: <

In [4]:
input_description = "In a clear environment, a male voice speaks with a sad tone.".replace('\n', ' ').rstrip()
input_transcription = "Was that your landlord?".replace('\n', ' ').rstrip()

input_description_tokenized = description_tokenizer(input_description, return_tensors="pt").to(model.device)
input_transcription_tokenized = transcription_tokenizer(input_transcription, return_tensors="pt").to(model.device)

### Simple inference

In [5]:
generation = model.generate(input_ids=input_description_tokenized.input_ids, prompt_input_ids=input_transcription_tokenized.input_ids)

audio_arr = generation.cpu().numpy().squeeze()
sf.write("output_simple.wav", audio_arr, model.config.sampling_rate)

The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


### Inference with classifier-free guidance

In [6]:
guidance_scale = 1.5
generation = model.generate(input_ids=input_description_tokenized.input_ids, prompt_input_ids=input_transcription_tokenized.input_ids, guidance_scale=guidance_scale)

audio_arr = generation.cpu().numpy().squeeze()
sf.write("output_with_cfg.wav", audio_arr, model.config.sampling_rate)

### Inference with classifier-free guidance and ASR-based resampling

In [7]:
# Taken from https://github.com/huggingface/parler-tts/blob/d108732cd57788ec86bc857d99a6cabd66663d68/training/eval.py
asr_model_name_or_path = 'distil-whisper/distil-large-v2'
asr_pipeline = pipeline(model=asr_model_name_or_path, device=device, chunk_length_s=25.0)

def wer(asr_pipeline, prompt, audio, sampling_rate):
    """
    Calculate Word Error Rate (WER) for a single audio sample against a reference text.
    Args:
        asr_pipeline: Huggingface ASR pipeline
        prompt: Reference text string
        audio: Audio array
        sampling_rate: Audio sampling rate
    
    Returns:
        float: Word Error Rate as a percentage
    """
    # Load WER metric
    metric = evaluate.load("wer")

    # Handle Whisper's return_language parameter
    return_language = None
    if isinstance(asr_pipeline.model, WhisperForConditionalGeneration):
        return_language = True

    # Transcribe audio
    transcription = asr_pipeline(
        {"raw": audio, "sampling_rate": sampling_rate},
        return_language=return_language,
    )

    # Get appropriate normalizer
    if isinstance(asr_pipeline.tokenizer, (WhisperTokenizer, WhisperTokenizerFast)):
        tokenizer = asr_pipeline.tokenizer
    else:
        tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3")

    english_normalizer = tokenizer.normalize
    basic_normalizer = tokenizer.basic_normalize

    # Choose normalizer based on detected language
    normalizer = (
        english_normalizer
        if isinstance(transcription.get("chunks", None), list) 
        and transcription["chunks"][0].get("language", None) == "english"
        else basic_normalizer
    )

    # Calculate WER
    norm_pred = normalizer(transcription["text"])
    norm_ref = normalizer(prompt)
    
    return 100 * metric.compute(predictions=[norm_pred], references=[norm_ref])

In [10]:
num_retries = 3
wer_threshold = 20
generated_audios = []
word_errors = []
for i in range(num_retries):
    generation = model.generate(input_ids=input_description_tokenized.input_ids, prompt_input_ids=input_transcription_tokenized.input_ids, guidance_scale=guidance_scale)
    audio_arr = generation.cpu().numpy().squeeze()

    word_error = wer(asr_pipeline, input_transcription, audio_arr, model.config.sampling_rate)

    if word_error < wer_threshold:
        break
    generated_audios.append(audio_arr)
    word_errors.append(word_error)
else:
    # Pick the audio with the lowest WER
    audio_arr = generated_audios[word_errors.index(min(word_errors))]

sf.write("output_with_asr_resampling_and_cfg.wav", audio_arr, model.config.sampling_rate)

You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
