In [None]:
import json
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import librosa
import matplotlib.pyplot as plt
from pathlib import Path
import pickle
from sklearn.model_selection import train_test_split
import re
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')

class PashtoTextProcessor:
    """Text preprocessing for Pashto language"""
    
    def __init__(self):
        # Pashto alphabet and common characters
        self.pashto_chars = [
            'ا', 'آ', 'ب', 'پ', 'ت', 'ټ', 'ث', 'ج', 'چ', 'ح', 'خ', 'د', 'ډ', 'ذ', 'ر', 'ړ', 'ز', 'ژ', 'س', 'ش', 'ښ',
            'ص', 'ض', 'ط', 'ظ', 'ع', 'غ', 'ف', 'ق', 'ک', 'ګ', 'ل', 'م', 'ن', 'ڼ', 'و', 'ه', 'ي', 'ۍ', 'ې',
            ' ', '.', '،', '؟', '!', '\n'
        ]
        
        # Create character mappings
        self.char_to_idx = {char: idx for idx, char in enumerate(self.pashto_chars)}
        self.idx_to_char = {idx: char for idx, char in enumerate(self.pashto_chars)}
        self.vocab_size = len(self.pashto_chars)
        
        # Add special tokens
        self.pad_token = '<PAD>'
        self.sos_token = '<SOS>'
        self.eos_token = '<EOS>'
        self.unk_token = '<UNK>'
        
        special_tokens = [self.pad_token, self.sos_token, self.eos_token, self.unk_token]
        for token in special_tokens:
            self.char_to_idx[token] = len(self.char_to_idx)
            self.idx_to_char[len(self.idx_to_char)] = token
        
        self.vocab_size = len(self.char_to_idx)
        self.pad_idx = self.char_to_idx[self.pad_token]
        self.sos_idx = self.char_to_idx[self.sos_token]
        self.eos_idx = self.char_to_idx[self.eos_token]
        self.unk_idx = self.char_to_idx[self.unk_token]
    
    def text_to_sequence(self, text: str) -> List[int]:
        """Convert text to sequence of indices"""
        # Clean text
        text = text.strip().replace('\n', ' ')
        text = re.sub(r'\s+', ' ', text)
        
        # Convert to indices
        sequence = [self.sos_idx]
        for char in text:
            if char in self.char_to_idx:
                sequence.append(self.char_to_idx[char])
            else:
                sequence.append(self.unk_idx)
        sequence.append(self.eos_idx)
        
        return sequence
    
    def sequence_to_text(self, sequence: List[int]) -> str:
        """Convert sequence of indices back to text"""
        chars = []
        for idx in sequence:
            if idx in self.idx_to_char and idx not in [self.pad_idx, self.sos_idx, self.eos_idx]:
                chars.append(self.idx_to_char[idx])
        return ''.join(chars)

class AudioProcessor:
    """Audio preprocessing for TTS"""
    
    def __init__(self, sample_rate=22050, n_mels=80, hop_length=256, win_length=1024):
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.hop_length = hop_length
        self.win_length = win_length
        self.n_fft = win_length
        
        # Mel spectrogram transformer
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=self.n_fft,
            win_length=win_length,
            hop_length=hop_length,
            n_mels=n_mels,
            power=2.0
        )
    
    def load_audio(self, audio_path: str) -> torch.Tensor:
        """Load and preprocess audio file"""
        try:
            # Load audio
            waveform, sr = torchaudio.load(audio_path)
            
            # Resample if necessary
            if sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                waveform = resampler(waveform)
            
            # Convert to mono if stereo
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            
            return waveform.squeeze(0)
        
        except Exception as e:
            print(f"Error loading audio {audio_path}: {e}")
            return None
    
    def audio_to_mel(self, waveform: torch.Tensor) -> torch.Tensor:
        """Convert waveform to mel spectrogram"""
        mel_spec = self.mel_transform(waveform)
        # Convert to log scale
        mel_spec = torch.log(mel_spec + 1e-8)
        return mel_spec.squeeze(0).T  # (time, n_mels)
    
    def normalize_mel(self, mel_spec: torch.Tensor) -> torch.Tensor:
        """Normalize mel spectrogram"""
        return (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-8)

