In [1]:
import os
import math
import random
from pathlib import Path
from itertools import groupby

import pandas as pd
from tqdm import tqdm

import torch
from torch import nn, optim
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
import torchaudio
from torchaudio.transforms import FrequencyMasking, TimeMasking, Vol
from torchmetrics.text import CharErrorRate

In [2]:
class Config:
    folder = Path("/kaggle/input/asr-numbers-recognition-in-russian/")

    target_samplerate = 16000
    n_mels = 80
    n_fft = 400
    hop_length = 160
    max_frames = 1000

    hidden = 128
    num_layers = 2
    dropout = 0.2

    batch_size = 32
    num_epochs = 10
    lr = 1e-3
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

cfg = Config()

# Load data

In [3]:
class NumberTokenizer:
    def __init__(self):
        self.labels = ['<blank>'] + [str(d) for d in range(10)]
        self.token2idx = {t: i for i, t in enumerate(self.labels)}
        self.idx2token = {i: t for t, i in self.token2idx.items()}

    def encode(self, text):
        return [self.token2idx[c] for c in text]

    def decode(self, logits: torch.Tensor, greedy=True) -> str:
        if greedy:
            logits = torch.argmax(logits, dim=-1).tolist()
            logits = [idx for idx, _ in groupby(logits) if idx != 0]
        return "".join(self.idx2token[t] for t in logits)


tokenizer = NumberTokenizer()

In [4]:
class NumbersDataset(Dataset):
    def __init__(self, folder, subset, tokenizer, config, augment=False):
        self.df = pd.read_csv(folder / f"{subset}.csv")
        self.folder = folder
        self.tokenizer = tokenizer
        self.config = config
        self.augment = augment

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        wav_path = self.folder / row.filename
        waveform, sr = torchaudio.load(wav_path)

        if sr != self.config.target_samplerate:
            waveform = torchaudio.transforms.Resample(sr, self.config.target_samplerate)(waveform)

        waveform = waveform.mean(dim=0, keepdim=True)

        if self.augment:
            waveform += 0.003 * torch.randn_like(waveform)
            waveform = Vol(random.uniform(-6, 6), gain_type='db')(waveform)
            shift = int(random.uniform(-0.1, 0.1) * waveform.size(1))
            waveform = torch.roll(waveform, shift, dims=-1)

        melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.config.target_samplerate,
            n_mels=self.config.n_mels,
            n_fft = self.config.n_fft,
            hop_length = self.config.hop_length
        )(waveform)
        melspec = torchaudio.transforms.AmplitudeToDB()(melspec)
        melspec = melspec[..., :self.config.max_frames]

        if self.augment:
            melspec = FrequencyMasking(freq_mask_param=15)(melspec)
            melspec = TimeMasking(time_mask_param=35)(melspec)

        melspec = melspec.squeeze(0).transpose(0, 1)
        target = torch.tensor(self.tokenizer.encode(str(row.transcription)), dtype=torch.long)
        return melspec, target, row.spk_id

def collate_fn(batch):
    X, y, ids   = zip(*batch)
    X_len = torch.tensor([e.size(0) for e in X], dtype=torch.long)
    X_pad = pad_sequence(X, batch_first=True)

    y_len = torch.tensor([e.size(0) for e in y], dtype=torch.long)
    y = torch.cat(y)
    return X_pad, X_len, y, y_len, ids

In [5]:
train_ds = NumbersDataset(cfg.folder, 'train', tokenizer, cfg, augment=True)
val_ds = NumbersDataset(cfg.folder, 'dev', tokenizer, cfg, augment=False)

train_loader = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False, collate_fn=collate_fn)

print(f"Train size: {len(train_loader)}, Val size: {len(val_loader)}")

Train size: 393, Val size: 71


# Init model

In [6]:
class CRNN(nn.Module):
    def __init__(self, cfg: Config, vocab_size: int):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, (3, 3), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, (3, 3), padding=(1, 1)),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(cfg.dropout),

            nn.Conv2d(64, 128, (3, 3), padding=(1, 1)),
            nn.BatchNorm2d(128), nn.ReLU(),
            nn.MaxPool2d((1, 2)),
            nn.Dropout(cfg.dropout),
        )
        rnn_in = (cfg.n_mels // 4) * 128
        self.rnn = nn.LSTM(rnn_in, cfg.hidden, num_layers=cfg.num_layers,
                           bidirectional=True, batch_first=True, dropout=cfg.dropout)
        self.fc = nn.Linear(cfg.hidden * 2, vocab_size)

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.conv(x)
        b, c, t, m = x.shape

        x = x.permute(0, 2, 3, 1).reshape(b, t, m * c)
        x, _ = self.rnn(x)
        return self.fc(x)


model = CRNN(cfg, vocab_size=len(tokenizer.labels)).to(cfg.device)

In [7]:
print(f"Model params: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f} M")

Model params: 3.26 M


In [8]:
ctc_loss = nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
optimizer = optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=1e-5)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=cfg.lr, epochs=cfg.num_epochs,
                                          steps_per_epoch=len(train_loader))

