# Armenian Speech Recognition

In [16]:
from pathlib import Path

import torch
import torchaudio
from IPython.display import Audio

from datamodule import alphabet, CommonVoiceDataModule
from decoder import BeamCTCDecoder, GreedyDecoder
from torchaudio.models.decoder import ctc_decoder
from lit_conformer import LitConformer
from eval import run_evaluation

In [2]:
class Preprocessor:
    def __init__(self, normalize: bool = True, target_sample_rate: int = 16000):
        self.normalize = normalize
        self.target_sample_rate = target_sample_rate
        
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate = target_sample_rate,
            n_fft = int(25 / 1000 * target_sample_rate), # 25 ms
            n_mels = 80, # num of mel filter banks
            hop_length = int(10 / 1000 * target_sample_rate) # 10 ms
        )
    
    def __call__(self, signal: torch.Tensor, sr: int):
        signal = self._resample_if_necessary(signal, sr)
        log_mel = torch.log(self.mel_transform(signal) + 1e-9)
        log_mel = (log_mel - log_mel.mean()) / (log_mel.std())
        return log_mel.squeeze().permute(1, 0)

    def _resample_if_necessary(self, signal, sr):
        if sr != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.target_sample_rate)
            signal = resampler(signal)
        return signal

In [3]:
def recognize(audio_path, preprocess, model, decoder, return_all=False):
    signal, sr = torchaudio.load(audio_path)
    model_input = preprocess(signal=signal, sr=sr)
    model_input = model_input.to(device)
    input_sizes = torch.LongTensor([model_input.shape[0]])
    model_input = model_input.unsqueeze(0)
    out, output_lengths = model(model_input, input_sizes)
    if type(decoder) == BeamCTCDecoder:
        decoded_output = decoder.decode(out, output_lengths, return_all)
    else:
        decoded_output = decoder.decode(out, output_lengths)
    return decoded_output[0]

In [4]:
data_root = Path('./data')
clips_root = data_root / 'clips'
device = 'cuda'
checkpoint_path = 'checkpoints/version_6/armspeech2text-epoch=602-cer=12.83.ckpt'

In [5]:
preprocess = Preprocessor(normalize=True, target_sample_rate=16000)

In [6]:
model = LitConformer.load_from_checkpoint(checkpoint_path, labels=alphabet).to(device)

In [7]:
beam_decoder = BeamCTCDecoder(alphabet, beam_size=100, nbest=1)
greedy_decoder = GreedyDecoder(alphabet)

## Examples

In [13]:
audio_path = clips_root / 'common_voice_hy-AM_31545119.mp3'
Audio(audio_path)

In [14]:
recognize(audio_path, preprocess=preprocess, model=model, decoder=beam_decoder, return_all=True)

[['մետաքսի վրա աստեղնագործված աթանաշերը խորհրդանշում են հաջողություն երկանկություն ֆարստություն իշխանություն։']]

In [15]:
recognize(audio_path, preprocess=preprocess, model=model, decoder=greedy_decoder)

['մետաքսի վրա աստեղնագործված աթանաշերը խորհրդանշում են հաջողություն երկանկություն ֆարստություն իշխանություն։']

## Evaluation

In [11]:
dm = CommonVoiceDataModule(batch_size=8)
dm.setup()
run_evaluation(
    test_loader=dm.test_dataloader(),
    model=model,
    decoder=greedy_decoder,
    device=device
)

100%|█████████████████████████████████████████████████████████████████████████████| 63/63 [00:02<00:00, 25.82it/s]


(tensor(0.5450), tensor(0.1361))

In [12]:
dm = CommonVoiceDataModule(batch_size=8)
dm.setup()
run_evaluation(
    test_loader=dm.test_dataloader(),
    model=model,
    decoder=beam_decoder,
    device=device
)

100%|█████████████████████████████████████████████████████████████████████████████| 63/63 [00:24<00:00,  2.52it/s]


(tensor(0.5450), tensor(0.1361))