In [1]:
import os
import re
import json
import random
import numpy as np
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn as nn
import torchvision
import evaluate
from datasets import Dataset, Audio
from torchvision.datasets import ImageFolder
from transformers import ( Wav2Vec2CTCTokenizer, 
                           Wav2Vec2FeatureExtractor, 
                           Wav2Vec2Processor, 
                           AutoModelForCTC, 
                           TrainingArguments, 
                           Trainer )

from typing import Iterable, Generator, Any

In [2]:
current_dir = os.getcwd()

def check_exists_path(path):
    if not os.path.exists(path):
        os.makedirs(f"{current_dir}/{path}")

# Модели Speech Recognition для одного языка
## Общие функции
### Загрузка датасетов

In [3]:
non_alphanum_chars_regexp = re.compile(r"[^\w\s]", flags=re.IGNORECASE)

def remove_special_characters(text: str) -> str:
    return re.sub(non_alphanum_chars_regexp, '', text).upper()

In [4]:
def samples_generator(split_name: str, lang: str) -> Generator[dict[str, Any], None, None]:
    base_path = Path(f"./data/{split_name}_audio/{lang}/")
    file_path_patttern = f"*.wav"
    for path in base_path.glob(file_path_patttern):
        file_path_template = f"{path.parent.name}/{path.stem}.{{ext}}"

        audio_path = f"./data/{split_name}_audio/{file_path_template.format(ext='wav')}"
        text_path = f"./data/{split_name}_transcription/{file_path_template.format(ext='txt')}"

        with open(text_path, "r") as f:
            transcription = f.read()
        transcription = remove_special_characters(transcription)

        yield { "audio": audio_path, "transcription": transcription }

def load_dataset(split_name: str, lang: str) -> Dataset:
    return Dataset.from_generator(samples_generator, gen_kwargs={"split_name": split_name, "lang": lang}).cast_column("audio", Audio(sampling_rate=16000))

### Предобработка данных

In [None]:
def chars_generator(datasets: Iterable[Dataset]) -> Generator[str, None, None]:
    for dataset in datasets:
      for text in dataset["transcription"]:
         yield from text

def create_vocabulary(datasets: Iterable[Dataset], lang: str):
   chars = set(chars_generator(datasets))
   vocab = {v: k for k, v in enumerate(sorted(chars))}

   vocab["[UNK]"] = len(vocab)
   vocab["[PAD]"] = len(vocab)

   vocab_path = f"./models/{lang}"
   check_exists_path(vocab_path)
   with open(f'{vocab_path}/vocab.json', 'w') as vocab_file:
       json.dump(vocab, vocab_file)

In [None]:
def create_processor(lang: str) -> Wav2Vec2Processor:
    tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(f"./models/{lang}/", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token=" ")
    feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True)
    return Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
def tokenize_dataset(dataset: Dataset, processor: Wav2Vec2Processor) -> Dataset:
    
    def prepare_dataset(batch):
        audio = batch["audio"]
        batch = processor(audio["array"], sampling_rate=audio["sampling_rate"], text=batch["transcription"])
        batch["input_length"] = len(batch["input_values"][0])
        return batch
    
    return dataset.map(prepare_dataset, remove_columns=dataset.column_names)

In [None]:
@dataclass
class DataCollatorCTCWithPadding:
    processor: AutoProcessor
    padding: bool | str = "longest"

    def __call__(self, features: list[dict[str, list[int] | torch.Tensor]]) -> dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"][0]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(input_features, padding=self.padding, return_tensors="pt")

        labels_batch = self.processor.pad(labels=label_features, padding=self.padding, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

### Пайплайн обучения

In [None]:
wer = evaluate.load("wer")

def get_compute_metrics(processor: Wav2Vec2Processor):
    def compute_metrics(pred):
        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)

        pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

        pred_str = processor.batch_decode(pred_ids)
        label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

        wer_value = wer.compute(predictions=pred_str, references=label_str)

        return {"wer": wer_value}
    return compute_metrics

