# Khmer Text-to-Speech Model training (Tecotron2)

## Import Libraries

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

# TTS Model
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm import tqdm
from torch.nn.utils.rnn import pad_sequence
from IPython.display import Audio

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA version: {torch.version.cuda}")

# Check backend version
# print(f"Audion backend: {torchaudio.get_audio_backend()}")

Using device: cuda
CUDA available: True
CUDA version: 12.4


In [4]:
# Load the dataset as dataframe
processed_transcriptions = pd.read_csv('../dataset/processed_transcriptions.csv')

print(processed_transcriptions.head())

            line_index                                      transcription  \
0  khm_0308_0011865648  ស្ពាន កំពង់ ចម្លង អ្នកលឿង នៅ ព្រៃវែង ជា ស្ពាន ...   
1  khm_0308_0032157149  ភ្លើង កំពុង ឆាប ឆេះ ផ្ទះ ប្រជា ពលរដ្ឋ នៅ សង្កា...   
2  khm_0308_0038959268  អ្នក សុំ ទាន ដេក ប្រកាច់ ម្នាក់ ឯង ក្បែរ ខ្លោង...   
3  khm_0308_0054635313  ស្ករ ត្នោត ដែល មាន គុណភាព ល្អ ផលិត នៅ ខេត្ត កំ...   
4  khm_0308_0055735195         ភ្នំបាខែង មាន កម្ពស់ តែ ចិត សិប ម៉ែត្រ សោះ   

                            normalized_transcription  \
0  ស្ពាន កំពង់ ចម្លង អ្នកលឿង នៅ ព្រៃវែង ជា ស្ពាន ...   
1  ភ្លើង កំពុង ឆាប ឆេះ ផ្ទះ ប្រជា ពលរដ្ឋ នៅ សង្កា...   
2  អ្នក សុំ ទាន ដេក ប្រកាច់ ម្នាក់ ឯង ក្បែរ ខ្លោង...   
3  ស្ករ ត្នោត ដែល មាន គុណភាព ល្អ ផលិត នៅ ខេត្ត កំ...   
4         ភ្នំបាខែង មាន កម្ពស់ តែ ចិត សិប ម៉ែត្រ សោះ   

                                          mel_path  \
0  ../mel_spectrograms\khm_0308_0011865648_mel.npy   
1  ../mel_spectrograms\khm_0308_0032157149_mel.npy   
2  ../mel_spectrograms\khm_0308_003895

## 1. Model Training

### 1.1 Train Test split

In [5]:
# Split the dataset
train_data, val_data = train_test_split(processed_transcriptions, test_size=0.1, random_state=42)

In [6]:
print(train_data.columns)

Index(['line_index', 'transcription', 'normalized_transcription', 'mel_path',
       'tokenized_transcription'],
      dtype='object')


### 1.2 Load Processed Data

In [7]:
import ast
class TTSDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        # Load the tokenized transcription
        tokens = self.data.iloc[idx]['tokenized_transcription']
        
        # Convert tokens to tensor
        if isinstance(tokens, str):
            tokens = ast.literal_eval(tokens)
        tokens = torch.tensor(tokens, dtype=torch.long)
        
        # Load the mel spectrogram
        mel_path = self.data.iloc[idx]['mel_path']
        mel_spectrogram = np.load(mel_path)
        mel_spectrogram = torch.tensor(mel_spectrogram, dtype=torch.float32)
        
        # Ensure mel spectrogram has shape [time, n_mels]
        if mel_spectrogram.shape[1] != 80:  # Assuming 80 mel bins
            mel_spectrogram = mel_spectrogram.transpose(0, 1)
        
        return tokens, mel_spectrogram

def collate_fn(batch):
    tokens, mel_spectrograms = zip(*batch)
    
    # Pad tokens
    tokens = pad_sequence(tokens, batch_first=True, padding_value=0)
    
    # Get the maximum sequence length for mel spectrograms
    max_length = max(mel.size(0) for mel in mel_spectrograms)
    
    # Pad mel spectrograms to have the same sequence length
    padded_mels = []
    for mel in mel_spectrograms:
        # Calculate padding needed
        pad_length = max_length - mel.size(0)
        if pad_length > 0:
            # Pad at the end (right side) of the sequence
            padded_mel = torch.nn.functional.pad(mel, (0, 0, 0, pad_length))
        else:
            padded_mel = mel
        padded_mels.append(padded_mel)
    
    # Stack the padded mel spectrograms and reshape to [batch, time, n_mels]
    mel_spectrograms = torch.stack(padded_mels)
    mel_spectrograms = mel_spectrograms.transpose(1, 2)  # Change to [batch, n_mels, time]
    mel_spectrograms = mel_spectrograms.transpose(1, 2)  # Change to [batch, time, n_mels]
    
    return tokens, mel_spectrograms

