# Khmer Text-to-Speech Implementation

## Complete Pipeline from Data Analysis to Inference

## Overview

This notebook implements a complete Khmer TTS system using NVIDIA's Tacotron2 architecture. We'll use transfer learning from a pre-trained English model to handle the low-resource nature of Khmer language.

In [1]:
output_path = '../dataset'
print(output_path)

../dataset


## Setup and Dependencies

In [2]:
import torch
import torch.nn as nn
import torchaudio
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.notebook import tqdm
from unidecode import unidecode
import re
import librosa
import soundfile as sf
import logging
import torch.nn.functional as F

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Set random seeds
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

Using device: cuda


## 1. Exploratory Data Analysis (EDA)

### 1.1 Dataset Loading

In [3]:
def load_transcript(file_path):
    """Load and analyze transcript file"""
    df = pd.read_csv(file_path,
                    sep='\t',
                    header=None,
                    names=['audio_file', 'unused', 'text'],
                    encoding='utf-8')
    return df[['audio_file', 'text']]

# # Load dataset
# transcript_path = f"{output_path}/line_index.tsv"
# df = load_transcript(transcript_path)
# print(f"Total samples: {len(df)}")

### 1.2 Audio Analysis

In [4]:
def analyze_audio_files(df, audio_dir):
    """Analyze audio file properties"""
    durations = []
    sample_rates = []
    file_sizes = []

    for file in tqdm(df['audio_file'], desc="Analyzing audio files"):
        path = Path(audio_dir) / f"{file}.wav"
        if path.exists():
            # Get audio info
            info = torchaudio.info(path)
            duration = info.num_frames / info.sample_rate

            durations.append(duration)
            sample_rates.append(info.sample_rate)
            file_sizes.append(path.stat().st_size / 1024)  # Size in KB

    return durations, sample_rates, file_sizes

# Plot distributions
def plot_audio_stats(durations, sample_rates, file_sizes):
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Duration distribution
    sns.histplot(durations, bins=50, ax=axes[0])
    axes[0].set_title('Audio Duration Distribution')
    axes[0].set_xlabel('Duration (seconds)')

    # Sample rate distribution
    sns.countplot(sample_rates, ax=axes[1])
    axes[1].set_title('Sample Rate Distribution')

    # File size distribution
    sns.histplot(file_sizes, bins=50, ax=axes[2])
    axes[2].set_title('File Size Distribution')
    axes[2].set_xlabel('Size (KB)')

    plt.tight_layout()
    plt.show()

### 1.3 Text Analysis

In [5]:
def analyze_text(df):
    """Analyze text properties"""
    # Text length distribution
    text_lengths = df['text'].str.len()

    # Character distribution
    char_freq = pd.Series(list(''.join(df['text']))).value_counts()

    # Plot distributions
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    sns.histplot(text_lengths, bins=50, ax=ax1)
    ax1.set_title('Text Length Distribution')
    ax1.set_xlabel('Number of Characters')

    char_freq.head(20).plot(kind='bar', ax=ax2)
    ax2.set_title('Top 20 Character Frequencies')
    ax2.tick_params(axis='x', rotation=45)

    plt.tight_layout()
    plt.show()

    return text_lengths, char_freq

## 2. Data Processing

### 2.1 Text Processing

In [6]:
class TextProcessor:
    """Text processor for NVIDIA Tacotron2 compatibility"""

    def __init__(self):
        # NVIDIA Tacotron2 symbol set
        self._pad = '_'
        self._punctuation = '!\'(),.:;? '
        self._special = '-'
        self._letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'

        # Create symbol mappings
        symbols = self._pad + self._special + self._punctuation + self._letters
        self._symbol_to_id = {s: i for i, s in enumerate(symbols)}

    def process_text(self, text):
        # Romanize Khmer text
        text = unidecode(text).lower()

        # Clean text
        text = re.sub(r'[^a-z!\'(),.:;? -]', '', text)

        # Convert to IDs
        sequence = [self._symbol_to_id[s] for s in text if s in self._symbol_to_id]

        return torch.LongTensor(sequence)

### 2.2 Audio Processing

In [7]:
class AudioProcessor:
    def __init__(self, sample_rate=22050):
        self.sample_rate = sample_rate

    def load_audio(self, file_path):
        """Load and preprocess audio file"""
        waveform, sr = torchaudio.load(file_path)

        # Convert to mono if stereo
        if waveform.size(0) > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Resample if necessary
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            waveform = resampler(waveform)

        return waveform