# Train model

In [9]:
cer = CharErrorRate()

In [10]:
def train_one_epoch():
    model.train()
    total_loss = 0
    for X, X_len, y, y_len, ids in tqdm(train_loader):
        X, X_len, y = X.to(cfg.device), X_len.to(cfg.device), y.to(cfg.device)
        optimizer.zero_grad()
        X_ = model(X)
        loss = ctc_loss(X_.log_softmax(-1).transpose(0, 1), y, X_len, y_len)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        scheduler.step()
        total_loss += loss.item()
    return total_loss / len(train_loader)


def eval_one_epoch():
    model.eval()
    spk2preds = dict()
    spk2gts = dict()

    with torch.no_grad():
        for X, X_len, y, y_len, ids in val_loader:
            X = X.to(cfg.device)
            logits = model(X).cpu()

            targets_split = torch.split(y, y_len.tolist())
            for i, (logit, tgt, spk) in enumerate(zip(logits, targets_split, ids)):
                pred_str = tokenizer.decode(logit[:X_len[i]])
                true_str = tokenizer.decode(tgt.tolist(), greedy=False)
                if spk not in spk2preds:
                    spk2preds[spk] = [pred_str]
                    spk2gts[spk] = [true_str]
                else:
                    spk2preds[spk].append(pred_str)
                    spk2gts[spk].append(true_str)

    per_spk = {s: cer(spk2preds[s], spk2gts[s]).item()
               for s in spk2preds}

    macro_cer = sum(per_spk.values()) / len(per_spk)
    return macro_cer, per_spk

In [11]:
best_cer = 1.0
for epoch in range(1, cfg.num_epochs + 1):
    train_loss = train_one_epoch()
    val_cer, per_spk = eval_one_epoch()

    worst = sorted(per_spk.items(), key=lambda x: x[1], reverse=True)[:5]
    best  = sorted(per_spk.items(), key=lambda x: x[1])[:5]
    if val_cer < best_cer:
        best_cer, best_state = val_cer, model.state_dict()

    print(f"\nEpoch {epoch} — val CER {val_cer*100:.2f}% (best {best_cer*100:.2f}%)")
    print("Worst 5 speakers:", ", ".join(f"{s}:{e*100:.1f}%" for s,e in worst))
    print("Best  5 speakers:", ", ".join(f"{s}:{e*100:.1f}%" for s,e in best))

100%|██████████| 393/393 [08:45<00:00,  1.34s/it]



Epoch 1 — val CER 100.00% (best 100.00%)
Worst 5 speakers: spk_J:100.0%, spk_I:100.0%, spk_K:100.0%, spk_H:100.0%, spk_F:100.0%
Best  5 speakers: spk_J:100.0%, spk_I:100.0%, spk_K:100.0%, spk_H:100.0%, spk_F:100.0%


100%|██████████| 393/393 [05:14<00:00,  1.25it/s]



Epoch 2 — val CER 80.27% (best 80.27%)
Worst 5 speakers: spk_K:93.0%, spk_C:84.2%, spk_A:83.8%, spk_I:81.4%, spk_D:80.9%
Best  5 speakers: spk_E:72.0%, spk_B:72.3%, spk_J:76.2%, spk_H:79.1%, spk_F:79.7%


100%|██████████| 393/393 [05:22<00:00,  1.22it/s]



Epoch 3 — val CER 38.52% (best 38.52%)
Worst 5 speakers: spk_K:59.9%, spk_I:45.4%, spk_C:42.6%, spk_F:39.7%, spk_A:38.8%
Best  5 speakers: spk_B:30.1%, spk_H:30.2%, spk_E:30.7%, spk_J:31.0%, spk_D:36.8%


100%|██████████| 393/393 [05:24<00:00,  1.21it/s]



Epoch 4 — val CER 36.17% (best 36.17%)
Worst 5 speakers: spk_K:50.3%, spk_I:43.9%, spk_C:38.9%, spk_A:37.9%, spk_F:36.6%
Best  5 speakers: spk_E:27.5%, spk_H:30.4%, spk_B:31.0%, spk_J:31.3%, spk_D:34.0%


