In [None]:
import os


os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from unsloth import FastModel
from borealis.modeling import BorealisForConditionalGeneration
from borealis.dataset import BorealisBaseDataset
import torch
from torch.utils.data import DataLoader
from transformers import WhisperFeatureExtractor, Qwen2ForCausalLM
from tqdm.auto import tqdm
from datasets import load_dataset, Audio
from jiwer import wer, cer
import string

In [None]:
language_model, tokenizer = FastModel.from_pretrained(
    model_name="Qwen/Qwen2.5-0.5B-Instruct",
    dtype=None,
    auto_model=Qwen2ForCausalLM,
    full_finetuning=True,
)

start_audio_token = "<|start_of_audio|>"
end_audio_token = "<|end_of_audio|>"

tokenizer.add_special_tokens(
    {"additional_special_tokens": [start_audio_token, end_audio_token]}
)

In [None]:
ds = load_dataset("Vikhrmodels/ToneSpeak")
ds = ds.cast_column("audio", Audio(decode=True, sampling_rate=16_000))

whisper_encoder = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v3")

In [None]:
model = BorealisForConditionalGeneration(
    language_model=language_model, tokenizer=tokenizer
)

In [None]:
model.load_state_dict(
    torch.load("/workspace/Borealis/asr_qwen_ckpts/checkpoint-163525/pytorch_model.bin")
)

In [None]:
model = model.to("cuda:0")

In [None]:
eval_dataset = BorealisBaseDataset(
    audio_processor=whisper_encoder,
    text_tokenizer=tokenizer,
    hf_ds=ds["validation"].select(range(79)),
    max_text_len=320,
)

In [None]:
generated_transcripts = []
ground_truth_texts = []

with torch.inference_mode():
    for batch in tqdm(eval_dataset, desc="Inference on eval set"):
        mel = batch["mel"].to(model.encoder.device)  # (B, 128, 3000)
        att_mask = batch["audio_att_mask"].to(model.encoder.device)  # (B, 3000)

        transcripts = model.generate(
            mel=mel,
            att_mask=att_mask,
            max_new_tokens=320,
            do_sample=True,
            top_p=0.9,
            top_k=50,
        )
        print(transcripts)

        gt_texts = tokenizer.decode(batch["labels"], skip_special_tokens=True)

        print(gt_texts)

        generated_transcripts.extend(transcripts)
        ground_truth_texts.extend(gt_texts)


for i in range(min(5, len(generated_transcripts))):
    print(f"Сгенерировано: {generated_transcripts[i]}")
    print(f"Правильный текст: {ground_truth_texts[i]}")
    print("-" * 80)

In [None]:
import torch
import librosa
import numpy as np


def test_single_wav(
    wav_path: str,
    model: BorealisForConditionalGeneration,
    audio_processor: WhisperFeatureExtractor,
    max_seconds_len: float = 30.0,
    sampling_rate: int = 16000,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
    generation_params: dict = None,
):
    """
    Загружает один WAV файл, обрабатывает его и генерирует транскрипт с помощью модели.

    :param wav_path: Путь к WAV файлу.
    :param model: Инстанс WhisperQWenASRModel.
    :param audio_processor: WhisperFeatureExtractor для обработки аудио.
    :param max_seconds_len: Максимальная длина аудио в секундах (для паддинга).
    :param sampling_rate: Частота дискретизации (по умолчанию 16000).
    :param device: Устройство ('cuda' или 'cpu').
    :param generation_params: Словарь с параметрами для model.generate (например, {'temperature': 0.0, 'max_new_tokens': 320}).
    """

    model.eval()
    model.to(device)

    waveform, sr = librosa.load(wav_path, sr=None)  # sr=None to load native sample rate
    if sr != sampling_rate:
        # Resample to the desired sampling rate
        waveform = librosa.resample(waveform, orig_sr=sr, target_sr=sampling_rate)

    if waveform.ndim > 1:
        waveform = np.mean(waveform, axis=0)

    proc = audio_processor(
        waveform,
        sampling_rate=sampling_rate,
        padding="max_length",
        max_length=int(max_seconds_len * sampling_rate),
        return_attention_mask=True,
        return_tensors="pt",
    )
    mel = proc.input_features.squeeze(0).to(device)  # (80, 3000)
    att_mask = proc.attention_mask.squeeze(0).to(device)  # (3000)

    if generation_params is None:
        generation_params = {
            "max_new_tokens": 320,
            "do_sample": True,
            "top_p": 0.9,
            "top_k": 50,
        }

    with torch.inference_mode():
        transcript = model.generate(mel=mel, att_mask=att_mask, **generation_params)

    print(f"Generated transcript for {wav_path}:")
    print(transcript)

