In [1]:
!pip install transformers torch torchaudio
!pip install snac bitsandbytes  # For audio decoding and quantization

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [20]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from snac import SNAC
import soundfile as sf

# Model configuration for 4-bit inference
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

# Load model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
    "maya-research/veena-tts",
    quantization_config=quantization_config,
    device_map="auto",
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained("maya-research/veena-tts", trust_remote_code=True)

# Initialize SNAC decoder
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().cuda()

# Control token IDs (fixed for Veena)
START_OF_SPEECH_TOKEN = 128257
END_OF_SPEECH_TOKEN = 128258
START_OF_HUMAN_TOKEN = 128259
END_OF_HUMAN_TOKEN = 128260
START_OF_AI_TOKEN = 128261
END_OF_AI_TOKEN = 128262
AUDIO_CODE_BASE_OFFSET = 128266

# Available speakers
speakers = ["kavya", "agastya", "maitri", "vinaya"]

def generate_speech(text, speaker="kavya", temperature=0.4, top_p=0.9):
    """Generate speech from text using specified speaker voice"""

    # Prepare input with speaker token
    prompt = f"<spk_{speaker}> {text}"
    prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)

    # Construct full sequence: [HUMAN] <spk_speaker> text [/HUMAN] [AI] [SPEECH]
    input_tokens = [
        START_OF_HUMAN_TOKEN,
        *prompt_tokens,
        END_OF_HUMAN_TOKEN,
        START_OF_AI_TOKEN,
        START_OF_SPEECH_TOKEN
    ]

    input_ids = torch.tensor([input_tokens], device=model.device)

    # Calculate max tokens based on text length
    max_tokens = int(len(text) * 1.3) * 7 + 21

    # Generate audio tokens
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_new_tokens=max_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            repetition_penalty=1.05,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=[END_OF_SPEECH_TOKEN, END_OF_AI_TOKEN]
        )

    # Extract SNAC tokens
    generated_ids = output[0][len(input_tokens):].tolist()
    snac_tokens = [
        token_id for token_id in generated_ids
        if AUDIO_CODE_BASE_OFFSET <= token_id < (AUDIO_CODE_BASE_OFFSET + 7 * 4096)
    ]

    if not snac_tokens:
        raise ValueError("No audio tokens generated")

    # Decode audio
    audio = decode_snac_tokens(snac_tokens, snac_model)
    return audio

def decode_snac_tokens(snac_tokens, snac_model):
    """De-interleave and decode SNAC tokens to audio"""
    if not snac_tokens or len(snac_tokens) % 7 != 0:
        return None

    # De-interleave tokens into 3 hierarchical levels
    codes_lvl = [[] for _ in range(3)]
    llm_codebook_offsets = [AUDIO_CODE_BASE_OFFSET + i * 4096 for i in range(7)]

    # Get the device from one of the model's parameters
    model_device = next(snac_model.parameters()).device

    for i in range(0, len(snac_tokens), 7):
        # Level 0: Coarse (1 token)
        codes_lvl[0].append(snac_tokens[i] - llm_codebook_offsets[0])
        # Level 1: Medium (2 tokens)
        codes_lvl[1].append(snac_tokens[i+1] - llm_codebook_offsets[1])
        codes_lvl[1].append(snac_tokens[i+4] - llm_codebook_offsets[4])
        # Level 2: Fine (4 tokens)
        codes_lvl[2].append(snac_tokens[i+2] - llm_codebook_offsets[2])
        codes_lvl[2].append(snac_tokens[i+3] - llm_codebook_offsets[3])
        codes_lvl[2].append(snac_tokens[i+5] - llm_codebook_offsets[5])
        codes_lvl[2].append(snac_tokens[i+6] - llm_codebook_offsets[6])

    # Convert to tensors for SNAC decoder
    hierarchical_codes = []
    for lvl_codes in codes_lvl:
        # Use the obtained device for creating tensors
        tensor = torch.tensor(lvl_codes, dtype=torch.int32, device=model_device).unsqueeze(0)
        if torch.any((tensor < 0) | (tensor > 4095)):
            raise ValueError("Invalid SNAC token values")
        hierarchical_codes.append(tensor)

    # Decode with SNAC
    with torch.no_grad():
        audio_hat = snac_model.decode(hierarchical_codes)

    return audio_hat.squeeze().clamp(-1, 1).cpu().numpy()

# --- Example Usage ---

# Hindi
text_hindi = " कर्मण्येवाधिकारस्ते मा फलेषु कदाचन"
audio = generate_speech(text_hindi, speaker="kavya")
sf.write("output_hindi_kavya.wav", audio, 24000)

# English
text_english = "Today I learned about a new technology that uses artificial intelligence to generate human-like voices."
audio = generate_speech(text_english, speaker="agastya")
sf.write("output_english_agastya.wav", audio, 24000)

# Code-mixed
text_mixed = "मैंने तो पूरा presentation  तैयार कर लिया है! कल रात को ही मैंने पूरा code base चेक किया।"
audio = generate_speech(text_mixed, speaker="maitri")
sf.write("output_mixed_maitri.wav", audio, 24000)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [21]:
from IPython.display import Audio, display

audio_file_path='/content/output_hindi_kavya.wav'
display(Audio(audio_file_path, autoplay=True))

In [22]:
from IPython.display import Audio, display

audio_file_path='/content/output_english_agastya.wav'
display(Audio(audio_file_path, autoplay=True))

In [23]:
from IPython.display import Audio, display

audio_file_path='/content/output_mixed_maitri.wav'
display(Audio(audio_file_path, autoplay=True))

In [24]:
# Hinglish
text_hindi = " Mujhe shopping karni hai, but I'm too busy."
audio = generate_speech(text_hindi, speaker="kavya")
sf.write("output_hindi_kavya2.wav", audio, 24000)

from IPython.display import Audio, display

audio_file_path='/content/output_hindi_kavya2.wav'
display(Audio(audio_file_path, autoplay=True))

In [25]:
text_hindi = "किसी व्यक्ति को आप चाहे कितनी ही सलाह दे दो किन्तु उसका मूल स्वभाव नहीं बदलता ठीक उसी तरह जैसे ठन्डे पानी को उबालने पर तो वह गर्म हो जाता है लेकिन बाद में वह पुनः ठंडा हो जाता है."
audio = generate_speech(text_hindi, speaker="vinaya")
sf.write("output_hindi_vinaya3.wav", audio, 24000)

from IPython.display import Audio, display

audio_file_path='/content/output_hindi_vinaya3.wav'
display(Audio(audio_file_path, autoplay=True))

In [26]:
# Hindi
text_hindi = "jaldi karo guys, or we’ll be late for the movie"
audio = generate_speech(text_hindi, speaker="vinaya")
sf.write("output_hindi_vinaya2.wav", audio, 24000)

from IPython.display import Audio, display

audio_file_path='/content/output_hindi_vinaya2.wav'
display(Audio(audio_file_path, autoplay=True))