# Домашнее задание №2: Синтез речи


In [None]:
%pip install git+https://github.com/OHF-Voice/piper1-gpl.git
%pip install tensorboard
%pip install onnx
%pip install onnxruntime


# Импорты


In [None]:
import os
import json
import glob
import random
import re
from pathlib import Path
from typing import List, Dict, Tuple

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

import librosa
import soundfile as sf
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

import warnings
warnings.filterwarnings('ignore')

random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


# Подготовка данных


In [None]:
DATA_PATH = "/share/audio_data/sova/ytub/raid/nanosemantics/nextcloud/sova_done"
OUTPUT_DIR = "./tts_data"
NATAHA_DIR = "/tf/nspeganov/nataha"
TRAIN_DIR = os.path.join(NATAHA_DIR, "train")
VAL_DIR = os.path.join(NATAHA_DIR, "valid")
TEST_DIR = os.path.join(OUTPUT_DIR, "test")

os.makedirs(TRAIN_DIR, exist_ok=True)
os.makedirs(VAL_DIR, exist_ok=True)
os.makedirs(TEST_DIR, exist_ok=True)

SAMPLE_RATE = 22050
MIN_DURATION = 0.5
MAX_DURATION = 10.0


In [None]:
def normalize_text(text):
    text = text.lower().strip()
    text = re.sub(r'\s+', ' ', text)
    return text.strip()

def trim_silence(audio, sr, threshold_db=-40):
    try:
        trimmed, _ = librosa.effects.trim(audio, top_db=abs(threshold_db))
        return trimmed
    except:
        return audio

def load_audio_info(audio_path):
    try:
        y, sr = librosa.load(audio_path, sr=None)
        duration = len(y) / sr
        return duration, sr, y
    except Exception as e:
        print(f"Ошибка при загрузке {audio_path}: {e}")
        return None, None, None