class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, df, audio_dir, text_processor, audio_processor):
        self.df = df
        self.audio_dir = Path(audio_dir)
        self.text_processor = text_processor
        self.audio_processor = audio_processor

        # Add mel spectrogram transformer
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=22050,
            n_fft=1024,
            hop_length=256,
            n_mels=80
        )

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        # Process text
        text = self.text_processor.process_text(row['text'])

        # Load audio
        audio_path = self.audio_dir / f"{row['audio_file']}.wav"
        waveform = self.audio_processor.load_audio(audio_path)

        # Generate mel spectrogram
        mel_spec = self.mel_transform(waveform)
        mel_spec = torch.log(torch.clamp(mel_spec, min=1e-5))

        return {
            'text': text,
            'text_lengths': torch.LongTensor([len(text)]),
            'mel_target': mel_spec,
            'mel_lengths': torch.LongTensor([mel_spec.size(2)])
        }

In [8]:
def collate_fn(batch):
    """
    Custom collate function to match NVIDIA Tacotron2 input format
    """
    # Sort batch by text length (descending)
    batch = sorted(batch, key=lambda x: x['text_lengths'][0], reverse=True)

    # Get max lengths
    max_text_len = max([item['text'].size(0) for item in batch])
    max_mel_len = max([item['mel_target'].size(2) for item in batch])

    # Initialize padded tensors
    text_padded = torch.zeros(len(batch), max_text_len).long()
    mel_padded = torch.zeros(len(batch), 80, max_mel_len)

    # Fill padded tensors
    text_lengths = []
    mel_lengths = []

    for i, item in enumerate(batch):
        text = item['text']
        mel = item['mel_target']

        text_padded[i, :text.size(0)] = text
        mel_padded[i, :, :mel.size(2)] = mel

        text_lengths.append(item['text_lengths'][0])
        mel_lengths.append(item['mel_lengths'][0])

    # Convert lengths to tensors
    text_lengths = torch.LongTensor(text_lengths)
    mel_lengths = torch.LongTensor(mel_lengths)

    # Return in NVIDIA Tacotron2 format: (inputs, input_lengths, targets, max_len, output_lengths)
    return (text_padded, text_lengths, mel_padded, max_mel_len, mel_lengths)

## 3. Model Training

### 3.1 Load Pre-trained Model

In [9]:
def load_pretrained_model():
    """Load NVIDIA Tacotron2 model"""
    model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub',
                          'nvidia_tacotron2',
                          model_math='fp16',
                          pretrained=False)
    return model.to(device)

# Load HiFi-GAN vocoder
def load_vocoder():
    """Load Waveglow vocoder"""
    try:
        # Load Waveglow
        waveglow = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub',
                               'nvidia_waveglow',
                               model_math='fp16')

        # No need to call remove_weightnorm for Waveglow
        waveglow = waveglow.to(device)
        waveglow.eval()

        return waveglow

    except Exception as e:
        print(f"Error loading Waveglow: {str(e)}")
        raise

### 3.2 Training Loop

In [10]:
import traceback
import os


def compute_metrics(mel_outputs, mel_targets, alignments, input_lengths, output_lengths):
    """
    Compute various TTS evaluation metrics with proper length handling
    """
    metrics = {}
    
    # For each sample in batch
    batch_size = mel_outputs.size(0)
    mcd_values = []
    frame_mse_values = []
    duration_ratios = []
    
    for i in range(batch_size):
        # Get actual lengths for this sample
        target_len = output_lengths[i]
        
        # Compute metrics on valid lengths
        mel_output = mel_outputs[i, :, :target_len]
        mel_target = mel_targets[i, :, :target_len]
        
        # MCD and Frame MSE for this sample
        mcd = F.mse_loss(mel_output, mel_target, reduction='mean')
        mcd_values.append(mcd.item())
        
        frame_mse = torch.mean((mel_output - mel_target) ** 2, dim=0).mean()
        frame_mse_values.append(frame_mse.item())
        
        # Duration ratio for this sample
        pred_duration = mel_output.size(-1)
        target_duration = target_len.item()
        duration_ratios.append(pred_duration / target_duration)
    
    # Average metrics across batch
    metrics['mcd'] = np.mean(mcd_values)
    metrics['frame_mse'] = np.mean(frame_mse_values)
    metrics['duration_ratio'] = np.mean(duration_ratios)
    
    # Attention metrics
    # Coverage: how much of input received attention
    coverage = torch.sum(alignments > 0.5, dim=2).float() / alignments.size(2)
    metrics['attention_coverage'] = coverage.mean().item()
    
    # Entropy: measure of attention sharpness
    entropy = -torch.sum(alignments * torch.log(alignments + 1e-8), dim=2) / alignments.size(2)
    metrics['attention_entropy'] = entropy.mean().item()
    
    return metrics

