In [13]:
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample
from datasets import load_dataset
from torch.nn import CTCLoss
from torch.optim import Adam
import torch.nn as nn

In [2]:
# Скачиваем датасет с русскими данными 
# https://huggingface.co/datasets/mozilla-foundation/common_voice_11_0
cv_11_train = load_dataset("mozilla-foundation/common_voice_11_0", "ru", split="train")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [6]:
# Создание класса датасета
class VoiceDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
        self.resampler = Resample(orig_freq=48_000, new_freq=16_000)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        audio_path = self.dataset[idx]["path"]
        waveform, sample_rate = torchaudio.load(audio_path, normalize=True)
        waveform = self.resampler(waveform)
        transcription = self.dataset[idx]["sentence"]
        return waveform, transcription

In [7]:
# Инициализация датасета и загрузчика данных
voice_dataset = VoiceDataset(cv_11_train)
dataloader = DataLoader(voice_dataset, batch_size=32, shuffle=True, num_workers=4)

In [18]:
# Пример определения простой модели для голосового распознавания
class SimpleSpeechRecognitionModel(nn.Module):
    def __init__(self, num_classes):
        super(SimpleSpeechRecognitionModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(64, num_classes)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [19]:
# Указание количества классов
num_classes = 10  # Замените на фактическое количество классов в вашем датасете

# Инициализация модели
model = SimpleSpeechRecognitionModel(num_classes)

In [20]:
# Определение функции потерь и оптимизатора
criterion = CTCLoss(blank=0)
optimizer = Adam(model.parameters(), lr=0.001)

In [25]:
# Пример обучения модели на нескольких эпохах
num_epochs = 5
for epoch in range(num_epochs):
    for batch_idx, (waveform, transcription) in enumerate(dataloader):
        optimizer.zero_grad()
        output = model(waveform)
        # Преобразование текстов в тензоры для расчета функции потерь CTC
        target = torch.IntTensor([ord(char) for char in transcription])
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % 10 == 0:
            print(f"Эпоха {epoch+1}/{num_epochs}, Шаг {batch_idx}, Потеря: {loss.item()}")

ImportError: Caught ImportError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/tmp/ipykernel_300860/1671197508.py", line 11, in __getitem__
    audio_path = self.dataset[idx]["path"]
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 2810, in __getitem__
    return self._getitem(key)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/arrow_dataset.py", line 2795, in _getitem
    formatted_output = format_table(
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/formatting/formatting.py", line 629, in format_table
    return formatter(pa_table, query_type=query_type)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/formatting/formatting.py", line 396, in __call__
    return self.format_row(pa_table)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/formatting/formatting.py", line 437, in format_row
    row = self.python_features_decoder.decode_row(row)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/formatting/formatting.py", line 215, in decode_row
    return self.features.decode_example(row) if self.features else row
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/features/features.py", line 1939, in decode_example
    return {
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/features/features.py", line 1940, in <dictcomp>
    column_name: decode_nested_example(feature, value, token_per_repo_id=token_per_repo_id)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/features/features.py", line 1340, in decode_nested_example
    return schema.decode_example(obj, token_per_repo_id=token_per_repo_id)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/datasets/features/audio.py", line 191, in decode_example
    array = librosa.to_mono(array)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/lazy_loader/__init__.py", line 78, in __getattr__
    attr = getattr(submod, name)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/lazy_loader/__init__.py", line 77, in __getattr__
    submod = importlib.import_module(submod_path)
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/importlib/__init__.py", line 127, in import_module
    return _bootstrap._gcd_import(name[level:], package, level)
  File "<frozen importlib._bootstrap>", line 1030, in _gcd_import
  File "<frozen importlib._bootstrap>", line 1007, in _find_and_load
  File "<frozen importlib._bootstrap>", line 986, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 680, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 850, in exec_module
  File "<frozen importlib._bootstrap>", line 228, in _call_with_frames_removed
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/librosa/core/audio.py", line 17, in <module>
    from numba import jit, stencil, guvectorize
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/numba/__init__.py", line 56, in <module>
    _ensure_critical_deps()
  File "/home/redalexdad/anaconda3/envs/dl_science/lib/python3.9/site-packages/numba/__init__.py", line 40, in _ensure_critical_deps
    raise ImportError(msg)
ImportError: Numba needs NumPy 1.22 or greater. Got NumPy 1.20.


In [None]:
# Сохранение обученной модели
torch.save(model.state_dict(), "my_model_001.pth")