In [None]:
BRANCH = 'main'
!python -m pip install --quiet git+https://github.com/NVIDIA/NeMo-text-processing.git@$BRANCH#egg=nemo_text_processing

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m154.8/154.8 MB[0m [31m11.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m42.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for nemo_text_processing (setup.py) ... [?25l[?25hdone
  Building wheel for cdifflib (pyproject.toml) ... [?25l[?25hdone
  Building wheel for wget (setup.py) ... [?25l[?25hdone


In [None]:
!pip install --quiet RapidFuzz

In [None]:
import os
import math
import random
from pathlib import Path
from itertools import groupby

import pandas as pd
from tqdm import tqdm
from rapidfuzz import process
from nemo_text_processing.inverse_text_normalization.inverse_normalize import InverseNormalizer

import torch
from torch import nn, optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torchaudio
from torchaudio.transforms import FrequencyMasking, TimeMasking, Vol
from torchmetrics.text import CharErrorRate

In [None]:
inv_normalizer_ru = InverseNormalizer(lang="ru")

RU_UNITS_M = ['ноль','один','два','три','четыре','пять','шесть','семь','восемь','девять']
RU_UNITS_F = ['ноль','одна','две','три','четыре','пять','шесть','семь','восемь','девять']
RU_TEENS   = ['десять','одиннадцать','двенадцать','тринадцать','четырнадцать',
              'пятнадцать','шестнадцать','семнадцать','восемнадцать','девятнадцать']
RU_TENS    = ['','десять','двадцать','тридцать','сорок','пятьдесят',
              'шестьдесят','семьдесят','восемьдесят','девяносто']
RU_HUNDS   = ['','сто','двести','триста','четыреста','пятьсот',
              'шестьсот','семьсот','восемьсот','девятьсот']


def digits_to_words(num: int) -> str:
    def _triad(n: int, fem: bool = False) -> list[str]:
        h, r  = divmod(n, 100)
        t, u  = divmod(r, 10)
        words = []
        if h: words.append(RU_HUNDS[h])
        if t == 1:
            words.append(RU_TEENS[u])
        else:
            if t: words.append(RU_TENS[t])
            if u:
                words.append((RU_UNITS_F if fem else RU_UNITS_M)[u])
        return words
    th, rest = divmod(num, 1000)
    words = _triad(th, fem=True)
    last = th % 100
    if 11 <= last <= 14:
        words.append('тысяч')
    else:
        last %= 10
        words.append({1:'тысяча', 2:'тысячи', 3:'тысячи', 4:'тысячи'}.get(last, 'тысяч'))
    if rest:
        words += _triad(rest, fem=False)
    return " ".join(words)

def words_to_digits(words):
    return inv_normalizer_ru.inverse_normalize(words, verbose=False)

 NeMo-text-processing :: INFO     :: Creating ClassifyFst grammars. This might take some time...
 NeMo-text-processing :: INFO     :: Creating ClassifyFst grammars. This might take some time...


In [None]:
class Config:
    folder = Path("/kaggle/input/asr-numbers-recognition-in-russian/")

    target_samplerate = 16000
    n_mels = 80
    n_fft = 400
    hop_length = 160
    max_frames = 1000

    hidden = 128
    num_layers = 2
    dropout = 0.2

    batch_size = 32
    num_epochs = 10
    lr = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cfg = Config()

# Load data

In [None]:
class RuTokenizer:
    def __init__(self):
        self.labels = ['<blank>', ' '] + [chr(idx) for idx in range(ord('а'), ord('я') + 1)]
        self.token2idx = {t: i for i, t in enumerate(self.labels)}
        self.idx2token = {i: t for t, i in self.token2idx.items()}

    def encode(self, text):
        return [self.token2idx[c] for c in text.lower()]

    def decode(self, logits: torch.Tensor, greedy=True) -> str:
        if greedy:
            logits = torch.argmax(logits, dim=-1).tolist()
            logits = [idx for idx, _ in groupby(logits) if idx != 0]
        return "".join(self.idx2token[t] for t in logits)


tokenizer = RuTokenizer()

In [None]:
class NumbersDataset(Dataset):
    def __init__(self, folder, subset, tokenizer, config, augment=False):
        self.df = pd.read_csv(folder / f"{subset}.csv")
        self.texts = [digits_to_words(n) for n in self.df.transcription.tolist()]
        self.folder = folder
        self.tokenizer = tokenizer
        self.config = config
        self.augment = augment

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        wav_path = self.folder / row.filename
        waveform, sr = torchaudio.load(wav_path)

        if sr != self.config.target_samplerate:
            waveform = torchaudio.transforms.Resample(sr, self.config.target_samplerate)(waveform)

        waveform = waveform.mean(dim=0, keepdim=True)

        if self.augment:
            waveform += 0.003 * torch.randn_like(waveform)
            waveform = Vol(random.uniform(-6, 6), gain_type='db')(waveform)
            shift = int(random.uniform(-0.1, 0.1) * waveform.size(1))
            waveform = torch.roll(waveform, shift, dims=-1)

        melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.config.target_samplerate,
            n_mels=self.config.n_mels,
            n_fft = self.config.n_fft,
            hop_length = self.config.hop_length
        )(waveform)
        melspec = torchaudio.transforms.AmplitudeToDB()(melspec)
        melspec = melspec[..., :self.config.max_frames]

        if self.augment:
            melspec = FrequencyMasking(freq_mask_param=15)(melspec)
            melspec = TimeMasking(time_mask_param=35)(melspec)

        melspec = melspec.squeeze(0).transpose(0, 1)   # (T, M)
        target = torch.tensor(self.tokenizer.encode(self.texts[idx]), dtype=torch.long)
        return melspec, target, row.spk_id


def collate_fn(batch):
    X, y, ids   = zip(*batch)
    X_len = torch.tensor([e.size(0) for e in X], dtype=torch.long)
    X_pad = pad_sequence(X, batch_first=True)

    y_len = torch.tensor([e.size(0) for e in y], dtype=torch.long)
    y = torch.cat(y)
    return X_pad, X_len, y, y_len, ids

In [None]:
train_ds = NumbersDataset(cfg.folder, 'train', tokenizer, cfg, augment=True)
val_ds = NumbersDataset(cfg.folder, 'dev', tokenizer, cfg, augment=False)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

print(f"Train size: {len(train_loader)}, Val size: {len(val_loader)}")

Train size: 393, Val size: 71


# Init model

In [None]:
class CRNN(nn.Module):
    def __init__(self, cfg: Config, vocab_size: int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, (3, 3), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, (3, 3), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(cfg.dropout),

            nn.Conv2d(64, 128, (3, 3), padding=(1, 1)),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(cfg.dropout),
        )
        rnn_in = (cfg.n_mels // 4) * 128
        self.rnn = nn.LSTM(rnn_in, cfg.hidden, num_layers=cfg.num_layers,
                           bidirectional=True, batch_first=True, dropout=cfg.dropout)
        self.fc = nn.Linear(cfg.hidden * 2, vocab_size)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.conv(x)
        b, c, t, m = x.shape

        x = x.permute(0, 2, 3, 1).reshape(b, t, m * c)
        x, _ = self.rnn(x)
        return self.fc(x)


model = CRNN(cfg, vocab_size=len(tokenizer.labels)).to(cfg.device)

In [None]:
print(f"Model params: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f} M")

Model params: 3.27 M


In [None]:
ctc_loss = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=1e-5)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=cfg.lr, epochs=cfg.num_epochs,
                                          steps_per_epoch=len(train_loader))

# Train model

In [None]:
cer = CharErrorRate()

In [None]:
def train_one_epoch():
    model.train()
    total_loss = 0
    for X, X_len, y, y_len, ids in tqdm(train_loader):
        X, X_len, y = X.to(cfg.device), X_len.to(cfg.device), y.to(cfg.device)
        optimizer.zero_grad()
        X_ = model(X)
        loss = ctc_loss(X_.log_softmax(-1).transpose(0, 1), y, X_len, y_len)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)


def eval_one_epoch():
    model.eval()
    spk2preds = dict()
    spk2gts = dict()

    with torch.no_grad():
        for X, X_len, y, y_len, ids in tqdm(val_loader):
            X = X.to(cfg.device)
            logits = model(X).cpu()

            targets_split = torch.split(y, y_len.tolist())
            for i, (logit, tgt, spk) in enumerate(zip(logits, targets_split, ids)):
                pred_str = tokenizer.decode(logit[:X_len[i]])
                pred_str = words_to_digits(pred_str)
                true_str = tokenizer.decode(tgt.tolist(), greedy=False)
                if spk not in spk2preds:
                    spk2preds[spk] = [pred_str]
                    spk2gts[spk] = [true_str]
                else:
                    spk2preds[spk].append(pred_str)
                    spk2gts[spk].append(true_str)

    per_spk = {s: cer(spk2preds[s], spk2gts[s]).item()
               for s in spk2preds}

    macro_cer = sum(per_spk.values()) / len(per_spk)
    return macro_cer, per_spk

In [None]:
best_cer = 1.0
for epoch in range(1, cfg.num_epochs + 1):
    train_loss = train_one_epoch()
    val_cer, per_spk = eval_one_epoch()

    worst = sorted(per_spk.items(), key=lambda x: x[1], reverse=True)[:5]
    best  = sorted(per_spk.items(), key=lambda x: x[1])[:5]
    if val_cer < best_cer:
        best_cer, best_state = val_cer, model.state_dict()

    print(f"\nEpoch {epoch} — val CER {val_cer*100:.2f}% (best {best_cer*100:.2f}%)")
    print("Worst 5 speakers:", ", ".join(f"{s}:{e*100:.1f}%" for s,e in worst))
    print("Best  5 speakers:", ", ".join(f"{s}:{e*100:.1f}%" for s,e in best))

100%|██████████| 393/393 [05:07<00:00,  1.28it/s]
100%|██████████| 71/71 [01:59<00:00,  1.68s/it]



Epoch 1 — val CER 57.29% (best 57.29%)
Worst 5 speakers: spk_K:66.0%, spk_A:61.0%, spk_C:59.4%, spk_B:57.3%, spk_F:57.2%
Best  5 speakers: spk_J:51.5%, spk_H:53.1%, spk_I:54.9%, spk_D:56.0%, spk_E:56.5%


100%|██████████| 393/393 [04:53<00:00,  1.34it/s]
100%|██████████| 71/71 [01:48<00:00,  1.53s/it]



Epoch 2 — val CER 66.74% (best 57.29%)
Worst 5 speakers: spk_B:73.0%, spk_E:71.2%, spk_K:70.6%, spk_D:69.2%, spk_H:66.8%
Best  5 speakers: spk_I:54.7%, spk_A:64.7%, spk_J:65.5%, spk_F:65.6%, spk_C:66.0%


100%|██████████| 393/393 [04:50<00:00,  1.35it/s]
100%|██████████| 71/71 [02:05<00:00,  1.77s/it]



Epoch 3 — val CER 68.50% (best 57.29%)
Worst 5 speakers: spk_B:76.2%, spk_E:75.0%, spk_D:74.8%, spk_C:70.8%, spk_J:70.1%
Best  5 speakers: spk_I:55.9%, spk_K:59.7%, spk_A:64.8%, spk_F:68.3%, spk_H:69.4%


100%|██████████| 393/393 [04:50<00:00,  1.35it/s]
100%|██████████| 71/71 [02:07<00:00,  1.80s/it]



Epoch 4 — val CER 67.73% (best 57.29%)
Worst 5 speakers: spk_B:76.7%, spk_E:76.6%, spk_D:73.1%, spk_J:68.8%, spk_C:67.5%
Best  5 speakers: spk_I:54.6%, spk_K:62.3%, spk_H:65.4%, spk_A:65.9%, spk_F:66.5%


100%|██████████| 393/393 [04:51<00:00,  1.35it/s]
100%|██████████| 71/71 [02:06<00:00,  1.78s/it]



Epoch 5 — val CER 69.52% (best 57.29%)
Worst 5 speakers: spk_B:79.8%, spk_E:75.8%, spk_D:73.1%, spk_J:73.0%, spk_C:69.8%
Best  5 speakers: spk_I:57.2%, spk_K:63.6%, spk_A:65.1%, spk_H:68.7%, spk_F:69.2%


100%|██████████| 393/393 [04:50<00:00,  1.35it/s]
100%|██████████| 71/71 [02:10<00:00,  1.83s/it]



Epoch 6 — val CER 71.13% (best 57.29%)
Worst 5 speakers: spk_E:78.5%, spk_B:78.4%, spk_D:75.9%, spk_J:73.9%, spk_F:71.8%
Best  5 speakers: spk_I:58.4%, spk_K:66.5%, spk_H:68.5%, spk_A:68.7%, spk_C:70.6%


100%|██████████| 393/393 [04:49<00:00,  1.36it/s]
100%|██████████| 71/71 [02:11<00:00,  1.86s/it]



Epoch 7 — val CER 72.83% (best 57.29%)
Worst 5 speakers: spk_E:81.9%, spk_B:80.0%, spk_D:76.9%, spk_J:74.6%, spk_F:73.0%
Best  5 speakers: spk_I:63.1%, spk_A:67.9%, spk_K:68.2%, spk_H:70.9%, spk_C:72.0%


100%|██████████| 393/393 [04:50<00:00,  1.35it/s]
100%|██████████| 71/71 [02:13<00:00,  1.89s/it]



Epoch 8 — val CER 70.80% (best 57.29%)
Worst 5 speakers: spk_E:79.5%, spk_B:78.5%, spk_D:74.4%, spk_J:74.0%, spk_F:70.9%
Best  5 speakers: spk_I:60.0%, spk_A:66.1%, spk_K:66.3%, spk_C:68.6%, spk_H:69.7%


 98%|█████████▊| 384/393 [04:45<00:06,  1.34it/s]


# Create submission

In [None]:
NUM_WORDS = [
    "ноль","один","одна","два","две","три","четыре","пять","шесть","семь","восемь",
    "девять","десять","одиннадцать","двенадцать","тринадцать","четырнадцать",
    "пятнадцать","шестнадцать","семнадцать","восемнадцать","девятнадцать",
    "двадцать","тридцать","сорок","пятьдесят","шестьдесят","семьдесят",
    "восемьдесят","девяносто","сто","двести","триста","четыреста","пятьсот",
    "шестьсот","семьсот","восемьсот","девятьсот","тысяча","тысячи","тысяч",
]

def fuzzy_fix(word, threshold=80):
    best, score, _ = process.extractOne(word, NUM_WORDS)
    return best if score >= threshold else word

def split_words(word):
    parts = []
    while word:
        match = max((w for w in NUM_WORDS if word.startswith(w)), key=len, default=None)
        if match:
            parts.append(match)
            word = word[len(match):]
        else:
            parts.append(word)
            break
    return parts

def preclean(text):
    tokens = []
    for token in text.lower().split():
        for part in split_stuck(token):
            tokens.append(fuzzy_fix(part))

    return " ".join(tokens)

In [None]:
class Predictor:
    def __init__(self, best_state):
        self.cfg = cfg
        self.tokenizer = RuTokenizer()
        self.model = CRNN(cfg, len(self.tokenizer.labels)).to(cfg.device)
        self.model.load_state_dict(best_state)
        self.model.eval()

    def transcribe(self, wav_path: Path) -> str:
        wav, sr = torchaudio.load(wav_path)
        if sr != self.cfg.target_samplerate:
            wav = torchaudio.transforms.Resample(sr, self.cfg.target_samplerate)(wav)
        wav = wav.mean(dim=0, keepdim=True)
        melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.cfg.target_samplerate,
            n_mels=self.cfg.n_mels,
            n_fft=self.cfg.n_fft,
            hop_length=self.cfg.hop_length,
        )(wav)
        melspec = torchaudio.transforms.AmplitudeToDB()(melspec)
        melspec = melspec.squeeze(0).transpose(0, 1).unsqueeze(0).to(cfg.device)
        with torch.no_grad():
            X = self.model(melspec)
        pred = self.tokenizer.decode(X[0])
        pred = preclean(pred)
        pred = words_to_digits(pred)
        pred = ''.join([c for c in pred if c.isdigit()])
        return pred if len(pred) != 0 else '0'


predictor = Predictor(best_state)

In [None]:
test_paths = [fn for fn in os.listdir(cfg.folder / 'test')]
records = []
for fn in tqdm(test_paths):
    pred_txt = predictor.transcribe(cfg.folder / 'test' / fn)
    records.append({'filename': f"test/{fn}", 'transcription': pred_txt})

submission = pd.DataFrame(records)
submission.to_csv('submission.csv', index=False)

100%|██████████| 2582/2582 [03:19<00:00, 12.96it/s]
