In [None]:
pip install -qU xformers transformers unsloth omegaconf einx einops soundfile librosa torch torchaudio

In [2]:
!git clone https://github.com/SparkAudio/Spark-TTS

Cloning into 'Spark-TTS'...
remote: Enumerating objects: 384, done.[K
remote: Counting objects: 100% (162/162), done.[K
remote: Compressing objects: 100% (68/68), done.[K
remote: Total 384 (delta 115), reused 94 (delta 94), pack-reused 222 (from 1)[K
Receiving objects: 100% (384/384), 7.07 MiB | 8.49 MiB/s, done.
Resolving deltas: 100% (155/155), done.


In [2]:
import tqdm.notebook as tqdm
from transformers import CsmForConditionalGeneration
import torch
import datasets
from IPython.display import Audio, display
from huggingface_hub import snapshot_download
from huggingface_hub import hf_hub_download
import torch
import re
import numpy as np
from typing import Dict, Any
import torchaudio.transforms as T
import huggingface_hub
import transformers
import time
import sys

In [None]:
huggingface_hub.notebook_login()

Get some code that allows us to load the audio tokenizer.

In [None]:
!git clone https://github.com/SparkAudio/Spark-TTS
sys.path.append('Spark-TTS')
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from sparktts.utils.audio import audio_volume_normalize

Get the audio tokenizer from the original repo.

In [None]:
snapshot_download(
    "unsloth/Spark-TTS-0.5B", local_dir = "Spark-TTS-0.5B",
    ignore_patterns=["*LLM*"])
audio_tokenizer = BiCodecTokenizer("Spark-TTS-0.5B", "cuda")

Get the customised TTS model

In [16]:
model = transformers.AutoModelForCausalLM.from_pretrained(
    "jq/spark-tts-salt",
    device_map='auto',
    torch_dtype="auto",
)
tokenizer = transformers.AutoTokenizer.from_pretrained("jq/spark-tts-salt")

In [18]:
@torch.inference_mode()
def generate_speech_from_text(
    text: str,
    temperature: float = 0.8,
    top_k: int = 50,
    top_p: float = 1, 
    max_new_audio_tokens: int = 2048, # Max tokens for audio part
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
) -> np.ndarray:
    """
    Generates speech audio from text using default voice control parameters.

    Args:
        text (str): The text input to be converted to speech.
        temperature (float): Sampling temperature for generation.
        top_k (int): Top-k sampling parameter.
        top_p (float): Top-p (nucleus) sampling parameter.
        max_new_audio_tokens (int): Max number of new tokens to generate (limits audio length).
        device (torch.device): Device to run inference on.

    Returns:
        np.ndarray: Generated waveform as a NumPy array.
    """

    prompt = "".join([
        "<|task_tts|>",
        "<|start_content|>",
        text,
        "<|end_content|>",
        "<|start_global_token|>"
    ])
    
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)

    ### This is the slow bit ###
    with torch.no_grad():
        generated_ids = model.generate(
            **model_inputs,
            max_new_tokens=max_new_audio_tokens, # Limit generation length
            do_sample=True,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            eos_token_id=tokenizer.eos_token_id, 
            pad_token_id=tokenizer.pad_token_id 
        )
    ### End of slow bit ###

    generated_ids_trimmed = generated_ids[:, model_inputs.input_ids.shape[1]:]

    predicts_text = tokenizer.batch_decode(generated_ids_trimmed, skip_special_tokens=False)[0]

    # Extract semantic token IDs using regex
    semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", predicts_text)
    if not semantic_matches:
        print("Warning: No semantic tokens found in the generated output.")
        # Handle appropriately - perhaps return silence or raise error
        return np.array([], dtype=np.float32)

    pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0) # Add batch dim

    # Extract global token IDs using regex (assuming controllable mode also generates these)
    global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", predicts_text)
    if not global_matches:
         print("Warning: No global tokens found in the generated output (controllable mode). Might use defaults or fail.")
         pred_global_ids = torch.zeros((1, 1), dtype=torch.long)
    else:
         pred_global_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0) # Add batch dim

    pred_global_ids = pred_global_ids.unsqueeze(0) # Shape becomes (1, 1, N_global)

    return pred_global_ids, pred_semantic_ids

In [19]:
# Speaker IDs:
# 241: Acholi (female)
# 242: Ateso (female)
# 243: Runyankore (female)
# 245: Lugbara (female)
# 246: Swahili (male)
# 248: Luganda (female)

In [24]:
short_input_text = "243: Hello" 
long_input_text = "243: Uganda is named after the Buganda kingdom, which encompasses a large portion of the south, including Kampala, and whose language Luganda is widely spoken"

In [23]:
# Test with short text
print("Generating short text...")
start_time = time.time()
pred_global_ids, pred_semantic_ids = generate_speech_from_text(short_input_text)
short_gen_time = time.time() - start_time
print(f"Short text generation took {short_gen_time:.2f} seconds")

wav_np = audio_tokenizer.detokenize(
    pred_global_ids.to('cuda').squeeze(0),
    pred_semantic_ids.to('cuda') 
)
display(Audio(wav_np, rate=16_000))

# Test with long text
print("\nGenerating long text...")
start_time = time.time()
pred_global_ids, pred_semantic_ids = generate_speech_from_text(long_input_text)
long_gen_time = time.time() - start_time
print(f"Long text generation took {long_gen_time:.2f} seconds")

wav_np = audio_tokenizer.detokenize(
    pred_global_ids.to('cuda').squeeze(0),
    pred_semantic_ids.to('cuda') 
)
display(Audio(wav_np, rate=16_000))

Generating short text...
Short text generation took 1.75 seconds



Generating long text...
Long text generation took 9.49 seconds
