In [2]:
# ==============================================================================
#
# indic_text_handler.py
#
# This module replaces the traditional Tacotron-style text processing pipeline
# (symbol-to-id mapping) with a modern, IndicBERT-based approach. It handles
# text normalization, tokenization, and contextual embedding generation for
# Indic languages, providing a rich semantic input for the VITS2 model.
#
# ==============================================================================

import torch
from transformers import AutoTokenizer, AutoModel
from indicnlp.normalize.indic_normalize import IndicNormalizerFactory
import re

# --- Module-level constants ---
# Using the canonical multilingual model from AI4Bharat for broad language support.
INDIC_BERT_MODEL_NAME = "ai4bharat/indic-bert"
# VITS2's internal hidden dimension. The BERT embeddings will be projected to this size.
# This value should match the `hidden_channels` parameter in your VITS2 config.
VITS2_HIDDEN_DIM = 192

class IndicBERTProcessor:
    """
    A comprehensive text processor for VITS2 that leverages IndicBERT.

    This class encapsulates the entire text-to-embedding pipeline:
    1. Normalization: Cleans and canonicalizes Indic script text.
    2. Tokenization: Uses the IndicBERT tokenizer.
    3. Embedding: Generates contextual word embeddings using the IndicBERT model.
    4. Projection: Maps the high-dimensional BERT embeddings to the VITS2 model's
       hidden dimension.
    """
    def __init__(self, device='cuda' if torch.cuda.is_available() else 'cpu'):
        """
        Initializes the processor by loading the IndicBERT model, tokenizer,
        and setting up the text normalizer.

        Args:
            device (str): The device to run the model on ('cuda' or 'cpu').
        """
        print(f"Initializing IndicBERTProcessor on device: {device}")
        self.device = device

        # 1. Load IndicBERT Tokenizer and Model from Hugging Face
        # This model is pre-trained on 12 major Indian languages and is ideal for
        # generating rich, context-aware embeddings.
        print(f"Loading IndicBERT model: {INDIC_BERT_MODEL_NAME}")
        self.tokenizer = AutoTokenizer.from_pretrained(INDIC_BERT_MODEL_NAME)
        self.model = AutoModel.from_pretrained(INDIC_BERT_MODEL_NAME).to(self.device)
        self.model.eval()  # Set the model to evaluation mode
        print("IndicBERT model and tokenizer loaded successfully.")

        # 2. Initialize Indic Text Normalizer
        # This is crucial for handling the complexities of Unicode in Indic scripts,
        # such as multiple representations for the same character (e.g., nuktas).
        self.normalizer_factory = IndicNormalizerFactory()
        self.normalizers = {} # Cache for normalizers of different languages

        # 3. Define the Projection Layer
        # IndicBERT outputs 768-dimensional embeddings. The VITS2 TextEncoder
        # typically expects a smaller dimension (e.g., 192 or 256). This linear
        # layer bridges that gap.
        bert_hidden_dim = self.model.config.hidden_size # Should be 768
        self.projection = torch.nn.Linear(bert_hidden_dim, VITS2_HIDDEN_DIM).to(self.device)

    def _get_normalizer(self, lang_code):
        """
        Retrieves or creates a normalizer for a given language code.

        Args:
            lang_code (str): The ISO 639-1 code for the language (e.g., 'hi' for Hindi).

        Returns:
            An instance of an IndicNormalizer.
        """
        if lang_code not in self.normalizers:
            # The `indic-nlp-library` provides script-specific normalization rules.
            # `remove_nuktas=False` is generally recommended to preserve phonetic accuracy.
            self.normalizers[lang_code] = self.normalizer_factory.get_normalizer(lang_code, remove_nuktas=False)
        return self.normalizers[lang_code]

    def _clean_text(self, text: str, lang_code: str) -> str:
        """
        Performs normalization and basic cleaning on the input text.

        Args:
            text (str): The raw input text.
            lang_code (str): The language code.

        Returns:
            str: The cleaned and normalized text.
        """
        # Use the script-specific normalizer from indic-nlp-library
        normalizer = self._get_normalizer(lang_code)
        normalized_text = normalizer.normalize(text)

        # Additional basic cleaning: collapse multiple whitespace characters into one
        cleaned_text = re.sub(r'\s+', ' ', normalized_text).strip()
        return cleaned_text

    def text_to_embeddings(self, text: str, lang_code: str) -> torch.Tensor:
        """
        The main public method to convert a raw text string into a tensor of
        contextual embeddings ready for the VITS2 model.

        This function replaces the old `text_to_sequence` function.

        Args:
            text (str): The input text string (e.g., "नमस्ते दुनिया").
            lang_code (str): The language code (e.g., "hi").

        Returns:
            torch.Tensor: A tensor of shape
                          containing the projected contextual embeddings.
        """
        # Step 1: Clean and normalize the input text
        cleaned_text = self._clean_text(text, lang_code)

        # Step 2: Tokenize the text using the IndicBERT tokenizer
        # We add special tokens (,) as required by BERT-style models.
        tokenized_inputs = self.tokenizer(
            cleaned_text,
            return_tensors="pt",
            padding=True
        ).to(self.device)

        # Step 3: Generate contextual embeddings using IndicBERT
        # We use torch.no_grad() to disable gradient calculations, as we are
        # only performing inference here, which saves memory and computation.
        with torch.no_grad():
            model_output = self.model(**tokenized_inputs)

        # The last_hidden_state contains the embeddings for each token.
        # Shape: [batch_size, sequence_length, bert_hidden_dim]
        last_hidden_state = model_output.last_hidden_state

        # Step 4: Project embeddings to the VITS2 hidden dimension
        projected_embeddings = self.projection(last_hidden_state)

        return projected_embeddings.squeeze(0) # Return shape

