In [None]:

#

# import torch;
# import re;

# v = re.match(r"[0-9]{1,}\.[0-9]{1,}", str(torch.__version__)).group(0)
# xformers = "xformers==" + ("0.0.33.post1" if v=="2.9" else "0.0.32.post2" if v=="2.8" else "0.0.29.post3")
# !pip install --no-deps bitsandbytes accelerate {xformers} peft==0.17.1 trl triton cut_cross_entropy unsloth_zoo
# !pip install sentencepiece protobuf "datasets==4.3.0" "huggingface_hub>=0.34.0" hf_transfer
# !pip install --no-deps unsloth
# !pip install transformers==4.56.2
# !pip install --no-deps trl==0.22.2
# !git clone https://github.com/SparkAudio/Spark-TTS
# !git clone https://github.com/Sibgat-Ul/VoxCPM_bn_en
# !pip install omegaconf einx torchcodec "datasets>=3.4.1,<4.0.0"
# !pip install soxr soundfile einops -q
# !pip install pyworld -q

# import warnings
# warnings.filterwarnings('ignore')

# UNCOMMENT AND RUN THE PREVIOUS CELL ONLY IF YOU HAVE NOT INSTALLED THE DEPENDENCIES ALREADY

In [13]:
import locale
import torchaudio.transforms as T
import os
import torch
import sys
# sys.path.append('Spark-TTS')
# import librosa
from tqdm import tqdm
from typing import List, Optional
import numpy as np
from datasets import load_dataset
from huggingface_hub import snapshot_download
import pandas as pd
from datasets import Dataset, Audio
from datasets import concatenate_datasets, load_dataset
import unsloth
# import pyworld as pw

PITCH_THRESHOLDS = {
    "male": {
        "very_low": (float('-inf'), 145),
        "low": (145, 164),
        "moderate": (164, 211),
        "high": (211, 250),
        "very_high": (250, float('inf')),
    },
    "female": {
        "very_low": (float('-inf'), 225),
        "low": (225, 258),
        "moderate": (258, 314),
        "high": (314, 353),
        "very_high": (353, float('inf')),
    }
}

PITCH_LEVEL_NAMES = ['very_low', 'low', 'moderate', 'high', 'very_high']

LEVELS_MAP = {
    "very_low": 0,
    "low": 1,
    "moderate": 2,
    "high": 3,
    "very_high": 4,
}

GENDER_MAP = {
    "female": 0,
    "male": 1,
}

TASK_TOKEN_MAP = {
    "vc": "<|task_vc|>",
    "tts": "<|task_tts|>",
    "asr": "<|task_asr|>",
    "s2s": "<|task_s2s|>",
    "t2s": "<|task_t2s|>",
    "understand": "<|task_understand|>",
    "caption": "<|task_cap|>",
    "controllable_tts": "<|task_controllable_tts|>",
    "prompt_tts": "<|task_prompt_tts|>",
    "speech_edit": "<|task_edit|>",
}


In [14]:
from huggingface_hub import login
import os

login(token=os.getenv("HF_TOKEN"))

In [15]:
from unsloth import FastModel
from huggingface_hub import snapshot_download

max_seq_length = 4096

snapshot_download("unsloth/Spark-TTS-0.5B", local_dir = "Spark-TTS-0.5B")

def get_model(checkpoint: str, return_peft: bool = False):
    model, tokenizer = FastModel.from_pretrained(
        model_name = f"{checkpoint}",
        # model_name = f"/workspace/outputs/checkpoint-100Spark-TTS-0.5B/LLM",
        max_seq_length = max_seq_length,
        dtype = torch.float32,
        full_finetuning = False,
        load_in_4bit = False,
    )

    if checkpoint[0] != '/' and return_peft:
        model = FastModel.get_peft_model(
            model,
            r = 128, 
            target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                              "gate_proj", "up_proj", "down_proj"],
            lora_alpha = 128,
            lora_dropout = 0, 
            bias = "none",   
            use_gradient_checkpointing = "unsloth",
            random_state = 3407,
            use_rslora = False, 
            loftq_config = None, 
        )

    return model, tokenizer

# model, tokenizer = get_model("/kaggle/input/spark-tts/transformers/checkpoint_300_1/4/checkpoint-350")
model, tokenizer = get_model("./pretrained_models/spark_tts_bn")

