# Spark TTS Inference Guide with vLLM

This notebook demonstrates how to perform text-to-speech (TTS) inference using the **jq/spark-tts-salt** model with vLLM.

## Overview

Spark TTS is a powerful text-to-speech model that can generate high-quality speech in multiple languages and voices. This implementation uses:
- **vLLM** for efficient model inference
- **BiCodec tokenizer** for audio token processing
- **Retry logic** to handle generation errors gracefully
- **Text chunking** to process long texts efficiently

## Key Features
- Multi-language support (English, Luganda, Swahili, etc.)
- Multiple speaker IDs for different voices
- Robust error handling with automatic retries
- Flexible text chunking strategies
- Audio playback in Jupyter notebooks

## 1. Setup and Installation

First, install all required dependencies. This includes:
- **vLLM**: For efficient LLM inference
- **transformers & unsloth**: Model loading utilities
- **soundfile & librosa**: Audio processing
- **torch & torchaudio**: Deep learning framework
- **xformers, omegaconf, einx, einops**: Supporting libraries

In [None]:
# Install all required packages
!pip install xformers transformers unsloth omegaconf einx einops soundfile librosa torch torchaudio vllm

## 2. Import Libraries

Import all necessary libraries for TTS inference.

In [None]:
# Core imports
from vllm import LLM
from vllm.sampling_params import SamplingParams
import os
from getpass import getpass
import re
import soundfile as sf
from huggingface_hub import snapshot_download
import torch
import sys
from typing import Tuple, List, Optional
import numpy as np
from IPython.display import Audio, display
import time

# Determine device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 3. Clone Spark-TTS Repository

Clone the Spark-TTS repository to access the audio tokenizer and utilities.

**Note**: Uncomment the git clone line if you haven't cloned the repository yet.

In [None]:
# Clone the Spark-TTS repository (uncomment if needed)
# !git clone https://github.com/SparkAudio/Spark-TTS

# Add Spark-TTS to Python path
sys.path.append('Spark-TTS')
print("Spark-TTS repository path added to sys.path")

## 4. Set Hugging Face Token

Set your Hugging Face token for model access. Get your token from: https://huggingface.co/settings/tokens

In [None]:
# Set Hugging Face token securely
os.environ["HF_TOKEN"] = getpass("Enter your HF_TOKEN: ")

## 5. Load the TTS Model

Load the Spark TTS model using vLLM with `enforce_eager=True` for compatibility.

**Model**: `jq/spark-tts-salt`

This may take a few minutes depending on your internet connection.

In [None]:
# Load the TTS model with vLLM
print("Loading Spark TTS model...")
model = LLM("jq/spark-tts-salt", enforce_eager=True)
print("‚úÖ Model loaded successfully!")

## 6. Download and Setup Audio Tokenizer

Download the BiCodec tokenizer model files from Hugging Face and initialize the audio tokenizer.

The tokenizer converts between audio and token representations.

In [None]:
# Download tokenizer model files
model_base_repo = "unsloth/Spark-TTS-0.5B"
cache_dir = "Spark-TTS-0.5B"

print(f"Downloading tokenizer files from {model_base_repo}...")
snapshot_download(
    repo_id=model_base_repo,
    local_dir=cache_dir,
    ignore_patterns=["*LLM*"],  # Skip LLM files, we only need tokenizer
)
print(f"‚úÖ Tokenizer files downloaded to {cache_dir}")

In [None]:
# Initialize the audio tokenizer
from sparktts.models.audio_tokenizer import BiCodecTokenizer

print("Initializing audio tokenizer...")
audio_tokenizer = BiCodecTokenizer(cache_dir, device)
print("‚úÖ Audio tokenizer initialized!")

## 7. Text Chunking Utilities

These functions split long text into manageable chunks for TTS processing.

### Three Chunking Strategies:

