In [1]:
%%capture
!pip install pyctcdecode
!python -m pip install pypi-kenlm
!pip install jiwer

![](https://developer-blogs.nvidia.com/wp-content/uploads/2019/12/automatic-speech-recognition_updated.png)

![](https://www.researchgate.net/profile/Diana-Militaru/publication/299594444/figure/fig1/AS:346834426974208@1459703179403/The-block-diagram-of-an-automatic-speech-recognition-and-understanding-system.png)

in this notebook we will try to demonstrate how to calculate CER,WER metric on validation dataset using xls-r wav2vec2 model,we will be using public best available pretrained model from huggingface to demonstrate the metric calculation process. for understanding how to train wav2vec2 on this dataset please check our past work [wav2vec2 starter](https://www.kaggle.com/code/nazmuddhohaansary/wave2vec2-starter-for-dl-sprint-commonvoice)

# Imports

In [2]:
import os
import numpy as np
from tqdm.auto import tqdm
from glob import glob
from transformers import AutoFeatureExtractor, pipeline
import pandas as pd
import librosa
import IPython
from datasets import load_metric
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
import torch
import re
import gc
import wave
from scipy.io import wavfile
import scipy.signal as sps

tqdm.pandas()
import warnings
warnings.filterwarnings("ignore")



# Configs

In [3]:
#according to our experiment this is the best model -> arijitx/wav2vec2-xls-r-300m-bengali
class CFG:
    model_name = 'arijitx/wav2vec2-xls-r-300m-bengali' #arijitx/wav2vec2-large-xlsr-bengali,arijitx/wav2vec2-xls-r-300m-bengali, Tahsin-Mayeesha/wav2vec2-bn-300m
    valid_df_path = '../input/dlsprint/validation.csv'
    sample_sub_df_path = '../input/dlsprint/sample_submission.csv'
    valid = "../input/dlsprint/validation_files/"
    test = "../input/dlsprint/test_files/"
    valid_wav = '../input/validation-fileswav-format/validation_files_wav/'
    test_wav = '../input/test-wav-files-dl-sprint/test_files_wav/'
    batch_size = 48#not using this param now
    single_SPEECH_FILE = "../input/dlsprint/validation_files/common_voice_bn_30620258.mp3"
    



# single sample inference demo

In [4]:
asr = pipeline("automatic-speech-recognition", model=CFG.model_name, device=0)
feature_extractor = AutoFeatureExtractor.from_pretrained(
        CFG.model_name, cache_dir=None, use_auth_token=False
    )
speech, sr = librosa.load(CFG.single_SPEECH_FILE, sr=feature_extractor.sampling_rate)
prediction = asr(
            speech, chunk_length_s=112, stride_length_s=None
        )

pred = prediction["text"]
pred


Downloading:   0%|          | 0.00/1.99k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.18G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/287 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/309 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/261 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/755 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/78.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/3.48G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/78.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/57.2M [00:00<?, ?B/s]

'তার পিতার নাম কালীপ্রসন্ন ভট্টাচার্য'

# check the original audio

In [5]:
IPython.display.Audio(CFG.single_SPEECH_FILE)

# Fix paths

In [6]:
df = pd.read_csv('../input/dlsprint/validation.csv')
directory ="../input/dlsprint/validation_files/"
df["path"]=df["path"].progress_apply(lambda x:os.path.join(directory,str(x)))
df.head(3)

  0%|          | 0/7747 [00:00<?, ?it/s]

Unnamed: 0,client_id,path,sentence,up_votes,down_votes,age,gender,accents,locale
0,c0494c8220a53efec93f188e32be94d3c1832c48117423...,../input/dlsprint/validation_files/common_voic...,"কৃষি, সেবা, রেমিটেন্স, ব্যবসা ও অন্যান্য।",3.0,0.0,,,,bn
1,c0494c8220a53efec93f188e32be94d3c1832c48117423...,../input/dlsprint/validation_files/common_voic...,তিনি ছিলেন চাকমা ভাষার প্রথম আধুনিক গীতিকার।,6.0,1.0,,,,bn
2,c06b36547c86713d53bb2bf696a34b696de586c5ab1aa9...,../input/dlsprint/validation_files/common_voic...,ইংরেজির সাথে সাথে তাদের হিন্দী ও সংস্কৃত শিক্ষ...,3.0,1.0,,,,bn


# Custom dataset class

librosa with mp3 is super slow,so we will be using wav files for faster inference

In [7]:
class bn_asr_Dataset(Dataset):
    '''
    args:
        df      : path of the dataframe
        dir     : directory of sound files
    '''
    def __init__(self,df,dir):
        self.df = pd.read_csv(df)
        self.dir = dir

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
   
        #speech, _ = librosa.load(self.dir+self.df.path[i], sr=feature_extractor.sampling_rate) 
        path = self.dir+self.df.path[i]
        path = os.path.splitext(path)[0]+'.wav'
        # Read file
        sampling_rate, data = wavfile.read(path)
        # Resample data
        number_of_samples = round(len(data) * float(feature_extractor.sampling_rate) / sampling_rate)
        speech = sps.resample(data, number_of_samples)
        return speech
  


# making prediction on whole validation set

In [8]:
%%time
#single image inference
''' 
#super slow inference...

predictions = []
references = []
for i in range(len(df.path)):
    speech, sr = librosa.load(df.path[i], sr=feature_extractor.sampling_rate)
    prediction = asr(speech, chunk_length_s=112, stride_length_s=None)
    pred = prediction["text"]
    predictions.append(pred)
    references.append(df.sentence[i])
    
print(len(predictions),len(references))
'''

df = pd.read_csv(CFG.valid_df_path)
valid_dataset = bn_asr_Dataset(CFG.valid_df_path,CFG.valid_wav)#CFG.valid
predictions = []
references = []
# for i,pred_sentence in enumerate(tqdm(asr(valid_dataset, chunk_length_s=112, stride_length_s=None,batch_size=CFG.batch_size), total=len(valid_dataset))):
#     references.append(df.sentence[i])
#     predictions.append(pred_sentence['text'])
    
for i in range(len(valid_dataset)):
    pred = asr(valid_dataset.__getitem__(i), chunk_length_s=112, stride_length_s=None)
    references.append(df.sentence[i])
    predictions.append(pred['text'])
  

CPU times: user 19min 28s, sys: 5.04 s, total: 19min 33s
Wall time: 21min 14s


In [9]:
torch.cuda.empty_cache() 
gc.collect()
!nvidia-smi

Sat Jul 16 16:29:04 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   49C    P0    36W / 250W |   2167MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

# WER (word error rate) calculation process



![](https://miro.medium.com/max/700/1*MUGLdWm3zMYK7dLmyo3pqA.png)

**WER = 100 (insertions(INS) + substitutions(SUB) + deletions(DEL))**

![](http://www.italk2learn.eu/wp-content/uploads/2015/02/speech-bubble-image.png)

# CER (character error rate) calculation process



character error rate (cer) is a common metric of the performance of an automatic speech recognition system. This value indicates the percentage of characters that were incorrectly predicted. The lower the value, the better the performance of the ASR system with a CER of 0 being a perfect score.

CER calculation is based on the concept of [Levenshtein distance](https://towardsdatascience.com/evaluating-ocr-output-quality-with-character-error-rate-cer-and-word-error-rate-wer-853175297510#9bd1), where we count the minimum number of character-level operations required to transform the ground truth text (aka reference text) into the OCR output.

Character Error Rate (CER) formula :

![](https://miro.medium.com/max/700/1*KsWFDKnLI7mudmhbzGjc4w.png)

where:

* S = Number of Substitutions
* D = Number of Deletions
* I = Number of Insertions
* N = Number of characters in reference text (aka ground truth)

Let’s look at an example:

**Ground Truth Reference Text**: 809475127

**ASR Transcribed Output Text**: 80g475Z7

Several errors require edits to transform ASR output into the ground truth:

1. g instead of 9 (at reference text character 3)
2. Missing 1 (at reference text character 7)
3. Z instead of 2 (at reference text character 8)

With that, here are the values to input into the equation:

* Number of Substitutions (S) = 2
* Number of Deletions (D) = 1
* Number of Insertions (I) = 0
* Number of characters in reference text (N) = 9

Based on the above, we get (2 + 1 + 0) / 9 = 0.3333. When converted to a percentage value, the CER becomes 33.33%. This implies that every 3rd character in the sequence was incorrectly transcribed.

We repeat this calculation for all the pairs of transcribed output and corresponding ground truth, and take the mean of these values to obtain an overall CER percentage.

**Reference :** [Evaluate OCR Output Quality with Character Error Rate (CER) and Word Error Rate (WER)](https://towardsdatascience.com/evaluating-ocr-output-quality-with-character-error-rate-cer-and-word-error-rate-wer-853175297510#5aec)

# calculating metric on whole validation set

In [10]:

df = pd.DataFrame(columns=['predictions', 'references'])
df.predictions = predictions
df.references = references
df.to_csv('./results.csv',index = False) #use it for error analysis and other stuffs
df.head(10)

Unnamed: 0,predictions,references
0,কৃষি সেবা রেমিটেন্স ব্যবসা ও অন্যান্য,"কৃষি, সেবা, রেমিটেন্স, ব্যবসা ও অন্যান্য।"
1,তিনি ছিলেন চাকমা বাসার প্রথ মাধ্যমের গীতিকা।,তিনি ছিলেন চাকমা ভাষার প্রথম আধুনিক গীতিকার।
2,ইংরেজি সাথে সাথে তাদের হিন্দি ও সংস্কৃত শিক্ষা...,ইংরেজির সাথে সাথে তাদের হিন্দী ও সংস্কৃত শিক্ষ...
3,শিক্ষার ধারণ তার প্রথম আন্তর্জাতিক উইক েটশিকার...,শিখর ধাওয়ান তার প্রথম আন্তর্জাতিক উইকেট শিকার...
4,চতুর্থ সপ্তাহ থেকে অবস্থার উন্নতি হতে থাকে কিন...,"চতুর্থ সপ্তাহ থেকে অবস্থার উন্নতি হতে থাকে,কিন..."
5,এখানে তিনি এয়াকিনসের অধীনে অধ্যয়ন করেছেন,এখানে তিনি এয়াকিনসের অধীনে অধ্যয়ন করেছিলেন।
6,শীতকালীন গেমস এ কোন কোন পদক জিততে পারেনি।,শীতকালীন গেমসে এখনো কোন পদক জিততে পারেনি।
7,অবশেষে তিনি কাবুলে ফিরে আসেন,"অবশেষে, তিনি কাবুলে ফিরে আসেন।"
8,তিনি অ্যাথেলসেএসে চক্রটির শীর্ষ হন,তিনি এথেন্সে এসে সক্রেটিসের শিষ্য হন।
9,তিন বছর বয়সে তার বাবা মারা যান,তিন বছর বয়সে তাঁর বাবা মারা যান।


# Without Post Processing

In [11]:
cer = load_metric("cer")
wer = load_metric("wer")

cer_score = cer.compute(predictions=predictions, references=references)
print("validation cer_score -> ",cer_score)
wer_score = wer.compute(predictions=predictions, references=references)
print("validation wer_score -> ",wer_score)

Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

Downloading builder script:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

validation cer_score ->  0.09787704766628824
validation wer_score ->  0.30921300101701055


# With  post processing

during error analysis using the results.csv file we've seen that the model is frequently missing to predict punctuations, almost all the sentences in ground truth ends with '।' but while predicting using the public best trained model we can see that the model is missing to predict '।' most of the times, so in the simple post processing code below we will check if the predicted sentence ends with '।' or not,if no then we forcefully add '।' at the end of the predicted sentence.

In [12]:
for i in range(len(predictions)):
    if(predictions[i][-1] == '।'):
        continue
    else:
        predictions[i] = predictions[i]+'।'

In [13]:
cer_score = cer.compute(predictions=predictions, references=references)
print("Final validation cer_score -> ",cer_score)
wer_score = wer.compute(predictions=predictions, references=references)
print("Final validation wer_score -> ",wer_score)

Final validation cer_score ->  0.09301847750965592
Final validation wer_score ->  0.28501372267654884


**WOW great!!!
with above post processing word error rate improved from 0.30921300101701055 to 0.28501372267654884 and that's 0.024199278340461705 improvement,not bad no?**


# Submission with post processing

In [14]:
df = pd.read_csv('../input/dlsprint/sample_submission.csv')
len(df.path)

7747

In [15]:
%%time

test_dataset = bn_asr_Dataset(CFG.sample_sub_df_path,CFG.test_wav)

# for i,prediction in enumerate(tqdm(asr(test_dataset, chunk_length_s=112, stride_length_s=None,batch_size=CFG.batch_size), total=len(test_dataset))):
#     df.sentence[i] = prediction["text"]
    
for i in range(len(test_dataset)):
    pred = asr(test_dataset.__getitem__(i), chunk_length_s=112, stride_length_s=None)
    
    #applying simple post processing with error handler
    try:
        if(pred["text"][-1] == '।'):
            df.sentence[i] = pred["text"]
        else:
            df.sentence[i] = pred["text"]+'।'
    except:
        print("predicted text at idx ",i," is -> ",pred["text"])
        df.sentence[i] = pred["text"]+'।'


predicted text at idx  3936  is ->  
CPU times: user 20min 37s, sys: 5.7 s, total: 20min 43s
Wall time: 22min 21s


In [16]:
df.head(3)

Unnamed: 0,path,sentence
0,common_voice_bn_31675220.mp3,এছাড়াও নিউজল্যান্ড এই ক্রিকেট দলের হয়ে খেলছে...
1,common_voice_bn_31513116.mp3,এই ফল পাখি রাখায় কিন্তু নিজে পড়ে থাকা ফল খেল...
2,common_voice_bn_31558126.mp3,জন পরিকল্পিত।


In [17]:
df.to_csv('./submission.csv',index = False)
df.sentence[1]

'এই ফল পাখি রাখায় কিন্তু নিজে পড়ে থাকা ফল খেলে কুকুর অসুস্থ হয়ে পড়ে।'

In [18]:
IPython.display.Audio('../input/dlsprint/test_files/common_voice_bn_31675220.mp3')

In [19]:
df.sentence[0]

'এছাড়াও নিউজল্যান্ড এই ক্রিকেট দলের হয়ে খেলছেন তিনি।'

In [20]:
df.sentence[80]

'প্রথম শ্রেণির ক্রিকেট প্রতিযোগিতা শেলিল্ডের উদ্বোধনী আসরে অংশ নেয়।'

# optional (post ASR correction attempt)

in this section we will try to implement the recent best research on POST OCR (optical character recognition) CORRECTION titled[ Post-OCR Document Correction with large Ensembles of Character
Sequence-to-Sequence Models](https://arxiv.org/pdf/2109.06264.pdf) this research work was done in ocr domain and not in ASR domain so i was thinking what will happen if we try this approach in ASR domain? **well if you never try you'll never know**.
The core of this system is a standard sequence-to-sequence model that can correct sequences of characters. In the below implementation, we used a Transformer as the sequence model, which takes as input a segment of characters from the document to correct, and the output is the corrected segment. To train this sequence model, it is necessary to align the raw documents with their corresponding correct transcriptions, which is not always straightforward.Since the output is not necessarily of the same length as the input (because of possible insertions or deletions of characters), a decoding method like Greedy Search or Beam Search
is needed to produce the most likely corrected sequence according to the model.
for the below experiment we will be using results.csv where references column contains actual clean annotation and predictions column contains output of STT model including errors


In [21]:
!git clone https://github.com/jarobyte91/post_ocr_correction.git
os.chdir('./post_ocr_correction')
!pwd
!ls

Cloning into 'post_ocr_correction'...
remote: Enumerating objects: 225, done.[K
remote: Counting objects: 100% (22/22), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 225 (delta 11), reused 6 (delta 2), pack-reused 203[K
Receiving objects: 100% (225/225), 287.00 KiB | 1.65 MiB/s, done.
Resolving deltas: 100% (107/107), done.
/kaggle/working/post_ocr_correction
LICENSE    download_data.py  lib	pyproject.toml	  setup.cfg  train
README.md  evaluate	     notebooks	requirements.txt  tests


In [22]:
!pip install .

Processing /kaggle/working/post_ocr_correction
  Installing build dependencies ... [?25l- \ | / - done
[?25h  Getting requirements to build wheel ... [?25l- done
[?25h  Installing backend dependencies ... [?25l- \ | done
[?25h  Preparing metadata (pyproject.toml) ... [?25l- done
[?25hCollecting pytorch-beam-search<2,>1.2
  Downloading pytorch_beam_search-1.2.2-py3-none-any.whl (18 kB)
Collecting nltk>=3.6.5
  Downloading nltk-3.7-py3-none-any.whl (1.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
Building wheels for collected packages: post-ocr-correction
  Building wheel for post-ocr-correction (pyproject.toml) ... [?25l- done
[?25h  Created wheel for post-ocr-correction: filename=post_ocr_correction-0.0.1-py3-none-any.whl size=6368 sha256=fc1dccb07b6ace86862aa2bb9897869c03b0eba104b290b924765aad388cf053
  Stored in directory: /tmp/pip-ephem-wheel-cache-p_3xh26i/whee

In [23]:
os.chdir('..')
!ls

__notebook__.ipynb  post_ocr_correction  results.csv  submission.csv


In [24]:
results = pd.read_csv('../input/commonvoice-bn-xls-r-metric-calculation/results.csv')
print(len(results))
preds = results.predictions.tolist()
refs = results.references.tolist()
results.head()

7747


Unnamed: 0,predictions,references
0,কৃষি সেবা রেমিটেন্স ব্যবসা ও অন্যান্য,"কৃষি, সেবা, রেমিটেন্স, ব্যবসা ও অন্যান্য।"
1,তিনি ছিলেন চাকমা বাসার প্রথ মাধ্যমের গীতিকা।,তিনি ছিলেন চাকমা ভাষার প্রথম আধুনিক গীতিকার।
2,ইংরেজি সাথে সাথে তাদের হিন্দি ও সংস্কৃত শিক্ষা...,ইংরেজির সাথে সাথে তাদের হিন্দী ও সংস্কৃত শিক্ষ...
3,শিক্ষার ধারণ তার প্রথম আন্তর্জাতিক উইক েটশিকার...,শিখর ধাওয়ান তার প্রথম আন্তর্জাতিক উইকেট শিকার...
4,চতুর্থ সপ্তাহ থেকে অবস্থার উন্নতি হতে থাকে কিন...,"চতুর্থ সপ্তাহ থেকে অবস্থার উন্নতি হতে থাকে,কিন..."


Be careful,
if you predict on train set using the best publicly available ASR bangla model from huggingface you will see the model making NaN prediction for many audio samples in train set,to get the index of those NaN output files i used the code below 

In [25]:
#no nan output in results.csv but they exist in train.csv (give it a try)
idx = [i for i, x in zip(range(len(preds)), preds) if not isinstance(x,str)]
for ele in sorted(idx, reverse = True):
    del preds[ele]
    del refs[ele]
len(preds),len(refs)


(7747, 7747)

In [26]:
torch.cuda.empty_cache()
gc.collect()
!nvidia-smi

Sat Jul 16 16:52:23 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.82.01    Driver Version: 470.82.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   45C    P0    35W / 250W |   2167MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [27]:
from pytorch_beam_search import seq2seq
from post_ocr_correction import correction

for i in range(len(refs)):
    refs[i] = list(refs[i])
    preds[i] = list(preds[i])


# train data and model
source = preds
target = refs
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)
X = source_index.text2tensor(source)
Y = target_index.text2tensor(target)
print(source_index)
print(".....")
print(target_index)
print(".....")
print(X)
print(".....")
print(Y)

<Seq2Seq Index with 78 items>
.....
<Seq2Seq Index with 80 items>
.....
tensor([[ 1, 20, 58,  ...,  0,  0,  0],
        [ 1, 35, 54,  ...,  0,  0,  0],
        [ 1, 11,  7,  ...,  0,  0,  0],
        ...,
        [ 1, 22, 44,  ...,  0,  0,  0],
        [ 1, 16, 11,  ...,  0,  0,  0],
        [ 1, 20, 56,  ...,  0,  0,  0]])
.....
tensor([[ 1, 30, 68,  ...,  0,  0,  0],
        [ 1, 45, 64,  ...,  0,  0,  0],
        [ 1, 21, 17,  ...,  0,  0,  0],
        ...,
        [ 1, 32, 54,  ...,  0,  0,  0],
        [ 1, 26, 21,  ...,  0,  0,  0],
        [ 1, 30, 66,  ...,  0,  0,  0]])



![](https://i.ibb.co/741XzvD/post-stt-corrector.png)

In [28]:
%%time

epochs = 6000
batch_size = 1024
PATH = './post_ASR_corrector.pt'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
X= X.to(device)
Y= Y.to(device)

# def post_asr_corrector():
#     model = seq2seq.Transformer(source_index, target_index,max_sequence_length = len(results)+4,
#                     embedding_dimension = 512,
#                     feedforward_dimension = 1024,
#                     attention_heads = 2,
#                     encoder_layers = 2,
#                     decoder_layers = 2)
#     return model

def post_asr_corrector():
    model = seq2seq.Transformer(source_index, target_index,max_sequence_length = 256,dropout = 0.0,embedding_dimension = 192)
    return model
model = post_asr_corrector()
model.to(device)
model.train()
model.fit(X, Y, epochs = epochs, progress_bar = 1,batch_size = batch_size)
model.eval()
torch.save(model.state_dict(), PATH)

cuda
Model: Seq2Seq Transformer
Source index: <Seq2Seq Index with 78 items>
Target index: <Seq2Seq Index with 80 items>
Max sequence length: 256
Embedding dimension: 192
Feedforward dimension: 128
Encoder layers: 2
Decoder layers: 2
Attention heads: 2
Activation: relu
Dropout: 0.0
Trainable parameters: 1,186,768



  0%|          | 0/6000 [00:00<?, ?it/s]

Training started
X_train.shape: torch.Size([7747, 156])
Y_train.shape: torch.Size([7747, 184])
Epochs: 6,000
Learning rate: 0.0001
Weight decay: 0
Epoch | Train                 | Minutes
      | Loss     | Error Rate |
---------------------------------------
    1 |   4.0696 |     97.039 |     0.1
    2 |   3.6394 |     95.553 |     0.1
    3 |   3.5109 |     95.452 |     0.2
    4 |   3.4444 |     95.431 |     0.3
    5 |   3.3961 |     95.371 |     0.4
    6 |   3.3464 |     95.160 |     0.4
    7 |   3.2853 |     94.610 |     0.5
    8 |   3.2081 |     93.900 |     0.6
    9 |   3.1242 |     93.295 |     0.6
   10 |   3.0456 |     92.772 |     0.7
   11 |   2.9753 |     92.322 |     0.8
   12 |   2.9174 |     92.027 |     0.9
   13 |   2.8719 |     91.814 |     0.9
   14 |   2.8354 |     91.660 |     1.0
   15 |   2.8055 |     91.497 |     1.1
   16 |   2.7795 |     91.333 |     1.1
   17 |   2.7571 |     91.214 |     1.2
   18 |   2.7371 |     91.106 |     1.3
   19 |   2.7189 |   

# load and infer

In [29]:
del model
torch.cuda.empty_cache()
gc.collect()


model = post_asr_corrector()

model.load_state_dict(torch.load(PATH))
model.to(device)
model.eval()
model

Model: Seq2Seq Transformer
Source index: <Seq2Seq Index with 78 items>
Target index: <Seq2Seq Index with 80 items>
Max sequence length: 256
Embedding dimension: 192
Feedforward dimension: 128
Encoder layers: 2
Decoder layers: 2
Attention heads: 2
Activation: relu
Dropout: 0.0
Trainable parameters: 1,186,768



Transformer(
  (source_embeddings): Embedding(78, 192)
  (target_embeddings): Embedding(80, 192)
  (positional_embeddings): Embedding(256, 192)
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=192, out_features=192, bias=True)
          )
          (linear1): Linear(in_features=192, out_features=128, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (linear2): Linear(in_features=128, out_features=192, bias=True)
          (norm1): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((192,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.0, inplace=False)
          (dropout2): Dropout(p=0.0, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonD

In [30]:
# test data
test = 'শীতকালীন গেমসে এখনো কোন পদক জিততে পারেনি।'
reference = 'শীতকালীন গেমস এ কোন কোন পদক জিততে পারেনি।'
new_source = [list(test)]
X_new = source_index.text2tensor(new_source).to(device)

# plain beam search
predictions, log_probabilities = seq2seq.beam_search(
    model, 
    X_new,
    progress_bar = 0
)
just_beam = target_index.tensor2text(predictions[:, 0, :])[0]
just_beam = re.sub(r"<START>|<PAD>|<UNK>|<END>.*", "", just_beam)
just_beam

'তীতকালীন গেমসখতখন-খত'

In [31]:
log_probabilities

tensor([[-1.6066, -1.6481, -1.8176, -2.1724, -3.2150]], device='cuda:0')

In [32]:
# post ASR correction
# disjoint_beam = correction.disjoint(
#     test,
#     model,
#     source_index,
#     target_index,
#     50,
#     "beam_search",
# )

In [33]:
print("\nresults")
print("  test data                      ", test)
print("  plain beam search              ", just_beam)
#print("  disjoint windows, beam search  ", disjoint_beam)



results
  test data                       শীতকালীন গেমসে এখনো কোন পদক জিততে পারেনি।
  plain beam search               তীতকালীন গেমসখতখন-খত


# improvement ideas

for better post ASR correction,i would like to recommend going deep in [ROBART](https://arxiv.org/pdf/2202.01157.pdf)
![](https://i.ibb.co/YWVGRVF/post-asr-corrector.png)

![](https://i.ibb.co/9sZbvD7/post-asr.png)
one example implementation of levenshtein transformer can be found [here](https://github.com/nmfisher/levenshtein_transformer/blob/master/Untitled.ipynb)

more about post ASR correction was discussed [here](https://www.kaggle.com/competitions/dlsprint/discussion/335411)

![](https://images.unsplash.com/photo-1499744937866-d7e566a20a61?ixlib=rb-1.2.1&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=870&q=80)