Fetching 31 files: 100%|██████████| 31/31 [00:00<00:00, 2452.86it/s]


==((====))==  Unsloth 2026.1.2: Fast Qwen2 patching. Transformers: 4.56.2.
   \\   /|    NVIDIA GeForce RTX 3060. Num GPUs = 1. Max memory: 11.622 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.1+cu128. CUDA: 8.6. CUDA Toolkit: 12.8. Triton: 3.5.1
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


In [16]:
import torch
from sparktts.models.audio_tokenizer import BiCodecTokenizer
from sparktts.utils.audio import audio_volume_normalize
from typing import Optional

audio_tokenizer = BiCodecTokenizer("Spark-TTS-0.5B", "cuda")

def extract_wav2vec2_features(wavs: torch.Tensor) -> torch.Tensor:
    """extract wav2vec2 features"""

    if wavs.shape[0] != 1:
         raise ValueError(f"Expected batch size 1, but got shape {wavs.shape}")
    wav_np = wavs.squeeze(0).cpu().numpy()

    processed = audio_tokenizer.processor(
        wav_np,
        sampling_rate=16000,
        return_tensors="pt",
        padding=True,
    )
    input_values = processed.input_values

    input_values = input_values.to(audio_tokenizer.feature_extractor.device)

    model_output = audio_tokenizer.feature_extractor(
        input_values,
    )

    if model_output.hidden_states is None:
         raise ValueError("Wav2Vec2Model did not return hidden states. Ensure config `output_hidden_states=True`.")

    num_layers = len(model_output.hidden_states)
    required_layers = [11, 14, 16]
    if any(l >= num_layers for l in required_layers):
         raise IndexError(f"Requested hidden state indices {required_layers} out of range for model with {num_layers} layers.")

    feats_mix = (
        model_output.hidden_states[11] + model_output.hidden_states[14] + model_output.hidden_states[16]
    ) / 3

    return feats_mix


def build_attribute_tokens(pitch_level: Optional[str]) -> str:
    tokens = []

    if pitch_level is not None and pitch_level in LEVELS_MAP:
        pitch_level_id = LEVELS_MAP[pitch_level]
        tokens.append(f"<|pitch_label_{pitch_level_id}|>")

    return "".join(tokens)

def build_tts_input_with_attributes(
    text: str,
    pitch_level: Optional[str],
    global_tokens: str,
    semantic_tokens: str
) -> str:
    attribute_tokens = build_attribute_tokens(pitch_level)

    if attribute_tokens:
        inputs = [
            "<|task_tts|>",
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_style_label|>",
            attribute_tokens,
            "<|end_style_label|>",
            "<|start_global_token|>",
            global_tokens,
            "<|end_global_token|>",
            "<|start_semantic_token|>",
            semantic_tokens,
            "<|end_semantic_token|>",
            "<|im_end|>"
        ]
    else:
        inputs = [
            "<|task_tts|>",
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_global_token|>",
            global_tokens,
            "<|end_global_token|>",
            "<|start_semantic_token|>",
            semantic_tokens,
            "<|end_semantic_token|>",
            "<|im_end|>"
        ]

    return "".join(inputs)


def formatting_audio_func_v2(example):
    torch.cuda.empty_cache()
    audio_tokenizer.model.cuda()
    audio_tokenizer.feature_extractor.cuda()
    
    text = f"{example['source']}: {example['text']}" if "source" in example else example["text"]
    
    gender = example.get('gender', None)
    pitch_level = example.get('pitch_level', None)
    
    audio_array = example["audio"]["array"]
    sampling_rate = example["audio"]["sampling_rate"]

    target_sr = audio_tokenizer.config['sample_rate']

    if sampling_rate != target_sr:
        resampler = T.Resample(orig_freq=sampling_rate, new_freq=target_sr)
        audio_tensor_temp = torch.from_numpy(audio_array).float()
        audio_array = resampler(audio_tensor_temp).numpy()

    if audio_tokenizer.config["volume_normalize"]:
        audio_array = audio_volume_normalize(audio_array)

    ref_wav_np = audio_tokenizer.get_ref_clip(audio_array)

    audio_tensor = torch.from_numpy(audio_array).unsqueeze(0).float().to(audio_tokenizer.device)
    ref_wav_tensor = torch.from_numpy(ref_wav_np).unsqueeze(0).float().to(audio_tokenizer.device)

    feat = extract_wav2vec2_features(audio_tensor)

    batch = {
        "wav": audio_tensor,
        "ref_wav": ref_wav_tensor,
        "feat": feat.to(audio_tokenizer.device),
    }

    semantic_token_ids, global_token_ids = audio_tokenizer.model.tokenize(batch)

    global_tokens = "".join(
        [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze().cpu().numpy()]
    )
    
    semantic_tokens = "".join(
        [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze().cpu().numpy()]
    )

    inputs = build_tts_input_with_attributes(
        text=text,
        pitch_level=pitch_level,
        global_tokens=global_tokens,
        semantic_tokens=semantic_tokens
    )

    audio_tokenizer.model.cpu()
    audio_tokenizer.feature_extractor.cpu()
    torch.cuda.empty_cache()
    
    return {"text": inputs}

