# Fine Tune ASR tutorial

## Download Thai Common Voice

In [None]:
!wget https://voice-prod-bundler-ee1969a6ce8178826482b88e843c335139bd3fb4.s3.amazonaws.com/cv-corpus-6.1-2020-12-11/th.tar.gz

In [None]:
!tar -xvf th.tar.gz

In [None]:
import os
import json
import pandas as pd

import tqdm
import time
import librosa
import numpy as np
from scipy.io import wavfile

import pythainlp
from pythainlp.tokenize import word_tokenize
from pythainlp.util import normalize, isthaichar
import multiprocessing as mp
import tqdm

import warnings
warnings.filterwarnings('ignore')

## Create menifest for NeMo

NeMo menifest format <br>
```json
{
    "audio_filepath": "cv-corpus-6.1-2020-12-11/th/clips/common_voice_th_23656065_trim.wav",
    "text": "เขา#กำลัง#มา#พอดี", 
    "duration": 1.253877551020408
}
{
    "audio_filepath": "cv-corpus-6.1-2020-12-11/th/clips/common_voice_th_23727063_trim.wav",
    "text": "ฉัน#สามารถ#ทำ#อะไร#ให้#คุณ#ได้#บ้าง", 
    "duration": 4.2260317460317465
}
```

We use # instead of spaces because the tool for build language models is word-level language models and it uses spaces to separate tokens, but we want to use character-level language models, so we need to use different symbols instead of spaces. 

In [None]:
def clean_text(text):
    norm_text = normalize(text)
    norm_text = ''.join(char for char in norm_text if isthaichar(char))
    words = word_tokenize(norm_text, engine='attacut')
    words = [words[i-1] if words[i] == 'ๆ' else words[i] for i in range(len(words))]
    return '#'.join(words)

def __process(inp):
    filename, transcript, inp_dir, out_dir = inp
    inp_file_path = os.path.join(inp_dir, "clips", filename)
    out_file_path = os.path.join(out_dir, "clips", filename.replace('.mp3', '_trim.wav'))
    text = clean_text(transcript)
    y, sr = librosa.load(inp_file_path)
    y_trim, _ = librosa.effects.trim(y, top_db=30)
    duration = librosa.get_duration(y_trim, sr)
    wavfile.write(out_file_path, sr, (y_trim*32768).astype('int16'))
    return '{"audio_filepath": "'+out_file_path+'", \
                       "text": "'+text.strip()+'", "duration": '+str(duration)+'}\n'

def clean_and_build_menifest(inp_dir, out_dir, num_workers = 30):

    cates = ['train', 'dev', 'test']
    
    if not os.path.exists(out_dir):
        os.makedirs(os.path.join(out_dir, 'clips'))
    
    for cate in cates:
        entries = []
        df = pd.read_csv(os.path.join(inp_dir, cate+".tsv"), sep='\t')
        with mp.Pool(num_workers) as p:
            results = p.imap(__process, [(row.path, row.sentence, inp_dir, out_dir) \
                                         for row in df.itertuples(index=False)])

            for result in tqdm.tqdm(results, total=len(df)):
                entries.append(result)
                # print(entries)

        json_file = open(os.path.join(out_dir, cate+'_trim.json'), 'w')
        for entry in entries:
            json_file.write(entry)
        json_file.close()

# Create menifest
num_workers = 16
clean_and_build_menifest('cv-corpus-6.1-2020-12-11/th/', 'common_voice_th_clean', num_workers)

## Add word error rate metric

WER metric implemented in NeMo is only support seperate word by space but we use #, So we need to customize WER metric code.

In [None]:
import torch

from nemo.utils import logging
from typing import List
from torchmetrics import Metric
import editdistance

