# Khmer Text-to-Speech

### Import Packages

In [2]:
import os
import numpy as np
import pandas as pd
import librosa
import matplotlib.pyplot as plt
import soundfile as sf
import re
from tqdm import tqdm
import pickle

### Load Datasets

In [3]:
transcriptions_path = '../dataset/line_index.tsv'
transcriptions = pd.read_csv(transcriptions_path, sep='\t\t', names=['line_index', 'transcription'], engine='python')

# Dataset overview
print(transcriptions.shape)
print(transcriptions.head())

(2906, 2)
            line_index                                      transcription
0  khm_0308_0011865648  ស្ពាន កំពង់ ចម្លង អ្នកលឿង នៅ ព្រៃវែង ជា ស្ពាន ...
1  khm_0308_0032157149  ភ្លើង កំពុង ឆាប ឆេះ ផ្ទះ ប្រជា ពលរដ្ឋ នៅ សង្កា...
2  khm_0308_0038959268  អ្នក សុំ ទាន ដេក ប្រកាច់ ម្នាក់ ឯង ក្បែរ ខ្លោង...
3  khm_0308_0054635313  ស្ករ ត្នោត ដែល មាន គុណភាព ល្អ ផលិត នៅ ខេត្ត កំ...
4  khm_0308_0055735195         ភ្នំបាខែង មាន កម្ពស់ តែ ចិត សិប ម៉ែត្រ សោះ


## 1. Basic statistics

In [4]:
total_duration = 0
total_text_length = 0

dataset_path = '../dataset'

# Calculate audio duration and text length
for line_index, transcription in zip(transcriptions['line_index'], transcriptions['transcription']):
    # print(f'Processing {line_index}')
    wav_path = os.path.join(dataset_path, 'wavs', f'{line_index}.wav')
    
    if os.path.exists(wav_path):
        y, sr = librosa.load(wav_path, sr=None)
        duration = librosa.get_duration(y=y, sr=sr)
        total_duration += duration
        total_text_length += len(transcription)
    else:
        print(f'File not found: {wav_path}')

# Print Statistics
total_duration_hours = total_duration / 3600
print(f'Total duration: {total_duration_hours} hours')
print(f'Total text length: {total_text_length} characters')

Total duration: 3.9667806481481387 hours
Total text length: 146460 characters


### 1.2 Distribution Analysis

In [12]:
audio_lengths = []
text_lengths = []

for line_index, transcription in zip(transcriptions['line_index'], transcriptions['transcription']): 
    wav_path = os.path.join(dataset_path, 'wavs', f'{line_index}.wav')
    y, sr = librosa.load(wav_path, sr=None)
    audio_lengths.append(librosa.get_duration(y=y, sr=sr))
    text_lengths.append(len(transcription))

# # Plot historgram
# plt.figure(figsize=(12, 6))
# plt.subplot(1, 2, 1)
# plt.hist(audio_lengths, bins=30, color='blue', edgecolor='black', align= 'mid')
# plt.title('Audio Length Distribution')
# plt.xlabel('Duration (seconds)')
# plt.ylabel('Count')

# plt.subplot(1, 2, 2)
# plt.hist(text_lengths, bins=30, color='blue', edgecolor='black', align= 'mid')
# plt.title('Text Length Distribution')
# plt.xlabel('Length (characters)')
# plt.ylabel('Frequency')

# plt.tight_layout()
# plt.show()

### 1.3 Quality Check

In [13]:
# # Visualize a random audio file
# random_file = os.path.join(dataset_path, 'wavs', transcriptions.sample(1)['line_index'].values[0] + '.wav')
# y, sr = librosa.load(random_file, sr=None)
# plt.figure(figsize=(10, 4))
# plt.plot(y)
# plt.title('Waveform of a random audio file')
# plt.xlabel('Samples')
# plt.ylabel('Amplitude')
# plt.show()

# # Listen to the audio file if running locally
# import IPython.display as ipd
# ipd.Audio(random_file)

### 1.4 Identify and handle outliers

In [14]:
# Identify outliers
short_audio = transcriptions[[length < 1 for length in audio_lengths]] # Audio files less than 1 second
long_audio = transcriptions[[length > 10 for length in audio_lengths]] # Audio files more than 10 seconds

# print(f"Short audio files: {short_audio}") # All Dataset audio are longer than 1 second
# print(f"Long audio files: {long_audio}") 

# Overview
# - We found 17 audio files that are longer than 10 seconds which considered as outliers
# - Fortunately, there are no audio which are shorter than 1 second