# Create the dataloaders
train_dataset = TTSDataset(train_data)
val_dataset = TTSDataset(val_data)

train_loader = DataLoader(
    train_dataset, 
    batch_size=32, 
    shuffle=True, 
    collate_fn=collate_fn,
    drop_last=True  # Drop incomplete batches
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=32, 
    shuffle=False, 
    collate_fn=collate_fn,
    drop_last=True  # Drop incomplete batches
)

### 1.3 Define Model

In [8]:
class Tacotron2(nn.Module):
    def __init__(self, n_vocab=256, embedding_dim=512, encoder_dim=512, decoder_dim=512, n_mels=80):
        super(Tacotron2, self).__init__()
        
        # Text Embedding
        self.embedding = nn.Embedding(n_vocab, embedding_dim)
        
        # Encoder
        self.encoder_prenet = nn.Sequential(
            nn.Linear(embedding_dim, encoder_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(encoder_dim, encoder_dim),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.encoder_lstm = nn.LSTM(
            input_size=encoder_dim,
            hidden_size=encoder_dim // 2,
            num_layers=1,
            batch_first=True,
            bidirectional=True
        )
        
        # Decoder
        self.decoder_prenet = nn.Sequential(
            nn.Linear(n_mels, decoder_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(decoder_dim, decoder_dim),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        self.decoder_lstm = nn.LSTM(
            input_size=decoder_dim + encoder_dim,
            hidden_size=decoder_dim,
            num_layers=2,
            batch_first=True
        )
        
        # Output projection
        self.mel_projection = nn.Linear(decoder_dim, n_mels)
        
        # Postnet
        self.postnet = nn.Sequential(
            nn.Conv1d(n_mels, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, 512, kernel_size=5, padding=2),
            nn.BatchNorm1d(512),
            nn.Tanh(),
            nn.Dropout(0.5),
            nn.Conv1d(512, n_mels, kernel_size=5, padding=2),
        )
        
    def forward(self, text, mel_target=None):
        # Text embedding
        embedded = self.embedding(text)  # [batch, text_length, embedding_dim]
        
        # Encoder
        encoder_input = self.encoder_prenet(embedded)
        encoder_output, _ = self.encoder_lstm(encoder_input)  # [batch, text_length, encoder_dim]
        
        batch_size = text.size(0)
        
        # Initialize decoder input (moved outside the conditional)
        go_frame = torch.zeros(batch_size, 1, self.mel_projection.out_features).to(text.device)
        
        if self.training and mel_target is not None:
            # Teacher forcing: concatenate go_frame with target mel frames
            decoder_inputs = torch.cat((go_frame, mel_target[:, :-1, :]), dim=1)
            
            # Process through decoder prenet
            decoder_inputs = self.decoder_prenet(decoder_inputs)  # [batch, mel_length, decoder_dim]
            
            # Prepare encoder outputs for attention
            # Expand encoder outputs to match decoder sequence length
            expanded_encoder_output = encoder_output.unsqueeze(1)  # [batch, 1, text_length, encoder_dim]
            expanded_encoder_output = expanded_encoder_output.expand(
                -1, decoder_inputs.size(1), -1, -1)  # [batch, mel_length, text_length, encoder_dim]
            
            # For now, use simple averaging of encoder outputs
            context_vectors = expanded_encoder_output.mean(dim=2)  # [batch, mel_length, encoder_dim]
            
            # Concatenate decoder inputs with context vectors
            decoder_lstm_input = torch.cat((decoder_inputs, context_vectors), dim=-1)
            
            # Decoder LSTM
            decoder_output, _ = self.decoder_lstm(decoder_lstm_input)
            
            # Project to mel-spectrogram
            mel_output = self.mel_projection(decoder_output)
            
            # Postnet processing
            mel_output_postnet = self.postnet(mel_output.transpose(1, 2)).transpose(1, 2)
            mel_output_refined = mel_output + mel_output_postnet
            
            return mel_output, mel_output_refined
        else:
            # Determine output length
            if mel_target is not None:
                # For validation: use target length
                target_length = mel_target.size(1)
            else:
                # For inference: use maximum length
                target_length = 1000
                
            mel_outputs = []
            current_frame = go_frame
            
            for _ in range(target_length):
                # Process current frame through prenet
                prenet_out = self.decoder_prenet(current_frame)
                
                # Get context vector (simplified attention)
                context_vector = encoder_output.mean(dim=1, keepdim=True)
                
                # Concatenate and process through LSTM
                decoder_input = torch.cat((prenet_out, context_vector.expand(-1, 1, -1)), dim=-1)
                decoder_output, _ = self.decoder_lstm(decoder_input)
                
                # Generate next frame
                current_frame = self.mel_projection(decoder_output)
                mel_outputs.append(current_frame)
                
                # Stop if we predict a stop token (you'll need to implement this)
                
            mel_outputs = torch.cat(mel_outputs, dim=1)
            mel_outputs_postnet = self.postnet(mel_outputs.transpose(1, 2)).transpose(1, 2)
            mel_outputs_refined = mel_outputs + mel_outputs_postnet
            
            return mel_outputs, mel_outputs_refined

In [9]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    epoch_loss = 0
    batch_count = 0
    
    for tokens, mel_spectrograms in tqdm(loader, desc="Training", leave=False):
        try:
            tokens = tokens.to(device)
            mel_spectrograms = mel_spectrograms.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            mel_output, mel_output_refined = model(tokens, mel_spectrograms)
            
            # Compute loss (combining both outputs)
            loss = criterion(mel_output, mel_spectrograms) + criterion(mel_output_refined, mel_spectrograms)
            
            # Backward pass and optimization
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Add gradient clipping
            optimizer.step()
            
            epoch_loss += loss.item()
            batch_count += 1
            
        except RuntimeError as e:
            print(f"Error in batch: {e}")
            continue
    
    return epoch_loss / batch_count if batch_count > 0 else float('inf')

# Modified validation function with error handling
def validate_one_epoch(model, loader, criterion, device):
    model.eval()
    epoch_loss = 0
    batch_count = 0
    
    with torch.no_grad():
        for tokens, mel_spectrograms in tqdm(loader, desc="Validation", leave=False):
            try:
                tokens = tokens.to(device)
                mel_spectrograms = mel_spectrograms.to(device)
                
                # Forward pass
                mel_output, mel_output_refined = model(tokens, mel_spectrograms)
                
                # Compute loss
                loss = criterion(mel_output, mel_spectrograms) + criterion(mel_output_refined, mel_spectrograms)
                
                epoch_loss += loss.item()
                batch_count += 1
                
            except RuntimeError as e:
                print(f"Error in validation batch: {e}")
                continue
    
    return epoch_loss / batch_count if batch_count > 0 else float('inf')

# Initialize model and move to device
model = Tacotron2(
    n_vocab=256,
    embedding_dim=512,
    encoder_dim=512,
    decoder_dim=512,
    n_mels=80
).to(device)

# Initialize optimizer and criterion
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

### 1.4 Train the Model

In [10]:
# Training loop
num_epochs = 20
best_val_loss = float('inf')

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    
    # Train
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
    print(f"Train Loss: {train_loss:.4f}")
    
    # Validate
    val_loss = validate_one_epoch(model, val_loader, criterion, device)
    print(f"Validation Loss: {val_loss:.4f}")
    
    # Save the best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_tacotron2_model.pth")
        print("Saved best model!")

Epoch 1/20


                                                         

Train Loss: 1574.8424


                                                         

Validation Loss: 6020.9544
Saved best model!
Epoch 2/20


                                                         

Train Loss: 983.3863


                                                         

Validation Loss: 5055.8333
Saved best model!
Epoch 3/20


                                                         

Train Loss: 934.9413


                                                         

Validation Loss: 2319.3285
Saved best model!
Epoch 4/20


                                                         

Train Loss: 890.3351


                                                         

Validation Loss: 1834.5391
Saved best model!
Epoch 5/20


                                                         

Train Loss: 827.4638


                                                         

Validation Loss: 2094.4998
Epoch 6/20


                                                         

Train Loss: 745.3214


                                                         

Validation Loss: 3061.4314
Epoch 7/20


                                                         

Train Loss: 657.5508


                                                         

Validation Loss: 3397.1665
Epoch 8/20


                                                         

Train Loss: 544.4181


                                                         

Validation Loss: 3288.4491
Epoch 9/20


                                                         

Train Loss: 431.5051


                                                         

Validation Loss: 2657.4871
Epoch 10/20


                                                         

Train Loss: 310.2101


                                                         

Validation Loss: 3156.0175
Epoch 11/20


                                                         

Train Loss: 205.6654


                                                         

Validation Loss: 3406.4183
Epoch 12/20


                                                         

Train Loss: 171.2692


                                                         

Validation Loss: 3576.1490
Epoch 13/20


                                                         

Train Loss: 164.0181


                                                         

Validation Loss: 3634.3490
Epoch 14/20


                                                         

Train Loss: 154.7960


                                                         

Validation Loss: 3714.4941
Epoch 15/20


                                                         

Train Loss: 145.2554


                                                         

Validation Loss: 3710.2556
Epoch 16/20


                                                         

Train Loss: 139.0830


                                                         

Validation Loss: 3900.7633
Epoch 17/20


                                                         

Train Loss: 136.8008


                                                         

Validation Loss: 3965.8009
Epoch 18/20


                                                         

Train Loss: 134.8769


                                                         

Validation Loss: 4075.8133
Epoch 19/20


                                                         

Train Loss: 128.4979


                                                         

Validation Loss: 4107.8998
Epoch 20/20


                                                         

Train Loss: 126.8000


                                                         

Validation Loss: 4288.8414




# Khmer Text-to-Speech Demonstration

In [11]:
# import torch
# import torchaudio
# import numpy as np
# import librosa
# from IPython.display import Audio
# import pandas as pd

# def create_char_to_id_mapping(processed_transcriptions, max_vocab_size=256):
#     """Create character to ID mapping with vocabulary size limit"""
#     # Create a set of all unique characters
#     unique_chars = set()
#     for text in processed_transcriptions['normalized_transcription']:
#         chars = text.strip().split()
#         unique_chars.update(chars)
    
#     # Sort characters to ensure consistent ordering
#     sorted_chars = sorted(list(unique_chars))
    
#     # Limit vocabulary size if necessary
#     if len(sorted_chars) + 2 > max_vocab_size:  # +2 for <pad> and <unk>
#         print(f"Warning: Truncating vocabulary from {len(sorted_chars)} to {max_vocab_size-2}")
#         sorted_chars = sorted_chars[:max_vocab_size-2]
    
#     # Create mapping dictionary
#     char_to_id = {char: idx for idx, char in enumerate(sorted_chars)}
    
#     # Add special tokens
#     char_to_id['<pad>'] = len(char_to_id)
#     char_to_id['<unk>'] = len(char_to_id)
    
#     print(f"Total vocabulary size: {len(char_to_id)}")
#     return char_to_id

# def text_to_sequence(text, char_to_id):
#     """Convert Khmer text to sequence of token IDs"""
#     chars = text.strip().split()
#     return [char_to_id.get(char, char_to_id['<unk>']) for char in chars]

# def generate_mel_spectrogram(model, text_sequence, device):
#     """Generate mel spectrogram with error handling"""
#     try:
#         with torch.no_grad():
#             text_tensor = torch.LongTensor([text_sequence]).to(device)
#             print(f"Input tensor shape: {text_tensor.shape}")
#             mel_output, mel_output_refined = model(text_tensor)
#             return mel_output_refined.cpu().numpy()[0]
#     except Exception as e:
#         print(f"Error in generate_mel_spectrogram: {e}")
#         raise

# def generate_audio_librosa(mel_spectrogram):
#     """Generate audio using librosa"""
#     try:
#         # Create mel filterbank
#         mel_basis = librosa.filters.mel(
#             sr=22050,
#             n_fft=1024,
#             n_mels=80,
#             fmin=0,
#             fmax=8000,
#         )
        
#         # Print shapes for debugging
#         print(f"Mel spectrogram shape: {mel_spectrogram.shape}")
#         print(f"Mel basis shape: {mel_basis.shape}")
        
#         # Compute pseudoinverse of mel filterbank
#         mel_inverse = np.linalg.pinv(mel_basis)
        
#         # Convert mel spectrogram to linear spectrogram
#         linear_spectrogram = np.dot(mel_inverse, mel_spectrogram.T).T
        
#         # Ensure non-negative values
#         linear_spectrogram = np.maximum(1e-10, linear_spectrogram)
        
#         # Use librosa's istft
#         audio = librosa.istft(
#             linear_spectrogram.astype(np.complex64),
#             hop_length=256,
#             win_length=1024
#         )
        
#         return audio
#     except Exception as e:
#         print(f"Error in librosa audio generation: {e}")
#         raise

# def generate_audio_torchaudio(mel_spectrogram):
#     """Generate audio using torchaudio"""
#     try:
#         mel_tensor = torch.FloatTensor(mel_spectrogram)
#         griffin_lim = torchaudio.transforms.GriffinLim(
#             n_fft=1024,
#             n_iter=32,
#             win_length=1024,
#             hop_length=256
#         )
#         waveform = griffin_lim(mel_tensor)
#         return waveform.numpy()
#     except Exception as e:
#         print(f"Error in torchaudio audio generation: {e}")
#         raise

# def run_demo(force_cpu=False, use_torch_version=False):
#     """Run the complete demo pipeline"""
#     try:
#         # 1. Setup device
#         device = torch.device('cpu') if force_cpu else torch.device('cuda')
#         print(f"Using device: {device}")
        
#         # 2. Load processed transcriptions
#         processed_transcriptions = pd.read_csv('../dataset/processed_transcriptions.csv')
        
#         # 3. Load model state dict
#         state_dict = torch.load("best_tacotron2_model.pth", 
#                               map_location='cpu' if force_cpu else device)
        
#         saved_vocab_size = state_dict['embedding.weight'].shape[0]
#         print(f"Saved model vocabulary size: {saved_vocab_size}")
        
#         # 4. Initialize model
#         model = Tacotron2(
#             n_vocab=saved_vocab_size,
#             embedding_dim=512,
#             encoder_dim=512,
#             decoder_dim=512,
#             n_mels=80
#         )
        
#         # 5. Load weights
#         model.load_state_dict(state_dict)
#         model = model.to(device)
#         model.eval()
        
#         # 6. Create character mapping
#         char_to_id = create_char_to_id_mapping(processed_transcriptions, max_vocab_size=saved_vocab_size)
        
#         # 7. Prepare test text
#         test_text = "សួស្តី"  # "Hello"
#         print(f"\nProcessing text: {test_text}")
        
#         # 8. Convert text to sequence
#         text_sequence = text_to_sequence(test_text, char_to_id)
#         print(f"Token sequence: {text_sequence}")
        
#         # 9. Generate mel spectrogram
#         print("\nGenerating mel spectrogram...")
#         mel_spectrogram = generate_mel_spectrogram(model, text_sequence, device)
        
#         # 10. Convert to audio using selected method
#         print("Converting to audio...")
#         if use_torch_version:
#             audio = generate_audio_torchaudio(mel_spectrogram)
#         else:
#             audio = generate_audio_librosa(mel_spectrogram)
        
#         # 11. Normalize audio
#         audio = audio / np.max(np.abs(audio))
        
#         return Audio(audio, rate=22050)
        
#     except Exception as e:
#         print(f"\nError in demo: {e}")
#         if torch.cuda.is_available():
#             print(f"CUDA Memory allocated: {torch.cuda.memory_allocated()/1024**2:.2f} MB")
#             print(f"CUDA Memory cached: {torch.cuda.memory_reserved()/1024**2:.2f} MB")
#         raise

# # Try running with both implementations
# print("Testing with librosa implementation...")
# try:
#     audio_player = run_demo(force_cpu=True, use_torch_version=False)
#     print("\nLibrosa implementation successful!")
#     display(audio_player)
# except Exception as e:
#     print(f"\nLibrosa implementation failed: {e}")
#     print("\nTrying torchaudio implementation...")
#     try:
#         audio_player = run_demo(force_cpu=True, use_torch_version=True)
#         print("\nTorchaudio implementation successful!")
#         display(audio_player)
#     except Exception as e:
#         print(f"\nBoth implementations failed: {e}")