# VoiceTech for All - Multilingual TTS with Accent & Style Transfer

This notebook implements a comprehensive Text-to-Speech (TTS) system for Indian languages with:
- **Multilingual Support**: 11 Indian languages (SYSPIN + SPICOR)
- **Text Normalization**: Handles Indian language text preprocessing
- **Accent Transfer**: Synthesize speech with different accents
- **Style Transfer**: Synthesize speech with different speaking styles

## Challenge Requirements
- Support multiple languages (especially SYSPIN 9 languages)
- Multi-speaker, multilingual TTS model
- Open-source implementation
- State-of-the-art quality

## 1. Setup and Dependencies

In [None]:
# Install required packages
!pip install -q torch torchaudio librosa numpy scipy matplotlib
!pip install -q g2p-en indic-nlp-library
!pip install -q gdown  # For downloading datasets
!pip install -q pydub soundfile

import torch
import torchaudio
import numpy as np
import librosa
import matplotlib.pyplot as plt
from pathlib import Path
import json
import os
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"PyTorch version: {torch.__version__}")

## 2. Text Normalization for Indian Languages

In [None]:
class IndianLanguageNormalizer:
    """Normalize text for Indian languages"""
    
    # Supported languages
    LANGUAGES = {
        'hi': 'Hindi',
        'bn': 'Bengali',
        'mr': 'Marathi',
        'kn': 'Kannada',
        'te': 'Telugu',
        'bh': 'Bhojpuri',
        'cc': 'Chhattisgarhi',
        'mg': 'Magahi',
        'mt': 'Maithili',
        'ta': 'Tamil',
        'ml': 'Malayalam'
    }
    
    # Devanagari script characters
    DEVANAGARI_VOWELS = '‡§Ö‡§Ü‡§á‡§à‡§â‡§ä‡§ã‡§è‡§ê‡§ì‡§î'
    DEVANAGARI_CONSONANTS = '‡§ï‡§ñ‡§ó‡§ò‡§ô‡§ö‡§õ‡§ú‡§ù‡§û‡§ü‡§†‡§°‡§¢‡§£‡§§‡§•‡§¶‡§ß‡§®‡§™‡§´‡§¨‡§≠‡§Æ‡§Ø‡§∞‡§≤‡§µ‡§∂‡§∑‡§∏‡§π'
    
    def __init__(self):
        self.char_to_phoneme = self._build_phoneme_map()
    
    def _build_phoneme_map(self) -> Dict[str, str]:
        """Build character to phoneme mapping for Indian languages"""
        phoneme_map = {}
        
        # Devanagari vowels to phonemes
        vowel_phonemes = ['a', 'aa', 'i', 'ii', 'u', 'uu', 'ri', 'e', 'ai', 'o', 'au']
        for char, phoneme in zip(self.DEVANAGARI_VOWELS, vowel_phonemes):
            phoneme_map[char] = phoneme
        
        # Devanagari consonants to phonemes
        consonant_phonemes = [
            'ka', 'kha', 'ga', 'gha', 'nga',
            'cha', 'chha', 'ja', 'jha', 'nya',
            'ta', 'tha', 'da', 'dha', 'na',
            'pa', 'pha', 'ba', 'bha', 'ma',
            'ya', 'ra', 'la', 'va', 'sha', 'sha', 'sa', 'ha'
        ]
        for char, phoneme in zip(self.DEVANAGARI_CONSONANTS, consonant_phonemes):
            phoneme_map[char] = phoneme
        
        return phoneme_map
    
    def normalize(self, text: str, language: str = 'hi') -> str:
        """Normalize text for TTS"""
        # Remove extra whitespace
        text = ' '.join(text.split())
        
        # Remove special characters except punctuation
        text = ''.join(c for c in text if c.isalnum() or c in ' .,!?;:')
        
        return text.lower()
    
    def text_to_phonemes(self, text: str) -> List[str]:
        """Convert text to phoneme sequence"""
        phonemes = []
        for char in text:
            if char in self.char_to_phoneme:
                phonemes.append(self.char_to_phoneme[char])
            elif char == ' ':
                phonemes.append('|')  # Word boundary
        return phonemes