1. **chunk_text**: Splits by sentence boundaries with a maximum character limit
2. **chunk_text_simple**: Splits into individual sentences (recommended for TTS)
3. **chunk_text_with_count**: Groups a fixed number of sentences per chunk

In [None]:
def chunk_text(text: str, max_chunk_size: int = 500) -> List[str]:
    """
    Split text into chunks based on sentence boundaries.
    
    This approach preserves natural sentence flow and intonation for TTS.
    
    Args:
        text: The input string to chunk
        max_chunk_size: Maximum character length per chunk (soft limit)
    
    Returns:
        List of text chunks, each containing one or more complete sentences
    """
    # Split on sentence-ending punctuation (. ! ?) followed by whitespace
    sentences = re.split(r'(?<=[.!?])\s+', text.strip())
    
    chunks: List[str] = []
    current_chunk: List[str] = []
    current_length = 0
    
    for sentence in sentences:
        sentence = sentence.strip()
        if not sentence:
            continue
        
        sentence_length = len(sentence)
        
        # Start new chunk if adding this sentence would exceed limit
        if current_chunk and (current_length + sentence_length + 1) > max_chunk_size:
            chunks.append(' '.join(current_chunk))
            current_chunk = []
            current_length = 0
        
        current_chunk.append(sentence)
        current_length += sentence_length + 1
    
    # Add the final chunk
    if current_chunk:
        chunks.append(' '.join(current_chunk))
    
    return chunks


def chunk_text_simple(text: str) -> List[str]:
    """
    Split text into individual sentences.
    
    Recommended for TTS - provides maximum control with one sentence per chunk.
    
    Args:
        text: The input string to chunk
    
    Returns:
        List of individual sentences
    """
    sentences = re.split(r'(?<=[.!?])\s+', text.strip())
    return [s.strip() for s in sentences if s.strip()]


def chunk_text_with_count(text: str, sentences_per_chunk: int = 3) -> List[str]:
    """
    Split text into chunks containing a specific number of sentences.
    
    Args:
        text: The input string to chunk
        sentences_per_chunk: Number of sentences to include in each chunk
    
    Returns:
        List of text chunks
    """
    sentences = re.split(r'(?<=[.!?])\s+', text.strip())
    sentences = [s.strip() for s in sentences if s.strip()]
    
    chunks: List[str] = []
    
    for i in range(0, len(sentences), sentences_per_chunk):
        chunk = ' '.join(sentences[i:i + sentences_per_chunk])
        chunks.append(chunk)
    
    return chunks


print("‚úÖ Text chunking utilities defined")

## 8. Core TTS Functions

These are the main functions that handle the text-to-speech conversion process.

### Function Pipeline:
1. **get_tts_tokens**: Generates audio tokens from text using the LLM
2. **generate_speech_from_text**: Extracts semantic and global tokens
3. **generate_speech_segment_with_retry**: Converts tokens to audio with retry logic
4. **get_speech_segments**: Processes multiple text chunks
5. **text_to_speech**: Main function that orchestrates the entire pipeline

In [None]:
def get_tts_tokens(text: str, speaker_id: int, temperature: float, model) -> str:
    """
    Generate TTS tokens from input text using the model.
    
    Args:
        text: Input text to synthesize
        speaker_id: Speaker voice ID (e.g., 248 for Luganda, 246 for Swahili)
        temperature: Sampling temperature for generation (0.1-1.0)
        model: The loaded vLLM model
    
    Returns:
        String containing audio tokens in special format
    """
    sampling_params = SamplingParams(temperature=temperature, max_tokens=2048)
    
    # Format prompt with task identifier and speaker ID
    prompt = f"<|task_tts|><|start_content|>{speaker_id}: {text}<|end_content|><|start_global_token|>"
    
    outputs = model.generate(
        prompts=prompt,
        sampling_params=sampling_params
    )
    
    audio_tokens = outputs[0].outputs[0].text
    return audio_tokens