Missing tensor: mel_transformer.spectrogram.window
Missing tensor: mel_transformer.mel_scale.fb


OutOfMemoryError: CUDA out of memory. Tried to allocate 14.00 MiB. GPU 0 has a total capacity of 11.62 GiB of which 25.38 MiB is free. Process 94176 has 4.38 GiB memory in use. Including non-PyTorch memory, this process has 6.53 GiB memory in use. Of the allocated memory 6.36 GiB is allocated by PyTorch, and 4.87 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
from pathlib import Path
from typing import Tuple

def process_prompt_text(gender: str, pitch: str, speed: str, text: str):
    assert gender in GENDER_MAP.keys()
    assert pitch in LEVELS_MAP.keys()
    assert speed in LEVELS_MAP.keys()

    gender_id = GENDER_MAP[gender]
    pitch_level_id = LEVELS_MAP[pitch]
    speed_level_id = LEVELS_MAP[speed]

    pitch_label_tokens = f"<|pitch_label_{pitch_level_id}|>"
    speed_label_tokens = f"<|speed_label_{speed_level_id}|>"
    gender_tokens = f"<|gender_{gender_id}|>"

    attribte_tokens = "".join(
        [gender_tokens, pitch_label_tokens, speed_label_tokens]
    )

    control_tts_inputs = [
        "<|task_tts|>",
        "<|start_content|>",
        text,
        "<|end_content|>",
        "<|start_style_label|>",
        attribte_tokens,
        "<|end_style_label|>",
        "<|start_global_token|>"
    ]

    return "".join(control_tts_inputs)


def process_prompt_wav(text: str, prompt_speech_path: Path, prompt_text: str = None) -> Tuple[str, torch.Tensor]:

    global_token_ids, semantic_token_ids = audio_tokenizer.tokenize(
        prompt_speech_path
    )
    global_tokens = "".join(
        [f"<|bicodec_global_{i}|>" for i in global_token_ids.squeeze()]
    )

    if prompt_text is not None:
        semantic_tokens = "".join(
            [f"<|bicodec_semantic_{i}|>" for i in semantic_token_ids.squeeze()]
        )
        inputs = [
            TASK_TOKEN_MAP["tts"],
            "<|start_content|>",
            prompt_text,
            text,
            "<|end_content|>",
            "<|start_global_token|>",
            global_tokens,
            "<|end_global_token|>",
            "<|start_semantic_token|>",
            semantic_tokens,
        ]
    else:
        inputs = [
            TASK_TOKEN_MAP["tts"],
            "<|start_content|>",
            text,
            "<|end_content|>",
            "<|start_global_token|>",
            global_tokens,
            "<|end_global_token|>",
        ]

    inputs = "".join(inputs)

    return inputs, global_token_ids