class PashtoTTSDataset(Dataset):
    """Dataset class for Pashto TTS training"""
    
    def __init__(self, json_path: str, audio_base_path: str, text_processor: PashtoTextProcessor, 
                 audio_processor: AudioProcessor, max_text_len: int = 200):
        
        self.text_processor = text_processor
        self.audio_processor = audio_processor
        self.max_text_len = max_text_len
        self.audio_base_path = audio_base_path
        
        # Load data
        with open(json_path, 'r', encoding='utf-8') as f:
            self.data = json.load(f)
        
        # Filter and validate data
        self.valid_samples = []
        print("Validating dataset...")
        
        for item in self.data:
            audio_path = os.path.join(audio_base_path, item['file'])
            if os.path.exists(audio_path):
                # Check if text is not too long
                text_seq = self.text_processor.text_to_sequence(item['sentence'])
                if len(text_seq) <= max_text_len:
                    self.valid_samples.append({
                        'id': item['id'],
                        'audio_path': audio_path,
                        'text': item['sentence'],
                        'gender': item.get('gender', 'Unknown'),
                        'accent': item.get('accent', 'Unknown')
                    })
        
        print(f"Valid samples: {len(self.valid_samples)} out of {len(self.data)}")
    
    def __len__(self):
        return len(self.valid_samples)
    
    def __getitem__(self, idx):
        sample = self.valid_samples[idx]
        
        # Load and process audio
        waveform = self.audio_processor.load_audio(sample['audio_path'])
        if waveform is None:
            # Return a dummy sample if audio loading fails
            mel_spec = torch.zeros(100, self.audio_processor.n_mels)
        else:
            mel_spec = self.audio_processor.audio_to_mel(waveform)
            mel_spec = self.audio_processor.normalize_mel(mel_spec)
        
        # Process text
        text_sequence = self.text_processor.text_to_sequence(sample['text'])
        
        return {
            'id': sample['id'],
            'text_sequence': torch.tensor(text_sequence, dtype=torch.long),
            'mel_spectrogram': mel_spec,
            'text': sample['text'],
            'gender': sample['gender'],
            'accent': sample['accent']
        }

def collate_fn(batch):
    """Custom collate function for batching"""
    
    # Sort by text length (for better training)
    batch = sorted(batch, key=lambda x: len(x['text_sequence']), reverse=True)
    
    # Pad sequences
    text_sequences = [item['text_sequence'] for item in batch]
    mel_spectrograms = [item['mel_spectrogram'] for item in batch]
    
    # Pad text sequences
    text_lengths = [len(seq) for seq in text_sequences]
    padded_texts = pad_sequence(text_sequences, batch_first=True, padding_value=0)
    
    # Pad mel spectrograms
    mel_lengths = [mel.shape[0] for mel in mel_spectrograms]
    max_mel_len = max(mel_lengths)
    
    padded_mels = torch.zeros(len(batch), max_mel_len, mel_spectrograms[0].shape[1])
    for i, mel in enumerate(mel_spectrograms):
        padded_mels[i, :mel.shape[0], :] = mel
    
    return {
        'text_sequences': padded_texts,
        'text_lengths': torch.tensor(text_lengths),
        'mel_spectrograms': padded_mels,
        'mel_lengths': torch.tensor(mel_lengths),
        'ids': [item['id'] for item in batch],
        'texts': [item['text'] for item in batch]
    }

class TTSEncoder(nn.Module):
    """Text encoder for TTS model"""
    
    def __init__(self, vocab_size, embed_dim=256, hidden_dim=512, num_layers=2):
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, 
                           batch_first=True, bidirectional=True)
        self.projection = nn.Linear(hidden_dim * 2, hidden_dim)
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, text_sequences, text_lengths):
        # Embedding
        embedded = self.embedding(text_sequences)
        embedded = self.dropout(embedded)
        
        # Pack padded sequence
        packed = nn.utils.rnn.pack_padded_sequence(
            embedded, text_lengths.cpu(), batch_first=True, enforce_sorted=True
        )
        
        # LSTM
        lstm_out, _ = self.lstm(packed)
        
        # Unpack
        unpacked, _ = nn.utils.rnn.pad_packed_sequence(lstm_out, batch_first=True)
        
        # Projection
        output = self.projection(unpacked)
        
        return output

