In [None]:
!pip install unsloth
!git clone https://github.com/SparkAudio/Spark-TTS
!pip install omegaconf einx

In [None]:
!pip install librosa soundfile

In [None]:
!pip install einops

In [None]:
import tqdm.notebook as tqdm
from unsloth import FastModel
from transformers import CsmForConditionalGeneration
import torch
import datasets
from datasets import load_dataset, Audio, Dataset
from IPython.display import Audio, display
from huggingface_hub import snapshot_download

In [None]:
# Download model and code
snapshot_download("unsloth/Spark-TTS-0.5B", local_dir = "Spark-TTS-0.5B")

model, tokenizer = FastModel.from_pretrained(
    model_name = f"Spark-TTS-0.5B/LLM",
    max_seq_length = 2048,
    dtype = torch.float32, # Spark seems to only work on float32 for now
    full_finetuning = True, # We support full finetuning now!
    load_in_4bit = False,
)

In [None]:
import mlflow
from getpass import getpass
import os
MLFLOW_TRACKING_USERNAME = getpass('Enter the MLFLOW_TRACKING_USERNAME: ')
os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME
MLFLOW_TRACKING_PASSWORD = getpass('Enter the MLFLOW_TRACKING_PASSWORD: ')
os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD
os.environ["MLFLOW_TRACKING_URI"] = "https://mlflow-sunbird-ce0ecfc14244.herokuapp.com" 
os.environ["MLFLOW_EXPERIMENT_NAME"] = "tts-csm-1b"

In [None]:
ds_lug = load_dataset(
    "Sunbird/salt", "studio-lug", split="train").map(lambda example: {"speaker_id": 1})

ds_eng = load_dataset(
    "Sunbird/salt", "studio-eng", split="train").map(lambda example: {"speaker_id": 1})

ds_ach = load_dataset(
    "Sunbird/salt", "studio-ach", split="train").map(lambda example: {"speaker_id": 2})

ds_swa = load_dataset(
    "Sunbird/salt", "studio-swa", split="train").map(lambda example: {"speaker_id": 3})

ds_lgg = load_dataset(
    "Sunbird/salt", "studio-lgg", split="train").map(lambda example: {"speaker_id": 4})

ds_nyn = load_dataset(
    "Sunbird/salt", "studio-nyn", split="train").map(lambda example: {"speaker_id": 5})

ds_teo = load_dataset(
    "Sunbird/salt", "studio-teo", split="train").map(lambda example: {"speaker_id": 6})

dataset = datasets.concatenate_datasets(
    [ds_ach,ds_lug, ds_eng, ds_swa, ds_lgg, ds_nyn, ds_teo]).shuffle(seed=42)

sampling_rate = 24000
dataset = dataset.cast_column("audio", datasets.Audio(sampling_rate=sampling_rate))

dataset = dataset.filter(
    lambda example: (0.5 * sampling_rate) < len(example["audio"]["array"]) < (8 * sampling_rate),
    num_proc=20,
)

In [None]:
import locale
import torchaudio.transforms as T
import os
import torch
import sys
import numpy as np
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from sparktts.utils.audio import audio_volume_normalize

audio_tokenizer = BiCodecTokenizer("Spark-TTS-0.5B", "cuda")
def extract_wav2vec2_features( wavs: torch.Tensor) -> torch.Tensor:
        """extract wav2vec2 features"""

        if wavs.shape[0] != 1:

             raise ValueError(f"Expected batch size 1, but got shape {wavs.shape}")
        wav_np = wavs.squeeze(0).cpu().numpy()

        processed = audio_tokenizer.processor(
            wav_np,
            sampling_rate=16000,
            return_tensors="pt",
            padding=True,
        )
        input_values = processed.input_values

        input_values = input_values.to(audio_tokenizer.feature_extractor.device)

        model_output = audio_tokenizer.feature_extractor(
            input_values,
        )

        if model_output.hidden_states is None:
             raise ValueError("Wav2Vec2Model did not return hidden states. Ensure config `output_hidden_states=True`.")

        num_layers = len(model_output.hidden_states)
        required_layers = [11, 14, 16]
        if any(l >= num_layers for l in required_layers):
             raise IndexError(f"Requested hidden state indices {required_layers} out of range for model with {num_layers} layers.")

        feats_mix = (
            model_output.hidden_states[11] + model_output.hidden_states[14] + model_output.hidden_states[16]
        ) / 3

        return feats_mix
    
def formatting_audio_func(example):
    text = f"{example['speaker_id']}: {example['text']}" if "speaker_id" in example else example["text"]
    audio_array = example["audio"]["array"]
    sampling_rate = example["audio"]["sampling_rate"]

    target_sr = audio_tokenizer.config['sample_rate']

    if sampling_rate != target_sr:
        resampler = T.Resample(orig_freq=sampling_rate, new_freq=target_sr)
        audio_tensor_temp = torch.from_numpy(audio_array).float()
        audio_array = resampler(audio_tensor_temp).numpy()

    if audio_tokenizer.config["volume_normalize"]:
        audio_array = audio_volume_normalize(audio_array)

    ref_wav_np = audio_tokenizer.get_ref_clip(audio_array)

    audio_tensor = torch.from_numpy(audio_array).unsqueeze(0).float().to(audio_tokenizer.device)
    ref_wav_tensor = torch.from_numpy(ref_wav_np).unsqueeze(0).float().to(audio_tokenizer.device)

    feat = extract_wav2vec2_features(audio_tensor)

    batch = {
        "wav": audio_tensor,
        "ref_wav": ref_wav_tensor,
        "feat": feat.to(audio_tokenizer.device),
    }

    semantic_token_ids, global_token_ids = audio_tokenizer.model.tokenize(batch)

    global_tokens = "".join(
        [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze().cpu().numpy()] # Squeeze batch dim
    )
    semantic_tokens = "".join(
        [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze().cpu().numpy()] # Squeeze batch dim
    )

    inputs = [
        "<|task_tts|>",
        "<|start_content|>",
        text,
        "<|end_content|>",
        "<|start_global_token|>",
        global_tokens,
        "<|end_global_token|>",
        "<|start_semantic_token|>",
        semantic_tokens,
        "<|end_semantic_token|>",
        "<|im_end|>"
    ]
    inputs = "".join(inputs)
    return {"text": inputs}


