In [1]:
import os
import pandas as pd
import torchaudio
from torch.utils.data import Dataset
from torchaudio.transforms import Resample
from num2words import num2words

In [None]:
class NumberASRDataset(Dataset):
    def __init__(self, csv_path, audio_dir, target_sample_rate=16000, transform=None):
        self.csv = pd.read_csv(csv_path)
        self.audio_dir = audio_dir
        self.target_sample_rate = target_sample_rate
        self.transform = transform

    def __len__(self):
        return len(self.csv)

    def __getitem__(self, idx):
        row = self.csv.iloc[idx]
        audio_path = os.path.join(self.audio_dir, os.path.basename(row["filename"]))

        waveform, sample_rate = torchaudio.load(audio_path)

        if sample_rate != self.target_sample_rate:
            resampler = Resample(
                orig_freq=sample_rate, new_freq=self.target_sample_rate
            )
            waveform = resampler(waveform)

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        if self.transform:
            waveform = self.transform(waveform)

        number = int(row["transcription"])
        text = num2words(number, lang="ru", to="cardinal")

        return waveform, text

In [None]:
dataset = NumberASRDataset(
    csv_path="data/train.csv",
    audio_dir="data/train",
    target_sample_rate=16000,
)

waveform, text = dataset[0]
print(waveform.shape)
print(text)

torch.Size([1, 48706])
сто тридцать девять тысяч четыреста семьдесят три


In [None]:
from torch.nn.utils.rnn import pad_sequence
import torch


def collate_fn(batch, tokenizer):
    waveforms = []
    waveform_lengths = []
    targets = []
    target_lengths = []

    for waveform, text in batch:
        waveforms.append(waveform.squeeze(0))
        waveform_lengths.append(waveform.shape[-1])

        tokenized = torch.tensor(tokenizer(text), dtype=torch.long)
        targets.append(tokenized)
        target_lengths.append(len(tokenized))

    waveforms = pad_sequence(waveforms, batch_first=True).unsqueeze(1)

    targets = pad_sequence(targets, batch_first=True, padding_value=0)

    waveform_lengths = torch.tensor(waveform_lengths, dtype=torch.long)
    target_lengths = torch.tensor(target_lengths, dtype=torch.long)

    return {
        "waveforms": waveforms,
        "waveform_lengths": waveform_lengths,
        "targets": targets,
        "target_lengths": target_lengths,
    }

In [None]:
class SimpleTokenizer:
    def __init__(self, alphabet):
        self.alphabet = sorted(list(set(alphabet)))
        self.char2idx = {c: i + 1 for i, c in enumerate(self.alphabet)}
        self.idx2char = {i + 1: c for i, c in enumerate(self.alphabet)}

    def __call__(self, text):
        return [self.char2idx[c] for c in text if c in self.char2idx]

    def decode(self, indices):
        return "".join([self.idx2char[i] for i in indices if i != 0])

    def vocab_size(self):
        return len(self.char2idx) + 1

In [None]:
alphabet = " абвгдеёжзийклмнопрстуфхцчшщъыьэюя0123456789-"


tokenizer = SimpleTokenizer(alphabet)


from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    collate_fn=lambda batch: collate_fn(batch, tokenizer),
)


batch = next(iter(train_loader))
print(batch["waveforms"].shape)
print(batch["waveform_lengths"])
print(batch["targets"].shape)
print(batch["target_lengths"])

torch.Size([8, 1, 66988])
tensor([66988, 54004, 66245, 53590, 47735, 50472, 51200, 44388])
torch.Size([8, 56])
tensor([55, 52, 46, 56, 46, 49, 55, 43])


In [None]:
import torch.nn as nn


class SmallASRModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=5, stride=2, padding=2),
            nn.ReLU(),
        )

        self.encoder = nn.LSTM(
            input_size=64,
            hidden_size=128,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
        )

        self.classifier = nn.Linear(128 * 2, vocab_size)

    def forward(self, x, lengths):
        x = self.conv(x)
        print(x.shape)
        lengths = lengths // 4
        x = x.permute(0, 2, 1)

        x = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        x, _ = self.encoder(x)
        x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)

        logits = self.classifier(x)
        return logits

In [None]:
model = SmallASRModel(vocab_size=len(alphabet))
params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {params/1e6:.2f}M")

Total parameters: 0.62M


In [None]:
import torch.optim as optim
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SmallASRModel(vocab_size=len(alphabet)).to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
ctc_loss = nn.CTCLoss(blank=0, zero_infinity=True)

for epoch in range(10):
    model.train()
    total_loss = 0

    for batch in train_loader:
        waveforms = batch["waveforms"].to(device)
        waveform_lengths = batch["waveform_lengths"].to(device)
        targets = batch["targets"].to(device)
        target_lengths = batch["target_lengths"].to(device)

        optimizer.zero_grad()
        print(waveforms.shape)
        print(waveform_lengths)
        logits = model(waveforms, waveform_lengths)

        log_probs = F.log_softmax(logits, dim=-1)
        input_lengths = torch.full(
            size=(log_probs.size(0),), fill_value=log_probs.size(1), dtype=torch.long
        ).to(device)

        loss = ctc_loss(
            log_probs.permute(1, 0, 2),
            targets,
            input_lengths,
            target_lengths,
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss {total_loss / len(train_loader):.4f}")

torch.Size([8, 1, 63087])
tensor([40331, 56778, 63087, 49560, 49556, 45502, 54365, 55346])
torch.Size([8, 64, 15772])
