импорт всех библиотек 

In [23]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchaudio
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

Реализация модели
Она будет похожа на DeepSpeech2.
Плюсы:
 - логичность и простота
 - мало гиперпараметров (в зависимости от размера модели)
 - лучше чем обычная сеть из RNN (решена проблема взрывающихся градиентов)
Минусы:
 - требует большого объема данных для обучения
 - зависит от ресурсов

Архитектура:
    - входной слой (входная матрица), размерность входная:, выходная размерность: 
    - CNN layers для уменьшения размера входной матрицы (решение взрывных градиентов), входная размерность:  ,размерность выходная
    - RNN Layers обработка матрицы значения по времени: размерность входная: , размерность выходная:
    - Fully connected layers для более точного предсказания ответа модели для каждого момента времени
    - Softmax -> получение итого результата (из логитов в вероятность)



CNN

In [2]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=5, dropout_rate=0.3):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size, padding=kernel_size // 2
        )
        self.layer_norm = nn.LayerNorm(out_channels)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        # Conv1d принимает (B, C, T), LayerNorm требует переставить размерности
        x = self.conv(x)  # (B, C_out, T)
        x = x.permute(0, 2, 1)  # (B, C_out, T) -> (B, T, C_out)
        x = self.layer_norm(x)  # Нормализация
        x = x.permute(0, 2, 1)  # (B, T, C_out) -> (B, C_out, T)
        x = self.relu(x)
        x = self.dropout(x)
        return x


RNN

