In [4]:
!pip install snac


Collecting snac
  Downloading snac-1.2.1-py3-none-any.whl.metadata (3.5 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->snac)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->snac)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->snac)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->snac)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->snac)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->snac)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.

In [5]:
!pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
!pip install sentencepiece protobuf 'datasets>=3.4.1,<4.0.0' huggingface_hub hf_transfer
!pip install --no-deps unsloth

!pip uninstall torch torchvision torchaudio unsloth unsloth_zoo transformers -y
!pip cache purge

!pip install torch==2.4.1+cu121 torchvision==0.19.1+cu121 torchaudio==2.4.1+cu121 --index-url https://download.pytorch.org/whl/cu121

!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

# Additional dependencies
!pip install librosa
!pip install -U datasets


Collecting bitsandbytes
  Downloading bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting xformers==0.0.29.post3
  Downloading xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (1.0 kB)
Collecting trl
  Downloading trl-0.19.1-py3-none-any.whl.metadata (10 kB)
Collecting cut_cross_entropy
  Downloading cut_cross_entropy-25.1.1-py3-none-any.whl.metadata (9.3 kB)
Collecting unsloth_zoo
  Downloading unsloth_zoo-2025.7.10-py3-none-any.whl.metadata (8.1 kB)
Downloading xformers-0.0.29.post3-cp311-cp311-manylinux_2_28_x86_64.whl (43.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.4/43.4 MB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl (72.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.9/72.9 MB[0m [31m9.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading trl-0.19.1-py3-none-any.whl (376 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [6]:
import gradio as gr
import torch
from unsloth import FastLanguageModel
from IPython.display import display, Audio
import numpy as np

# Global model variables
model = None
tokenizer = None
snac_model = None
device = None

def load_models():
    """Initialize and load all required models for Sanskrit TTS inference."""
    global model, tokenizer, snac_model, device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading models on: {device}")

    # Load the fine-tuned Sanskrit TTS model
    model, tokenizer = FastLanguageModel.from_pretrained(
        "rverma0631/Sanskrit_TTS",
        max_seq_length=2048,
        dtype=None,
        load_in_4bit=False,
    )

    model = model.to(device)
    FastLanguageModel.for_inference(model)

    # Load SNAC model for audio generation
    try:
        from snac import SNAC
        snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
    except ImportError:
        print("Warning: SNAC model import failed. Make sure SNAC is installed.")

    snac_model.to("cpu")

    print("Models loaded successfully!")

def redistribute_codes(code_list):
    """Redistribute generated codes into hierarchical layers for audio synthesis."""
    layer_1 = []
    layer_2 = []
    layer_3 = []

    for i in range((len(code_list)+1)//7):
        layer_1.append(code_list[7*i])
        layer_2.append(code_list[7*i+1]-4096)
        layer_3.append(code_list[7*i+2]-(2*4096))
        layer_3.append(code_list[7*i+3]-(3*4096))
        layer_2.append(code_list[7*i+4]-(4*4096))
        layer_3.append(code_list[7*i+5]-(5*4096))
        layer_3.append(code_list[7*i+6]-(6*4096))

    codes = [torch.tensor(layer_1).unsqueeze(0),
             torch.tensor(layer_2).unsqueeze(0),
             torch.tensor(layer_3).unsqueeze(0)]

    audio_hat = snac_model.decode(codes)
    return audio_hat

def sanskrit_tts_inference(sanskrit_text, chosen_voice=""):
    """
    Generate Sanskrit speech from input text using the fine-tuned model.

    Args:
        sanskrit_text (str): Input Sanskrit text in Devanagari script
        chosen_voice (str): Voice selection parameter (optional)

    Returns:
        tuple: (audio_data, status_message)
    """
    if not sanskrit_text.strip():
        return None, "Please enter some Sanskrit text."

    try:
        prompts = [sanskrit_text]
        chosen_voice = 1070

        # Prepare prompts with voice selection
        prompts_ = [(f"{chosen_voice}: " + p) if chosen_voice else p for p in prompts]

        # Tokenize input prompts
        all_input_ids = []
        for prompt in prompts_:
            input_ids = tokenizer(prompt, return_tensors="pt").input_ids
            all_input_ids.append(input_ids)

        # Define special tokens
        start_token = torch.tensor([[ 128259]], dtype=torch.int64)
        end_tokens = torch.tensor([[128009, 128260]], dtype=torch.int64)

        # Construct modified input sequences
        all_modified_input_ids = []
        for input_ids in all_input_ids:
            modified_input_ids = torch.cat([start_token, input_ids, end_tokens], dim=1)
            all_modified_input_ids.append(modified_input_ids)

        # Apply padding and create attention masks
        all_padded_tensors = []
        all_attention_masks = []
        max_length = max([modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids])

        for modified_input_ids in all_modified_input_ids:
            padding = max_length - modified_input_ids.shape[1]
            padded_tensor = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1)
            attention_mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64), torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)], dim=1)
            all_padded_tensors.append(padded_tensor)
            all_attention_masks.append(attention_mask)

        # Batch tensors for inference
        all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
        all_attention_masks = torch.cat(all_attention_masks, dim=0)

        input_ids = all_padded_tensors.to(device)
        attention_mask = all_attention_masks.to(device)

        # Generate audio codes using the model
        generated_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=1200,
            do_sample=True,
            temperature=0.6,
            top_p=0.95,
            repetition_penalty=1.1,
            num_return_sequences=1,
            eos_token_id=128258,
            use_cache=True
        )

        # Post-process generated tokens
        token_to_find = 128257
        token_to_remove = 128258

        token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)

        if len(token_indices[1]) > 0:
            last_occurrence_idx = token_indices[1][-1].item()
            cropped_tensor = generated_ids[:, last_occurrence_idx+1:]
        else:
            cropped_tensor = generated_ids

        mask = cropped_tensor != token_to_remove

        processed_rows = []
        for row in cropped_tensor:
            masked_row = row[row != token_to_remove]
            processed_rows.append(masked_row)

        # Convert tokens to audio codes
        code_lists = []
        for row in processed_rows:
            row_length = row.size(0)
            new_length = (row_length // 7) * 7
            trimmed_row = row[:new_length]
            trimmed_row = [t - 128266 for t in trimmed_row]
            code_lists.append(trimmed_row)

        # Generate audio samples
        my_samples = []
        for code_list in code_lists:
            samples = redistribute_codes(code_list)
            my_samples.append(samples)

        if len(my_samples) > 0:
            audio_sample = my_samples[0].detach().squeeze().to("cpu").numpy()
            return (24000, audio_sample), f"✅ Generated audio for: {sanskrit_text}"
        else:
            return None, "❌ Failed to generate audio - no valid codes produced."

    except Exception as e:
        return None, f"❌ Error during inference: {str(e)}"

# Initialize models
print("Loading models... This may take a moment.")
load_models()

# Create Gradio interface
with gr.Blocks(title="Sanskrit Text-to-Speech") as demo:
    gr.Markdown("""
    # 🕉️ Sanskrit Text-to-Speech

    Enter Sanskrit text in Devanagari script and generate speech using your fine-tuned model.
    """)

    with gr.Row():
        with gr.Column():
            sanskrit_input = gr.Textbox(
                label="Sanskrit Text",
                placeholder="Enter Sanskrit text in Devanagari script...",
                lines=3,
                value="नमस्ते"
            )

            generate_btn = gr.Button("🎵 Generate Speech", variant="primary")

        with gr.Column():
            audio_output = gr.Audio(
                label="Generated Sanskrit Speech",
                type="numpy"
            )

            status_output = gr.Textbox(
                label="Status",
                lines=2,
                interactive=False
            )

    # Example inputs for demonstration
    gr.Examples(
        examples=[
            ["नमस्ते"],
            ["संस्कृत एक प्राचीन भाषा है।"],
            ["ॐ शान्ति शान्ति शान्तिः"],
            ["सर्वे भवन्तु सुखिनः"],
        ],
        inputs=[sanskrit_input],
        outputs=[audio_output, status_output],
        fn=sanskrit_tts_inference,
        cache_examples=False
    )

    # Connect interface components
    generate_btn.click(
        fn=sanskrit_tts_inference,
        inputs=[sanskrit_input],
        outputs=[audio_output, status_output]
    )

# Launch the application
if __name__ == "__main__":
    demo.launch(
        share=True,
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.6.0+cu124 with CUDA 1204 (you have 2.4.1+cu121)
    Python  3.11.11 (you have 3.11.13)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


🦥 Unsloth Zoo will now patch everything to make training faster!
Loading models... This may take a moment.
Loading models on: cuda
==((====))==  Unsloth 2025.7.9: Fast Llama patching. Transformers: 4.54.0.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.4.1+cu121. CUDA: 7.5. CUDA Toolkit: 12.1. Triton: 3.0.0
\        /    Bfloat16 = FALSE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/1.61G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/230 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/22.8M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/508 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/389M [00:00<?, ?B/s]

Unsloth 2025.7.9 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


config.json:   0%|          | 0.00/300 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/79.5M [00:00<?, ?B/s]

  state_dict = torch.load(model_path, map_location="cpu")


Models loaded successfully!
Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://f7e2a102e5b11bd692.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