# Test normalizer
normalizer = IndianLanguageNormalizer()
test_text = "‡§®‡§Æ‡§∏‡•ç‡§§‡•á ‡§¶‡•Å‡§®‡§ø‡§Ø‡§æ"
normalized = normalizer.normalize(test_text)
phonemes = normalizer.text_to_phonemes(normalized)
print(f"Original: {test_text}")
print(f"Normalized: {normalized}")
print(f"Phonemes: {phonemes}")

## 3. Multilingual TTS Model Architecture

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class MultilingualTTSEncoder(nn.Module):
    """Encoder for multilingual TTS"""
    
    def __init__(self, vocab_size: int, embedding_dim: int = 256, hidden_dim: int = 512):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers=2, 
                            batch_first=True, bidirectional=True)
        self.linear = nn.Linear(hidden_dim * 2, hidden_dim)
    
    def forward(self, x):
        x = self.embedding(x)
        x, _ = self.lstm(x)
        x = self.linear(x)
        return x

class AccentStyleTransferModule(nn.Module):
    """Module for accent and style transfer"""
    
    def __init__(self, hidden_dim: int = 512, num_accents: int = 5, num_styles: int = 3):
        super().__init__()
        self.num_accents = num_accents
        self.num_styles = num_styles
        
        # Accent embeddings
        self.accent_embedding = nn.Embedding(num_accents, hidden_dim)
        
        # Style embeddings
        self.style_embedding = nn.Embedding(num_styles, hidden_dim)
        
        # Fusion layers
        self.fusion = nn.Sequential(
            nn.Linear(hidden_dim * 3, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
    
    def forward(self, encoder_output, accent_id, style_id):
        batch_size, seq_len, hidden_dim = encoder_output.shape
        
        # Get accent and style embeddings
        accent_emb = self.accent_embedding(accent_id).unsqueeze(1).expand(-1, seq_len, -1)
        style_emb = self.style_embedding(style_id).unsqueeze(1).expand(-1, seq_len, -1)
        
        # Concatenate and fuse
        combined = torch.cat([encoder_output, accent_emb, style_emb], dim=-1)
        output = self.fusion(combined)
        
        return output

class MultilingualTTSDecoder(nn.Module):
    """Decoder for multilingual TTS (generates mel-spectrogram)"""
    
    def __init__(self, hidden_dim: int = 512, mel_bins: int = 80):
        super().__init__()
        self.lstm = nn.LSTM(hidden_dim, hidden_dim, num_layers=2,
                            batch_first=True, bidirectional=False)
        self.linear = nn.Linear(hidden_dim, mel_bins)
    
    def forward(self, x):
        x, _ = self.lstm(x)
        mel_spec = self.linear(x)
        return mel_spec

class MultilingualTTS(nn.Module):
    """Complete Multilingual TTS Model"""
    
    def __init__(self, vocab_size: int, num_languages: int = 11, 
                 num_accents: int = 5, num_styles: int = 3,
                 embedding_dim: int = 256, hidden_dim: int = 512, mel_bins: int = 80):
        super().__init__()
        
        self.vocab_size = vocab_size
        self.num_languages = num_languages
        
        # Language embedding
        self.language_embedding = nn.Embedding(num_languages, embedding_dim)
        
        # Encoder
        self.encoder = MultilingualTTSEncoder(vocab_size, embedding_dim, hidden_dim)
        
        # Accent & Style Transfer
        self.transfer_module = AccentStyleTransferModule(hidden_dim, num_accents, num_styles)
        
        # Decoder
        self.decoder = MultilingualTTSDecoder(hidden_dim, mel_bins)
    
    def forward(self, text_ids, language_id, accent_id, style_id):
        # Encode text
        encoder_output = self.encoder(text_ids)
        
        # Add language information
        lang_emb = self.language_embedding(language_id).unsqueeze(1)
        encoder_output = encoder_output + lang_emb
        
        # Apply accent and style transfer
        transferred = self.transfer_module(encoder_output, accent_id, style_id)
        
        # Decode to mel-spectrogram
        mel_spec = self.decoder(transferred)
        
        return mel_spec

# Test model
vocab_size = 500
model = MultilingualTTS(vocab_size=vocab_size).to(device)
print(f"Model created successfully")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Dataset Handling - SYSPIN Dataset

In [None]:
from torch.utils.data import Dataset, DataLoader
import librosa

class SYSPINDataset(Dataset):
    """Load SYSPIN multilingual TTS dataset"""
    
    def __init__(self, data_dir: str, language: str, normalizer: IndianLanguageNormalizer,
                 max_seq_len: int = 150, sr: int = 22050, n_mels: int = 80):
        self.data_dir = Path(data_dir)
        self.language = language
        self.normalizer = normalizer
        self.max_seq_len = max_seq_len
        self.sr = sr
        self.n_mels = n_mels
        
        # Language to ID mapping
        self.lang_to_id = {lang: idx for idx, lang in enumerate(normalizer.LANGUAGES.keys())}
        
        # Build vocabulary
        self.vocab = self._build_vocab()
        self.char_to_id = {char: idx for idx, char in enumerate(self.vocab)}
        
        # Load SYSPIN metadata
        self.samples = self._load_syspin_metadata()
        print(f"Loaded {len(self.samples)} samples from SYSPIN dataset for {language}")
    
    def _build_vocab(self) -> List[str]:
        """Build vocabulary from all characters"""
        vocab = ['<pad>', '<unk>', '<start>', '<end>', '|']
        vocab.extend(list('abcdefghijklmnopqrstuvwxyz'))
        vocab.extend(list('0123456789'))
        vocab.extend(list('.,!?;:'))
        return vocab
    
    def _load_syspin_metadata(self) -> List[Dict]:
        """Load SYSPIN dataset metadata
        
        Expected structure:
        data_dir/
          {language}/
            metadata.json  (contains list of {"text": "...", "audio": "path/to/audio.wav", "speaker": "..."})
            wavs/
              speaker_001/
                *.wav
        """
        samples = []
        lang_dir = self.data_dir / self.language
        metadata_file = lang_dir / 'metadata.json'
        
        if metadata_file.exists():
            with open(metadata_file) as f:
                samples = json.load(f)
            print(f"‚úì Loaded metadata from {metadata_file}")
        else:
            print(f"‚ö† Metadata file not found at {metadata_file}")
            print(f"  Please download SYSPIN dataset from: https://spiredatasets.ee.iisc.ac.in/syspincorpus")
            print(f"  Expected structure: {self.data_dir}/{self.language}/metadata.json")
        
        return samples
    
    def _load_audio(self, audio_path: str) -> torch.Tensor:
        """Load audio and convert to mel-spectrogram"""
        try:
            audio, sr = librosa.load(audio_path, sr=self.sr)
            mel_spec = librosa.feature.melspectrogram(y=audio, sr=sr, n_mels=self.n_mels)
            mel_spec = librosa.power_to_db(mel_spec, ref=np.max)
            mel_spec = torch.from_numpy(mel_spec).float().T  # (time, n_mels)
            
            # Pad or truncate
            if mel_spec.shape[0] < self.max_seq_len:
                mel_spec = torch.cat([mel_spec, torch.zeros(self.max_seq_len - mel_spec.shape[0], self.n_mels)])
            else:
                mel_spec = mel_spec[:self.max_seq_len]
            
            return mel_spec
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            return torch.zeros(self.max_seq_len, self.n_mels)
    
    def text_to_ids(self, text: str) -> torch.Tensor:
        """Convert text to token IDs"""
        normalized = self.normalizer.normalize(text, self.language)
        ids = []
        for char in normalized:
            if char in self.char_to_id:
                ids.append(self.char_to_id[char])
            else:
                ids.append(self.char_to_id['<unk>'])
        
        if len(ids) < self.max_seq_len:
            ids = ids + [self.char_to_id['<pad>']] * (self.max_seq_len - len(ids))
        else:
            ids = ids[:self.max_seq_len]
        
        return torch.tensor(ids, dtype=torch.long)
    
    def __len__(self) -> int:
        return len(self.samples) if self.samples else 0
    
    def __getitem__(self, idx: int) -> Dict:
        if not self.samples:
            raise RuntimeError(f"No samples loaded. Please download SYSPIN dataset.")
        
        sample = self.samples[idx]
        text = sample.get('text', '')
        audio_path = sample.get('audio', '')
        speaker = sample.get('speaker', 'unknown')
        
        text_ids = self.text_to_ids(text)
        mel_spec = self._load_audio(audio_path)
        
        # Extract accent/style from speaker ID if available
        accent_id = torch.tensor(hash(speaker) % 5, dtype=torch.long)
        style_id = torch.tensor(np.random.randint(0, 3), dtype=torch.long)
        
        return {
            'text_ids': text_ids,
            'mel_spec': mel_spec,
            'accent_id': accent_id,
            'style_id': style_id,
            'language_id': torch.tensor(self.lang_to_id.get(self.language, 0), dtype=torch.long),
            'speaker': speaker,
            'text': text
        }

# Download SYSPIN dataset
print("\nüì• SYSPIN Dataset Setup:")
print("\nTo use real SYSPIN data:")
print("1. Download from: https://spiredatasets.ee.iisc.ac.in/syspincorpus")
print("2. Extract to: ./syspin_data/")
print("3. Structure should be:")
print("   syspin_data/")
print("     hi/metadata.json")
print("     hi/wavs/speaker_001/*.wav")
print("     bn/metadata.json")
print("     ... (other languages)")

# Try to load dataset
try:
    dataset = SYSPINDataset('./syspin_data', 'hi', normalizer)
    if len(dataset) > 0:
        dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
        print(f"\n‚úÖ Dataset loaded successfully!")
        print(f"   Samples: {len(dataset)}")
        print(f"   Vocabulary size: {len(dataset.vocab)}")
        
        batch = next(iter(dataloader))
        print(f"\n   Batch shapes:")
        for key, val in batch.items():
            if isinstance(val, torch.Tensor):
                print(f"     {key}: {val.shape}")
    else:
        print("\n‚ö† Dataset is empty. Please download SYSPIN data.")
except Exception as e:
    print(f"\n‚ö† Could not load SYSPIN dataset: {e}")
    print("   Using demo mode for testing...")

## 5. Training Loop

In [None]:
class TTSTrainer:
    """Trainer for multilingual TTS model"""
    
    def __init__(self, model: nn.Module, device: torch.device, learning_rate: float = 1e-3):
        self.model = model.to(device)
        self.device = device
        self.optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        self.criterion = nn.MSELoss()
        self.losses = []
    
    def train_step(self, batch: Dict) -> float:
        """Single training step"""
        self.model.train()
        
        text_ids = batch['text_ids'].to(self.device)
        mel_spec = batch['mel_spec'].to(self.device)
        accent_id = batch['accent_id'].to(self.device)
        style_id = batch['style_id'].to(self.device)
        language_id = batch['language_id'].to(self.device)
        
        pred_mel = self.model(text_ids, language_id, accent_id, style_id)
        loss = self.criterion(pred_mel, mel_spec)
        
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
        self.optimizer.step()
        
        return loss.item()
    
    def train_epoch(self, dataloader: DataLoader, epoch: int) -> float:
        """Train for one epoch"""
        epoch_loss = 0.0
        
        for batch_idx, batch in enumerate(dataloader):
            loss = self.train_step(batch)
            epoch_loss += loss
            
            if (batch_idx + 1) % 10 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx + 1}: Loss = {loss:.4f}")
        
        avg_loss = epoch_loss / len(dataloader)
        self.losses.append(avg_loss)
        return avg_loss
    
    def save_checkpoint(self, path: str):
        """Save model checkpoint"""
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'losses': self.losses
        }, path)
        print(f"Checkpoint saved to {path}")
    
    def load_checkpoint(self, path: str):
        """Load model checkpoint"""
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.losses = checkpoint['losses']
        print(f"Checkpoint loaded from {path}")