def train_model(model, text_processor, train_loader, optimizer, num_epochs=100, checkpoint_path=None):
    """Train Tacotron2 model with checkpoint training loop"""
    # Initialize training state
    start_epoch = 0
    best_loss = float('inf')

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=5, verbose=True
    )

    patience_counter = 0
    max_patience = 10
    
    history = {
        'train_losses': [],
        'learning_rates': [],
        'gradients': [],
        'mcd_values': [],
        'frame_mse_values': [],
        'attention_coverage': [],
        'attention_entropy': [],
        'duration_ratios': []
    }

    # Load checkpoint if available
    if checkpoint_path and os.path.exists(checkpoint_path):
        print(f"Loading checkpoint: {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location=device)

        # Load model and optimizer state
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        print(checkpoint)

        # Load traning state
        start_epoch = checkpoint['epoch'] + 1
        
        # Get best loss from metrics
        if 'metrics' in checkpoint:
            best_loss = checkpoint['metrics']['loss']
            print(f"Loaded best loss from metrics: {best_loss:.4f}")

        if 'history' in checkpoint:
            history = checkpoint['history']
        
        print(f"Checkpoint loaded. Resuming training from epoch: {start_epoch}")

    for epoch in range(start_epoch, start_epoch + num_epochs):
        model.train()
        epoch_metrics = {
            'losses': [],
            'mcd': [],
            'frame_mse': [],
            'attention_coverage': [],
            'attention_entropy': [],
            'duration_ratios': [],
            'gradients': []
        }

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        for batch in progress_bar:
            optimizer.zero_grad()

            # Move batch to device
            inputs, input_lengths, targets, max_len, output_lengths = batch
            inputs = inputs.to(device)
            input_lengths = input_lengths.to(device)
            targets = targets.to(device)
            output_lengths = output_lengths.to(device)

            # Forward pass
            outputs = model((inputs, input_lengths, targets, max_len, output_lengths))
            mel_outputs, mel_outputs_postnet, gate_outputs, alignments = outputs

            # Calculate loss with L2 regularization
            mel_loss = F.mse_loss(mel_outputs, targets) + F.mse_loss(mel_outputs_postnet, targets)
            
            # Gate loss for stop token prediction
            gate_loss = F.binary_cross_entropy_with_logits(
                gate_outputs, 
                torch.zeros_like(gate_outputs)
            )
            
            # Total loss
            l2_lambda = 1e-6
            l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
            loss = mel_loss + gate_loss + l2_lambda * l2_norm

            # Backward pass
            loss.backward()

            # Gradient clipping
            grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            epoch_metrics['gradients'].append(grad_norm.item())

            optimizer.step()

            # Compute additional metrics
            metrics = compute_metrics(
                mel_outputs, 
                targets, 
                alignments,
                input_lengths,  # From the batch
                output_lengths  # From the batch
            )
            
            # Store metrics
            epoch_metrics['losses'].append(loss.item())
            epoch_metrics['mcd'].append(metrics['mcd'])
            epoch_metrics['frame_mse'].append(metrics['frame_mse'])
            epoch_metrics['attention_coverage'].append(metrics['attention_coverage'])
            epoch_metrics['attention_entropy'].append(metrics['attention_entropy'])
            epoch_metrics['duration_ratios'].append(metrics['duration_ratio'])
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f"{loss.item():.4f}",
                'mcd': f"{metrics['mcd']:.4f}",
                'lr': f"{optimizer.param_groups[0]['lr']:.6f}"
            })

        # Calculate epoch averages
        avg_metrics = {
            'loss': np.mean(epoch_metrics['losses']),
            'mcd': np.mean(epoch_metrics['mcd']),
            'frame_mse': np.mean(epoch_metrics['frame_mse']),
            'attention_coverage': np.mean(epoch_metrics['attention_coverage']),
            'attention_entropy': np.mean(epoch_metrics['attention_entropy']),
            'duration_ratio': np.mean(epoch_metrics['duration_ratios']),
            'gradient': np.mean(epoch_metrics['gradients'])
        }

        # Update history
        for key, value in avg_metrics.items():
            if key + 's' in history:  # Add 's' for plural in history keys
                history[key + 's'].append(value)

        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{num_epochs} Summary:")
        for metric, value in avg_metrics.items():
            print(f"{metric.replace('_', ' ').title()}: {value:.4f}")

        # Learning rate scheduling based on MCD
        scheduler.step(avg_metrics['mcd'])

        # Save checkpoint if loss improved
        if avg_metrics['loss'] < best_loss:
            best_loss = avg_metrics['loss']
            patience_counter = 0
            
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'metrics': avg_metrics,
                'history': history,
                'best_loss': best_loss
            }, f'{output_path}/models/best_model.pt')
            
            print("\nNew best model saved!")
        else:
            patience_counter += 1
        
        # For Debug Only
        if epoch == 0:
            generate_test_sample(model, text_processor, f"{output_path}/samples/test_epoch_{epoch+1}.wav")

        # Regular checkpoint and evaluation
        if (epoch + 1) % 10 == 0:
            # Save checkpoint
            checkpoint_path = f'{output_path}/models/checkpoint_epoch_{epoch+1}.pt'
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'metrics': avg_metrics,
                'history': history
            }, checkpoint_path)

            # Generate test samples
            generate_test_sample(model, text_processor, f"{output_path}/samples/test_epoch_{epoch+1}.wav")
            
            # Plot training history
            plot_training_history(history, epoch+1)
            
            # Plot attention alignments
            plot_attention(alignments[0], f"{output_path}/attention/attention_epoch_{epoch+1}.png")

        # Early stopping check
        if patience_counter >= max_patience:
            print("\nEarly stopping triggered!")
            break