class CUSTOM_WER(Metric):
    def __init__(self, 
                 vocabulary,
                 batch_dim_index=0,
                 use_cer=False,
                 ctc_decode=True,
                 log_prediction=True,
                 dist_sync_on_step=False,
                 sep_token=' ',):
        super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False)
        self.batch_dim_index = batch_dim_index
        self.blank_id = len(vocabulary)
        self.labels_map = dict([(i, vocabulary[i]) for i in range(len(vocabulary))])
        self.use_cer = use_cer
        self.ctc_decode = ctc_decode
        self.log_prediction = log_prediction

        self.add_state("scores", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)
        self.add_state("words", default=torch.tensor(0), dist_reduce_fx='sum', persistent=False)
        
        self.sep_token=sep_token

    def ctc_decoder_predictions_tensor(
        self, predictions: torch.Tensor, predictions_len: torch.Tensor = None, return_hypotheses: bool = False,
    ) -> List[str]:
        """
        Decodes a sequence of labels to words
        Args:
            predictions: A torch.Tensor of shape [Batch, Time] of integer indices that correspond
                to the index of some character in the label set.
            predictions_len: Optional tensor of length `Batch` which contains the integer lengths
                of the sequence in the padded `predictions` tensor.
            return_hypotheses: Bool flag whether to return just the decoding predictions of the model
                or a Hypothesis object that holds information such as the decoded `text`,
                the `alignment` of emited by the CTC Model, and the `length` of the sequence (if available).
                May also contain the log-probabilities of the decoder (if this method is called via
                transcribe())
        Returns:
            Either a list of str which represent the CTC decoded strings per sample,
            or a list of Hypothesis objects containing additional information.
        """
        hypotheses = []
        # Drop predictions to CPU
        prediction_cpu_tensor = predictions.long().cpu()
        # iterate over batch
        for ind in range(prediction_cpu_tensor.shape[self.batch_dim_index]):
            prediction = prediction_cpu_tensor[ind].detach().numpy().tolist()
            if predictions_len is not None:
                prediction = prediction[: predictions_len[ind]]
            # CTC decoding procedure
            decoded_prediction = []
            previous = self.blank_id
            for p in prediction:
                if (p != previous or previous == self.blank_id) and p != self.blank_id:
                    decoded_prediction.append(p)
                previous = p

            text = self.decode_tokens_to_str(decoded_prediction)

            if not return_hypotheses:
                hypothesis = text
            else:
                hypothesis = Hypothesis(
                    y_sequence=None,
                    score=-1.0,
                    text=text,
                    alignments=prediction,
                    length=predictions_len[ind] if predictions_len is not None else 0,
                )

            hypotheses.append(hypothesis)
        return hypotheses

    def decode_tokens_to_str(self, tokens: List[int]) -> str:
        """
        Implemented in order to decoder a token list into a string.
        Args:
            tokens: List of int representing the token ids.
        Returns:
            A decoded string.
        """
        hypothesis = ''.join(self.decode_ids_to_tokens(tokens))
        return hypothesis

    def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]:
        """
        Implemented in order to decode a token id list into a token list.
        A token list is the string representation of each token id.
        Args:
            tokens: List of int representing the token ids.
        Returns:
            A list of decoded tokens.
        """
        token_list = [self.labels_map[c] for c in tokens if c != self.blank_id]
        return token_list

    def compute(self):
        scores = self.scores.detach().float()
        words = self.words.detach().float()
        return scores / words, scores, words

    def update(
        self,
        predictions: torch.Tensor,
        targets: torch.Tensor,
        target_lengths: torch.Tensor,
        predictions_lengths: torch.Tensor = None,
    ) -> torch.Tensor:
        words = 0.0
        scores = 0.0
        references = []
        with torch.no_grad():
            # prediction_cpu_tensor = tensors[0].long().cpu()
            targets_cpu_tensor = targets.long().cpu()
            tgt_lenths_cpu_tensor = target_lengths.long().cpu()

            # iterate over batch
            for ind in range(targets_cpu_tensor.shape[self.batch_dim_index]):
                tgt_len = tgt_lenths_cpu_tensor[ind].item()
                target = targets_cpu_tensor[ind][:tgt_len].numpy().tolist()
                reference = self.decode_tokens_to_str(target)
                references.append(reference)
            if self.ctc_decode:
                hypotheses = self.ctc_decoder_predictions_tensor(predictions, predictions_lengths)
            else:
                raise NotImplementedError("Implement me if you need non-CTC decode on predictions")

        if self.log_prediction:
            logging.info(f"\n")
            logging.info(f"reference:{references[0]}")
            logging.info(f"predicted:{hypotheses[0]}")

        for h, r in zip(hypotheses, references):
            if self.use_cer:
                h_list = list(h.replace(self.sep_token, ''))
                r_list = list(r.replace(self.sep_token, ''))
            else:
                h_list = h.split(self.sep_token)
                r_list = r.split(self.sep_token)
            words += len(r_list)
            # Compute Levenstein's distance
            scores += editdistance.eval(h_list, r_list)

        self.scores = torch.tensor(scores, device=self.scores.device, dtype=self.scores.dtype)
        self.words = torch.tensor(words, device=self.words.device, dtype=self.words.dtype)

## Fine Tune

In [None]:
import pytorch_lightning as pl

from omegaconf import OmegaConf, DictConfig

import nemo
from nemo.collections.asr.models import EncDecCTCModel
from nemo.collections.common.callbacks import LogEpochTimeCallback
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

In [None]:
# load config
cfg = OmegaConf.load('quartznet_5x3_th.yaml')

# load pretrain
asr_model = EncDecCTCModel.load_from_checkpoint("QuartzNet5x3_2e02_libri---val_wer=0.08-epoch=97.ckpt")