trainer = TTSTrainer(model, device)
print("Trainer initialized")

## 6. Training Execution

In [None]:
# Train for a few epochs
num_epochs = 3

for epoch in range(num_epochs):
    avg_loss = trainer.train_epoch(dataloader, epoch + 1)
    print(f"\\nEpoch {epoch + 1} completed. Average Loss: {avg_loss:.4f}")

# Save checkpoint
trainer.save_checkpoint('multilingual_tts_model.pt')

# Plot training loss
plt.figure(figsize=(10, 5))
plt.plot(trainer.losses, marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)
plt.show()

## 7. Inference and Synthesis

In [None]:
class TTSInference:
    """Inference engine for TTS"""
    
    def __init__(self, model: nn.Module, dataset: SYSPINDataset, device: torch.device):
        self.model = model.to(device)
        self.model.eval()
        self.dataset = dataset
        self.device = device
    
    def synthesize(self, text: str, language: str = 'hi', accent_id: int = 0, style_id: int = 0) -> torch.Tensor:
        """Synthesize speech from text"""
        with torch.no_grad():
            # Convert text to IDs
            text_ids = self.dataset.text_to_ids(text).unsqueeze(0).to(self.device)
            
            # Get language ID
            lang_id = torch.tensor([self.dataset.lang_to_id.get(language, 0)], dtype=torch.long).to(self.device)
            
            # Get accent and style IDs
            accent_tensor = torch.tensor([accent_id], dtype=torch.long).to(self.device)
            style_tensor = torch.tensor([style_id], dtype=torch.long).to(self.device)
            
            # Generate mel-spectrogram
            mel_spec = self.model(text_ids, lang_id, accent_tensor, style_tensor)
            
            return mel_spec.squeeze(0).cpu()
    
    def mel_to_audio(self, mel_spec: torch.Tensor, sr: int = 22050) -> np.ndarray:
        """Convert mel-spectrogram to audio (placeholder)"""
        # In production, use a vocoder like HiFi-GAN or WaveGlow
        mel_np = mel_spec.numpy()
        # Simple inverse mel-scale (placeholder)
        audio = np.random.randn(mel_np.shape[0] * 256)  # Placeholder
        return audio

