# Chatterbox TTS - Colab Notebook
This notebook demonstrates how to use Chatterbox, a state-of-the-art open-source TTS model with IndexTTS-inspired features. Follow the steps below to get started.

## 🚨 IMPORTANT: If you get 'generate_long_text' AttributeError

If you see an error like `'ChatterboxTTS' object has no attribute 'generate_long_text'`, this is due to Python's import cache. **SOLUTION:**

1. **RESTART YOUR KERNEL** (Runtime → Restart Runtime in Colab)
2. Run the "Clear Import Cache" cell below
3. Re-run all cells in order

The `generate_long_text` method exists and works - it's just a caching issue!

## 1. Install Dependencies
First, let's install the required packages.

In [None]:
!pip install chatterbox-tts gradio

## 2. Clear Import Cache (IMPORTANT!)
If you're getting 'generate_long_text' attribute errors, run this cell first to clear Python's import cache.

In [None]:
# Clear Python import cache to ensure we get the latest version
import sys
import importlib

# Remove cached chatterbox modules
modules_to_remove = [key for key in sys.modules.keys() if key.startswith('chatterbox')]
for module in modules_to_remove:
    del sys.modules[module]
    print(f"Removed cached module: {module}")

print("✅ Import cache cleared successfully!")

## 3. Import Libraries and Setup

In [None]:
import sys
import importlib
import torch
import torchaudio as ta

# Comprehensive module cache clearing for Colab
print("🔄 Clearing all chatterbox modules from cache...")
modules_to_remove = [key for key in sys.modules.keys() if 'chatterbox' in key.lower()]
for module in modules_to_remove:
    if module in sys.modules:
        del sys.modules[module]
        print(f"   Removed: {module}")

# Force fresh import
try:
    import chatterbox
    importlib.reload(chatterbox)
except:
    pass

try:
    import chatterbox.tts
    importlib.reload(chatterbox.tts)
except:
    pass

# Final import
from chatterbox.tts import ChatterboxTTS
print("✅ ChatterboxTTS imported successfully!")