In [None]:
test_single_wav(
    wav_path="/workspace/res.wav",
    model=model,
    audio_processor=whisper_encoder,
    generation_params={
        "top_p": 0.9,
        "do_sample": True,
        "max_new_tokens": 320,
        "temperature": 0.95,
    },
)

In [None]:
test_ds = load_dataset("Vikhrmodels/ToneRuLS")

In [None]:
test_ds

In [None]:
eval_dataset = BorealisBaseDataset(
    audio_processor=whisper_encoder,
    text_tokenizer=tokenizer,
    hf_ds=test_ds["validation"].select(range(79)),
    max_text_len=320,
)

In [None]:
next(iter(eval_dataset))

In [None]:
generated_transcripts = []
ground_truth_texts = []

with torch.inference_mode():
    for batch in tqdm(eval_dataset, desc="Inference on eval set"):
        mel = batch["mel"].to(model.encoder.device)  # (B, 128, 3000)
        att_mask = batch["audio_att_mask"].to(model.encoder.device)  # (B, 3000)

        transcripts = model.generate(
            mel=mel,
            att_mask=att_mask,
            max_new_tokens=320,
            do_sample=True,
            top_p=0.9,
            top_k=50,
        )

        gt_texts = tokenizer.decode(batch["labels"], skip_special_tokens=True)

        generated_transcripts.append(transcripts)
        ground_truth_texts.append(gt_texts)


In [None]:
def extract_assistant_content(text: str) -> str:
    if "assistant\n" in text:
        return text.split("assistant\n")[-1].strip()
    return text.strip()

In [None]:


def clean_text_list(text_list):
    # Создаем множество символов пунктуации
    punct = set(string.punctuation)

    # Обрабатываем каждый текст в списке
    cleaned_list = [
        "".join(char for char in text.lower() if char not in punct)
        for text in text_list
    ]

    return cleaned_list

In [None]:
ground_truth_texts = [extract_assistant_content(text) for text in ground_truth_texts]

In [None]:
wer_score = wer(
    clean_text_list(ground_truth_texts), clean_text_list(generated_transcripts)
)
cer_score = cer(
    clean_text_list(ground_truth_texts), clean_text_list(generated_transcripts)
)

In [None]:
wer_score

In [None]:
cer_score

In [1]:
from datasets import load_dataset

In [2]:
bench_set = load_dataset("Vikhrmodels/RuASRBenchmark", num_proc=1)

Loading dataset shards:   0%|          | 0/73 [00:00<?, ?it/s]

In [3]:
bench_set

DatasetDict({
    Russian_LibriSpeech: Dataset({
        features: ['audio', 'text'],
        num_rows: 1352
    })
    Common_Voice_Corpus_22.0: Dataset({
        features: ['audio', 'text'],
        num_rows: 10244
    })
    Tone_Webinars: Dataset({
        features: ['audio', 'text'],
        num_rows: 21587
    })
    Tone_Books: Dataset({
        features: ['audio', 'text'],
        num_rows: 4930
    })
    Tone_Speak: Dataset({
        features: ['audio', 'text'],
        num_rows: 700
    })
    Sova_RuDevices: Dataset({
        features: ['audio', 'text'],
        num_rows: 5799
    })
})