class TTSDecoder(nn.Module):
    """Mel spectrogram decoder"""
    
    def __init__(self, hidden_dim=512, mel_dim=80, attention_dim=128):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.mel_dim = mel_dim
        self.attention_dim = attention_dim
        
        # Attention mechanism
        self.attention_query = nn.Linear(hidden_dim, attention_dim)
        self.attention_key = nn.Linear(hidden_dim, attention_dim)
        self.attention_value = nn.Linear(hidden_dim, attention_dim)
        
        # Decoder layers
        self.lstm = nn.LSTM(attention_dim + mel_dim, hidden_dim, 2, batch_first=True)
        self.mel_projection = nn.Linear(hidden_dim, mel_dim)
        self.stop_projection = nn.Linear(hidden_dim, 1)
        
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, encoder_outputs, mel_targets=None, max_length=1000):
        batch_size = encoder_outputs.size(0)
        
        if self.training and mel_targets is not None:
            return self.teacher_forcing(encoder_outputs, mel_targets)
        else:
            return self.inference(encoder_outputs, max_length)
    
    def teacher_forcing(self, encoder_outputs, mel_targets):
        """Training with teacher forcing"""
        batch_size, max_mel_len, mel_dim = mel_targets.shape
        seq_len = encoder_outputs.size(1)
        
        outputs = []
        stop_tokens = []
        
        # Initialize hidden state
        h_0 = torch.zeros(2, batch_size, self.hidden_dim).to(encoder_outputs.device)
        c_0 = torch.zeros(2, batch_size, self.hidden_dim).to(encoder_outputs.device)
        hidden = (h_0, c_0)
        
        # Start with zero frame
        prev_mel = torch.zeros(batch_size, 1, mel_dim).to(encoder_outputs.device)
        
        for t in range(max_mel_len):
            # Attention
            attention_weights = self.compute_attention(
                encoder_outputs, hidden[0][-1].unsqueeze(1)
            )
            context = torch.bmm(attention_weights, encoder_outputs)
            
            # Decoder input
            decoder_input = torch.cat([context, prev_mel], dim=-1)
            
            # LSTM
            lstm_out, hidden = self.lstm(decoder_input, hidden)
            lstm_out = self.dropout(lstm_out)
            
            # Predictions
            mel_pred = self.mel_projection(lstm_out)
            stop_pred = self.stop_projection(lstm_out)
            
            outputs.append(mel_pred)
            stop_tokens.append(stop_pred)
            
            # Use ground truth for next input (teacher forcing)
            if t < max_mel_len - 1:
                prev_mel = mel_targets[:, t:t+1, :]
        
        mel_outputs = torch.cat(outputs, dim=1)
        stop_outputs = torch.cat(stop_tokens, dim=1)
        
        return mel_outputs, stop_outputs
    
    def inference(self, encoder_outputs, max_length):
        """Inference without teacher forcing"""
        batch_size = encoder_outputs.size(0)
        
        outputs = []
        stop_tokens = []
        
        # Initialize
        h_0 = torch.zeros(2, batch_size, self.hidden_dim).to(encoder_outputs.device)
        c_0 = torch.zeros(2, batch_size, self.hidden_dim).to(encoder_outputs.device)
        hidden = (h_0, c_0)
        
        prev_mel = torch.zeros(batch_size, 1, self.mel_dim).to(encoder_outputs.device)
        
        for t in range(max_length):
            # Attention
            attention_weights = self.compute_attention(
                encoder_outputs, hidden[0][-1].unsqueeze(1)
            )
            context = torch.bmm(attention_weights, encoder_outputs)
            
            # Decoder input
            decoder_input = torch.cat([context, prev_mel], dim=-1)
            
            # LSTM
            lstm_out, hidden = self.lstm(decoder_input, hidden)
            
            # Predictions
            mel_pred = self.mel_projection(lstm_out)
            stop_pred = self.stop_projection(lstm_out)
            
            outputs.append(mel_pred)
            stop_tokens.append(stop_pred)
            
            # Use prediction for next input
            prev_mel = mel_pred
            
            # Check if we should stop
            if torch.sigmoid(stop_pred).item() > 0.5:
                break
        
        mel_outputs = torch.cat(outputs, dim=1)
        stop_outputs = torch.cat(stop_tokens, dim=1)
        
        return mel_outputs, stop_outputs
    
    def compute_attention(self, encoder_outputs, decoder_hidden):
        """Compute attention weights"""
        # encoder_outputs: (batch, seq_len, hidden_dim)
        # decoder_hidden: (batch, 1, hidden_dim)
        
        query = self.attention_query(decoder_hidden)  # (batch, 1, attention_dim)
        key = self.attention_key(encoder_outputs)     # (batch, seq_len, attention_dim)
        
        # Compute attention scores
        scores = torch.bmm(query, key.transpose(1, 2))  # (batch, 1, seq_len)
        attention_weights = torch.softmax(scores, dim=-1)
        
        return attention_weights

class PashtoTTSModel(nn.Module):
    """Complete TTS model"""
    
    def __init__(self, vocab_size, mel_dim=80, hidden_dim=512):
        super().__init__()
        
        self.encoder = TTSEncoder(vocab_size, hidden_dim=hidden_dim)
        self.decoder = TTSDecoder(hidden_dim=hidden_dim, mel_dim=mel_dim)
        
    def forward(self, text_sequences, text_lengths, mel_targets=None):
        # Encode text
        encoder_outputs = self.encoder(text_sequences, text_lengths)
        
        # Decode to mel spectrograms
        mel_outputs, stop_outputs = self.decoder(encoder_outputs, mel_targets)
        
        return mel_outputs, stop_outputs