# Initialize inference engine
inference = TTSInference(model, dataset, device)

# Test synthesis with different accents and styles
test_texts = {
    'hi': '‡§®‡§Æ‡§∏‡•ç‡§§‡•á ‡§¶‡•Å‡§®‡§ø‡§Ø‡§æ',
    'bn': '‡¶®‡¶Æ‡¶∏‡ßç‡¶ï‡¶æ‡¶∞ ‡¶¨‡¶ø‡¶∂‡ßç‡¶¨',
    'mr': '‡§®‡§Æ‡§∏‡•ç‡§ï‡§æ‡§∞ ‡§ú‡§ó'
}

print("Testing synthesis with different accents and styles:")
for lang, text in test_texts.items():
    for accent in range(2):
        for style in range(2):
            mel_spec = inference.synthesize(text, language=lang, accent_id=accent, style_id=style)
            print(f"  {lang} (accent={accent}, style={style}): mel_spec shape = {mel_spec.shape}")

## 8. Evaluation Metrics

In [None]:
class TTSEvaluator:
    """Evaluate TTS model quality"""
    
    @staticmethod
    def mel_spectrogram_distance(pred_mel: torch.Tensor, target_mel: torch.Tensor) -> float:
        """Compute L2 distance between mel-spectrograms"""
        return torch.nn.functional.mse_loss(pred_mel, target_mel).item()
    
    @staticmethod
    def spectral_convergence(pred_mel: torch.Tensor, target_mel: torch.Tensor) -> float:
        """Compute spectral convergence"""
        numerator = torch.norm(target_mel - pred_mel, p='fro')
        denominator = torch.norm(target_mel, p='fro')
        return (numerator / denominator).item()
    
    @staticmethod
    def log_magnitude_distance(pred_mel: torch.Tensor, target_mel: torch.Tensor) -> float:
        """Compute log magnitude distance"""
        pred_log = torch.log(torch.clamp(pred_mel, min=1e-5))
        target_log = torch.log(torch.clamp(target_mel, min=1e-5))
        return torch.nn.functional.l1_loss(pred_log, target_log).item()
    
    @staticmethod
    def evaluate_batch(model: nn.Module, batch: Dict, device: torch.device) -> Dict[str, float]:
        """Evaluate model on a batch"""
        model.eval()
        with torch.no_grad():
            text_ids = batch['text_ids'].to(device)
            mel_spec = batch['mel_spec'].to(device)
            accent_id = batch['accent_id'].to(device)
            style_id = batch['style_id'].to(device)
            language_id = batch['language_id'].to(device)
            
            pred_mel = model(text_ids, language_id, accent_id, style_id)
            
            metrics = {
                'mse_loss': TTSEvaluator.mel_spectrogram_distance(pred_mel, mel_spec),
                'spectral_convergence': TTSEvaluator.spectral_convergence(pred_mel, mel_spec),
                'log_magnitude_distance': TTSEvaluator.log_magnitude_distance(pred_mel, mel_spec)
            }
        
        return metrics