def plot_training_history(history, epoch):
    """Enhanced training history visualization"""
    plt.figure(figsize=(15, 15))
    
    # Plot all metrics
    metrics = [
        ('train_losses', 'Training Loss'),
        ('mcd_values', 'Mel-Cepstral Distortion'),
        ('frame_mse_values', 'Frame MSE'),
        ('attention_coverage', 'Attention Coverage'),
        ('attention_entropy', 'Attention Entropy'),
        ('duration_ratios', 'Duration Ratio'),
        ('learning_rates', 'Learning Rate'),
        ('gradients', 'Gradient Norm')
    ]
    
    for idx, (metric, title) in enumerate(metrics, 1):
        plt.subplot(4, 2, idx)
        plt.plot(history[metric])
        plt.title(title)
        plt.xlabel('Epoch')
        plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(f'{output_path}/training_history_epoch_{epoch}.png')
    plt.close()

def plot_attention(attention, filename):
    """Plot attention alignment"""
    plt.figure(figsize=(10, 6))
    plt.imshow(attention.detach().cpu().numpy(), aspect='auto', origin='lower')
    plt.colorbar()
    plt.title('Attention Alignment')
    plt.xlabel('Decoder Steps')
    plt.ylabel('Encoder Steps')
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

def generate_test_sample(model, text_processor, output_file):
    """Generate test samples with error handling and audio saving"""
    model.eval()
    Path(output_file).parent.mkdir(exist_ok=True)
    
    test_texts = [
        "ខ្ញុំស្រឡាញ់អ្នក",  # I love you
        "សួស្តី",           # Hello
    ]
    
    try:
        for i, test_text in enumerate(test_texts):
            with torch.no_grad():
                print(f"\nGenerating sample for: {test_text}")
                
                # Process text
                sequence = text_processor.process_text(test_text)
                sequence = sequence.unsqueeze(0).to(device)
                input_lengths = torch.LongTensor([sequence.size(1)]).to(device)

                # Generate mel spectrogram
                mel_outputs, mel_lengths, alignments = model.infer(sequence, input_lengths)
                
                # Create unique output paths
                current_output_file = output_file.replace('.wav', f'_{i+1}.wav')
                mel_plot_path = current_output_file.replace('.wav', '_mel.png')
                attention_path = current_output_file.replace('.wav', '_attention.png')
                
                # Save mel spectrogram plot
                plot_mel_spectrogram(mel_outputs[0], mel_plot_path, 
                                   title=f"Mel Spectrogram: {test_text}")
                
                # Move to CPU and convert to numpy
                mel_spec = mel_outputs[0].cpu().numpy()
                
                # Create mel filterbank
                mel_basis = librosa.filters.mel(
                    sr=22050,
                    n_fft=1024,
                    n_mels=80,
                    fmin=0,
                    fmax=8000
                )
                
                # Convert to linear spectrogram
                linear_spec = np.maximum(1e-10, np.dot(mel_basis.T, mel_spec))
                
                # Generate audio using librosa's Griffin-Lim
                audio_numpy = librosa.griffinlim(
                    linear_spec,
                    n_iter=64,
                    hop_length=256,
                    win_length=1024,
                    window='hann'
                )
                
                # Normalize audio
                audio_numpy = audio_numpy / np.abs(audio_numpy).max()
                
                # Apply fade in/out
                fade_length = int(0.01 * 22050)  # 10ms fade
                fade_in = np.linspace(0, 1, fade_length)
                fade_out = np.linspace(1, 0, fade_length)
                
                audio_numpy[:fade_length] *= fade_in
                audio_numpy[-fade_length:] *= fade_out
                
                # Save audio
                sf.write(current_output_file, audio_numpy, 22050)
                
                # Save attention plot
                plot_attention(alignments[0], attention_path)
                
                print(f"Generated: {current_output_file}")
                print(f"Mel shape: {mel_outputs.shape}")
                print(f"Audio length: {len(audio_numpy)/22050:.2f} seconds")
                
    except Exception as e:
        print(f"Error generating test sample: {str(e)}")
        import traceback
        traceback.print_exc()
        
    print("\nTest sample generation complete!")