100%|██████████| 393/393 [05:11<00:00,  1.26it/s]



Epoch 5 — val CER 32.07% (best 32.07%)
Worst 5 speakers: spk_K:45.9%, spk_I:40.2%, spk_A:35.6%, spk_C:34.1%, spk_F:32.5%
Best  5 speakers: spk_E:21.9%, spk_H:25.6%, spk_J:26.4%, spk_B:27.4%, spk_D:31.1%


100%|██████████| 393/393 [04:59<00:00,  1.31it/s]



Epoch 6 — val CER 29.35% (best 29.35%)
Worst 5 speakers: spk_K:40.4%, spk_I:37.0%, spk_A:31.8%, spk_D:29.7%, spk_F:29.7%
Best  5 speakers: spk_E:23.6%, spk_B:23.6%, spk_H:23.9%, spk_J:24.1%, spk_C:29.6%


100%|██████████| 393/393 [05:02<00:00,  1.30it/s]



Epoch 7 — val CER 31.56% (best 29.35%)
Worst 5 speakers: spk_K:42.2%, spk_I:39.6%, spk_A:34.4%, spk_C:33.9%, spk_F:32.0%
Best  5 speakers: spk_E:24.5%, spk_H:25.3%, spk_B:25.9%, spk_J:26.6%, spk_D:31.2%


100%|██████████| 393/393 [05:01<00:00,  1.30it/s]



Epoch 8 — val CER 29.97% (best 29.35%)
Worst 5 speakers: spk_K:40.8%, spk_I:39.4%, spk_C:34.1%, spk_A:30.9%, spk_F:29.5%
Best  5 speakers: spk_E:22.4%, spk_H:23.7%, spk_B:24.0%, spk_J:26.2%, spk_D:28.7%


100%|██████████| 393/393 [05:03<00:00,  1.29it/s]



Epoch 9 — val CER 28.91% (best 28.91%)
Worst 5 speakers: spk_K:39.5%, spk_I:37.9%, spk_C:32.9%, spk_A:30.6%, spk_F:28.1%
Best  5 speakers: spk_H:22.0%, spk_E:22.1%, spk_J:24.2%, spk_B:24.8%, spk_D:27.0%


100%|██████████| 393/393 [04:55<00:00,  1.33it/s]



Epoch 10 — val CER 29.60% (best 28.91%)
Worst 5 speakers: spk_K:39.2%, spk_I:38.0%, spk_C:34.4%, spk_A:31.3%, spk_F:29.3%
Best  5 speakers: spk_H:22.4%, spk_E:22.8%, spk_J:24.8%, spk_B:25.2%, spk_D:28.7%


# Create submission

In [14]:
class Predictor:
    def __init__(self, best_state):
        self.cfg = cfg
        self.tokenizer = NumberTokenizer()
        self.model = CRNN(cfg, len(self.tokenizer.labels)).to(cfg.device)
        self.model.load_state_dict(best_state)
        self.model.eval()

    def transcribe(self, wav_path: Path) -> str:
        wav, sr = torchaudio.load(wav_path)
        if sr != self.cfg.target_samplerate:
            wav = torchaudio.transforms.Resample(sr, self.cfg.target_samplerate)(wav)
        wav = wav.mean(dim=0, keepdim=True)
        melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.cfg.target_samplerate,
            n_mels=self.cfg.n_mels,
            n_fft=self.cfg.n_fft,
            hop_length=self.cfg.hop_length,
        )(wav)
        melspec = torchaudio.transforms.AmplitudeToDB()(melspec)
        melspec = melspec.squeeze(0).transpose(0, 1).unsqueeze(0).to(cfg.device)
        with torch.no_grad():
            X = self.model(melspec)
        pred = self.tokenizer.decode(X[0])
        return pred


predictor = Predictor(best_state)

In [13]:
test_paths = [fn for fn in os.listdir(cfg.folder / 'test')]
records = []
for fn in tqdm(test_paths):
    pred_txt = predictor.transcribe(cfg.folder / 'test' / fn)
    pred_num = int(pred_txt.replace(' ', '')) if pred_txt else 0
    records.append({'filename': f"test/{fn}", 'transcription': pred_num})

submission = pd.DataFrame(records)
submission.to_csv('submission.csv', index=False)

100%|██████████| 2582/2582 [01:11<00:00, 36.32it/s]
