# Домашнее задание №1: Распознавание речи


In [None]:
%pip install nvidia-ml-py
%pip install --upgrade jupyter ipywidgets


# Импорты


In [None]:
import os
import json
import glob
import random
from pathlib import Path
from typing import List, Dict, Tuple
from collections import Counter, defaultdict
import re

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

# Для работы с аудио
import librosa
import soundfile as sf

# NeMo
import nemo.collections.asr as nemo_asr
from nemo.core.config import hydra_runner
from nemo.utils import logging

# Для токенизации
from sentencepiece import SentencePieceTrainer, SentencePieceProcessor

import warnings
warnings.filterwarnings('ignore')

# Настройка для воспроизводимости
random.seed(42)
np.random.seed(42)


# Подготовка данных


In [None]:
# Пути к датасетам
SOVA_PATH = "/share/audio_data/sova/ytub/raid/nanosemantics/nextcloud/sova_done"
SBER_GOLOS_PATH = "/share/audio_data/sber-golos/tar/train"

# Выходные директории
OUTPUT_DIR = "./asr_data"
MANIFEST_DIR = os.path.join(OUTPUT_DIR, "manifests")
TOKENIZER_DIR = os.path.join(OUTPUT_DIR, "tokenizer")

os.makedirs(MANIFEST_DIR, exist_ok=True)
os.makedirs(TOKENIZER_DIR, exist_ok=True)


In [None]:
def load_audio_info(audio_path):
    try:
        y, sr = librosa.load(audio_path, sr=None)
        duration = len(y) / sr
        return duration, sr
    except Exception as e:
        print(f"Ошибка при загрузке {audio_path}: {e}")
        return None, None

def normalize_text(text):
    text = text.lower().strip()
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[^а-яёa-z0-9\s]', '', text)
    return text.strip()

def process_sova_dataset(base_path, max_files=None):
    manifest = []
    audio_extensions = ['*.wav', '*.mp3', '*.flac', '*.ogg']
    audio_files = []
    for ext in audio_extensions:
        audio_files.extend(glob.glob(os.path.join(base_path, '**', ext), recursive=True))
    
    if max_files:
        audio_files = audio_files[:max_files]
    
    print(f"Найдено {len(audio_files)} аудиофайлов в СОВА")
    print("Начинаю обработку...")
    
    for audio_path in tqdm(audio_files, desc="Обработка СОВА", unit="файл"):
        duration, sr = load_audio_info(audio_path)
        if duration is None:
            continue
        
        text_path = audio_path.rsplit('.', 1)[0] + '.txt'
        if not os.path.exists(text_path):
            continue
        
        with open(text_path, 'r', encoding='utf-8') as f:
            text = f.read().strip()
        
        text = normalize_text(text)
        
        if len(text) > 0 and duration > 0.5:
            manifest.append({
                "audio_filepath": audio_path,
                "duration": duration,
                "text": text
            })
    
    print(f"Обработано {len(manifest)} записей из {len(audio_files)} файлов")
    return manifest