@torch.inference_mode()
def generate_speech_from_text(
    text: str,
    gender: str,
    pitch: str,
    speed: str,
    wav_path: str = None,
    wav_text: str = None,
    temperature: float = 0.8,  
    top_k: int = 50,           
    top_p: float = 1,       
    max_new_audio_tokens: int = 4096, 
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
) -> np.ndarray:

    torch.compiler.reset()
    audio_tokenizer.model.to(device)
    model.to(device)

    if wav_path != None:
        prompt, global_token_ids = process_prompt_wav(text, wav_path, wav_text)
    else:
        prompt = process_prompt_text(gender=gender, pitch=pitch, speed=speed, text=text)
        
    model_inputs = tokenizer([prompt], return_tensors="pt").to(device)
    # print(prompt, global_token_ids)
    # print(model_inputs)

    # print("Generating token sequence...")
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_new_audio_tokens,
        do_sample=True,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id
    )
    # print("Token sequence generated.")
    # print("generated tokens: ", tokenizer.decode(generated_ids[0], special_tokens=True))

    generated_ids_trimmed = generated_ids[:, model_inputs.input_ids.shape[1]:]

    predicts_text = tokenizer.batch_decode(generated_ids_trimmed, skip_special_tokens=False)[0]

    semantic_matches = re.findall(r"<\|bicodec_semantic_(\d+)\|>", predicts_text)
    if not semantic_matches:
        print("Warning: No semantic tokens found in the generated output.")
        return np.array([], dtype=np.float32)

    pred_semantic_ids = torch.tensor([int(token) for token in semantic_matches]).long().unsqueeze(0)

    global_matches = re.findall(r"<\|bicodec_global_(\d+)\|>", predicts_text)
    if not global_matches:
        if global_token_ids is not None:
            pred_global_ids = global_token_ids
        else:
            print("Warning: No global tokens found in the generated output (controllable mode). Might use defaults or fail.")
    else:
        pred_global_ids = torch.tensor([int(token) for token in global_matches]).long().unsqueeze(0)
        pred_global_ids = pred_global_ids.unsqueeze(0)
        
    # print(f"Found {pred_semantic_ids.shape} semantic tokens.")
    # print(f"Found {pred_global_ids.shape} global tokens.")

    audio_tokenizer.device = device
    audio_tokenizer.model.to(device)
    wav_np = audio_tokenizer.detokenize(
        pred_global_ids.to(device).squeeze(0),
        pred_semantic_ids.to(device) 
    )

    return wav_np

In [None]:
import torch
import re
import numpy as np
from typing import Dict, Any
import torchaudio.transforms as T

# FastModel.for_inference(model) 
text = "আপনি পিচ এবং গতির মতো পরামিতিগুলি সামঞ্জস্য করে একটি কাস্টমাইজড ভয়েস তৈরি করতে পারেন।"
wav_path = './cli/en_male_1.wav'
wav_text = 'বাংলা সার্বভৌম ভাষাভিত্তিক জাতিরাষ্ট্র বাংলাদেশের একমাত্র রাষ্ট্রভাষা তথা সরকারি ভাষা।'
print(f"Generating speech for: '{text}'")
generated_waveform = generate_speech_from_text(text=text, gender="male", pitch="moderate", speed="moderate")

if generated_waveform.size > 0:
    import soundfile as sf
    output_filename = "generated_speech_controllable_male.wav"
    sample_rate = audio_tokenizer.config.get("sample_rate", 16000)
    sf.write(output_filename, generated_waveform, sample_rate)
    print(f"Audio saved to {output_filename}")

    from IPython.display import Audio, display
    display(Audio(generated_waveform, rate=sample_rate))
else:
    print("Audio generation failed (no tokens found?).")

Generating speech for: 'আপনি পিচ এবং গতির মতো পরামিতিগুলি সামঞ্জস্য করে একটি কাস্টমাইজড ভয়েস তৈরি করতে পারেন।'
Audio saved to generated_speech_controllable_male.wav


In [None]:
import regex as re
text = "আপনি পিচ এবং গতির মতো পরামিতিগুলি সামঞ্জস্য করে একটি কাস্টমাইজড ভয়েস তৈরি করতে পারেন।"
wav_path = './cli/en_male_1.wav'
wav_text = 'বাংলা সার্বভৌম ভাষাভিত্তিক জাতিরাষ্ট্র বাংলাদেশের একমাত্র রাষ্ট্রভাষা তথা সরকারি ভাষা।'
print(f"Generating speech for: '{text}'")
generated_waveform = generate_speech_from_text(text=text, gender="male", pitch="moderate", speed="moderate", wav_path=wav_path)

if generated_waveform.size > 0:
    import soundfile as sf
    output_filename = "generated_speech_controllable_male2.wav"
    sample_rate = audio_tokenizer.config.get("sample_rate", 16000)
    sf.write(output_filename, generated_waveform, sample_rate)
    print(f"Audio saved to {output_filename}")

    from IPython.display import Audio, display
    display(Audio(generated_waveform, rate=sample_rate))