# Automatically detect the best available device with detailed info
if torch.cuda.is_available():
    device = "cuda"
    print(f"🚀 Using GPU: {torch.cuda.get_device_name(0)}")
    print(f"   CUDA Version: {torch.version.cuda}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"   PyTorch Version: {torch.__version__}")
elif torch.backends.mps.is_available():
    device = "mps"
    print("🍎 Using Apple Metal Performance Shaders (MPS)")
else:
    device = "cpu"
    print("⚠️  Using CPU (No GPU/MPS available)")
    print("   For better performance, ensure CUDA (NVIDIA) or MPS (Apple) is available")

print(f"\nDevice: {device}")

# Comprehensive method verification
print("\n🔍 Verifying ChatterboxTTS methods...")
methods_to_check = ['generate', 'generate_long_text', 'generate_streaming', 'estimate_memory_usage']
for method_name in methods_to_check:
    if hasattr(ChatterboxTTS, method_name):
        print(f"✅ {method_name} method is available!")
    else:
        print(f"❌ {method_name} method not found!")

# Additional verification for instance methods
try:
    import inspect
    sig_long = inspect.signature(ChatterboxTTS.generate_long_text)
    sig_estimate = inspect.signature(ChatterboxTTS.estimate_memory_usage)
    print(f"\n📋 Method signatures verified:")
    print(f"   generate_long_text{sig_long}")
    print(f"   estimate_memory_usage{sig_estimate}")
except Exception as e:
    print(f"⚠️  Could not verify method signatures: {e}")

print("\n🎉 Setup completed! Ready to load model.")

## 4. Load the Model
Now we'll load the pre-trained Chatterbox TTS model.

In [None]:
model = ChatterboxTTS.from_pretrained(device=device)

## 5. Generate Speech from Text
Let's generate some speech using the default voice.

In [None]:
text = "Hello! This is a test of the Chatterbox TTS system. It can generate natural-sounding speech from text."
wav = model.generate(text)

# Save the generated audio
output_path = "test_output.wav"
ta.save(output_path, wav, model.sr)

# Display audio player in the notebook
from IPython.display import Audio
Audio(output_path)

## 6. Voice Cloning (Optional)
You can also clone a voice by providing an audio prompt. Upload your audio file and specify its path below.

In [None]:
# Upload your audio file and set its path here
AUDIO_PROMPT_PATH = "path_to_your_audio.wav"  # Replace with your audio file path

# Generate speech with the voice from the audio prompt
text = "This is the same text but spoken in a different voice."
wav = model.generate(text, audio_prompt_path=AUDIO_PROMPT_PATH)

# Save and play the generated audio
output_path_cloned = "test_output_cloned.wav"
ta.save(output_path_cloned, wav, model.sr)
Audio(output_path_cloned)

## 7. Long Text Generation Example (IndexTTS-Inspired Features)
This demonstrates the new generate_long_text method with advanced chunking and memory optimization.

In [None]:
# Example of generating audio from long text with memory optimization
long_text = """
This is a demonstration of the new long text generation capabilities. 
The system can now handle ultra-long texts by breaking them into smaller chunks. 
Each chunk is processed separately to avoid memory issues. 
The chunks are then combined to create seamless audio output. 
This approach allows for generating hours of audio content without running out of memory. 
The chunking can be done by sentences, clauses, or character count. 
Sentence-based chunking preserves natural speech patterns for the best quality. 
Memory optimization between chunks ensures efficient resource usage. 
You can also estimate memory usage before generation to plan accordingly.
"""

# Verify methods are available before using them
print("🔍 Checking if model has required methods...")

if hasattr(model, 'estimate_memory_usage'):
    print("✅ estimate_memory_usage method found!")
    # Estimate memory usage
    memory_info = model.estimate_memory_usage(long_text)
    print(f"Text length: {memory_info['text_length']} characters")
    print(f"Estimated memory usage: {memory_info['total_estimated_mb']:.0f}MB")
    print(f"Recommended chunk size: {memory_info['recommended_chunk_size']} characters")
else:
    print("❌ estimate_memory_usage method not found! Using fallback...")
    print(f"Text length: {len(long_text)} characters")
    print("Estimated memory usage: ~50MB (fallback estimate)")

print("\n")

if hasattr(model, 'generate_long_text'):
    print("✅ generate_long_text method found! Generating audio...")
    # Generate audio using long text method
    wav_long = model.generate_long_text(
        text=long_text,
        chunk_method="sentences",
        max_chunk_size=200,
        optimize_memory_between_chunks=True
    )
    
    # Save and play the generated audio
    output_path_long = "test_output_long.wav"
    ta.save(output_path_long, wav_long, model.sr)
    print(f"✅ Long audio saved to: {output_path_long}")
    Audio(output_path_long)
else:
    print("❌ generate_long_text method not found! Using fallback generate method...")
    # Fallback to regular generate method
    wav_long = model.generate(long_text)
    
    # Save and play the generated audio
    output_path_long = "test_output_long_fallback.wav"
    ta.save(output_path_long, wav_long, model.sr)
    print(f"✅ Audio saved to: {output_path_long} (using fallback method)")
    Audio(output_path_long)

## 8. Gradio UI with Live Server
Launch an interactive web interface with all the new features including long text generation.

In [None]:
import gradio as gr
import inspect

def generate_with_gradio(text, audio_prompt, exaggeration, temperature, cfgw, min_p, top_p, repetition_penalty, max_new_tokens, use_long_text, chunk_method, max_chunk_size, optimize_memory):
    try:
        # The model is already loaded in the notebook's global state
        audio_prompt_path = audio_prompt if audio_prompt else None
        
        # Estimate memory usage for long texts
        if len(text) > 500:
            memory_info = model.estimate_memory_usage(text, int(max_new_tokens))
            print(f"Memory estimate: {memory_info['total_estimated_mb']:.0f}MB, recommended chunk size: {memory_info['recommended_chunk_size']}")
        
        # Choose generation method based on text length and user preference
        if use_long_text or len(text) > max_chunk_size:
            # Use long text generation with chunking
            wav = model.generate_long_text(
                text=text,
                chunk_method=chunk_method,
                max_chunk_size=int(max_chunk_size),
                audio_prompt_path=audio_prompt_path,
                exaggeration=exaggeration,
                temperature=temperature,
                cfg_weight=cfgw,
                min_p=min_p,
                top_p=top_p,
                repetition_penalty=repetition_penalty,
                max_new_tokens=int(max_new_tokens),
                optimize_memory_between_chunks=optimize_memory,
            )
        else:
            # Use standard generation for shorter texts
            # Check if the generate method supports max_new_tokens parameter
            generate_params = inspect.signature(model.generate).parameters
            
            if 'max_new_tokens' in generate_params:
                wav = model.generate(
                    text,
                    audio_prompt_path=audio_prompt_path,
                    exaggeration=exaggeration,
                    temperature=temperature,
                    cfg_weight=cfgw,
                    min_p=min_p,
                    top_p=top_p,
                    repetition_penalty=repetition_penalty,
                    max_new_tokens=int(max_new_tokens),
                )
            else:
                # Fallback for older version without max_new_tokens
                wav = model.generate(
                    text,
                    audio_prompt_path=audio_prompt_path,
                    exaggeration=exaggeration,
                    temperature=temperature,
                    cfg_weight=cfgw,
                    min_p=min_p,
                    top_p=top_p,
                    repetition_penalty=repetition_penalty,
                )
        
        return (model.sr, wav.squeeze(0).numpy()), None
    except Exception as e:
        return None, str(e)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            text = gr.Textbox(
                value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
                label="Text to synthesize",
                max_lines=5
            )
            ref_wav = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Reference Audio File", value=None)
            exaggeration = gr.Slider(0.25, 2, step=.05, label="Exaggeration (Neutral = 0.5, extreme values can be unstable)", value=.5)
            cfg_weight = gr.Slider(0.0, 1, step=.05, label="CFG/Pace", value=0.5)

            with gr.Accordion("Long Text Options", open=False):
                use_long_text = gr.Checkbox(label="Force Long Text Mode (for ultra-long texts)", value=False)
                chunk_method = gr.Dropdown(["sentences", "clauses", "character"], label="Chunking Method", value="sentences")
                max_chunk_size = gr.Slider(50, 1000, step=50, label="Max Chunk Size (characters)", value=200)
                optimize_memory = gr.Checkbox(label="Optimize Memory Between Chunks", value=True)

            with gr.Accordion("More options", open=False):
                temp = gr.Slider(0.05, 5, step=.05, label="temperature", value=.8)
                min_p = gr.Slider(0.00, 1.00, step=0.01, label="min_p || Newer Sampler. Recommend 0.02 > 0.1. Handles Higher Temperatures better. 0.00 Disables", value=0.05)
                top_p = gr.Slider(0.00, 1.00, step=0.01, label="top_p || Original Sampler. 1.0 Disables(recommended). Original 0.8", value=1.00)
                repetition_penalty = gr.Slider(1.00, 2.00, step=0.1, label="repetition_penalty", value=1.2)
                max_new_tokens = gr.Slider(100, 2000, step=50, label="Max New Tokens", value=1000)

            run_btn = gr.Button("Generate", variant="primary")

        with gr.Column():
            audio_output = gr.Audio(label="Output Audio")
            error_output = gr.Textbox(label="Error", visible=False)

    def show_error(error_message):
        if error_message:
            return gr.update(visible=True, value=error_message)
        return gr.update(visible=False)

    run_btn.click(
        fn=generate_with_gradio,
        inputs=[
            text,
            ref_wav,
            exaggeration,
            temp,
            cfg_weight,
            min_p,
            top_p,
            repetition_penalty,
            max_new_tokens,
            use_long_text,
            chunk_method,
            max_chunk_size,
            optimize_memory,
        ],
        outputs=[audio_output, error_output],
    ).then(show_error, inputs=error_output, outputs=error_output)

demo.launch(share=True)