In [None]:
def train(lang: str, num_epochs: int, checkpoint_dir: str | None = None) -> AutoModelForCTC:
    train_dataset = load_dataset("train", lang)
    test_dataset = load_dataset("test", lang)
    validation_dataset = load_dataset("valid", lang)

    create_vocabulary([train_dataset, test_dataset, validation_dataset])

    processor = create_processor(lang)

    train_dataset = tokenize_dataset(train_dataset, processor)
    validation_dataset = tokenize_dataset(validation_dataset, processor)

    data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

    compute_metrics = get_compute_metrics(processor)

    model = AutoModelForCTC.from_pretrained(
        "facebook/wav2vec2-base",
        ctc_loss_reduction="mean",
        pad_token_id=processor.tokenizer.pad_token_id,
    )
    
    training_args = TrainingArguments(
        output_dir=f"./models/{lang}/",
        remove_unused_columns=False,
        per_device_train_batch_size=1,
        gradient_accumulation_steps=4,
        num_train_epochs=num_epochs,
        learning_rate=1e-5,
        warmup_steps=500,
        gradient_checkpointing=True,
        fp16=True,
        group_by_length=True,
        evaluation_strategy="steps",
        per_device_eval_batch_size=8,
        save_steps=1000,
        eval_steps=200,
        logging_steps=200,
        load_best_model_at_end=True,
        metric_for_best_model="wer",
        greater_is_better=False
    )
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=validation_dataset,
        tokenizer=processor,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    if checkpoint_dir:
        trainer.train(checkpoint_dir)
    else:
        trainer.train()

    return processor, model

## Дообучение моделей

In [8]:
en_processor, en_model = train(lang="en", num_epochs=90)

In [9]:
ru_processor, ru_model = train(lang="ru", num_epochs=90)

In [10]:
de_processor, de_model = train(lang="ru", num_epochs=90)

## Проверка результатов

In [None]:
from transformers import AutoProcessor

en_processor = AutoProcessor.from_pretrained("facebook/wav2vec2-base")
en_model = AutoModelForCTC.from_pretrained("./models/en/checkpoint-10000")

In [None]:
class CNNModel(nn.Module):

    def __init__(self):
        super(CNNModel, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding = 2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        self.relu = nn.ReLU()
        self.softmax = nn.Softmax()
        self.fc1 = nn.Linear(18432, 1000)
        self.fc2 = nn.Linear(1000, 3)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = x.reshape(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [None]:
def load_image_dataset() -> ImageFolder:
    data_path = "./data/test_spectrogram"
    return ImageFolder(root=data_path, transform=torchvision.transforms.ToTensor())

In [None]:
def load_test_datasets() -> dict[str, Dataset]:
    en_dataset = load_dataset("test", "en")
    en_dataset = tokenize_dataset(en_dataset, en_processor)
    ru_dataset = load_dataset("test", "ru")
    ru_dataset = tokenize_dataset(ru_dataset, ru_processor)
    de_dataset = load_dataset("test", "de")
    de_dataset = tokenize_dataset(de_dataset, de_processor)
    return {
        "en": en_dataset,
        "ru": ru_dataset,
        "de": de_dataset
    }

In [None]:
lid_model_path = "./models/lid/12-07-2023T20-47-58/CNNModel-date(12-07-2023T20-51-36)-accuracy(79.53).pth"

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}.")

In [None]:
lid_model = CNNModel()
lid_model.load_state_dict(torch.load(lid_model_path, map_location=device))

In [None]:
models = {
    "en": (en_model, en_processor),
    "ru": (ru_model, ru_processor),
    "de": (de_model, de_processor)
}

image_dataset = load_image_dataset()
datasets = load_test_datasets()

In [14]:
def get_index_by_image_index(image_index: int, true_class: str) -> int:
    if true_class in ["en", "ru"]:
        image_index -= len(datasets["de"])
    if true_class == "ru":
        image_index -= len(datasets["en"])
    return image_index

def print_prediction():
    with torch.no_grad():
        rand_int = random.randint(0, len(image_dataset) - 1)
        image, true_class_idx = image_dataset[rand_int]

        image = image.to(device)
        predicted_class_idx = lid_model.to(device)(image)

        true_class = image_dataset.classes[true_class_idx]
        predicted_class = image_dataset.classes[predicted_class_idx]
        print(f"Predicted class: {predicted_class}, true class: {true_class}.")

        model, processor = models[predicted_class]
        dataset = datasets[predicted_class]
    
        index = get_index_by_image_index(rand_int)
        entry = dataset[index]

        input_dict = processor(entry["input_values"], sampling_rate=16000, return_tensors="pt", padding=True)
        input_dict = {k: v.to(device) for k, v in input_dict.items()}

        logits = model.to(device)(**input_dict).logits

        pred_ids = torch.argmax(logits, dim=-1)
        transcription = processor.batch_decode(pred_ids)

        print(f"Prediction: {transcription}")
        ref = processor.decode(dataset[rand_int]["labels"])
        print(f"Reference: {ref}")

In [None]:
for i in range(5):
    print_prediction()
    print()