In [3]:
class GRUBlock(nn.Module):
    def __init__(self, input_size, hidden_size, dropout_rate=0.3):
        super(GRUBlock, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.layer_norm = nn.LayerNorm(hidden_size)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x, _ = self.gru(x)  # (B, T, H_out)
        x = self.layer_norm(x)
        x = self.dropout(x)
        return x


Реализуем все под части в 1 модель

In [4]:
class ASRModel(nn.Module):
    def __init__(self, input_channels, gru_hidden_size, num_classes, dropout_rate=0.3):
        super(ASRModel, self).__init__()

        # Сверточные блоки
        self.conv_blocks = nn.ModuleList(
            [
                ConvBlock(input_channels, 32, dropout_rate=dropout_rate),
                ConvBlock(32, 64, dropout_rate=dropout_rate),
                ConvBlock(64, 128, dropout_rate=dropout_rate),
            ]
        )

        # Рекуррентные блоки
        self.gru_blocks = nn.ModuleList(
            [
                GRUBlock(128, gru_hidden_size, dropout_rate=dropout_rate)
                for _ in range(7)
            ]
        )

        # Полносвязные слои
        self.fc = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(gru_hidden_size, gru_hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(gru_hidden_size // 2, num_classes),
        )
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        # Сверточные блоки
        for conv in self.conv_blocks:
            x = conv(x)  # (B, C_out, T)

        # Переставляем размерности для GRU
        x = x.permute(0, 2, 1)  # (B, C_out, T) -> (B, T, C_out)

        # Рекуррентные блоки
        for gru in self.gru_blocks:
            x = gru(x)  # (B, T, H_out)

        # Полносвязные слои
        x = x[:, -1, :]  # Берем последний временной шаг (B, H_out)
        x = self.fc(x)  # Полносвязные слои
        x = self.softmax(x)  # Преобразование в вероятности

        return x


предобработка данных

In [None]:
class TextPreprocessor:
    def __init__(self, char_map: str | list | set, use_blank_token: bool = True):
        if isinstance(char_map, list):
            char_map = "".join(char_map)
        if isinstance(char_map, set):
            char_map = "".join(char_map)
        self.char_map = char_map
        self.use_blank_token = use_blank_token
        self.char_to_int = self.__create_char_to_index()
        self.int_to_char = self.__create_int_char()

    def __create_int_char(self):
        start_index = 0
        if self.use_blank_token:
            start_index = 1
        index_to_char = {
            idx: char for idx, char in enumerate(self.char_map, start=start_index)
        }
        return index_to_char

    def __create_char_to_index(self):
        start_index = 0
        if self.use_blank_token:
            start_index = 1
        char_to_index = {
            char: idx for idx, char in enumerate(self.char_map, start=start_index)
        }
        return char_to_index

    def __convert_list_to_tensor(
        self, data: list[list[int]], torch_device="cpu"
    ) -> torch.Tensor:
        # Converts a batch of lists to a torch.Tensor
        return torch.tensor(data, device=torch_device)

    def __convert_list_to_numpy(self, data: list[list[int]]) -> np.ndarray:
        # Converts a batch of lists to a numpy array
        return np.array(data)

    def encode(
        self, text: str, types: str = "list", torch_device="cpu"
    ) -> list[int] | torch.Tensor | np.ndarray:
        encoded = [self.char_to_int[char] for char in text]
        if types == "torch":
            return torch.tensor(encoded, device=torch_device)
        elif types == "numpy":
            return np.array(encoded)
        return encoded

    def decode(self, indices: list[int] | torch.Tensor | np.ndarray) -> str:
        if isinstance(indices, torch.Tensor):
            indices = indices.tolist()
        elif isinstance(indices, np.ndarray):
            indices = indices.tolist()
        return "".join([self.int_to_char.get([idx], "") for idx in indices])

    def encode_batch(
        self, texts: list[str], types: str = "list", torch_device="cpu"
    ) -> list[list[int]] | torch.Tensor | np.ndarray:
        encoded_batch = [self.encode(text, types, torch_device) for text in texts]
        if types == "torch":
            return self.__convert_list_to_tensor(encoded_batch, torch_device)
        elif types == "numpy":
            return self.__convert_list_to_numpy(encoded_batch)
        return encoded_batch

    def decode_batch(
        self, indices: list[list[int] | torch.Tensor | np.ndarray]
    ) -> list[str]:
        return [self.decode(idx) for idx in indices]


# Example usage
text_preprocessor = TextPreprocessor(
    char_map=" абвгдеёжзийклмнопрстуфхцчшщъыьэюя", use_blank_token=True
)

encoded_text = text_preprocessor.encode("абв", types="torch")
print(encoded_text)

encoded_batch = text_preprocessor.encode_batch(["абв", "где"], types="numpy")
print(encoded_batch)

decoded_text = text_preprocessor.decode(encoded_text)
print(decoded_text)

decoded_batch = text_preprocessor.decode_batch(encoded_batch)
print(decoded_batch)


tensor([2, 3, 4])
[[2 3 4]
 [5 6 7]]
абв
['абв', 'где']


In [7]:
train_audio_transforms = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128, n_fft=1024),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
    torchaudio.transforms.TimeMasking(time_mask_param=100),
)

valid_audio_transforms = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128, n_fft=1024)


In [9]:
def data_processing_valid(data):
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    for waveform, utterance in data:
        # Apply the MelSpectrogram transformation
        spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        # Add to the list of spectrograms
        spectrograms.append(spec)
        # Convert the utterance to integers using the TextPreprocessor
        label = text_preprocessor.encode(utterance.lower(), types="torch")
        labels.append(label)
        # Input length is determined by the number of frames in the spectrogram
        input_lengths.append(spec.shape[0] // 2)
        # Label length is the length of the utterance
        label_lengths.append(label.shape[0])

    # Pad the spectrograms to the same length and remove the channel dimension
    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).transpose(
        1, 2
    )
    # Pad the labels
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

    return spectrograms, labels, input_lengths, label_lengths