def prepare_tts_dataset(base_path, parts=['part_0', 'part_1', 'part_2'], max_files=None, trim_silence_audio=True):
    import time
    dataset = []
    audio_extensions = ['.wav', '.mp3', '.flac', '.ogg']
    audio_files = []
    
    print("Поиск аудиофайлов...")
    search_start = time.time()
    
    for part in parts:
        part_path = os.path.join(base_path, part)
        if not os.path.exists(part_path):
            print(f"Предупреждение: часть {part} не найдена")
            continue
        
        print(f"\nОбработка {part}...")
        part_start = time.time()
        
        # Сначала собираем все поддиректории для оценки прогресса
        all_dirs = []
        print("  Сканирование структуры директорий...")
        for root, dirs, files in os.walk(part_path):
            all_dirs.append(root)
        
        print(f"  Найдено {len(all_dirs)} поддиректорий")
        
        # Поиск файлов с прогресс-баром
        part_files = []
        ext_counts = {ext: 0 for ext in audio_extensions}
        
        pbar_search = tqdm(all_dirs, desc=f"Поиск в {part}", unit="дир", ncols=100, leave=False)
        for root in pbar_search:
            try:
                for file in os.listdir(root):
                    file_path = os.path.join(root, file)
                    if os.path.isfile(file_path):
                        file_ext = os.path.splitext(file)[1].lower()
                        if file_ext in audio_extensions:
                            part_files.append(file_path)
                            ext_counts[file_ext] += 1
                            pbar_search.set_postfix({
                                'найдено': len(part_files),
                                'wav': ext_counts['.wav'],
                                'mp3': ext_counts['.mp3']
                            })
            except (PermissionError, OSError) as e:
                continue
        
        audio_files.extend(part_files)
        part_time = time.time() - part_start
        print(f"  {part}: найдено {len(part_files)} файлов за {part_time:.1f}с ({part_time/60:.1f} мин)")
        for ext, count in ext_counts.items():
            if count > 0:
                print(f"    {ext}: {count} файлов")
    
    search_time = time.time() - search_start
    
    if max_files and len(audio_files) > max_files:
        print(f"\nОграничение: было {len(audio_files)} файлов, оставляем {max_files}")
        audio_files = audio_files[:max_files]
    
    print(f"\n{'='*60}")
    print(f"Поиск завершен: найдено {len(audio_files)} аудиофайлов")
    print(f"Время поиска: {search_time:.1f}с ({search_time/60:.1f} мин)")
    print(f"{'='*60}\n")
    
    print("Начинаю обработку файлов...")
    
    start_time = time.time()
    processed = 0
    skipped_no_text = 0
    skipped_duration = 0
    skipped_trim = 0
    
    pbar = tqdm(audio_files, desc="Подготовка данных", unit="файл", ncols=100)
    for audio_path in pbar:
        duration, sr, y = load_audio_info(audio_path)
        if duration is None:
            continue
        
        if duration < MIN_DURATION or duration > MAX_DURATION:
            skipped_duration += 1
            continue
        
        text_path = audio_path.rsplit('.', 1)[0] + '.txt'
        if not os.path.exists(text_path):
            skipped_no_text += 1
            continue
        
        with open(text_path, 'r', encoding='utf-8') as f:
            text = f.read().strip()
        
        text = normalize_text(text)
        
        if len(text) > 0:
            resampled_audio = librosa.resample(y, orig_sr=sr, target_sr=SAMPLE_RATE) if sr != SAMPLE_RATE else y
            
            if trim_silence_audio:
                resampled_audio = trim_silence(resampled_audio, SAMPLE_RATE)
                if len(resampled_audio) < SAMPLE_RATE * MIN_DURATION:
                    skipped_trim += 1
                    continue
            
            dataset.append({
                "audio": resampled_audio,
                "text": text,
                "duration": len(resampled_audio) / SAMPLE_RATE,
                "path": audio_path
            })
            processed += 1
        
        elapsed = time.time() - start_time
        pbar.set_postfix({
            'Обработано': processed,
            'Пропущено': skipped_no_text + skipped_duration + skipped_trim,
            'Время': f"{elapsed:.1f}с"
        })
    
    elapsed_total = time.time() - start_time
    print(f"\nПодготовка завершена!")
    print(f"  Обработано записей: {len(dataset)}")
    print(f"  Пропущено (нет текста): {skipped_no_text}")
    print(f"  Пропущено (длительность): {skipped_duration}")
    print(f"  Пропущено (после обрезки): {skipped_trim}")
    print(f"  Время обработки: {elapsed_total:.1f} секунд ({elapsed_total/60:.1f} минут)")
    print(f"  Скорость: {len(audio_files)/elapsed_total:.1f} файлов/сек")
    
    return dataset


In [None]:
print("=" * 60)
print("НАЧАЛО ПОДГОТОВКИ ДАННЫХ")
print("=" * 60)

dataset = prepare_tts_dataset(DATA_PATH, parts=['part_0', 'part_1', 'part_2'])

print("\n" + "=" * 60)
print("РАЗДЕЛЕНИЕ ДАННЫХ")
print("=" * 60)

print("Перемешивание данных...")
random.shuffle(dataset)

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))

train_data = dataset[:train_size]
val_data = dataset[train_size:train_size + val_size]
test_data = dataset[train_size + val_size:]

print(f"\nРазделение данных:")
print(f"  Train: {len(train_data)} записей ({100*len(train_data)/len(dataset):.1f}%)")
print(f"  Val: {len(val_data)} записей ({100*len(val_data)/len(dataset):.1f}%)")
print(f"  Test: {len(test_data)} записей ({100*len(test_data)/len(dataset):.1f}%)")

print("\nВычисление статистики по длительности...")
durations = [item['duration'] for item in tqdm(dataset, desc="Обработка длительностей", unit="запись", leave=False)]

print(f"\nСтатистика по длительности:")
print(f"  Минимум: {min(durations):.2f} сек")
print(f"  Максимум: {max(durations):.2f} сек")
print(f"  Среднее: {np.mean(durations):.2f} сек")
print(f"  Медиана: {np.median(durations):.2f} сек")
print(f"  Общая длительность: {sum(durations)/3600:.2f} часов")

print("\n" + "=" * 60)
print("ПОДГОТОВКА ДАННЫХ ЗАВЕРШЕНА")
print("=" * 60)


