In [24]:
import re
import random
from pathlib import Path
from dataclasses import dataclass

import torch
import torch.nn as nn
import torchvision
from jiwer import RemoveWhiteSpace, ReduceToListOfListOfChars, Compose, wer, cer
from datasets import Dataset, Audio
from torchvision.datasets import ImageFolder
from transformers import ( Wav2Vec2CTCTokenizer, 
                           Wav2Vec2FeatureExtractor, 
                           Wav2Vec2Processor, 
                           Wav2Vec2ForCTC
                        )

from typing import Generator, Any

# Сравнение результатов моделей

In [2]:
non_alphanum_chars_regexp = re.compile(r"[^\w\s]", flags=re.IGNORECASE)

def remove_special_characters(text: str) -> str:
    return re.sub(non_alphanum_chars_regexp, '', text).upper()

In [3]:
def samples_generator(lang: str) -> Generator[dict[str, Any], None, None]:
    base_path = Path(f"./data/test_audio/{lang}/")
    file_path_patttern = f"*.wav"
    for path in base_path.glob(file_path_patttern):
        file_path_template = f"{path.parent.name}/{path.stem}.{{ext}}"

        audio_path = f"./data/test_audio/{file_path_template.format(ext='wav')}"
        text_path = f"./data/test_transcription/{file_path_template.format(ext='txt')}"

        with open(text_path, "r") as f:
            transcription = f.read()
        transcription = remove_special_characters(transcription)

        yield { "audio": audio_path, "transcription": transcription }

def load_dataset(lang: str) -> Dataset:
    return Dataset.from_generator(samples_generator, gen_kwargs={"lang": lang}).cast_column("audio", Audio(sampling_rate=16000))

In [4]:
def create_processor(tokenizer_type: str) -> Wav2Vec2Processor:
    if tokenizer_type == "en":
        from transformers import AutoProcessor
        return AutoProcessor.from_pretrained("facebook/wav2vec2-base")

    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(f"./models/{tokenizer_type}/", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token=" ")
    feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
    return Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [5]:
def tokenize_dataset(dataset: Dataset, processor: Wav2Vec2Processor) -> Dataset:
    
    def prepare_dataset(batch):
        audio = batch["audio"]
        batch = processor(audio["array"], sampling_rate=audio["sampling_rate"], text=batch["transcription"])
        batch["input_length"] = len(batch["input_values"][0])
        return batch
    
    return dataset.map(prepare_dataset, remove_columns=dataset.column_names)

In [6]:
def create_dataset_and_processor(lang: str) -> (Dataset, Wav2Vec2Processor):
    dataset = load_dataset(lang)
    processor = create_processor(lang)
    dataset = tokenize_dataset(dataset, processor)
    return dataset, processor

In [7]:
def load_model(model_state_path: str, processor: Wav2Vec2Processor, model_type: str = "facebook/wav2vec2-base") -> Wav2Vec2ForCTC:
    model = Wav2Vec2ForCTC.from_pretrained(
        model_type,
        pad_token_id=processor.tokenizer.pad_token_id, 
        vocab_size=len(processor.tokenizer)
    )
    model.load_state_dict(torch.load(model_state_path))
    return model

In [8]:
def load_datasets_processors_models(model_paths: dict[str, str]) -> (dict[str, Dataset], dict[str, Wav2Vec2Processor], dict[str, Wav2Vec2ForCTC]):
    datasets, processors, models = {}, {}, {}
    for lang in ["en", "ru", "de"]:
        datasets[lang], processor = create_dataset_and_processor(lang)
        processors[lang] = processor
        models[lang] = load_model(model_paths[lang], processor)
    return datasets, processors, models

In [9]:
class CNNModel(nn.Module):

    def __init__(self):
        super(CNNModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding = 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
        self.fc1 = nn.Linear(18432, 1000)
        self.fc2 = nn.Linear(1000, 3)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [10]:
lid_model_path = "./models/lid/lid_model_state.pth"

model_paths = {
    "en": "./models/en/en_model_state.pth",
    "ru": "./models/ru/ru_model_state.pth",
    "de": "./models/de/de_model_state.pth"
}

multilingual_model_path = "./models/multilingual/multilingual_model_state.pth"

In [None]:
lid_dataset = ImageFolder(root="./data/test_spectrogram", transform=torchvision.transforms.ToTensor())
lid_model = CNNModel()
lid_model.load_state_dict(torch.load(lid_model_path))

datasets, processors, models = load_datasets_processors_models(model_paths)

multilingual_processor = create_processor("multilingual")
multilingual_model = load_model(multilingual_model_path, multilingual_processor, "facebook/wav2vec2-xls-r-300m")

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}.")

lid_model = lid_model.to(device)
for key, model in models.items():
    models[key] = model.to(device)
multilingual_model = multilingual_model.to(device)

Device: cuda.


In [13]:
@dataclass
class EvaluationPipelineResult:
    predicted_class: str
    true_class: str
    single_model_result: str
    multi_model_result: str
    reference_result: str

In [29]:
def get_index_by_lid_index(image_index: int, true_class: str) -> int:
    if true_class in ["en", "ru"]:
        image_index -= len(datasets["de"])
    if true_class == "ru":
        image_index -= len(datasets["en"])
    return image_index