# change decoder output
asr_model.change_vocabulary(
        new_vocabulary=['#','ก','ข','ฃ','ค','ฅ','ฆ','ง','จ','ฉ','ช','ซ','ฌ','ญ','ฎ','ฏ','ฐ','ฑ','ฒ','ณ','ด','ต',
                  'ถ','ท','ธ','น','บ','ป','ผ','ฝ','พ','ฟ','ภ','ม','ย','ร','ล','ว','ศ','ษ','ส','ห','ฬ',
                  'อ','ฮ','ฤ','ฦ','ะ','ั','า','ำ','ิ','ี','ึ','ื','ุ','ู','เ','แ','โ','ใ','ไ','ๅ','ํ','็','่','้','๊','๋']
    )

# set metric
asr_model._wer = CUSTOM_WER(
            vocabulary=asr_model.decoder.vocabulary,
            batch_dim_index=0,
            use_cer=False,
            ctc_decode=True,
            dist_sync_on_step=True,
            log_prediction=False,
            sep_token='#'
)

trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))

# setup dataset

asr_model.setup_optimization(optim_config=DictConfig(cfg.model.optim))
# Point to the data we'll use for fine-tuning as the training set
asr_model.setup_training_data(train_data_config=cfg.model.train_ds)
# Point to the new validation data for fine-tuning
asr_model.setup_validation_data(val_data_config=cfg.model.validation_ds)

asr_model.setup_test_data(test_data_config=cfg.model.test_ds)
asr_model.set_trainer(trainer)

trainer.fit(asr_model)

## Test model

In [None]:
trainer.test(asr_model)

## Inferrence

In [None]:
import IPython.display as ipd

ipd.Audio("common_voice_th_clean/clips/common_voice_th_23670069_trim.wav")

In [25]:
asr_model.transcribe([
    "common_voice_th_clean/clips/common_voice_th_23670069_trim.wav",
])

Transcribing:   0%|          | 0/1 [00:00<?, ?it/s]

['เมื่อ#ฉัน#รัก#ใคร#ทุก#คน#อย่าง#อดโยน#และ#มัน#ดูก#จับ#ใซ']

# Language model

We use KenLM to build language model.
For more information about KenLM : https://github.com/kpu/kenlm

### load test data

In [None]:
target_transcripts = list()
with open('common_voice_th_clean/test_trim.json', 'r') as test_manifest_file:
    audio_file_paths = []
    lines = test_manifest_file.read().strip('\n').split('\n')
    for line in lines:
        data = json.loads(line)
        target_transcripts.append(data['text'])
        audio_file_paths.append(data['audio_filepath'])

In [None]:
from nemo.collections.asr.modules import BeamSearchDecoderWithLM
from nemo.collections.asr.metrics.wer import WER, word_error_rate
import kenlm_utils

def beam_search(logits, beam_search_decoder, beam_batch_size=32):
    hypotheses = list()
    it = range(int(np.ceil(len(logits) / beam_batch_size)))
    for batch_idx in it:
        # disabling type checking
        with nemo.core.typecheck.disable_checks():
            probs_batch = logits[batch_idx * beam_batch_size : (batch_idx + 1) * beam_batch_size]
            beams_batch = beam_search_lm.forward(log_probs=probs_batch, log_probs_length=None,)
            for beams_idx, beams in enumerate(beams_batch):
                hypotheses.append(beams[0][1])
    return hypotheses

def decode_with_lm(asr_model, decoder, files, batch_size=1):
    logits = asr_model.transcribe(
                files,
                batch_size=batch_size,
                logprobs=True
            )
    probs = [kenlm_utils.softmax(logit.cpu().numpy()) for logit in logits]
    preds = beam_search(probs, beam_search_lm, batch_size)
    return preds

In [None]:
# initial decoder
beam_search_lm = BeamSearchDecoderWithLM(
        vocab=asr_model.decoder.vocabulary,
        beam_width=32,
        alpha=2.0,
        beta=1.7,
        lm_path='common_voice_6gram.bin',
        num_cpus=16,
        input_tensor=False,
)


start_time = time.time()

test_preds = decode_with_lm(asr_model, beam_search_lm, 
                            audio_file_paths, batch_size = 32)

print("Inferrence Time: ", time.time()-start_time)

test_cer_value = word_error_rate(hypotheses=[g.replace('#', '') for g in test_preds], 
                                 references=[g.replace('#', '') for g in target_transcripts], 
                                 use_cer=True)
test_wer_value = word_error_rate(hypotheses=[g.replace('#', ' ') for g in test_preds], 
                                references=[g.replace('#', ' ') for g in target_transcripts])
print("CER : {:.2f}%, WER : {:.2f}%".format(test_cer_value*100, test_wer_value*100))

In [None]:
decode_with_lm(asr_model, 
               beam_search_lm, 
              ["common_voice_th_clean/clips/common_voice_th_23670069_trim.wav"]
              )