def data_processing_train(data):
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    for waveform, utterance in data:
        # Apply the training transformations (MelSpectrogram, FrequencyMasking, TimeMasking)
        spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        # Add to the list of spectrograms
        spectrograms.append(spec)
        # Convert the utterance to integers using the TextPreprocessor
        label = text_preprocessor.encode(utterance.lower(), types="torch")
        labels.append(label)
        # Input length is determined by the number of frames in the spectrogram
        input_lengths.append(spec.shape[0] // 2)
        # Label length is the length of the utterance
        label_lengths.append(label.shape[0])

    # Pad the spectrograms to the same length and remove the channel dimension
    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).transpose(
        1, 2
    )
    # Pad the labels
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

    return spectrograms, labels, input_lengths, label_lengths


In [32]:
def GreedyDecoder(
    output, labels, label_lengths, blank_label=0, collapse_repeated=True
):
    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    targets = []
    for i, args in enumerate(arg_maxes):
        decode = []
        targets.append(text_preprocessor.decode(labels[i][: label_lengths[i]].tolist()))
        for j, index in enumerate(args):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == args[j - 1]:
                    continue
                decode.append(index.item())
        decodes.append(text_preprocessor.decode(decode))
    return decodes, targets


In [10]:
def normalization_text(text: str) -> str:
    """
    Normalizes the text by removing all non-alphabetic characters and converting it to lowercase.
    :param text: The text to be normalized.
    :return: The normalized text.
    """
    return "".join(
        char for char in text.lower() if char in " абвгдеёжзийклмнопрстуфхцчшщъыьэюя"
    )


In [27]:
class AudioDataset(Dataset):
    def __init__(self, data_path_csv: str, audio_path: str, sep=","):
        self.data_path_csv = data_path_csv
        self.audio_path = audio_path
        self.data = pd.read_csv(data_path_csv, sep=sep, header=None, usecols=[0, 1], names=["path", "sentence"])
        print(self.data.head())


    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path = os.path.join(self.audio_path, self.data.iloc[idx, 0])  # type: ignore str + str
        sentence_path = os.path.join(
            self.audio_path, self.data.iloc[idx, 1]
        )
        with open(sentence_path, "r") as file:
            sentence = file.read()
        sentence = normalization_text(sentence)
        audio = torchaudio.load(path)[0]
        return audio, sentence

train_dataset = AudioDataset(data_path_csv="D:/Dataset/speech_ru/public_youtube700.csv", audio_path="D:/Dataset/speech_ru")
valid_dataset = AudioDataset(data_path_csv="D:/Dataset/speech_ru/public_youtube700_val.csv", audio_path="D:/Dataset/speech_ru")

                                       path  \
0  public_youtube700/f/47/0b1f5586bbd7.opus   
1  public_youtube700/8/c8/7367ee2bf1d7.opus   
2  public_youtube700/1/1d/44616e05b60c.opus   
3  public_youtube700/8/5d/a7e9f299fb9c.opus   
4  public_youtube700/7/5b/b975fff943bf.opus   

                                  sentence  
0  public_youtube700/f/47/0b1f5586bbd7.txt  
1  public_youtube700/8/c8/7367ee2bf1d7.txt  
2  public_youtube700/1/1d/44616e05b60c.txt  
3  public_youtube700/8/5d/a7e9f299fb9c.txt  
4  public_youtube700/7/5b/b975fff943bf.txt  
                                          path  \
0  public_youtube700_val/a/e3/0f2e35efe76d.wav   
1  public_youtube700_val/f/0d/4cda8ca9e32c.wav   
2  public_youtube700_val/c/44/643792385a6d.wav   
3  public_youtube700_val/7/c6/3f876449790b.wav   
4  public_youtube700_val/4/a8/cf00f3c177ac.wav   

                                      sentence  
0  public_youtube700_val/a/e3/0f2e35efe76d.txt  
1  public_youtube700_val/f/0d/4cda8ca9e32c.txt  

In [28]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=data_processing_train)
valid_dataloader = DataLoader(valid_dataset, batch_size=16, shuffle=False, collate_fn=data_processing_valid)