In [15]:
# Removing Outliers
for line_index in long_audio['line_index']:
    wav_path = os.path.join(dataset_path, 'wavs', f'{line_index}.wav')
    os.remove(wav_path)

# Remove outliers from the dataset
transcriptions = transcriptions[~transcriptions['line_index'].isin(long_audio['line_index'])]

# Check dataset size
print(f'Number of samples after removing outliers: {len(transcriptions)}')

Number of samples after removing outliers: 2889


## 2. Data Processing

## 2.1 Text Normalization

In [16]:
def normalize_text(text):
    text = text.lower()
    text = re.sub(r"[^ក-៹ ]", "", text)
    text = re.sub(r"\s+", " ", text).strip()
    return text

transcriptions['normalized_transcription'] = transcriptions['transcription'].apply(normalize_text)

# Save the normalized transcriptions
# transcriptions.to_csv("dataset/processed_transcriptions.csv", index=False)

### 2.2 Audio Processing

In [17]:
def preprocess_audio(audio_path, target_sample_rate=16000):
    audio, sr = librosa.load(audio_path, sr=None)
    # Resample 
    if sr != target_sample_rate:
        audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sample_rate)
    
    # Normalize
    audio = librosa.util.normalize(audio)

    # Trim silence
    audio, _ = librosa.effects.trim(audio)

    return audio, target_sample_rate

# Apply preprocessing to audio files
for line_index in transcriptions['line_index']:
    audio_path = os.path.join(dataset_path, 'wavs', f'{line_index}.wav')
    audio, sr = preprocess_audio(audio_path)
    sf.write(audio_path, audio, sr)

### 2.3 Feature Extraction

In [18]:
def extract_mel_spectogram(audio, sr, n_mels=80, n_fft=2048, hop_length=512):
    mel_spec = librosa.feature.melspectrogram(
        y=audio,
        sr=sr,
        n_mels=n_mels,
        n_fft=n_fft,
        hop_length=hop_length
    )
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)
    return mel_spec_db

mel_spec = extract_mel_spectogram(audio, sr)

# Visualize Spectogram
# plt.figure(figsize=(10, 4))
# librosa.display.specshow(mel_spec, sr=sr, hop_length=512, x_axis='time', y_axis='mel')
# plt.colorbar(format='%+2.0f dB')
# plt.title('Mel Spectrogram')
# plt.tight_layout()
# plt.show()

In [19]:
# Extract Mel Spectrogram for all audio files
output_dir = "../mel_spectrograms"
os.makedirs(output_dir, exist_ok=True)

# Process and save mel spectrograms for all audio files
for idx, row in tqdm(transcriptions.iterrows(), total=len(transcriptions)):
    audio_path = os.path.join(dataset_path, 'wavs', f'{row["line_index"]}.wav')

    try:
        # Load audio file
        audio, sr = librosa.load(audio_path, sr=None)

        # Extract Mel Spectrogram
        mel_spec = extract_mel_spectogram(audio, sr)

        # Save the mel spectrogram as .npy file
        mel_path = mel_path = os.path.join(output_dir, f"{row['line_index']}_mel.npy")
        np.save(mel_path, mel_spec)

        # Update dataset with mel spectrogram path
        transcriptions.loc[idx, 'mel_path'] = mel_path
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")

100%|██████████| 2889/2889 [00:13<00:00, 213.58it/s]


### 2.4 Tokenize Transcriptions

In [20]:
# Create a dictionary from the dataset
from collections import Counter

vocab = Counter(''.join(transcriptions['normalized_transcription']))
char_to_index = {char: idx for idx, char in enumerate(vocab.keys())}

# Tokenize transcriptions
def tokenize_text(text, char_to_index):
    return [char_to_index[char] for char in text if char in char_to_index]

transcriptions['tokenized_transcription'] = transcriptions['normalized_transcription'].apply(lambda x: tokenize_text(x, char_to_index))

# print(transcriptions.head())

# Create assets directory if it doesn't exist
os.makedirs("../assets", exist_ok=True)

# Save the char_to_index mapping
with open("../assets/char_to_index.pkl", "wb") as f:
    pickle.dump(char_to_index, f)

In [21]:
# # Save the tokenized transcriptions
transcriptions.to_csv("../dataset/processed_transcriptions.csv", index=False)

# # Save the spectograms
# with open("../assets/mel_spectrograms.pkl", "wb") as f:
#     pickle.dump(mel_spec, f)

## Check out the training notebook