In [1]:
import re
from types import MethodType

import torch
import torch.nn.functional as F
from evaluate import load
from transformers import WhisperForConditionalGeneration, WhisperProcessor
from undecorated import undecorated
import matplotlib.pyplot as plt

from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
from dasr.models.facebook_denoiser import get_pretrained_model


def normalize_text(text: str):
    for char in [".", ",", "!", "?", "(", ")"]:
        text = text.replace(char, " ")
    text = text.replace("ё", "е")
    text = re.sub(" +", " ", text)
    text = re.sub(r"[^\w\s]", "", text)
    text = text.lower().strip()
    return text

import os
from hydra import initialize, compose
from hydra.utils import instantiate

os.environ["HYDRA_FULL_ERROR"] = "1"
os.environ["NUMBA_CACHE_DIR"] = "/tmp/"

with initialize(version_base=None, config_path="configs"):
    cfg = compose(config_name='config.yaml')
    
cfg.data.batch_size = 4

	Error importing 'hydra_plugins.hydra_colorlog'.
	Plugin is incompatible with this Hydra version or buggy.
	Recommended to uninstall or upgrade plugin.
		ImportError : cannot import name 'SearchPathPlugin' from 'hydra.plugins' (/usr/local/lib/python3.10/dist-packages/hydra/plugins/__init__.py)


In [2]:
device = "cuda"

# path_model = "jonatasgrosman/wav2vec2-xls-r-1b-russian"
path_model = "jonatasgrosman/wav2vec2-large-xlsr-53-russian"
processor = Wav2Vec2Processor.from_pretrained(path_model)
model = Wav2Vec2ForCTC.from_pretrained(path_model)
model.to(device)

denoiser = get_pretrained_model("dns64").to(device)
;

''

In [3]:
train_loader, test_loader = instantiate(cfg.data)

Start loading datasets
Train dataset loaded
Test dataset loaded


In [4]:
batch = next(iter(test_loader))
batch

{'clean_audios': tensor([[-1.8190e-12, -5.4570e-12, -2.9104e-11,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-6.1062e-16, -1.9429e-16,  1.3878e-16,  ..., -3.6380e-12,
          -7.2760e-12,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-8.1855e-12,  5.4570e-12, -1.0914e-11,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]),
 'noise_audios': tensor([[ 0.0260,  0.0118,  0.0027,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0295,  0.0043,  0.0102,  ..., -0.0046, -0.0149,  0.0048],
         [ 0.0143,  0.0060, -0.0099,  ...,  0.0000,  0.0000,  0.0000],
         [-0.0166, -0.0262, -0.0253,  ...,  0.0000,  0.0000,  0.0000]]),
 'clean_attention_masks': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.]]),
 'noise_attention_masks': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
   

In [5]:
all_predicted_sentences = []
all_target_sentences = []

for batch in test_loader:

    speech = batch["noise_audios"].to(device)

    denoisy_speech = denoiser(speech).squeeze(1)

    inputs = processor(denoisy_speech, sampling_rate=16_000, return_tensors="pt", padding=True, do_normalize=True)
    attention_mask = batch["noise_attention_masks"]

    with torch.no_grad():
        logits = model(inputs.input_values.squeeze(0).to(device), attention_mask=attention_mask.to(device)).logits
        # logits = model(inputs.input_values.squeeze(0).to(device)).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    predicted_sentences = processor.batch_decode(predicted_ids)

    all_predicted_sentences.extend(predicted_sentences) 
    all_target_sentences.extend(batch["transcriptions"])

In [6]:
from torchmetrics.text import WordErrorRate, CharErrorRate

wer = WordErrorRate()
cer = CharErrorRate()

In [7]:
all_predicted_sentences = [normalize_text(sentence) for sentence in all_predicted_sentences]
all_target_sentences = [normalize_text(sentence) for sentence in all_target_sentences]

wer(preds=all_predicted_sentences, target=all_target_sentences), cer(preds=all_predicted_sentences, target=all_target_sentences)

(tensor(0.5798), tensor(0.2800))