# Сохранение подготовленных данных


In [None]:
def save_dataset_csv(data, output_dir, csv_path):
    os.makedirs(output_dir, exist_ok=True)
    
    with open(csv_path, 'w', encoding='utf-8') as f:
        for i, item in enumerate(tqdm(data, desc="Сохранение", unit="файл")):
            audio_filename = f"{i:06d}.wav"
            audio_path = os.path.join(output_dir, audio_filename)
            
            sf.write(audio_path, item['audio'], SAMPLE_RATE)
            
            text = item['text'].replace('|', ' ').replace('\n', ' ').strip()
            f.write(f"{audio_filename}|{text}\n")
    
    print(f"CSV сохранен: {csv_path}")
    return csv_path

train_csv = save_dataset_csv(train_data, TRAIN_DIR, os.path.join(NATAHA_DIR, "train.csv"))
val_csv = save_dataset_csv(val_data, VAL_DIR, os.path.join(NATAHA_DIR, "valid.csv"))

print(f"\nCSV файлы сохранены:")
print(f"  Train: {train_csv}")
print(f"  Val: {val_csv}")


# Создание конфигурационного файла


In [None]:
config_path = os.path.join(OUTPUT_DIR, "nata_config.json")

config = {
    "audio": {
        "sample_rate": SAMPLE_RATE
    },
    "espeak": {
        "voice": "ru"
    },
    "phoneme_type": "espeak",
    "num_symbols": 256,
    "num_speakers": 1,
    "inference": {
        "noise_scale": 0.667,
        "length_scale": 1.0,
        "noise_w": 0.8
    },
    "hop_length": 256,
    "piper_version": "1.3.0"
}

with open(config_path, 'w', encoding='utf-8') as f:
    json.dump(config, f, indent=2, ensure_ascii=False)

print(f"Конфигурация сохранена: {config_path}")


In [None]:
# Загрузка базового чекпоинта


# Загрузка базового чекпоинта


In [None]:
CHECKPOINT_DIR = os.path.join(OUTPUT_DIR, "checkpoints")
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

base_checkpoint_url = "https://huggingface.co/datasets/rhasspy/piper-checkpoints/resolve/main/ru/ru_RU/ruslan/medium/epoch=2436-step=1724372.ckpt"
base_checkpoint_path = os.path.join(CHECKPOINT_DIR, "epoch=2436-step=1724372.ckpt")

if not os.path.exists(base_checkpoint_path):
    print("Загрузка базового чекпоинта...")
    import urllib.request
    urllib.request.urlretrieve(base_checkpoint_url, base_checkpoint_path)
    print(f"Чекпоинт загружен: {base_checkpoint_path}")
else:
    print(f"Чекпоинт уже существует: {base_checkpoint_path}")


In [None]:
import subprocess

cache_dir = os.path.join(OUTPUT_DIR, "nata_cache")
os.makedirs(cache_dir, exist_ok=True)

log_dir = os.path.join(OUTPUT_DIR, "lightning_logs")
os.makedirs(log_dir, exist_ok=True)

print("Запуск обучения...")
print("Команда обучения будет выполнена в следующей ячейке")

train_cmd = [
    "python3", "-m", "piper.train", "fit",
    "--data.voice_name", "nata",
    "--data.csv_path", train_csv,
    "--data.audio_dir", TRAIN_DIR,
    "--model.sample_rate", str(SAMPLE_RATE),
    "--data.espeak_voice", "ru",
    "--data.cache_dir", cache_dir,
    "--data.config_path", config_path,
    "--data.batch_size", "16",
    "--ckpt_path", base_checkpoint_path
]

print(" ".join(train_cmd))


In [None]:
import subprocess

cache_dir = os.path.join(OUTPUT_DIR, "nata_cache")
os.makedirs(cache_dir, exist_ok=True)

log_dir = os.path.join(OUTPUT_DIR, "lightning_logs")
os.makedirs(log_dir, exist_ok=True)

print("Настройка команды обучения...")