def predict(model: Wav2Vec2ForCTC, processor: Wav2Vec2Processor, entry) -> (torch.Tensor, str):
    input_dict = processor(entry["input_values"], sampling_rate=16000, return_tensors="pt", padding=True)
    input_dict = {k: v.to(device) for k, v in input_dict.items()}

    logits = model(**input_dict).logits

    pred_ids = torch.argmax(logits, dim=-1)
    return processor.batch_decode(pred_ids)[0]

def evaluate_predict_pipeline(lid_sample_idx: int) -> EvaluationPipelineResult:
    image, true_class_idx = lid_dataset[lid_sample_idx]

    image = image[None,:].to(device)
    lid_logits = lid_model(image)
    predicted_class_idx = torch.argmax(lid_logits, dim=-1)

    true_class = lid_dataset.classes[true_class_idx]
    predicted_class = lid_dataset.classes[predicted_class_idx]

    model = models[predicted_class]
    processor = processors[predicted_class]
    true_processor = processors[true_class]
    dataset = datasets[true_class]

    index = get_index_by_lid_index(lid_sample_idx, true_class)
    entry = dataset[index]

    single_language_prediction = predict(model, processor, entry)
    multi_language_prediction = predict(multilingual_model, multilingual_processor, entry)

    reference = true_processor.decode(entry["labels"])

    return EvaluationPipelineResult(predicted_class, true_class, single_language_prediction, multi_language_prediction, reference)

def print_prediction():
    with torch.no_grad():
        rand_int = random.randint(0, len(lid_dataset) - 1)
        result = evaluate_predict_pipeline(rand_int)
        
        print(f"Predicted class: {result.predicted_class}, true class: {result.true_class}.")
        print(f"Prediction (single language): {result.single_model_result}")
        print(f"Prediction (multilingual): {result.multi_model_result}")
        print(f"Reference: {result.reference_result}")

def print_metrics():
    with torch.no_grad():
        single_wer = 0
        single_cer = 0
        single_nospace_cer = 0
        multi_wer = 0
        multi_cer = 0
        multi_nospace_cer = 0

        transform = Compose([
            RemoveWhiteSpace(),
            ReduceToListOfListOfChars()
        ])
        samples_count = len(lid_dataset)
        
        for lid_index in range(samples_count):
            result = evaluate_predict_pipeline(lid_index)
            ref = result.reference_result
            single_res = result.single_model_result
            multi_res = result.multi_model_result

            single_wer += wer(ref, single_res)
            single_cer += cer(ref, single_res)
            single_nospace_cer += cer(ref, single_res, transform, transform)

            multi_wer += wer(ref, multi_res)
            multi_cer += cer(ref, multi_res)
            multi_nospace_cer += cer(ref, multi_res, transform, transform)

        single_wer /= samples_count
        single_cer /= samples_count
        single_nospace_cer /= samples_count
        multi_wer /= samples_count
        multi_cer /= samples_count
        multi_nospace_cer /= samples_count

        print(f"Single language: wer = {single_wer:.4f}, cer = {single_cer:.4f}, cer (ignoring spaces) = {single_nospace_cer:.4f}")
        print(f"Multilingual: wer = {multi_wer:.4f}, cer = {multi_cer:.4f}, cer (ignoring spaces) = {multi_nospace_cer:.4f}")


In [30]:
print_metrics()

Single language: wer = 0.8780, cer = 0.5066, cer (ignoring spaces) = 0.5285
Multilingual: wer = 1.0967, cer = 0.9592, cer (ignoring spaces) = 1.0677


In [16]:
for i in range(5):
    print_prediction()
    print()

Predicted class: en, true class: en.
Prediction (single language): CALIN ABOUT  REINT TRANACTIFOM MY CARD
Prediction (multilingual): olihebo reez fant twinboc onl my card
Reference: TEL ME ABOUT RECENT TRANSACTIONS ON MY CARD

Predicted class: de, true class: en.
Prediction (single language): UTZ MUSTE MANE MUNE ERKENUGE DRAGEN WON TZIEN INND DUNGIN TEN TAEGE DEL FR MINNITIE MÖISCEN
Prediction (multilingual): hso mafsomot o oney ihont withcrawt wontime aand douringd tir hayfre o aeat  mv
Reference: WHATS THE MOST AMOUNT OF MONEY I CAN WITHDRAW AT ONE TIME AND DURING AN ENTIRE DAY FROM AN ATM MACHINE

Predicted class: ru, true class: ru.
Prediction (single language): КАК НЕТ КРИТВОВ МЕСТНЫЙ СЧЁТ  СПАРАТНЕРОМ
Prediction (multilingual): као мнедкритомесный счёт спртняро
Reference: КАК МНЕ ОТКРЫТЬ СОВМЕСТНЫЙ СЧЁТ С ПАРТНЕРОМ

Predicted class: de, true class: ru.
Prediction (single language): BETHUMIER GRTE BERTETALE A WUTEI EFEMLEH KERDTOU MEINER
Prediction (multilingual): поче у м карто п