def plot_mel_spectrogram(mel_spectrogram, plot_path, title="Mel Spectrogram"):
    """
    Plot mel spectrogram with detailed visualization
    
    Args:
        mel_spectrogram: Tensor containing the mel spectrogram
        plot_path: Path to save the plot
        title: Title for the plot
    """
    # Convert to numpy array
    mel_np = mel_spectrogram.cpu().numpy()
    
    plt.figure(figsize=(12, 6))
    
    # Create main spectrogram plot
    plt.subplot(1, 1, 1)
    im = plt.imshow(mel_np, 
                    aspect='auto', 
                    origin='lower',
                    interpolation='none', 
                    cmap='viridis')
    
    # Add colorbar
    plt.colorbar(im, ax=plt.gca())
    
    # Add labels and title
    plt.xlabel('Frames')
    plt.ylabel('Mel Channels')
    plt.title(title)
    
    # Add grid
    plt.grid(False)
    
    # Adjust layout and save
    plt.tight_layout()
    plt.savefig(plot_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Mel spectrogram plot saved to: {plot_path}")

## 4. Evaluation

### 4.1 Evaluation Metrics

In [11]:
def calculate_metrics(model, test_loader):
    """Calculate evaluation metrics"""
    model.eval()

    mel_l1_loss = 0
    mel_l2_loss = 0

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            text = batch['text'].to(device)
            text_lengths = batch['text_lengths'].to(device)
            mel_target = batch['mel_target'].to(device)

            # Generate mel spectrograms
            mel_output, mel_output_postnet, _, _ = model(
                text, text_lengths, mel_target)

            # Calculate metrics
            mel_l1_loss += nn.L1Loss()(mel_output_postnet, mel_target).item()
            mel_l2_loss += nn.MSELoss()(mel_output_postnet, mel_target).item()

    # Average metrics
    mel_l1_loss /= len(test_loader)
    mel_l2_loss /= len(test_loader)

    return {
        'mel_l1_loss': mel_l1_loss,
        'mel_l2_loss': mel_l2_loss
    }

### 4.2 Alignment Analysis

In [12]:
def plot_alignment(alignment):
    """Plot attention alignment"""
    fig, ax = plt.subplots(figsize=(10, 6))
    im = ax.imshow(alignment.cpu().numpy(), aspect='auto', origin='lower')
    fig.colorbar(im, ax=ax)
    ax.set_title('Alignment')
    plt.tight_layout()
    return fig

## 5. Inference

### 5.1 Text to Speech Pipeline

In [13]:
def text_to_speech(text, model, vocoder, text_processor, output_path=None):
    """Complete TTS pipeline"""
    model.eval()
    vocoder.eval()

    # Process text
    sequence = text_processor.process_text(text)
    sequence = sequence.unsqueeze(0).to(device)
    sequence_length = torch.LongTensor([sequence.size(1)]).to(device)

    # Generate mel spectrogram
    with torch.no_grad():
        mel_outputs, mel_outputs_postnet, _, alignments = model.infer(
            sequence, sequence_length)

        # Generate audio
        audio = vocoder(mel_outputs_postnet)
        audio = audio.squeeze(0).cpu().numpy()

    # Save if path provided
    if output_path:
        sf.write(output_path, audio, 22050)

    return audio, mel_outputs_postnet, alignments

# Example usage
def inference_demo():
    # Load models
    model = load_pretrained_model()
    vocoder = load_vocoder()
    text_processor = TextProcessor()

    # Test text
    khmer_text = "ខ្ញុំស្រឡាញ់អ្នក"  # "I love you" in Khmer

    # Generate speech
    audio, mel, alignment = text_to_speech(
        khmer_text,
        model,
        vocoder,
        text_processor,
        "output.wav"
    )

    # Plot alignment
    plot_alignment(alignment[0])
    plt.show()

### 5.2 Batch Processing

In [14]:
def process_transcript(transcript_path, output_dir, model, vocoder, text_processor):
    """Process entire transcript file"""
    output_dir = Path(output_dir)
    output_dir.mkdir(exist_ok=True)

    # Load transcript
    df = load_transcript(transcript_path)

    # Process each entry
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Processing"):
        output_path = output_dir / f"{row['audio_file']}.wav"

        # Skip if already exists
        if output_path.exists():
            continue

        try:
            # Generate speech
            text_to_speech(
                row['text'],
                model,
                vocoder,
                text_processor,
                output_path
            )
        except Exception as e:
            logger.error(f"Error processing {row['audio_file']}: {str(e)}")

# Usage

In [15]:
def main():
    # Define Path Variables
    transcript_path = f"{output_path}/line_index.tsv"
    audio_path = f"{output_path}/wavs"

    # Load data
    df = load_transcript(transcript_path)
    print(f"Total samples: {len(df)}")

    # Perform EDA
    # durations, sample_rates, file_sizes = analyze_audio_files(df, audio_path)
    # plot_audio_stats(durations, sample_rates, file_sizes)
    # analyze_text(df)

    # Initialize processors
    text_processor = TextProcessor()
    audio_processor = AudioProcessor()

    # Create dataset
    dataset = AudioDataset(df, audio_path, text_processor, audio_processor)

    # Parameters
    train_params = {
        'batch_size': 24,
        'num_epochs': 300,
        'learning_rate': 1e-3,
        'num_workers': 0
    }

    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=train_params['batch_size'],
        shuffle=True,
        num_workers=train_params['num_workers'],
        collate_fn=collate_fn
    )

    # Load model
    model = load_pretrained_model()
    vocoder = load_vocoder()

    # Train model
    optimizer = torch.optim.Adam(model.parameters(), lr=train_params['learning_rate'])
    criterion = nn.MSELoss()

    train_model(
        model=model,
        text_processor=text_processor,
        train_loader=train_loader,
        optimizer=optimizer,
        num_epochs=train_params['num_epochs']
    )

    # # Training from previous checkpoint
    # train_model(
    #     model=model,
    #     text_processor=text_processor,
    #     train_loader=train_loader,
    #     optimizer=optimizer,
    #     num_epochs=train_params['num_epochs'],
    #     checkpoint_path=f'{output_path}/models/checkpoint_epoch_100.pt'
    # )

    # Evaluate
    # metrics = calculate_metrics(model, train_loader)
    # print("Evaluation metrics:", metrics)