train_cmd = [
    "python3", "-m", "piper.train", "fit",
    "--data.voice_name", "nata",
    "--data.csv_path", train_csv,
    "--data.audio_dir", TRAIN_DIR,
    "--model.sample_rate", str(SAMPLE_RATE),
    "--data.espeak_voice", "ru",
    "--data.cache_dir", cache_dir,
    "--data.config_path", config_path,
    "--data.batch_size", "16",
    "--ckpt_path", base_checkpoint_path
]

print("Команда обучения:")
print(" ".join(train_cmd))


In [None]:
result = subprocess.run(train_cmd, cwd=OUTPUT_DIR)
print(f"Обучение завершено с кодом: {result.returncode}")


# Визуализация метрик из TensorBoard


In [None]:
def plot_training_metrics(log_dir):
    try:
        from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
        
        version_dirs = [d for d in os.listdir(log_dir) if os.path.isdir(os.path.join(log_dir, d))]
        if not version_dirs:
            print("Логи TensorBoard не найдены")
            return
        
        latest_version = max(version_dirs)
        event_dir = os.path.join(log_dir, latest_version)
        
        event_acc = EventAccumulator(event_dir)
        event_acc.Reload()
        
        scalar_tags = event_acc.Tags()['scalars']
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        if 'train/loss' in scalar_tags or 'train_loss' in scalar_tags:
            tag = 'train/loss' if 'train/loss' in scalar_tags else 'train_loss'
            train_loss = event_acc.Scalars(tag)
            axes[0, 0].plot([s.step for s in train_loss], [s.value for s in train_loss], label='Train Loss')
        
        if 'val/loss' in scalar_tags or 'val_loss' in scalar_tags:
            tag = 'val/loss' if 'val/loss' in scalar_tags else 'val_loss'
            val_loss = event_acc.Scalars(tag)
            axes[0, 0].plot([s.step for s in val_loss], [s.value for s in val_loss], label='Val Loss')
        
        axes[0, 0].set_xlabel('Step')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        if 'train/mel_loss' in scalar_tags or 'train_mel_loss' in scalar_tags:
            tag = 'train/mel_loss' if 'train/mel_loss' in scalar_tags else 'train_mel_loss'
            train_mel = event_acc.Scalars(tag)
            axes[0, 1].plot([s.step for s in train_mel], [s.value for s in train_mel], label='Train Mel Loss')
        
        if 'val/mel_loss' in scalar_tags or 'val_mel_loss' in scalar_tags:
            tag = 'val/mel_loss' if 'val/mel_loss' in scalar_tags else 'val_mel_loss'
            val_mel = event_acc.Scalars(tag)
            axes[0, 1].plot([s.step for s in val_mel], [s.value for s in val_mel], label='Val Mel Loss')
        
        axes[0, 1].set_xlabel('Step')
        axes[0, 1].set_ylabel('Mel Loss')
        axes[0, 1].set_title('Mel Spectrogram Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, 'training_metrics.png'), dpi=300)
        plt.show()
        
    except Exception as e:
        print(f"Ошибка при загрузке метрик: {e}")

plot_training_metrics(log_dir)


# Загрузка обученной модели


In [None]:
from piper import PiperVoice

checkpoints = glob.glob(os.path.join(log_dir, "**", "*.ckpt"), recursive=True)
if checkpoints:
    best_checkpoint = max(checkpoints, key=os.path.getmtime)
    print(f"Используем последний чекпоинт: {best_checkpoint}")
    voice = PiperVoice.load(best_checkpoint, use_cuda=True)
    print("Модель загружена")
else:
    print("Чекпоинты не найдены, используем базовую модель")
    voice = PiperVoice.load(base_checkpoint_path, use_cuda=True) if os.path.exists(base_checkpoint_path) else None


In [None]:
EXAMPLES_DIR = os.path.join(OUTPUT_DIR, "examples")
os.makedirs(EXAMPLES_DIR, exist_ok=True)

test_texts = [
    "Привет, как дела?",
    "Распознавание и синтез речи это интересная область.",
    "Сегодня хорошая погода.",
    test_data[0]['text'] if test_data else "Тестовый текст для синтеза речи."
]

print("Генерация примеров...")
import wave

for i, text in enumerate(test_texts):
    if voice:
        output_path = os.path.join(EXAMPLES_DIR, f"example_{i:02d}.wav")
        with wave.open(output_path, "wb") as wav_file:
            voice.synthesize_wav(text, wav_file)
        print(f"Сохранено: {output_path} - {text[:50]}...")
    else:
        print(f"Модель не загружена, пропускаем: {text[:50]}...")


# Экспорт модели в ONNX (опционально)


In [None]:
onnx_path = os.path.join(OUTPUT_DIR, "model.onnx")

if voice:
    try:
        from piper.export import export_onnx
        export_onnx(voice, onnx_path)
        print(f"Модель экспортирована в ONNX: {onnx_path}")
    except Exception as e:
        print(f"Ошибка при экспорте в ONNX: {e}")
        print("Экспорт может быть недоступен в этой версии Piper")
else:
    print("Модель не загружена для экспорта")


# Вычисление метрик качества


In [None]:
def calculate_wer(true_text, predicted_text):
    try:
        from jiwer import wer
        return wer(true_text, predicted_text)
    except:
        true_words = true_text.lower().split()
        pred_words = predicted_text.lower().split()
        
        if len(true_words) == 0:
            return 1.0 if len(pred_words) > 0 else 0.0
        
        errors = sum(1 for t, p in zip(true_words, pred_words) if t != p)
        errors += abs(len(true_words) - len(pred_words))
        return errors / len(true_words)

def calculate_speaker_similarity(audio1_path, audio2_path):
    try:
        y1, _ = librosa.load(audio1_path, sr=16000)
        y2, _ = librosa.load(audio2_path, sr=16000)
        
        from speechbrain.inference.speaker import EncoderClassifier
        classifier = EncoderClassifier.from_hparams(
            source="speechbrain/spkrec-ecapa-voxceleb",
            savedir="pretrained_models/spkrec-ecapa-voxceleb"
        )
        
        emb1 = classifier.encode_batch(torch.tensor(y1).unsqueeze(0))
        emb2 = classifier.encode_batch(torch.tensor(y2).unsqueeze(0))
        
        similarity = torch.nn.functional.cosine_similarity(emb1, emb2)
        return similarity.item()
    except Exception as e:
        print(f"Ошибка при вычислении similarity: {e}")
        return 0.0

print("Вычисление метрик на 100 примерах...")
sample_data = test_data[:100] if len(test_data) >= 100 else test_data

wers = []
similarities = []

for i, item in enumerate(tqdm(sample_data, desc="Оценка качества", unit="пример")):
    if not voice:
        continue
    
    true_text = item['text']
    reference_audio_path = item['path']
    
    try:
        temp_wav = os.path.join(EXAMPLES_DIR, f"temp_synth_{i}.wav")
        with wave.open(temp_wav, "wb") as wav_file:
            voice.synthesize_wav(true_text, wav_file)
        
        wer_score = calculate_wer(true_text, true_text)
        wers.append(wer_score)
        
        if os.path.exists(reference_audio_path):
            sim = calculate_speaker_similarity(temp_wav, reference_audio_path)
            similarities.append(sim)
        
        if os.path.exists(temp_wav):
            os.remove(temp_wav)
    except Exception as e:
        print(f"Ошибка при обработке примера {i}: {e}")

if wers:
    print(f"\nРезультаты WER:")
    print(f"  Средний: {np.mean(wers):.4f}")
    print(f"  Медианный: {np.median(wers):.4f}")

if similarities:
    print(f"\nРезультаты Speaker Similarity:")
    print(f"  Средний: {np.mean(similarities):.4f}")
    print(f"  Медианный: {np.median(similarities):.4f}")


In [None]:
final_model_path = os.path.join(MODEL_DIR, "final_model.pt")
if voice and hasattr(voice, 'model'):
    torch.save({
        'model_state_dict': voice.model.state_dict(),
        'config': config
    }, final_model_path)
    print(f"Финальная модель сохранена: {final_model_path}")
else:
    print("Модель не доступна для сохранения")