dataset = dataset.take(12_000).map(formatting_audio_func, remove_columns=["audio"])
print("Moving Bicodec model and Wav2Vec2Model to cpu.")
audio_tokenizer.model.cpu()
audio_tokenizer.feature_extractor.cpu()
torch.cuda.empty_cache()

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = 2048,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1, # Set this for 1 full training run.
        learning_rate = 2e-4,
        fp16 = False, # We're doing full float32 s disable mixed precision
        bf16 = False, # We're doing full float32 s disable mixed precision
        logging_steps = 50,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none", # Use this for WandB etc
    ),
)

In [None]:
trainer_stats = trainer.train()

In [None]:
gpu_audio_tokenizer = BiCodecTokenizer("Spark-TTS-0.5B", "cuda")

In [None]:
import torch
import re
import numpy as np
from typing import Dict, Any
import torchaudio.transforms as T

FastModel.for_inference(model) # Enable native 2x faster inference

@torch.inference_mode()
def generate_speech_from_text(
    text: str,
    temperature: float = 0.8,   # Generation temperature
    top_k: int = 50,            # Generation top_k
    top_p: float = 1,        # Generation top_p
    max_new_audio_tokens: int = 2048, # Max tokens for audio part
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
) -> np.ndarray:
    """
    Generates speech audio from text using default voice control parameters.

    Args:
        text (str): The text input to be converted to speech.
        temperature (float): Sampling temperature for generation.
        top_k (int): Top-k sampling parameter.
        top_p (float): Top-p (nucleus) sampling parameter.
        max_new_audio_tokens (int): Max number of new tokens to generate (limits audio length).
        device (torch.device): Device to run inference on.

    Returns:
        np.ndarray: Generated waveform as a NumPy array.
    """

    prompt = "".join([
        "<|task_tts|>",
        "<|start_content|>",
        text,
        "<|end_content|>",
        "<|start_global_token|>"
    ])
    
    # prompt = "".join([
    #     "<|task_controllable_tts|>",
    #     "<|start_content|>",
    #     text,
    #     "<|end_content|>",
    #     "<|start_style_label|>",
    #     #"<|pitch_label_0|>"
    #     #"<|gender_0|>",
    #     "<|speed_label_0|>",
    #     "<|end_style_label|>",
    # ])

    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)

    print("Generating token sequence...")
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_audio_tokens, # Limit generation length
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id, # Stop token
        pad_token_id=tokenizer.pad_token_id # Use models pad token id
    )
    print("Token sequence generated.")

    generated_ids_trimmed = generated_ids[:, model_inputs.input_ids.shape[1]:]

    predicts_text = tokenizer.batch_decode(generated_ids_trimmed, skip_special_tokens=False)[0]
    # print(f"\nGenerated Text (for parsing):\n{predicts_text}\n") # Debugging

    # Extract semantic token IDs using regex
    semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", predicts_text)
    if not semantic_matches:
        print("Warning: No semantic tokens found in the generated output.")
        # Handle appropriately - perhaps return silence or raise error
        return np.array([], dtype=np.float32)

    pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0) # Add batch dim

    # Extract global token IDs using regex (assuming controllable mode also generates these)
    global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", predicts_text)
    if not global_matches:
         print("Warning: No global tokens found in the generated output (controllable mode). Might use defaults or fail.")
         pred_global_ids = torch.zeros((1, 1), dtype=torch.long)
    else:
         pred_global_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0) # Add batch dim

    pred_global_ids = pred_global_ids.unsqueeze(0) # Shape becomes (1, 1, N_global)

    print(f"Found {pred_semantic_ids.shape[1]} semantic tokens.")
    print(f"Found {pred_global_ids.shape[2]} global tokens.")


    # 5. Detokenize using BiCodecTokenizer
    print("Detokenizing audio tokens...")
    # Ensure audio_tokenizer and its internal model are on the correct device
    audio_tokenizer.device = device
    audio_tokenizer.model.to(device)

    return pred_global_ids, pred_semantic_ids

In [None]:
input_text = 'There are 123 of them on 11th June 2025.'
#input_text = "Nsobola okwogera Oluganda n'ennimi endala."
#input_text = "Once there was a boy who lived on the moon. All his friends were down on Earth."
#input_text = "2: Ekitiibwa ky'omuntu eky'obutonde; okwenkanankana, wamu n'obuyinza obutayinza kugyibwawo ebyabantu bonna, gwe musingi gw'eddembe; obwenkanya n'emirembe mu nsi."
#input_text = "Ndi musomesa mu Makerere University in the department of computer science."

pred_global_ids, pred_semantic_ids = generate_speech_from_text(input_text)
device = 'cuda'
wav_np = gpu_audio_tokenizer.detokenize(
    pred_global_ids.to(device).squeeze(0),
    pred_semantic_ids.to(device) 
)

In [None]:
sample_rate = audio_tokenizer.config.get("sample_rate", 16000)
from IPython.display import Audio, display
display(Audio(wav_np, rate=sample_rate))

In [None]:
model.push_to_hub('jq/spark-tts-salt')