# ==============================================================================
#
# Integration Notes & Example Usage
#
# ==============================================================================

def main():
    """
    Example of how to use the IndicBERTProcessor and explanation of how to
    integrate it into your VITS2 training pipeline.
    """
    # --- Example Usage ---
    processor = IndicBERTProcessor()
    hindi_text = "नमस्ते दुनिया, यह एक परीक्षण है। एक परीक्षण है एक परीक्षण है"
    language = "hi"

    # This is the new input for your model's text encoder
    embeddings_tensor = processor.text_to_embeddings(hindi_text, language)

    print(f"Original Text: {hindi_text}")
    print(f"Shape of generated embeddings: {embeddings_tensor.shape}")
    print("---")
    print("SUCCESS: Contextual embeddings generated successfully.")
    print("\n--- Integration Guide ---")
    print("1. OBSOLETE FILES: The files `text/symbols.py`, `text/cleaners.py`, and the original `text/__init__.py` are now obsolete and can be removed. The vocabulary and cleaning logic are handled by this new module.")
    print("\n2. VITS2 TextEncoder MODIFICATION: Your VITS2 `TextEncoder` module must be modified.")
    print("   - The original `TextEncoder` likely has an `nn.Embedding` layer as its first component to convert integer IDs to vectors.")
    print("   - This `nn.Embedding` layer should be REMOVED.")
    print("   - The `forward` method of your `TextEncoder` should now accept the tensor generated by `text_to_embeddings` directly.")
    print("   - The output of this new pipeline is already a sequence of vectors, so it can be fed directly into the Transformer blocks of your TextEncoder.")
    print("\n3. DATA LOADER: Your dataset's `__getitem__` method should now call `processor.text_to_embeddings(text, lang)` instead of `text_to_sequence(text, cleaner_names)`.")
    print("   - The collate function will need to handle padding of these embedding tensors if you are batching variable-length sequences.")

if __name__ == '__main__':
    main()



Initializing IndicBERTProcessor on device: cuda
Loading IndicBERT model: ai4bharat/indic-bert
IndicBERT model and tokenizer loaded successfully.
Original Text: नमस्ते दुनिया, यह एक परीक्षण है। एक परीक्षण है एक परीक्षण है
Shape of generated embeddings: torch.Size([25, 192])
---
SUCCESS: Contextual embeddings generated successfully.

--- Integration Guide ---
1. OBSOLETE FILES: The files `text/symbols.py`, `text/cleaners.py`, and the original `text/__init__.py` are now obsolete and can be removed. The vocabulary and cleaning logic are handled by this new module.

2. VITS2 TextEncoder MODIFICATION: Your VITS2 `TextEncoder` module must be modified.
   - The original `TextEncoder` likely has an `nn.Embedding` layer as its first component to convert integer IDs to vectors.
   - This `nn.Embedding` layer should be REMOVED.
   - The `forward` method of your `TextEncoder` should now accept the tensor generated by `text_to_embeddings` directly.
   - The output of this new pipeline is already a s

In [6]:
!pip install sentencepiece

Collecting sentencepiece
  Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.7 kB)
Downloading sentencepiece-0.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m949.5 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: sentencepiece
Successfully installed sentencepiece-0.2.0