In [None]:
def clean_text_list(text_list):
    punct = set(string.punctuation)

    cleaned_list = [
        "".join(char for char in text.lower() if char not in punct)
        for text in text_list
    ]

    return cleaned_list

def extract_assistant_content(text: str) -> str:
    if "assistant\n" in text:
        return text.split("assistant\n")[-1].strip()
    return text.strip()

In [None]:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"

from tqdm.auto import tqdm
import torch

# Настройка DataLoader для батчей
batch_size = 64  # Укажите размер батча
for split in bench_set:
    print(f"Подсчёт метрик на сплите {split}")

    # Создаем DataLoader с батчами
    test_loader = DataLoader(
        BorealisBaseDataset(
            audio_processor=whisper_encoder,
            text_tokenizer=tokenizer,
            hf_ds=bench_set[split],
            max_text_len=320,
        ),
        batch_size=batch_size,  # Устанавливаем размер батча
        shuffle=False,  # Не перемешиваем данные для теста
        num_workers=16,  # Параллельная загрузка данных (настройте по вашему оборудованию)
        pin_memory=True,  # Ускоряет передачу данных на GPU
    )

    ground_truth_texts = []
    generated_transcripts = []

    with torch.inference_mode():
        for batch in tqdm(test_loader, desc=f"Processing split {split}"):
            # Перемещаем батч на устройство модели
            mel = batch["mel"].to(model.encoder.device)  # (B, 128, 3000)
            att_mask = batch["audio_att_mask"].to(model.encoder.device)  # (B, 3000)

            # Генерация транскриптов для всего батча
            transcripts = model.generate(
                mel=mel,
                att_mask=att_mask,
                max_new_tokens=320,
                do_sample=True,
                top_p=0.9,
                top_k=50,
            )

            # Декодируем ground truth тексты для всего батча
            gt_texts = tokenizer.batch_decode(batch["labels"], skip_special_tokens=True)

            # Добавляем тексты в списки
            ground_truth_texts.extend(
                [extract_assistant_content(text) for text in gt_texts]
            )
            generated_transcripts.extend(transcripts)

    # Вычисляем метрики
    wer_score = wer(
        clean_text_list(ground_truth_texts), clean_text_list(generated_transcripts)
    )
    cer_score = cer(
        clean_text_list(ground_truth_texts), clean_text_list(generated_transcripts)
    )

    print(f"WER score for split {split}: {wer_score:.4f}")
    print(f"CER score for split {split}: {cer_score:.4f}")

In [None]:
!ffmpeg -version

In [None]:
import random


def test_random_sample(dataset):
    if len(dataset) == 0:
        print("Датасет пустой!")
        return

    index = random.randint(0, len(dataset) - 1)
    print(f"Тестируем сэмпл с индексом: {index}")

    try:
        sample = dataset[index]
        print("Сэмпл успешно загружен!")
        print("Ключи в сэмпле:", list(sample.keys()))

        # Печать форм (для отладки)
        print("Форма mel:", sample["mel"].shape)
        print("Форма audio_att_mask:", sample["audio_att_mask"].shape)
        print("Форма labels:", sample["labels"].shape)
        print("Форма text_att_mask:", sample["text_att_mask"].shape)

        # Опционально: декодировать текст для проверки
        decoded_text = dataset.tokenizer.decode(
            sample["labels"], skip_special_tokens=True
        )
        print("Декодированный текст:", decoded_text)

    except Exception as e:
        print(f"Ошибка при загрузке сэмпла {index}: {e}")
        # Если ошибка в конкретном файле, можно добавить больше отладки
        if "path" in dataset.audios[index]:
            print("Путь к аудио:", dataset.audios[index]["path"])


# Запуск теста
test_random_sample(eval_dataset)