# Evaluate on test batch
evaluator = TTSEvaluator()
test_batch = next(iter(dataloader))
metrics = evaluator.evaluate_batch(model, test_batch, device)

print("Evaluation Metrics:")
for metric_name, value in metrics.items():
    print(f"  {metric_name}: {value:.4f}")

## 9. Advanced Features: Multi-Language Support

In [None]:
class MultilingualTTSPipeline:
    """Complete pipeline for multilingual TTS"""
    
    def __init__(self, model: nn.Module, dataset: SYSPINDataset, device: torch.device):
        self.model = model
        self.dataset = dataset
        self.device = device
        self.inference = TTSInference(model, dataset, device)
        self.normalizer = dataset.normalizer
    
    def process_text(self, text: str, language: str) -> torch.Tensor:
        """Process text for synthesis"""
        normalized = self.normalizer.normalize(text, language)
        return self.dataset.text_to_ids(normalized)
    
    def synthesize_multilingual(self, texts: Dict[str, str], accent_id: int = 0, style_id: int = 0) -> Dict[str, torch.Tensor]:
        """Synthesize speech for multiple languages"""
        results = {}
        for lang, text in texts.items():
            mel_spec = self.inference.synthesize(text, language=lang, accent_id=accent_id, style_id=style_id)
            results[lang] = mel_spec
        return results
    
    def get_supported_languages(self) -> Dict[str, str]:
        """Get list of supported languages"""
        return self.normalizer.LANGUAGES