class TTSTrainer:
    """Training class for TTS model"""
    
    def __init__(self, model, device, lr=1e-3, mel_loss_weight=1.0, stop_loss_weight=0.5):
        self.model = model
        self.device = device
        self.mel_loss_weight = mel_loss_weight
        self.stop_loss_weight = stop_loss_weight
        
        # Optimizers
        self.optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-6)
        self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, step_size=10, gamma=0.5)
        
        # Loss functions
        self.mel_criterion = nn.MSELoss()
        self.stop_criterion = nn.BCEWithLogitsLoss()
        
        # Training history
        self.train_losses = []
        self.val_losses = []
    
    def train_epoch(self, train_loader):
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        total_mel_loss = 0
        total_stop_loss = 0
        
        for batch_idx, batch in enumerate(train_loader):
            # Move to device
            text_sequences = batch['text_sequences'].to(self.device)
            text_lengths = batch['text_lengths'].to(self.device)
            mel_targets = batch['mel_spectrograms'].to(self.device)
            mel_lengths = batch['mel_lengths'].to(self.device)
            
            # Forward pass
            mel_outputs, stop_outputs = self.model(text_sequences, text_lengths, mel_targets)
            
            # Compute losses
            mel_loss = self.compute_mel_loss(mel_outputs, mel_targets, mel_lengths)
            stop_loss = self.compute_stop_loss(stop_outputs, mel_lengths)
            
            loss = self.mel_loss_weight * mel_loss + self.stop_loss_weight * stop_loss
            
            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            # Accumulate losses
            total_loss += loss.item()
            total_mel_loss += mel_loss.item()
            total_stop_loss += stop_loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Batch {batch_idx}/{len(train_loader)}, '
                      f'Loss: {loss.item():.4f}, '
                      f'Mel: {mel_loss.item():.4f}, '
                      f'Stop: {stop_loss.item():.4f}')
        
        avg_loss = total_loss / len(train_loader)
        avg_mel_loss = total_mel_loss / len(train_loader)
        avg_stop_loss = total_stop_loss / len(train_loader)
        
        return avg_loss, avg_mel_loss, avg_stop_loss
    
    def validate(self, val_loader):
        """Validate model"""
        self.model.eval()
        total_loss = 0
        total_mel_loss = 0
        total_stop_loss = 0
        
        with torch.no_grad():
            for batch in val_loader:
                text_sequences = batch['text_sequences'].to(self.device)
                text_lengths = batch['text_lengths'].to(self.device)
                mel_targets = batch['mel_spectrograms'].to(self.device)
                mel_lengths = batch['mel_lengths'].to(self.device)
                
                mel_outputs, stop_outputs = self.model(text_sequences, text_lengths, mel_targets)
                
                mel_loss = self.compute_mel_loss(mel_outputs, mel_targets, mel_lengths)
                stop_loss = self.compute_stop_loss(stop_outputs, mel_lengths)
                
                loss = self.mel_loss_weight * mel_loss + self.stop_loss_weight * stop_loss
                
                total_loss += loss.item()
                total_mel_loss += mel_loss.item()
                total_stop_loss += stop_loss.item()
        
        avg_loss = total_loss / len(val_loader)
        avg_mel_loss = total_mel_loss / len(val_loader)
        avg_stop_loss = total_stop_loss / len(val_loader)
        
        return avg_loss, avg_mel_loss, avg_stop_loss
    
    def compute_mel_loss(self, outputs, targets, lengths):
        """Compute mel spectrogram loss"""
        loss = 0
        for i, length in enumerate(lengths):
            loss += self.mel_criterion(outputs[i, :length], targets[i, :length])
        return loss / len(lengths)
    
    def compute_stop_loss(self, outputs, lengths):
        """Compute stop token loss"""
        batch_size = outputs.size(0)
        max_len = outputs.size(1)
        
        # Create stop token targets
        stop_targets = torch.zeros_like(outputs)
        for i, length in enumerate(lengths):
            if length < max_len:
                stop_targets[i, length-1] = 1.0  # Stop at the end
        
        return self.stop_criterion(outputs, stop_targets)
    
    def save_checkpoint(self, epoch, loss, path):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'loss': loss,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses
        }
        torch.save(checkpoint, path)
        print(f"Checkpoint saved: {path}")
    
    def load_checkpoint(self, path):
        """Load model checkpoint"""
        checkpoint = torch.load(path, map_location=self.device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.train_losses = checkpoint.get('train_losses', [])
        self.val_losses = checkpoint.get('val_losses', [])
        print(f"Checkpoint loaded: {path}")
        return checkpoint['epoch'], checkpoint['loss']

def main():
    """Main training function"""
    
    # Configuration
    JSON_PATH = r"C:\Users\PC\Desktop\scirpts\json\new6.json"
    AUDIO_PATH = r"C:\Users\PC\Downloads\AudioFiles"
    
    # Training parameters
    BATCH_SIZE = 8
    LEARNING_RATE = 1e-3
    NUM_EPOCHS = 100
    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    print(f"Using device: {DEVICE}")
    
    # Initialize processors
    text_processor = PashtoTextProcessor()
    audio_processor = AudioProcessor(sample_rate=22050, n_mels=80)
    
    print(f"Vocabulary size: {text_processor.vocab_size}")
    
    # Create dataset
    dataset = PashtoTTSDataset(JSON_PATH, AUDIO_PATH, text_processor, audio_processor)
    
    # Split dataset
    train_indices, val_indices = train_test_split(
        range(len(dataset)), test_size=0.1, random_state=42
    )
    
    train_dataset = torch.utils.data.Subset(dataset, train_indices)
    val_dataset = torch.utils.data.Subset(dataset, val_indices)
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, batch_size=BATCH_SIZE, shuffle=True, 
        collate_fn=collate_fn, num_workers=2
    )
    val_loader = DataLoader(
        val_dataset, batch_size=BATCH_SIZE, shuffle=False, 
        collate_fn=collate_fn, num_workers=2
    )
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Initialize model
    model = PashtoTTSModel(
        vocab_size=text_processor.vocab_size,
        mel_dim=audio_processor.n_mels,
        hidden_dim=512
    ).to(DEVICE)
    
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    # Initialize trainer
    trainer = TTSTrainer(model, DEVICE, lr=LEARNING_RATE)
    
    # Create output directory
    output_dir = Path("tts_checkpoints")
    output_dir.mkdir(exist_ok=True)
    
    # Training loop
    best_val_loss = float('inf')
    
    for epoch in range(NUM_EPOCHS):
        print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
        print("-" * 50)
        
        # Train
        train_loss, train_mel_loss, train_stop_loss = trainer.train_epoch(train_loader)
        
        # Validate
        val_loss, val_mel_loss, val_stop_loss = trainer.validate(val_loader)
        
        # Update learning rate
        trainer.scheduler.step()
        
        # Store losses
        trainer.train_losses.append(train_loss)
        trainer.val_losses.append(val_loss)
        
        print(f"Train Loss: {train_loss:.4f} (Mel: {train_mel_loss:.4f}, Stop: {train_stop_loss:.4f})")
        print(f"Val Loss: {val_loss:.4f} (Mel: {val_mel_loss:.4f}, Stop: {val_stop_loss:.4f})")
        print(f"Learning Rate: {trainer.optimizer.param_groups[0]['lr']:.6f}")
        
        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            trainer.save_checkpoint(
                epoch, val_loss, output_dir / "best_model.pth"
            )
        
        # Save regular checkpoint
        if (epoch + 1) % 10 == 0:
            trainer.save_checkpoint(
                epoch, val_loss, output_dir / f"checkpoint_epoch_{epoch+1}.pth"
            )
        
        # Plot losses
        if (epoch + 1) % 5 == 0:
            plt.figure(figsize=(10, 5))
            plt.plot(trainer.train_losses, label='Train Loss')
            plt.plot(trainer.val_losses, label='Validation Loss')
            plt.xlabel('Epoch')
            plt.ylabel('Loss')
            plt.title('Training Progress')
            plt.legend()
            plt.grid(True)
            plt.savefig(output_dir / 'training_progress.png')
            plt.close()
    
    print("Training completed!")
    
    # Save final model and processors
    torch.save({
        'model_state_dict': model.state_dict(),
        'text_processor': text_processor,
        'audio_processor': audio_processor,
        'model_config': {
            'vocab_size': text_processor.vocab_size,
            'mel_dim': audio_processor.n_mels,
            'hidden_dim': 512
        }
    }, output_dir / 'final_model.pth')
    
    print(f"Final model saved to: {output_dir / 'final_model.pth'}")

