In [None]:
# 🎙️ VITS Azerbaijani Text-to-Speech

This notebook provides a complete workflow for training and using a VITS model for Azerbaijani text-to-speech synthesis, including voice cloning capabilities.

## 📋 Features
- Single-speaker TTS training
- Zero-shot voice cloning
- Audio preprocessing & normalization
- Checkpoint management
- Interactive Gradio demo

## 🗺️ Notebook Structure
1. Environment Setup
2. Dataset Preparation
3. Audio Preprocessing
4. Model Training
5. Inference & Voice Cloning
6. Web Demo

> 💡 This notebook works both in Google Colab and locally. Colab-specific cells are marked with a [COLAB] tag.


In [None]:
## 1. Environment Setup

First, let's set up our environment with all necessary dependencies.


In [None]:
# [COLAB] System packages
!apt-get update -y && apt-get install -y espeak ffmpeg


In [None]:
# Install Python packages
!pip install -q torch torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q numpy scipy librosa unidecode tensorboard phonemizer webdataset gradio tqdm pydub


In [None]:
# [COLAB] Keep-alive function to prevent disconnects
from IPython.display import display, Javascript

def keep_alive():
    display(Javascript('''
        function ClickConnect(){
            console.log("Clicking connect button...");
            document.querySelector("colab-connect-button").click()
        }
        setInterval(ClickConnect, 60000)
    '''))

# Uncomment next line if running in Colab:
# keep_alive()


In [None]:
## 2. Dataset Preparation

The VITS model requires:
1. WAV audio files (22050 Hz, mono)
2. Text transcriptions in Azerbaijani
3. Filelists mapping audio to text


In [None]:
# Dataset upload helper
from google.colab import files
import os

def upload_and_extract_dataset():
    """Upload and extract a dataset ZIP file to the datasets/raw directory."""
    print("Please upload your dataset ZIP file...")
    uploaded = files.upload()
    
    for filename in uploaded.keys():
        if filename.endswith('.zip'):
            print(f"Extracting {filename} to datasets/raw/...")
            os.makedirs('datasets/raw', exist_ok=True)
            !unzip -o "{filename}" -d datasets/raw/
            print("Dataset extracted! Contents:")
            !ls -la datasets/raw/
        else:
            print(f"Skipping {filename} - not a ZIP file")

# Uncomment to upload dataset:
# upload_and_extract_dataset()


In [None]:
## 3. Audio Preprocessing

Before training, we'll normalize the audio files to ensure consistent quality:
- Remove DC offset
- Normalize levels
- Resample to 22050 Hz
- Convert to mono


In [None]:
import os
import glob
import librosa
import soundfile as sf
import numpy as np
from tqdm.notebook import tqdm
import multiprocessing

def process_audio_file(file_path, target_sr=22050, target_level=-23.0, output_dir=None):
    """Process a single audio file with normalization and resampling."""
    try:
        # Set output path
        if output_dir:
            os.makedirs(output_dir, exist_ok=True)
            filename = os.path.basename(file_path)
            output_path = os.path.join(output_dir, filename)
        else:
            output_path = file_path
            
        # Load and process audio
        y, sr = librosa.load(file_path, sr=None, mono=True)
        
        # Resample if needed
        if sr != target_sr:
            y = librosa.resample(y, orig_sr=sr, target_sr=target_sr)
        
        # Remove DC offset
        y = y - np.mean(y)
        
        # Normalize level
        rms = np.sqrt(np.mean(y**2))
        target_rms = 10**(target_level/20)
        gain = target_rms / (rms + 1e-8)
        y_normalized = y * gain
        
        # Prevent clipping
        max_val = np.max(np.abs(y_normalized))
        if max_val > 0.99:
            y_normalized = y_normalized / max_val * 0.99
        
        # Save processed audio
        sf.write(output_path, y_normalized, target_sr)
        return True
        
    except Exception as e:
        print(f"Error processing {file_path}: {e}")
        return False

def normalize_dataset(dataset_dir, output_dir=None):
    """Normalize all WAV files in a directory using multiprocessing."""
    wav_files = glob.glob(os.path.join(dataset_dir, "**", "*.wav"), recursive=True)
    print(f"Found {len(wav_files)} WAV files")
    
    if not wav_files:
        print("No WAV files found!")
        return
    
    # Process files in parallel
    with multiprocessing.Pool(processes=os.cpu_count()) as pool:
        args = [(f, 22050, -23.0, output_dir) for f in wav_files]
        results = list(tqdm(pool.starmap(process_audio_file, args), total=len(args)))
    
    success_count = results.count(True)
    print(f"Successfully processed {success_count} of {len(wav_files)} files")