else:
    print("Audio generation failed (no tokens found?).")

Generating speech for: 'আপনি পিচ এবং গতির মতো পরামিতিগুলি সামঞ্জস্য করে একটি কাস্টমাইজড ভয়েস তৈরি করতে পারেন।'
Audio saved to generated_speech_controllable_male2.wav


In [None]:
text = "আপনি পিচ এবং গতির মতো পরামিতিগুলি সামঞ্জস্য করে একটি কাস্টমাইজড ভয়েস তৈরি করতে পারেন।"
wav_path = './cli/LJ001-0001.wav'
wav_text = 'বাংলা সার্বভৌম ভাষাভিত্তিক জাতিরাষ্ট্র বাংলাদেশের একমাত্র রাষ্ট্রভাষা তথা সরকারি ভাষা।'
print(f"Generating speech for: '{text}'")
generated_waveform = generate_speech_from_text(text=text, gender="male", pitch="moderate", speed="moderate", wav_path=wav_path)

if generated_waveform.size > 0:
    import soundfile as sf
    output_filename = "generated_speech_controllable_female.wav"
    sample_rate = audio_tokenizer.config.get("sample_rate", 16000)
    sf.write(output_filename, generated_waveform, sample_rate)
    print(f"Audio saved to {output_filename}")

    from IPython.display import Audio, display
    display(Audio(generated_waveform, rate=sample_rate))
else:
    print("Audio generation failed (no tokens found?).")

Generating speech for: 'আপনি পিচ এবং গতির মতো পরামিতিগুলি সামঞ্জস্য করে একটি কাস্টমাইজড ভয়েস তৈরি করতে পারেন।'
Audio saved to generated_speech_controllable_female.wav


In [None]:
text = 'বাংলা সার্বভৌম ভাষাভিত্তিক জাতিরাষ্ট্র বাংলাদেশের একমাত্র রাষ্ট্রভাষা তথা সরকারি ভাষা।'
wav_path = './cli/LJ001-0001.wav'
print(f"Generating speech for: '{text}'")
generated_waveform = generate_speech_from_text(text=text, gender="male", pitch="moderate", speed="moderate", wav_path=wav_path)

if generated_waveform.size > 0:
    import soundfile as sf
    output_filename = "generated_speech_controllable_female2.wav"
    sample_rate = audio_tokenizer.config.get("sample_rate", 16000)
    sf.write(output_filename, generated_waveform, sample_rate)
    print(f"Audio saved to {output_filename}")

    from IPython.display import Audio, display
    display(Audio(generated_waveform, rate=sample_rate))
else:
    print("Audio generation failed (no tokens found?).")

Generating speech for: 'বাংলা সার্বভৌম ভাষাভিত্তিক জাতিরাষ্ট্র বাংলাদেশের একমাত্র রাষ্ট্রভাষা তথা সরকারি ভাষা।'
Audio saved to generated_speech_controllable_female2.wav


In [18]:
text = 'বাংলা সার্বভৌম ভাষাভিত্তিক জাতিরাষ্ট্র বাংলাদেশের একমাত্র রাষ্ট্রভাষা তথা সরকারি ভাষা।'
wav_path = './cli/en_male_1.wav'
print(f"Generating speech for: '{text}'")
generated_waveform = generate_speech_from_text(text=text, gender="male", pitch="moderate", speed="moderate", wav_path=wav_path)

if generated_waveform.size > 0:
    import soundfile as sf
    output_filename = "generated_speech_controllable_male3.wav"
    sample_rate = audio_tokenizer.config.get("sample_rate", 16000)
    sf.write(output_filename, generated_waveform, sample_rate)
    print(f"Audio saved to {output_filename}")

    from IPython.display import Audio, display
    display(Audio(generated_waveform, rate=sample_rate))
else:
    print("Audio generation failed (no tokens found?).")

Generating speech for: 'বাংলা সার্বভৌম ভাষাভিত্তিক জাতিরাষ্ট্র বাংলাদেশের একমাত্র রাষ্ট্রভাষা তথা সরকারি ভাষা।'
Audio saved to generated_speech_controllable_male3.wav
