In [2]:
!pip install git+https://github.com/huggingface/parler-tts.git
!pip install flash-attn
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124

import torch
from parler_tts import ParlerTTSForConditionalGeneration, ParlerTTSStreamer
from transformers import AutoTokenizer
from threading import Thread
import sounddevice as sd
import queue

if torch.cuda.is_available():
    print(f"CUDA is available. Device count: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"Device {i}: {torch.cuda.get_device_name(i)}")
    print(f"Current device: {torch.cuda.current_device()} - {torch.cuda.get_device_name(torch.cuda.current_device())}")
else:
    print("CUDA is not available. Using CPU.")

torch_device = "cuda:0" # Use "mps" for Mac 
torch_dtype = torch.bfloat16
model_name = "parler-tts/parler-tts-mini-v1"

# need to set padding max length
max_length = 50

# load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name) 
model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_name,
).to(torch_device, dtype=torch_dtype)

sampling_rate = model.audio_encoder.config.sampling_rate
frame_rate = model.audio_encoder.config.frame_rate

Collecting git+https://github.com/huggingface/parler-tts.git
  Cloning https://github.com/huggingface/parler-tts.git to /tmp/pip-req-build-4kvtkpys
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/parler-tts.git /tmp/pip-req-build-4kvtkpys
  Resolved https://github.com/huggingface/parler-tts.git to commit 5d0aca9753ab74ded179732f5bd797f7a8c6f8ee
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Looking in indexes: https://download.pytorch.org/whl/cu124
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cu124/torchvision-0.19.1%2Bcu124-cp311-cp311-linux_x86_64.whl (7.1 MB)
[2K     [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m189.0 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torchvision
Successfully installed torchvision-0.19.1+cu124
CUDA is available. Device 

  WeightNorm.apply(module, name, dim)


In [3]:
# Queue to buffer audio chunks
audio_queue = queue.Queue()

def audio_player():
    """Thread to play audio from the queue"""
    while True:
        audio_chunk = audio_queue.get()
        if audio_chunk is None:
            break
        sd.play(audio_chunk, samplerate=sampling_rate)
        sd.wait()  # Wait until the sound has finished playing

def generate(text, description, play_steps_in_s=0.5):
    play_steps = int(frame_rate * play_steps_in_s)
    streamer = ParlerTTSStreamer(model, device=torch_device, play_steps=play_steps)
    
    # Tokenization
    inputs = tokenizer(description, return_tensors="pt").to(torch_device)
    prompt = tokenizer(text, return_tensors="pt").to(torch_device)
    
    # Generation kwargs
    generation_kwargs = dict(
        input_ids=inputs.input_ids,
        prompt_input_ids=prompt.input_ids,
        attention_mask=inputs.attention_mask,
        prompt_attention_mask=prompt.attention_mask,
        streamer=streamer,
        do_sample=True,
        temperature=1.0,
        min_new_tokens=10,
    )
    
    # Initialize Thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Iterate over chunks of audio and put them in the queue
    for new_audio in streamer:
        if new_audio.shape[0] == 0:
            break
        print(f"Sample of length: {round(new_audio.shape[0] / sampling_rate, 4)} seconds")
        audio_queue.put(new_audio)  # Add audio chunk to queue

# Now you can do
text = "This is a test of the streamer class"
description = "Jon's talking really fast."
chunk_size_in_s = 0.5

# Start the audio player thread
player_thread = Thread(target=audio_player, daemon=True)
player_thread.start()

# Generate audio and fill the queue
generate(text, description, chunk_size_in_s)

# Wait until all audio chunks are played
audio_queue.put(None)  # Signal end of audio to the player thread
player_thread.join()


`prompt_attention_mask` is specified but `attention_mask` is not. A full `attention_mask` will be created. Make sure this is the intended behaviour.


Sample of length: 0.329 seconds
(14507,)
Sample of length: 0.4992 seconds
(22016,)
Sample of length: 0.4992 seconds
(22016,)
Sample of length: 0.4992 seconds
(22016,)
Sample of length: 0.0658 seconds
(2901,)