# Uncomment to normalize dataset:
# normalize_dataset('datasets/raw', output_dir='datasets/normalized')


In [None]:
### Generate Filelists

Create train/validation splits with the format:
```
path/to/audio.wav|Azerbaijani text
```


In [None]:
# Generate train/val splits
!python data/tools/prepare_filelist.py \
    --wavs datasets/raw \
    --output data/filelists \
    --val-ratio 0.05


In [None]:
## 4. Model Training

We'll train the VITS model with periodic checkpointing and monitoring.


In [None]:
# Training monitor
import time
import os
import glob

def monitor_training(interval=60):
    """Monitor training progress and checkpoint saving."""
    checkpoint_dir = 'checkpoints'
    
    try:
        while True:
            checkpoint_files = glob.glob(f"{checkpoint_dir}/*.pt")
            
            print(f"\n=== Training Status: {time.strftime('%Y-%m-%d %H:%M:%S')} ===")
            print(f"Found {len(checkpoint_files)} checkpoints")
            
            if checkpoint_files:
                checkpoint_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
                print("\nMost recent checkpoints:")
                for i, ckpt in enumerate(checkpoint_files[:3]):
                    mod_time = time.strftime('%Y-%m-%d %H:%M:%S', 
                                        time.localtime(os.path.getmtime(ckpt)))
                    size_mb = os.path.getsize(ckpt) / (1024 * 1024)
                    print(f"{i+1}. {os.path.basename(ckpt)} - {size_mb:.2f} MB - {mod_time}")
            
            print(f"\nNext check in {interval} seconds...")
            time.sleep(interval)
            
    except KeyboardInterrupt:
        print("\nMonitoring stopped")

# Start training
!python train.py \
    --config config/base_vits.json \
    --batch_size 16 \
    --epochs 1000 \
    --checkpoint_dir checkpoints \
    --log_dir logs \
    --save_every 10 \
    --keep_last 3

# Uncomment to monitor training:
# monitor_training(interval=60)


In [None]:
## 5. Inference & Voice Cloning


In [None]:
import torch
import IPython.display as ipd
from model.vits import VITSInference

# Load model (uses best.pt or latest checkpoint)
checkpoint_dir = 'checkpoints'
checkpoint_files = glob.glob(f"{checkpoint_dir}/*.pt")

if not checkpoint_files:
    print(f"No checkpoints found in {checkpoint_dir}!")
else:
    # Prefer best.pt, fallback to latest
    best_model = os.path.join(checkpoint_dir, 'best.pt')
    if os.path.exists(best_model):
        checkpoint_path = best_model
    else:
        checkpoint_path = max(checkpoint_files, key=os.path.getmtime)
    
    print(f"Using checkpoint: {os.path.basename(checkpoint_path)}")
    
    # Initialize model
    tts = VITSInference(
        checkpoint=checkpoint_path,
        config='configs/base_vits.json')
    
    # Basic synthesis
    text = "Salam dünya! Bu VITS nümunəsidir."
    audio = tts.synthesize(text)
    ipd.display(ipd.Audio(audio, rate=22050))


In [None]:
# Voice cloning demo
def clone_voice(reference_wav=None):
    """Demonstrate voice cloning with a reference audio file."""
    if reference_wav is None:
        # Upload reference voice
        if 'google.colab' in globals():
            print("Upload a reference voice file (.wav):")
            uploaded = files.upload()
            if uploaded:
                reference_wav = list(uploaded.keys())[0]
        else:
            # Use example file
            reference_wav = 'datasets/raw/02.wav'
    
    if reference_wav and os.path.exists(reference_wav):
        # Original reference audio
        y, sr = librosa.load(reference_wav, sr=22050)
        print("Reference voice:")
        ipd.display(ipd.Audio(y, rate=sr))
        
        # Cloned speech
        text = "Mənim səsimlə danışan süni zəka!"
        print(f"\nCloned voice saying: {text}")
        cloned = tts.synthesize(text, speaker_ref=reference_wav)
        ipd.display(ipd.Audio(cloned, rate=22050))

# Try voice cloning
if 'tts' in locals():
    clone_voice()


In [None]:
## 6. Web Demo

Launch an interactive Gradio demo for testing the model.


In [None]:
# Launch Gradio demo
if 'google.colab' in globals():
    !python app.py --share  # Public URL
else:
    !python app.py  # Local URL