def generate_speech_from_text(
    text: str,
    speaker_id: int = 248,
    temperature: float = 0.7,
    max_new_audio_tokens: int = 2000,
    model=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Generate semantic and global tokens from text.
    
    The model generates two types of tokens:
    - Semantic tokens: Capture linguistic content
    - Global tokens: Capture prosody and speaker characteristics
    
    Args:
        text: Input text to synthesize
        speaker_id: Speaker voice ID
        temperature: Sampling temperature
        max_new_audio_tokens: Maximum tokens to generate
        model: The loaded vLLM model
    
    Returns:
        Tuple of (semantic_token_ids, global_token_ids)
    """
    predicted_tokens = get_tts_tokens(
        text=text,
        speaker_id=speaker_id,
        temperature=temperature,
        model=model,
    )

    # Extract semantic token IDs using regex
    semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", predicted_tokens)
    if not semantic_matches:
        raise ValueError("No semantic tokens found in the generated output.")

    pred_semantic_ids = (
        torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0)
    )

    # Extract global token IDs using regex
    global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", predicted_tokens)
    if not global_matches:
        print("Warning: No global tokens found. Using zeros as fallback.")
        pred_global_ids = torch.zeros((1, 1), dtype=torch.long)
    else:
        pred_global_ids = (
            torch.tensor([int(token) for token in global_matches])
            .long()
            .unsqueeze(0)
        )

    return pred_semantic_ids, pred_global_ids


def generate_speech_segment_with_retry(
    text: str,
    audio_tokenizer,
    model,
    speaker_id: int = 248,
    temperature: float = 0.7,
    max_new_audio_tokens: int = 2000,
    max_retries: int = 3,
) -> Optional[np.ndarray]:
    """
    Generate a single speech segment with automatic retry logic.
    
    This function handles common generation errors by retrying with slightly
    different parameters if dimension mismatches occur.
    
    Args:
        text: Input text to synthesize
        audio_tokenizer: The audio tokenizer for detokenization
        model: The TTS model
        speaker_id: Speaker ID
        temperature: Sampling temperature
        max_new_audio_tokens: Maximum tokens to generate
        max_retries: Maximum number of retry attempts
        
    Returns:
        Audio waveform as numpy array, or None if all retries failed
    """
    for attempt in range(max_retries):
        try:
            # Generate tokens
            pred_semantic_ids, pred_global_ids = generate_speech_from_text(
                text=text,
                speaker_id=speaker_id,
                temperature=temperature,
                max_new_audio_tokens=max_new_audio_tokens,
                model=model,
            )
            
            # Log token shapes for debugging
            print(f"   Attempt {attempt + 1}: semantic shape={pred_semantic_ids.shape}, "
                  f"global shape={pred_global_ids.shape}")
            
            # Detokenize to waveform
            wav_np = audio_tokenizer.detokenize(
                pred_global_ids.to(device), pred_semantic_ids.to(device)
            )
            
            return wav_np
            
        except RuntimeError as e:
            error_msg = str(e)
            print(f"‚ö†Ô∏è  Attempt {attempt + 1}/{max_retries} failed")
            print(f"   Error: {error_msg}")
            
            # Check if it's a dimension mismatch error
            if "cannot be multiplied" in error_msg or "shape" in error_msg.lower():
                if attempt < max_retries - 1:
                    print(f"   Retrying with adjusted temperature...")
                    # Slightly vary temperature to get different generation
                    temperature = temperature + np.random.uniform(-0.05, 0.05)
                    temperature = float(np.clip(temperature, 0.1, 1.0))
                    time.sleep(0.5)
                else:
                    print(f"   ‚ùå All {max_retries} attempts failed")
                    return None
            else:
                # Different error, re-raise
                raise
                
        except ValueError as e:
            print(f"‚ö†Ô∏è  ValueError on attempt {attempt + 1}: {e}")
            if attempt < max_retries - 1:
                print(f"   Retrying...")
                time.sleep(0.5)
            else:
                print(f"   ‚ùå All {max_retries} attempts failed")
                return None
    
    return None


def get_speech_segments(
    text_chunks: List[str],
    audio_tokenizer,
    model,
    speaker_id: int = 248,
    temperature: float = 0.7,
    max_new_audio_tokens: int = 2000,
    max_retries: int = 3,
) -> List[np.ndarray]:
    """
    Generate speech segments for multiple text chunks.
    
    Processes each chunk independently and adds silence for failed chunks.
    
    Args:
        text_chunks: List of text strings to synthesize
        audio_tokenizer: The audio tokenizer
        model: The TTS model
        speaker_id: Speaker ID
        temperature: Sampling temperature
        max_new_audio_tokens: Maximum tokens per chunk
        max_retries: Maximum retry attempts per chunk
        
    Returns:
        List of audio segments as numpy arrays
    """
    segments = []
    total_chunks = len(text_chunks)
    
    for i, text in enumerate(text_chunks, 1):
        print(f"\nüìù Processing chunk {i}/{total_chunks}")
        print(f"   Text: '{text[:60]}...'" if len(text) > 60 else f"   Text: '{text}'")
        
        wav_np = generate_speech_segment_with_retry(
            text=text,
            audio_tokenizer=audio_tokenizer,
            model=model,
            speaker_id=speaker_id,
            temperature=temperature,
            max_new_audio_tokens=max_new_audio_tokens,
            max_retries=max_retries,
        )
        
        if wav_np is not None:
            segments.append(wav_np)
            print(f"‚úÖ Chunk {i} completed successfully")
        else:
            print(f"‚ö†Ô∏è  Chunk {i} failed. Adding silence placeholder...")
            # Add 500ms of silence
            silence = np.zeros(int(16000 * 0.5), dtype=np.float32)
            segments.append(silence)
    
    return segments


def text_to_speech(
    text: str,
    audio_tokenizer,
    model,
    chunk_text_simple,
    speaker_id: int = 248,
    temperature: float = 0.7,
    max_new_audio_tokens: int = 2048,
    sample_rate: int = 16000,
    max_retries: int = 3,
) -> Tuple[np.ndarray, int]:
    """
    Convert text to speech waveform - Main TTS function.
    
    This is the primary function you should use for TTS conversion.
    It handles text chunking, generation, and concatenation automatically.
    
    Args:
        text: Input text to synthesize
        audio_tokenizer: The audio tokenizer
        model: The TTS model
        chunk_text_simple: Function to chunk text into sentences
        speaker_id: Speaker ID (default: 248)
        temperature: Sampling temperature (default: 0.7)
        max_new_audio_tokens: Maximum tokens per chunk (default: 2048)
        sample_rate: Output sample rate (default: 16000 Hz)
        max_retries: Maximum retry attempts per chunk (default: 3)
        
    Returns:
        Tuple of (waveform_array, sample_rate)
    """
    # Chunk the text into sentences
    texts = chunk_text_simple(text)
    texts = [t.strip() for t in texts if len(t.strip()) > 0]
    
    print(f"\nüéôÔ∏è  Starting TTS conversion")
    print(f"   Total chunks: {len(texts)}")
    print(f"   Speaker ID: {speaker_id}")
    print(f"   Temperature: {temperature}")
    
    # Generate speech segments
    speech_segments = get_speech_segments(
        text_chunks=texts,
        audio_tokenizer=audio_tokenizer,
        model=model,
        speaker_id=speaker_id,
        temperature=temperature,
        max_new_audio_tokens=max_new_audio_tokens,
        max_retries=max_retries,
    )
    
    # Concatenate all segments
    if speech_segments:
        result_wav = np.concatenate(speech_segments)
        duration = len(result_wav) / sample_rate
        print(f"\n‚úÖ TTS conversion completed!")
        print(f"   Total duration: {duration:.2f} seconds")
        print(f"   Waveform shape: {result_wav.shape}")
    else:
        print("\n‚ö†Ô∏è  No speech segments generated. Returning silence.")
        result_wav = np.zeros(sample_rate, dtype=np.float32)
    
    return result_wav, sample_rate


def save_wav(
    text: str,
    outfile: str,
    audio_tokenizer,
    model,
    chunk_text_simple,
    **kwargs,
) -> None:
    """
    Generate speech and save to a WAV file.
    
    Args:
        text: Input text
        outfile: Output file path (e.g., 'output.wav')
        audio_tokenizer: The audio tokenizer
        model: The TTS model
        chunk_text_simple: Text chunking function
        **kwargs: Additional arguments for text_to_speech function
    """
    wav, sr = text_to_speech(
        text=text,
        audio_tokenizer=audio_tokenizer,
        model=model,
        chunk_text_simple=chunk_text_simple,
        **kwargs
    )
    
    sf.write(outfile, wav, sr)
    print(f"\nüíæ Audio saved to: {outfile}")


print("‚úÖ Core TTS functions defined")

## 9. Usage Examples

Now let's see how to use the TTS system with different examples.

### Speaker IDs Reference:
- **248**: Luganda speaker
- **246**: Swahili speaker
- *More speaker IDs available in the model documentation*

### Example 1: English Text

In [None]:
# Sample English text
english_text = (
    "Hello, I'm Prosi Nafula. I am a nurse who takes care of many people who have cancer "
    "and who have questions about their illness and what to expect. There are many types of cancer. "
    "The type of cancer you have is named after the place where it started. For example, if cancer "
    "starts in the breast then it is called breast cancer. Cancer doesn't spread from one person to "
    "another but it can spread through your own body. All cancers need to be treated."
)

print(f"Text length: {len(english_text)} characters")
print(f"Text preview: {english_text[:100]}...")

In [None]:
%%time

# Generate speech (adjust speaker_id as needed)
speaker_id = 248  # Use appropriate speaker ID
temperature = 0.7

result_wav, sr = text_to_speech(
    text=english_text,
    audio_tokenizer=audio_tokenizer,
    model=model,
    chunk_text_simple=chunk_text_simple,
    speaker_id=speaker_id,
    temperature=temperature
)

In [None]:
# Play the generated audio
display(Audio(result_wav, rate=sr))

In [None]:
# Optional: Save to file
# sf.write('output_english.wav', result_wav, sr)
# print("Audio saved to output_english.wav")

### Example 2: Luganda Text

In [None]:
# Sample Luganda text
luganda_text = (
    "Nze Prosi Nafula. Ndi musawo akola ku bantu abalina kookolo era abalina ebibuuzo ku bulwadde bwabwe n'ekyo kye basuubira. "
    "Waliwo ebika bya kookolo bingi. Ekika kya kookolo ky'olina kiyitibwa erinnya ly'ekifo we kyatandikira. "
    "Okugeza, kookolo bw'atandikira mu mabeere, ayitibwa kookolo w'amabeere. Kookolo tasaasaana okuva ku muntu omu okudda ku mulala "
    "naye asobola okusaasaana mu mubiri gwo. Kkookolo yenna yeetaaga okujjanjabibwa."
)

print(f"Text length: {len(luganda_text)} characters")

In [None]:
%%time

# Generate Luganda speech
speaker_id = 248  # Luganda speaker
temperature = 0.7

result_wav_luganda, sr = text_to_speech(
    text=luganda_text,
    audio_tokenizer=audio_tokenizer,
    model=model,
    chunk_text_simple=chunk_text_simple,
    speaker_id=speaker_id,
    temperature=temperature
)

In [None]:
# Play the generated Luganda audio
display(Audio(result_wav_luganda, rate=sr))

### Example 3: Swahili Text

In [None]:
# Sample Swahili text
swahili_text = (
    "Habari, naitwa Prosi Nafula. Mimi ni muuguzi ambaye hushughulikia watu wengi walio na saratani "
    "na ambao wana maswali kuhusu ugonjwa wao na kile wanachoweza kutarajia. Kuna aina nyingi za saratani. "
    "Aina ya saratani unayokuwa nayo inaitwa kwa jina la mahali ilipoanza. Kwa mfano, saratani ikiwa imeanza "
    "katika matiti basi inaitwa saratani ya matiti. Saratani haisambaii kutoka mtu mmoja hadi mwingine lakini "
    "inaweza kusambaa katika mwili wako. Kansa zote zinahitaji kutibiwa."
)

print(f"Text length: {len(swahili_text)} characters")

In [None]:
%%time

# Generate Swahili speech
speaker_id = 246  # Swahili speaker
temperature = 0.7

result_wav_swahili, sr = text_to_speech(
    text=swahili_text,
    audio_tokenizer=audio_tokenizer,
    model=model,
    chunk_text_simple=chunk_text_simple,
    speaker_id=speaker_id,
    temperature=temperature
)

In [None]:
# Play the generated Swahili audio
display(Audio(result_wav_swahili, rate=sr))

## 10. Custom Usage

Use this cell to generate speech from your own text.

In [None]:
# Enter your custom text here
my_text = "Your text goes here."

# Configure parameters
my_speaker_id = 248  # Choose appropriate speaker ID
my_temperature = 0.7  # 0.1 (conservative) to 1.0 (creative)

In [None]:
%%time

# Generate speech
my_wav, my_sr = text_to_speech(
    text=my_text,
    audio_tokenizer=audio_tokenizer,
    model=model,
    chunk_text_simple=chunk_text_simple,
    speaker_id=my_speaker_id,
    temperature=my_temperature
)

In [None]:
# Play your audio
display(Audio(my_wav, rate=my_sr))

In [None]:
# Save to file (optional)
output_filename = 'my_tts_output.wav'
sf.write(output_filename, my_wav, my_sr)
print(f"‚úÖ Audio saved to {output_filename}")

## 11. Tips and Best Practices

### Temperature Settings:
- **0.1-0.3**: More consistent but potentially monotone
- **0.5-0.7**: Balanced (recommended)
- **0.8-1.0**: More varied but potentially less stable

### Text Chunking Strategies:
- **chunk_text_simple**: Best for most use cases (one sentence per chunk)
- **chunk_text**: Good for controlling chunk size
- **chunk_text_with_count**: Good for grouping related sentences

### Handling Errors:
- The system automatically retries failed chunks up to 3 times
- Failed chunks are replaced with silence to maintain timing
- Adjust `max_retries` parameter if needed

### Performance Tips:
- Longer texts take more time to process
- GPU acceleration significantly speeds up generation
- Consider breaking very long texts into multiple batches

### Common Issues:
1. **Dimension mismatch errors**: Usually resolved by retry logic
2. **No audio output**: Check speaker_id and ensure text is not empty
3. **Poor quality**: Try adjusting temperature or using different speaker_id

## 12. Conclusion

You now have a complete TTS inference pipeline using Spark TTS with vLLM!

### Key Features Covered:
‚úÖ Model loading and initialization
‚úÖ Audio tokenizer setup
‚úÖ Text chunking strategies
‚úÖ Robust generation with retry logic
‚úÖ Multi-language support
‚úÖ Audio playback and saving

### Next Steps:
- Experiment with different speaker IDs
- Try various temperature settings
- Test with your own texts and languages
- Integrate into your applications

### Resources:
- [Spark TTS GitHub](https://github.com/SparkAudio/Spark-TTS)
- [vLLM Documentation](https://docs.vllm.ai/)
- [Model on Hugging Face](https://huggingface.co/jq/spark-tts-salt)