if __name__ == "__main__":
    main()

# Function to test the trained model
def test_model(model_path, text_input):
    """Test the trained model with sample text"""
    
    # Load model and processors
    checkpoint = torch.load(model_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
    
    text_processor = checkpoint['text_processor']
    audio_processor = checkpoint['audio_processor']
    
    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = PashtoTTSModel(**checkpoint['model_config']).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Process input text
    text_sequence = text_processor.text_to_sequence(text_input)
    text_tensor = torch.tensor([text_sequence], dtype=torch.long).to(device)
    text_lengths = torch.tensor([len(text_sequence)]).to(device)
    
    # Generate mel spectrogram
    with torch.no_grad():
        mel_outputs, stop_outputs = model(text_tensor, text_lengths)
    
    # Convert mel spectrogram back to audio (you'll need a vocoder for this)
    mel_spec = mel_outputs.squeeze(0).cpu().numpy()
    
    print(f"Generated mel spectrogram shape: {mel_spec.shape}")
    print(f"Input text: {text_input}")
    
    return mel_spec

# Additional utility functions for inference and vocoder integration

class MelToAudioConverter:
    """Convert mel spectrograms back to audio using Griffin-Lim algorithm"""
    
    def __init__(self, sample_rate=22050, n_fft=1024, hop_length=256, win_length=1024, n_iter=60):
        self.sample_rate = sample_rate
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.n_iter = n_iter
        
        # Create inverse mel transform
        self.mel_scale = torchaudio.transforms.MelScale(
            n_mels=80, sample_rate=sample_rate, n_stft=n_fft // 2 + 1
        )
        self.inverse_mel_scale = torchaudio.transforms.InverseMelScale(
            n_stft=n_fft // 2 + 1, n_mels=80, sample_rate=sample_rate
        )
        self.griffin_lim = torchaudio.transforms.GriffinLim(
            n_fft=n_fft, hop_length=hop_length, win_length=win_length, n_iter=n_iter
        )
    
    def mel_to_audio(self, mel_spectrogram):
        """Convert mel spectrogram to audio waveform"""
        # Convert from log scale
        mel_spec = torch.exp(torch.tensor(mel_spectrogram).T)  # (n_mels, time)
        
        # Convert mel to linear spectrogram
        linear_spec = self.inverse_mel_scale(mel_spec)
        
        # Convert to audio using Griffin-Lim
        waveform = self.griffin_lim(linear_spec)
        
        return waveform.numpy()

def create_inference_pipeline():
    """Create a complete inference pipeline"""
    
    class PashtoTTSInference:
        def __init__(self, model_path):
            # Load model and processors
            checkpoint = torch.load(model_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
            
            self.text_processor = checkpoint['text_processor']
            self.audio_processor = checkpoint['audio_processor']
            self.mel_converter = MelToAudioConverter()
            
            # Initialize model
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            self.model = PashtoTTSModel(**checkpoint['model_config']).to(self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.model.eval()
            
            print("TTS model loaded successfully!")
        
        def synthesize(self, text, output_path=None):
            """Convert text to speech"""
            # Process text
            text_sequence = self.text_processor.text_to_sequence(text)
            text_tensor = torch.tensor([text_sequence], dtype=torch.long).to(self.device)
            text_lengths = torch.tensor([len(text_sequence)]).to(self.device)
            
            # Generate mel spectrogram
            with torch.no_grad():
                mel_outputs, stop_outputs = self.model(text_tensor, text_lengths)
            
            # Convert to audio
            mel_spec = mel_outputs.squeeze(0).cpu().numpy()
            audio = self.mel_converter.mel_to_audio(mel_spec)
            
            # Save audio if path provided
            if output_path:
                torchaudio.save(output_path, torch.tensor(audio).unsqueeze(0), self.audio_processor.sample_rate)
                print(f"Audio saved to: {output_path}")
            
            return audio, mel_spec
        
        def batch_synthesize(self, texts, output_dir):
            """Synthesize multiple texts"""
            Path(output_dir).mkdir(exist_ok=True)
            
            for i, text in enumerate(texts):
                output_path = Path(output_dir) / f"synthesis_{i+1}.wav"
                audio, mel_spec = self.synthesize(text, output_path)
                print(f"Synthesized: {text[:50]}...")
    
    return PashtoTTSInference

# Advanced training features

class AdvancedTTSTrainer(TTSTrainer):
    """Enhanced trainer with additional features"""
    
    def __init__(self, model, device, lr=1e-3, mel_loss_weight=1.0, stop_loss_weight=0.5):
        super().__init__(model, device, lr, mel_loss_weight, stop_loss_weight)
        
        # Add attention visualization
        self.attention_plots = []
        
        # Add mixed precision training
        self.scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
        
        # Add early stopping
        self.patience = 15
        self.best_val_loss = float('inf')
        self.patience_counter = 0
    
    def train_epoch_mixed_precision(self, train_loader):
        """Training with mixed precision for faster training"""
        self.model.train()
        total_loss = 0
        
        for batch_idx, batch in enumerate(train_loader):
            text_sequences = batch['text_sequences'].to(self.device)
            text_lengths = batch['text_lengths'].to(self.device)
            mel_targets = batch['mel_spectrograms'].to(self.device)
            mel_lengths = batch['mel_lengths'].to(self.device)
            
            self.optimizer.zero_grad()
            
            if self.scaler:
                # Mixed precision forward pass
                with torch.cuda.amp.autocast():
                    mel_outputs, stop_outputs = self.model(text_sequences, text_lengths, mel_targets)
                    mel_loss = self.compute_mel_loss(mel_outputs, mel_targets, mel_lengths)
                    stop_loss = self.compute_stop_loss(stop_outputs, mel_lengths)
                    loss = self.mel_loss_weight * mel_loss + self.stop_loss_weight * stop_loss
                
                # Backward pass with scaling
                self.scaler.scale(loss).backward()
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                # Regular forward pass
                mel_outputs, stop_outputs = self.model(text_sequences, text_lengths, mel_targets)
                mel_loss = self.compute_mel_loss(mel_outputs, mel_targets, mel_lengths)
                stop_loss = self.compute_stop_loss(stop_outputs, mel_lengths)
                loss = self.mel_loss_weight * mel_loss + self.stop_loss_weight * stop_loss
                
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()
            
            total_loss += loss.item()
            
            if batch_idx % 10 == 0:
                print(f'Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')
        
        return total_loss / len(train_loader)
    
    def check_early_stopping(self, val_loss):
        """Check if training should stop early"""
        if val_loss < self.best_val_loss:
            self.best_val_loss = val_loss
            self.patience_counter = 0
            return False
        else:
            self.patience_counter += 1
            if self.patience_counter >= self.patience:
                print(f"Early stopping triggered after {self.patience} epochs without improvement")
                return True
        return False

# Data augmentation for better training

class TTSDataAugmenter:
    """Data augmentation techniques for TTS training"""
    
    def __init__(self, sample_rate=22050):
        self.sample_rate = sample_rate
    
    def add_noise(self, waveform, noise_factor=0.005):
        """Add random noise to audio"""
        noise = torch.randn_like(waveform) * noise_factor
        return waveform + noise
    
    def time_stretch(self, waveform, stretch_factor=None):
        """Time stretching (speed change without pitch change)"""
        if stretch_factor is None:
            stretch_factor = np.random.uniform(0.9, 1.1)
        
        # Simple time stretching using interpolation
        original_length = waveform.shape[-1]
        new_length = int(original_length / stretch_factor)
        
        indices = torch.linspace(0, original_length - 1, new_length)
        stretched = torch.nn.functional.interpolate(
            waveform.unsqueeze(0).unsqueeze(0), 
            size=new_length, 
            mode='linear', 
            align_corners=True
        ).squeeze()
        
        return stretched
    
    def pitch_shift(self, waveform, n_steps=None):
        """Pitch shifting using phase vocoder"""
        if n_steps is None:
            n_steps = np.random.uniform(-2, 2)  # ±2 semitones
        
        # Convert to numpy for librosa processing
        audio_np = waveform.numpy()
        shifted = librosa.effects.pitch_shift(audio_np, sr=self.sample_rate, n_steps=n_steps)
        return torch.tensor(shifted)

# Model evaluation metrics

class TTSEvaluator:
    """Evaluation metrics for TTS model"""
    
    def __init__(self, model, text_processor, audio_processor, device):
        self.model = model
        self.text_processor = text_processor
        self.audio_processor = audio_processor
        self.device = device
    
    def compute_mel_distance(self, predicted_mel, target_mel):
        """Compute mel spectrogram distance"""
        mse = torch.nn.functional.mse_loss(predicted_mel, target_mel)
        return mse.item()
    
    def evaluate_dataset(self, dataset, num_samples=None):
        """Evaluate model on a dataset"""
        self.model.eval()
        
        if num_samples:
            indices = np.random.choice(len(dataset), min(num_samples, len(dataset)), replace=False)
        else:
            indices = range(len(dataset))
        
        total_mel_distance = 0
        valid_samples = 0
        
        with torch.no_grad():
            for idx in indices:
                sample = dataset[idx]
                
                # Prepare inputs
                text_seq = sample['text_sequence'].unsqueeze(0).to(self.device)
                text_len = torch.tensor([len(sample['text_sequence'])]).to(self.device)
                target_mel = sample['mel_spectrogram'].to(self.device)
                
                # Generate prediction
                pred_mel, _ = self.model(text_seq, text_len)
                pred_mel = pred_mel.squeeze(0)
                
                # Compute distance
                min_len = min(pred_mel.shape[0], target_mel.shape[0])
                distance = self.compute_mel_distance(
                    pred_mel[:min_len], target_mel[:min_len]
                )
                
                total_mel_distance += distance
                valid_samples += 1
        
        avg_distance = total_mel_distance / valid_samples if valid_samples > 0 else float('inf')
        return avg_distance

# Usage example and testing

def run_complete_training():
    """Complete training pipeline with all features"""
    
    # Configuration
    config = {
        'json_path': r"C:\Users\PC\Desktop\scirpts\json\new6.json",
        'audio_path': r"C:\Users\PC\Downloads\AudioFiles",
        'batch_size': 16,
        'learning_rate': 1e-3,
        'num_epochs': 50,
        'hidden_dim': 512,
        'mel_dim': 80,
        'sample_rate': 22050,
        'use_mixed_precision': True,
        'early_stopping': True
    }
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on: {device}")
    
    # Initialize components
    text_processor = PashtoTextProcessor()
    audio_processor = AudioProcessor(
        sample_rate=config['sample_rate'],
        n_mels=config['mel_dim']
    )
    
    # Create dataset with augmentation
    dataset = PashtoTTSDataset(
        config['json_path'], 
        config['audio_path'], 
        text_processor, 
        audio_processor
    )
    
    # Split data
    train_size = int(0.8 * len(dataset))
    val_size = int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config['batch_size'], 
        shuffle=True, 
        collate_fn=collate_fn,
        num_workers=0  # Set to 0 for Windows compatibility
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config['batch_size'], 
        shuffle=False, 
        collate_fn=collate_fn,
        num_workers=0
    )
    
    # Initialize model
    model = PashtoTTSModel(
        vocab_size=text_processor.vocab_size,
        mel_dim=config['mel_dim'],
        hidden_dim=config['hidden_dim']
    ).to(device)
    
    # Initialize trainer
    trainer = AdvancedTTSTrainer(model, device, lr=config['learning_rate'])
    
    # Training loop
    output_dir = Path("pashto_tts_output")
    output_dir.mkdir(exist_ok=True)
    
    print("Starting training...")
    for epoch in range(config['num_epochs']):
        print(f"\nEpoch {epoch+1}/{config['num_epochs']}")
        
        # Train
        if config['use_mixed_precision']:
            train_loss = trainer.train_epoch_mixed_precision(train_loader)
        else:
            train_loss, _, _ = trainer.train_epoch(train_loader)
        
        # Validate
        val_loss, _, _ = trainer.validate(val_loader)
        
        print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            trainer.save_checkpoint(
                epoch, val_loss, output_dir / f"checkpoint_epoch_{epoch+1}.pth"
            )
        
        # Early stopping
        if config['early_stopping'] and trainer.check_early_stopping(val_loss):
            break
        
        trainer.scheduler.step()
    
    # Save final model
    final_model_path = output_dir / "final_pashto_tts_model.pth"
    torch.save({
        'model_state_dict': model.state_dict(),
        'text_processor': text_processor,
        'audio_processor': audio_processor,
        'model_config': {
            'vocab_size': text_processor.vocab_size,
            'mel_dim': config['mel_dim'],
            'hidden_dim': config['hidden_dim']
        },
        'config': config
    }, final_model_path)
    
    print(f"Training completed! Model saved to: {final_model_path}")
    
    # Test the model
    print("\nTesting model...")
    inference_pipeline = create_inference_pipeline()
    tts_model = inference_pipeline(final_model_path)
    
    # Test with sample Pashto text
    test_texts = [
        "انګېزه د شیانو علت، سبب او رېښې ته وایي.",
        "افغانان د انیس په نوم مجله خپروي.",
        "اورکی د اور د بلېدو لامل کېږي."
    ]
    
    test_output_dir = output_dir / "test_outputs"
    tts_model.batch_synthesize(test_texts, test_output_dir)
    
    print("Testing completed!")

# Quick start function
def quick_start():
    """Quick start with minimal configuration"""
    print("Starting Pashto TTS training...")
    print("Make sure your paths are correct:")
    print("JSON: C:\\Users\\PC\\Desktop\\scirpts\\json\\new6.json")
    print("Audio: C:\\Users\\PC\\Downloads\\AudioFiles")
    
    try:
        run_complete_training()
    except Exception as e:
        print(f"Error during training: {e}")
        print("Check your file paths and make sure audio files exist!")

# Run the training
if __name__ == "__main__":
    quick_start()

Using device: cuda
Vocabulary size: 50
Validating dataset...
Valid samples: 9753 out of 10000
Training samples: 8777
Validation samples: 976
Model parameters: 13,809,617

Epoch 1/100
--------------------------------------------------