def process_sber_golos_dataset(base_path, max_files=None):
    manifest = []
    manifest_path = os.path.join(base_path, 'manifest.json')
    
    if os.path.exists(manifest_path):
        with open(manifest_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            
        print(f"Найдено {len(data)} записей в манифесте СберГолос")
        print("Начинаю обработку...")
        
        for item in tqdm(data, desc="Обработка СберГолос", unit="запись"):
            audio_path = item.get('audio_filepath') or item.get('audio')
            if not audio_path:
                continue
            
            if not os.path.isabs(audio_path):
                audio_path = os.path.join(base_path, audio_path)
            
            duration, sr = load_audio_info(audio_path)
            if duration is None:
                continue
            
            text = item.get('text') or item.get('transcript', '')
            text = normalize_text(text)
            
            if len(text) > 0 and duration > 0.5:
                manifest.append({
                    "audio_filepath": audio_path,
                    "duration": duration,
                    "text": text
                })
        
        print(f"Обработано {len(manifest)} записей из {len(data)} в манифесте")
    else:
        audio_extensions = ['*.wav', '*.mp3', '*.flac', '*.ogg']
        audio_files = []
        for ext in audio_extensions:
            audio_files.extend(glob.glob(os.path.join(base_path, '**', ext), recursive=True))
        
        if max_files:
            audio_files = audio_files[:max_files]
        
        print(f"Найдено {len(audio_files)} аудиофайлов в СберГолос")
        print("Начинаю обработку...")
        
        for audio_path in tqdm(audio_files, desc="Обработка СберГолос", unit="файл"):
            duration, sr = load_audio_info(audio_path)
            if duration is None:
                continue
            
            text_path = audio_path.rsplit('.', 1)[0] + '.txt'
            if os.path.exists(text_path):
                with open(text_path, 'r', encoding='utf-8') as f:
                    text = f.read().strip()
                text = normalize_text(text)
                
                if len(text) > 0 and duration > 0.5:
                    manifest.append({
                        "audio_filepath": audio_path,
                        "duration": duration,
                        "text": text
                    })
        
        print(f"Обработано {len(manifest)} записей из {len(audio_files)} файлов")
    
    return manifest


In [None]:
print("Загрузка данных из СОВА...")
sova_manifest = process_sova_dataset(SOVA_PATH)

print("\nЗагрузка данных из СберГолос...")
sber_manifest = process_sber_golos_dataset(SBER_GOLOS_PATH)

all_manifest = sova_manifest + sber_manifest

print(f"\nВсего записей: {len(all_manifest)}")
print(f"Из СОВА: {len(sova_manifest)}")
print(f"Из СберГолос: {len(sber_manifest)}")

durations = [item['duration'] for item in all_manifest]
print(f"\nСтатистика по длительности:")
print(f"  Минимум: {min(durations):.2f} сек")
print(f"  Максимум: {max(durations):.2f} сек")
print(f"  Среднее: {np.mean(durations):.2f} сек")
print(f"  Медиана: {np.median(durations):.2f} сек")
print(f"  Общая длительность: {sum(durations)/3600:.2f} часов")


In [None]:
random.shuffle(all_manifest)

train_size = int(0.8 * len(all_manifest))
val_size = int(0.1 * len(all_manifest))

train_manifest = all_manifest[:train_size]
val_manifest = all_manifest[train_size:train_size + val_size]
test_manifest = all_manifest[train_size + val_size:]

print(f"Разделение данных:")
print(f"  Train: {len(train_manifest)} записей")
print(f"  Val: {len(val_manifest)} записей")
print(f"  Test: {len(test_manifest)} записей")

def save_manifest(manifest, path):
    with open(path, 'w', encoding='utf-8') as f:
        for item in manifest:
            f.write(json.dumps(item, ensure_ascii=False) + '\n')

save_manifest(train_manifest, os.path.join(MANIFEST_DIR, 'train_manifest.json'))
save_manifest(val_manifest, os.path.join(MANIFEST_DIR, 'val_manifest.json'))
save_manifest(test_manifest, os.path.join(MANIFEST_DIR, 'test_manifest.json'))

print("\nМанифесты сохранены!")


# Обучение токенизатора


In [None]:
all_texts = [item['text'] for item in train_manifest]

text_file = os.path.join(TOKENIZER_DIR, 'train_texts.txt')
with open(text_file, 'w', encoding='utf-8') as f:
    for text in all_texts:
        f.write(text + '\n')

print(f"Собрано {len(all_texts)} текстов для обучения токенизатора")
print(f"Общая длина текста: {sum(len(t) for t in all_texts)} символов")


In [None]:
VOCAB_SIZE = 369

tokenizer_model = os.path.join(TOKENIZER_DIR, 'tokenizer.model')

SentencePieceTrainer.train(
    input=text_file,
    model_prefix=tokenizer_model.replace('.model', ''),
    vocab_size=VOCAB_SIZE,
    model_type='bpe',
    character_coverage=0.9995,
    max_sentencepiece_length=16,
    split_by_whitespace=True,
    byte_fallback=True,
    normalization_rule_name='nmt_nfkc_cf',
    user_defined_symbols=['<pad>', '<unk>', '<s>', '</s>']
)

print(f"Токенизатор обучен! Размер словаря: {VOCAB_SIZE}")


In [None]:
tokenizer = SentencePieceProcessor()
tokenizer.load(tokenizer_model)

print(f"Размер словаря токенизатора: {tokenizer.get_piece_size()}")

test_texts = [
    "привет как дела",
    "распознавание речи это интересно",
    all_texts[0] if all_texts else "тестовый текст"
]

for text in test_texts:
    tokens = tokenizer.encode(text, out_type=str)
    decoded = tokenizer.decode(tokenizer.encode(text))
    print(f"\nТекст: {text}")
    print(f"Токены ({len(tokens)}): {tokens[:10]}..." if len(tokens) > 10 else f"Токены ({len(tokens)}): {tokens}")
    print(f"Декодировано: {decoded}")


# Обучение модели


In [None]:
try:
    asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(
        model_name="QuartzNet15x5Base-En"
    )
    print("Загружена модель QuartzNet15x5Base-En")
except Exception as e:
    print(f"Не удалось загрузить предобученную модель: {e}")


In [None]:
asr_model._cfg.train_ds.manifest_filepath = os.path.join(MANIFEST_DIR, 'train_manifest.json')
asr_model._cfg.validation_ds.manifest_filepath = os.path.join(MANIFEST_DIR, 'val_manifest.json')

asr_model._cfg.train_ds.sample_rate = 16000
asr_model._cfg.validation_ds.sample_rate = 16000
asr_model._cfg.train_ds.batch_size = 16
asr_model._cfg.validation_ds.batch_size = 16

print("Конфигурация модели обновлена")
print(f"Train manifest: {asr_model._cfg.train_ds.manifest_filepath}")
print(f"Val manifest: {asr_model._cfg.validation_ds.manifest_filepath}")


# Настройка Trainer


In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

checkpoint_dir = os.path.join(OUTPUT_DIR, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)

logger = TensorBoardLogger(
    save_dir=OUTPUT_DIR,
    name='asr_training'
)

checkpoint_callback = ModelCheckpoint(
    dirpath=checkpoint_dir,
    filename='asr-{epoch:02d}-{val_loss:.2f}',
    monitor='val_loss',
    mode='min',
    save_top_k=3,
    save_last=True
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    mode='min',
    patience=5,
    verbose=True
)

trainer = Trainer(
    max_epochs=50,
    accelerator='gpu',
    devices=1,
    logger=logger,
    callbacks=[checkpoint_callback, early_stopping],
    log_every_n_steps=10,
    val_check_interval=0.5,
    gradient_clip_val=1.0,
    accumulate_grad_batches=1
)

print("Trainer настроен")


# Запуск обучения


In [None]:
trainer.fit(asr_model)

asr_model.save_to("./asr_model.nemo")


# Анализ результатов


In [None]:
asr_model = nemo_asr.models.EncDecCTCModel.restore_from(
    restore_path="./asr_model.nemo"
)


# Визуализация метрик


In [None]:
def plot_training_metrics(log_dir):
    try:
        from tensorboard.backend.event_processing.event_accumulator import EventAccumulator
        
        event_acc = EventAccumulator(log_dir)
        event_acc.Reload()
        
        scalar_tags = event_acc.Tags()['scalars']
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        if 'train_loss' in scalar_tags:
            train_loss = event_acc.Scalars('train_loss')
            axes[0, 0].plot([s.step for s in train_loss], [s.value for s in train_loss], label='Train Loss')
        
        if 'val_loss' in scalar_tags:
            val_loss = event_acc.Scalars('val_loss')
            axes[0, 0].plot([s.step for s in val_loss], [s.value for s in val_loss], label='Val Loss')
        
        axes[0, 0].set_xlabel('Step')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Training and Validation Loss')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        if 'val_wer' in scalar_tags:
            val_wer = event_acc.Scalars('val_wer')
            axes[0, 1].plot([s.step for s in val_wer], [s.value for s in val_wer], label='Val WER')
            axes[0, 1].set_xlabel('Step')
            axes[0, 1].set_ylabel('WER')
            axes[0, 1].set_title('Word Error Rate')
            axes[0, 1].legend()
            axes[0, 1].grid(True)
        
        plt.tight_layout()
        plt.savefig(os.path.join(OUTPUT_DIR, 'training_metrics.png'), dpi=300)
        plt.show()
        
    except Exception as e:
        print(f"Ошибка при загрузке метрик: {e}")

plot_training_metrics(logger.log_dir)


# Тестирование модели


In [None]:
def test_model(model, test_manifest, num_samples=10):
    samples = random.sample(test_manifest, min(num_samples, len(test_manifest)))
    results = []
    
    for sample in samples:
        audio_path = sample['audio_filepath']
        true_text = sample['text']
        
        try:
            predicted_text = model.transcribe([audio_path])[0]
            
            results.append({
                'audio': audio_path,
                'true': true_text,
                'predicted': predicted_text
            })
            
            print(f"\nАудио: {os.path.basename(audio_path)}")
            print(f"Истинный текст: {true_text}")
            print(f"Распознанный текст: {predicted_text}")
            
        except Exception as e:
            print(f"Ошибка при обработке {audio_path}: {e}")
    
    return results

test_results = test_model(asr_model, test_manifest, num_samples=20)


# Вычисление метрик качества


In [None]:
def calculate_wer(true_text, predicted_text):
    try:
        from jiwer import wer
        return wer(true_text, predicted_text)
    except:
        true_words = true_text.lower().split()
        pred_words = predicted_text.lower().split()
        
        if len(true_words) == 0:
            return 1.0 if len(pred_words) > 0 else 0.0
        
        errors = sum(1 for t, p in zip(true_words, pred_words) if t != p)
        errors += abs(len(true_words) - len(pred_words))
        return errors / len(true_words)

def calculate_metrics(model, test_manifest):
    wers = []
    
    print(f"Вычисление метрик на {len(test_manifest)} примерах...")
    
    for i, sample in enumerate(test_manifest):
        if i % 100 == 0:
            print(f"Обработано {i}/{len(test_manifest)}")
        
        audio_path = sample['audio_filepath']
        true_text = sample['text']
        
        try:
            predicted_text = model.transcribe([audio_path])[0]
            wer_score = calculate_wer(true_text, predicted_text)
            wers.append(wer_score)
        except Exception as e:
            print(f"Ошибка при обработке {audio_path}: {e}")
    
    avg_wer = np.mean(wers) if wers else 0
    
    print(f"\nРезультаты:")
    print(f"  Средний WER: {avg_wer:.4f}")
    print(f"  Медианный WER: {np.median(wers):.4f}")
    print(f"  Минимальный WER: {np.min(wers):.4f}")
    print(f"  Максимальный WER: {np.max(wers):.4f}")
    
    return {
        'avg_wer': avg_wer,
        'median_wer': np.median(wers),
        'min_wer': np.min(wers),
        'max_wer': np.max(wers),
        'all_wers': wers
    }

metrics = calculate_metrics(asr_model, test_manifest)