# Initialize pipeline
pipeline = MultilingualTTSPipeline(model, dataset, device)

# Test multilingual synthesis
multilingual_texts = {
    'hi': '‡§®‡§Æ‡§∏‡•ç‡§§‡•á',
    'bn': '‡¶®‡¶Æ‡¶∏‡ßç‡¶ï‡¶æ‡¶∞',
    'mr': '‡§®‡§Æ‡§∏‡•ç‡§ï‡§æ‡§∞',
    'kn': '‡≤®‡≤Æ‡≤∏‡≥ç‡≤ï‡≤æ‡≤∞',
    'te': '‡∞®‡∞Æ‡∞∏‡±ç‡∞ï‡∞æ‡∞∞‡∞Ç'
}

print("Supported Languages:")
for lang_code, lang_name in pipeline.get_supported_languages().items():
    print(f"  {lang_code}: {lang_name}")

print("\\nMultilingual Synthesis Results:")
results = pipeline.synthesize_multilingual(multilingual_texts)
for lang, mel_spec in results.items():
    print(f"  {lang}: mel_spec shape = {mel_spec.shape}")

## 10. Export and Deployment

In [None]:
# Export model for deployment
def export_model(model: nn.Module, export_path: str):
    """Export model to ONNX format for deployment"""
    model.eval()
    
    # Create dummy inputs
    dummy_text = torch.randint(0, 500, (1, 150), dtype=torch.long)
    dummy_lang = torch.tensor([0], dtype=torch.long)
    dummy_accent = torch.tensor([0], dtype=torch.long)
    dummy_style = torch.tensor([0], dtype=torch.long)
    
    # Export to ONNX
    try:
        torch.onnx.export(
            model,
            (dummy_text, dummy_lang, dummy_accent, dummy_style),
            export_path,
            input_names=['text_ids', 'language_id', 'accent_id', 'style_id'],
            output_names=['mel_spectrogram'],
            opset_version=12
        )
        print(f"Model exported to {export_path}")
    except Exception as e:
        print(f"Export failed: {e}")

# Export the model
export_model(model, 'multilingual_tts_model.onnx')

# Save configuration
config = {
    'vocab_size': 500,
    'num_languages': 11,
    'num_accents': 5,
    'num_styles': 3,
    'embedding_dim': 256,
    'hidden_dim': 512,
    'mel_bins': 80,
    'languages': pipeline.get_supported_languages()
}

with open('tts_config.json', 'w') as f:
    json.dump(config, f, indent=2, ensure_ascii=False)

print("Configuration saved to tts_config.json")
print("\\nModel ready for deployment!")