In [None]:
main()

Total samples: 2906


Using cache found in C:\Users\ADMIN/.cache\torch\hub\NVIDIA_DeepLearningExamples_torchhub
Using cache found in C:\Users\ADMIN/.cache\torch\hub\NVIDIA_DeepLearningExamples_torchhub
  ckpt = torch.load(ckpt_file)
  WeightNorm.apply(module, name, dim)


Epoch 1/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 1/300 Summary:
Loss: 9.0566
Mcd: 6.9636
Frame Mse: 6.9636
Attention Coverage: 0.0014
Attention Entropy: 0.0297
Duration Ratio: 1.0000
Gradient: 5.1085

New best model saved!

Generating sample for: ខ្ញុំស្រឡាញ់អ្នក


  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')


Mel spectrogram plot saved to: ../dataset/samples/test_epoch_1_1_mel.png
Generated: ../dataset/samples/test_epoch_1_1.wav
Mel shape: torch.Size([1, 80, 1000])
Audio length: 11.60 seconds

Generating sample for: សួស្តី


  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')


Mel spectrogram plot saved to: ../dataset/samples/test_epoch_1_2_mel.png
Generated: ../dataset/samples/test_epoch_1_2.wav
Mel shape: torch.Size([1, 80, 1000])
Audio length: 11.60 seconds

Test sample generation complete!


Epoch 2/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 2/300 Summary:
Loss: 4.2772
Mcd: 3.1843
Frame Mse: 3.1843
Attention Coverage: 0.0038
Attention Entropy: 0.0231
Duration Ratio: 1.0000
Gradient: 2.1352

New best model saved!


Epoch 3/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 3/300 Summary:
Loss: 3.3666
Mcd: 2.5683
Frame Mse: 2.5683
Attention Coverage: 0.0040
Attention Entropy: 0.0205
Duration Ratio: 1.0000
Gradient: 1.7724

New best model saved!


Epoch 4/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 4/300 Summary:
Loss: 2.9863
Mcd: 2.3304
Frame Mse: 2.3304
Attention Coverage: 0.0040
Attention Entropy: 0.0196
Duration Ratio: 1.0000
Gradient: 1.5815

New best model saved!


Epoch 5/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 5/300 Summary:
Loss: 2.7691
Mcd: 2.1988
Frame Mse: 2.1988
Attention Coverage: 0.0040
Attention Entropy: 0.0187
Duration Ratio: 1.0000
Gradient: 1.4405

New best model saved!


Epoch 6/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 6/300 Summary:
Loss: 2.5668
Mcd: 2.0896
Frame Mse: 2.0896
Attention Coverage: 0.0040
Attention Entropy: 0.0177
Duration Ratio: 1.0000
Gradient: 1.3575

New best model saved!


Epoch 7/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 7/300 Summary:
Loss: 2.5056
Mcd: 2.0133
Frame Mse: 2.0133
Attention Coverage: 0.0040
Attention Entropy: 0.0178
Duration Ratio: 1.0000
Gradient: 1.3854

New best model saved!


Epoch 8/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 8/300 Summary:
Loss: 2.4085
Mcd: 1.9586
Frame Mse: 1.9586
Attention Coverage: 0.0040
Attention Entropy: 0.0170
Duration Ratio: 1.0000
Gradient: 1.4022

New best model saved!


Epoch 9/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 9/300 Summary:
Loss: 2.3906
Mcd: 1.9104
Frame Mse: 1.9104
Attention Coverage: 0.0040
Attention Entropy: 0.0170
Duration Ratio: 1.0000
Gradient: 1.2475

New best model saved!


Epoch 10/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 10/300 Summary:
Loss: 2.3095
Mcd: 1.8664
Frame Mse: 1.8664
Attention Coverage: 0.0040
Attention Entropy: 0.0166
Duration Ratio: 1.0000
Gradient: 1.1326

New best model saved!

Generating sample for: ខ្ញុំស្រឡាញ់អ្នក


  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')
  plt.savefig(plot_path, dpi=300, bbox_inches='tight')


Mel spectrogram plot saved to: ../dataset/samples/test_epoch_10_1_mel.png
Generated: ../dataset/samples/test_epoch_10_1.wav
Mel shape: torch.Size([1, 80, 1000])
Audio length: 11.60 seconds

Generating sample for: សួស្តី
Mel spectrogram plot saved to: ../dataset/samples/test_epoch_10_2_mel.png
Generated: ../dataset/samples/test_epoch_10_2.wav
Mel shape: torch.Size([1, 80, 3])
Audio length: 0.02 seconds

Test sample generation complete!




Epoch 11/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 11/300 Summary:
Loss: 2.2388
Mcd: 1.8252
Frame Mse: 1.8252
Attention Coverage: 0.0040
Attention Entropy: 0.0159
Duration Ratio: 1.0000
Gradient: 1.1552

New best model saved!


Epoch 12/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 12/300 Summary:
Loss: 2.2490
Mcd: 1.7984
Frame Mse: 1.7984
Attention Coverage: 0.0041
Attention Entropy: 0.0165
Duration Ratio: 1.0000
Gradient: 1.1162


Epoch 13/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 13/300 Summary:
Loss: 2.1967
Mcd: 1.7682
Frame Mse: 1.7682
Attention Coverage: 0.0041
Attention Entropy: 0.0160
Duration Ratio: 1.0000
Gradient: 1.1463

New best model saved!


Epoch 14/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 14/300 Summary:
Loss: 2.1766
Mcd: 1.7403
Frame Mse: 1.7403
Attention Coverage: 0.0042
Attention Entropy: 0.0159
Duration Ratio: 1.0000
Gradient: 1.0837

New best model saved!


Epoch 15/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 15/300 Summary:
Loss: 2.1477
Mcd: 1.7185
Frame Mse: 1.7185
Attention Coverage: 0.0040
Attention Entropy: 0.0166
Duration Ratio: 1.0000
Gradient: 1.0821

New best model saved!


Epoch 16/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 16/300 Summary:
Loss: 2.1123
Mcd: 1.6925
Frame Mse: 1.6925
Attention Coverage: 0.0040
Attention Entropy: 0.0164
Duration Ratio: 1.0000
Gradient: 1.0909

New best model saved!


Epoch 17/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 17/300 Summary:
Loss: 2.1271
Mcd: 1.6779
Frame Mse: 1.6779
Attention Coverage: 0.0040
Attention Entropy: 0.0161
Duration Ratio: 1.0000
Gradient: 1.0374


Epoch 18/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 18/300 Summary:
Loss: 2.0699
Mcd: 1.6575
Frame Mse: 1.6575
Attention Coverage: 0.0041
Attention Entropy: 0.0157
Duration Ratio: 1.0000
Gradient: 1.0577

New best model saved!


Epoch 19/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 19/300 Summary:
Loss: 2.0495
Mcd: 1.6384
Frame Mse: 1.6384
Attention Coverage: 0.0041
Attention Entropy: 0.0155
Duration Ratio: 1.0000
Gradient: 0.9804

New best model saved!


Epoch 20/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 20/300 Summary:
Loss: 2.0659
Mcd: 1.6235
Frame Mse: 1.6235
Attention Coverage: 0.0041
Attention Entropy: 0.0158
Duration Ratio: 1.0000
Gradient: 1.0460

Generating sample for: ខ្ញុំស្រឡាញ់អ្នក
Mel spectrogram plot saved to: ../dataset/samples/test_epoch_20_1_mel.png
Generated: ../dataset/samples/test_epoch_20_1.wav
Mel shape: torch.Size([1, 80, 1000])
Audio length: 11.60 seconds

Generating sample for: សួស្តី
Mel spectrogram plot saved to: ../dataset/samples/test_epoch_20_2_mel.png
Generated: ../dataset/samples/test_epoch_20_2.wav
Mel shape: torch.Size([1, 80, 1000])
Audio length: 11.60 seconds

Test sample generation complete!


Epoch 21/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 21/300 Summary:
Loss: 2.0057
Mcd: 1.6115
Frame Mse: 1.6115
Attention Coverage: 0.0040
Attention Entropy: 0.0148
Duration Ratio: 1.0000
Gradient: 1.0919

New best model saved!


Epoch 22/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 22/300 Summary:
Loss: 2.0577
Mcd: 1.5962
Frame Mse: 1.5962
Attention Coverage: 0.0041
Attention Entropy: 0.0156
Duration Ratio: 1.0000
Gradient: 1.0740


Epoch 23/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 23/300 Summary:
Loss: 1.9789
Mcd: 1.5759
Frame Mse: 1.5759
Attention Coverage: 0.0042
Attention Entropy: 0.0151
Duration Ratio: 1.0000
Gradient: 0.8974

New best model saved!


Epoch 24/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 24/300 Summary:
Loss: 1.9620
Mcd: 1.5644
Frame Mse: 1.5644
Attention Coverage: 0.0042
Attention Entropy: 0.0151
Duration Ratio: 1.0000
Gradient: 0.9682

New best model saved!


Epoch 25/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 25/300 Summary:
Loss: 1.9714
Mcd: 1.5541
Frame Mse: 1.5541
Attention Coverage: 0.0041
Attention Entropy: 0.0150
Duration Ratio: 1.0000
Gradient: 0.9771


Epoch 26/300:   0%|          | 0/122 [00:00<?, ?it/s]


Epoch 26/300 Summary:
Loss: 1.9569
Mcd: 1.5457
Frame Mse: 1.5457
Attention Coverage: 0.0043
Attention Entropy: 0.0146
Duration Ratio: 1.0000
Gradient: 1.0375

New best model saved!


Epoch 27/300:   0%|          | 0/122 [00:00<?, ?it/s]

### Synthesis

In [None]:
def print_model_info(checkpoint_path):
    """Print model information from checkpoint"""
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Load Tacotron2 model
    model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub',
                          'nvidia_tacotron2',
                          model_math='fp16',
                          pretrained=False)

    # Load state dict from checkpoint
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)

    print("\nModel Architecture:")
    print(model)

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"\nTotal parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")

    # Print training info from checkpoint
    print(f"\nCheckpoint information:")
    print(f"Epoch: {checkpoint['epoch']}")
    print(f"Loss: {checkpoint['loss']}")

    return model

# Usage
checkpoint_path = f"{output_path}/models/checkpoint_epoch_100.pt"
model = print_model_info(checkpoint_path)

  checkpoint = torch.load(checkpoint_path, map_location=device)
Using cache found in /root/.cache/torch/hub/NVIDIA_DeepLearningExamples_torchhub



Model Architecture:
Tacotron2(
  (embedding): Embedding(148, 512)
  (encoder): Encoder(
    (convolutions): ModuleList(
      (0-2): 3 x Sequential(
        (0): ConvNorm(
          (conv): Conv1d(512, 512, kernel_size=(5,), stride=(1,), padding=(2,))
        )
        (1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (lstm): LSTM(512, 256, batch_first=True, bidirectional=True)
  )
  (decoder): Decoder(
    (prenet): Prenet(
      (layers): ModuleList(
        (0): LinearNorm(
          (linear_layer): Linear(in_features=80, out_features=256, bias=False)
        )
        (1): LinearNorm(
          (linear_layer): Linear(in_features=256, out_features=256, bias=False)
        )
      )
    )
    (attention_rnn): LSTMCell(768, 1024)
    (attention_layer): Attention(
      (query_layer): LinearNorm(
        (linear_layer): Linear(in_features=1024, out_features=128, bias=False)
      )
      (memory_layer): LinearNorm(
        (linear_

: 

: 