# Use TranS

In [1]:
import torch

In [2]:
# add to $y$tem path the parent directory
import os
import sys
notebook_dir = os.path.abspath(os.path.dirname(''))
parent_dir = os.path.abspath(os.path.join(notebook_dir, os.pardir))
sys.path.append(parent_dir)

In [3]:
from transformer import model
import torch



In [4]:
from transformer import transformer_probab,transformer_probab_list,transformer_probab_final_word

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
sample_sentences = ['हरेक [MASK] नेपामको संविधानक पालना गर्नुपर्छ ।' ,
                    'म पुस्तकालयबाट थुलो किताब पढ्न चाहन्छ ।',
                    'तर उस समयमा पनि स्वस्थ राजनैतिक वातावरनको अभावले गर्दा देश विकासतर्फ विशेष प्रगति हुन  सकेन।',
                   'नेपालमा आधुनिक रुपमा आर्थक विकाससम्बन्धी कार्यरू प्रारम्भ भएको हालै मात्र हो।',
                   'हार धुनुहोस् र स्वास्थ जीवन जिउनुहोस्।',
                   'जब प्रवीधिहरू एकीकृत हुन सूरु गर्छन् अर्थतन्त्र तथ सँस्कृति पनि निश्चितरूपमा विस्तारै एकीकृत हुने छ।',
                   'उद्देयहरुमा पनि कुनै एक उद्देश्य पूर्ति नहुँदै अर्को नयाँ  उद्देश्यको रुपमा लिइने परम्परा बस्यो।',
                   'लगानीकर्ताहरूको धयान तुरुन्त फेरियो , व्यापारीहरुले वताए ।',
                    'अति धेरै हिज्जे गलती भएका शब्दहरू । तपाईँले टाइप गर्दा हिज्जे जाँच अक्षम पारयो ।',
                    'लामो',
                    'म नेपाली राम्रोसँग बोल्दिन',
                    'तपाईंको उमर कति हो?',
                    'मलाइ एक्लै छोडनुहोस्',
                   'रुसी राष्ट्रपती पुटिनको प्रेम जावन त्य्ती सफल छैन ।']

In [7]:
#help(tokenizer)

# Functions to correct

In [8]:
import torch
import pickle
import BrillMoore
import regex as re
from utils import final_candidate_words, return_lexicon_dict
import math
import itertools

In [9]:
def words(text): 
    text = re.sub(r'[\u0964]', r'\u0020\u0964\u0020', text)
    return re.findall(r'[\u0900-\u097F]+', text.lower())

final_lexicon_dict = return_lexicon_dict()

In [10]:
with open(os.path.join(notebook_dir, '..','models','bma_27dec.pickle'),'rb') as f:
    bma = pickle.load(f)

In [11]:
len(final_lexicon_dict)

125727

In [12]:
def likelihood_bm(sentence,candidate_sentence):
    '''
    Returns P(Possible Typo Sentence/Candidate Correct Sentence)
    
    Uses Naive approach to compute probability for sentence from individual words
    
    '''    
    
    prod = 1
    for word,candidate_word in zip(sentence.split(),candidate_sentence):          
        prod*= bma.likelihood(word,candidate_word)
    return prod


alpha = 0.65
def constant_distributive_likelihood(sentence,candidate_sentence,candidate_count):
    prod = 1    
    i = 0    
    for word,candidate_word in zip(sentence.split(),candidate_sentence):        
        if word==candidate_word:
            prod*= alpha
        else:
            N = candidate_count[i]
            prod*= (1-alpha)/N
        i+=1
    return prod




In [13]:
def correctize_entire_nn(sentence, model,p_lambda = 1,prior='transformer',trie = False,likelihood = 'default'):
    
    tokens = words(sentence)

    candidates = []    
    
    #Forcing to limit the number of candidate sentences
    for _ in tokens:
        candidates.append(final_candidate_words(_,use_trie = trie,force = True))
    cs = list(itertools.product(*candidates))    
    candidate_sentences = [' '.join(sent) for sent in cs]

    if prior == 'transformer':
               
        candidate_probabilities = transformer_probab_list(candidate_sentences)

        if likelihood=='default':
            candidate_count = [len(_) for _ in candidates]  
            sentences_probab_post=[(row*p_lambda) +
                                   math.log(constant_distributive_likelihood(sentence,candidate_sentence,candidate_count)) 
                                   for row,candidate_sentence in zip(candidate_probabilities,cs)]
        elif likelihood=='bm':
            sentences_probab_post=[(row*p_lambda) + 
                                    math.log(likelihood_bm(sentence,candidate_sentence)) 
                                    for row,candidate_sentence in zip(candidate_probabilities,cs)]           

        elif likelihood=='defaultWithNoLM':
            candidate_count = [len(_) for _ in candidates]  
            modifier_probab = [sum([-1.5 if w not in final_lexicon_dict else 1.0 for w in c]) for c in cs]
            sentences_probab_post=[mp+ math.log(constant_distributive_likelihood(sentence,candidate_sentence,candidate_count)) 
                                   for candidate_sentence,mp in zip(cs,modifier_probab)]
        elif likelihood=='bmWithNoLM':
            modifier_probab = [sum([-1.5 if w not in final_lexicon_dict else 1.0 for w in c]) for c in cs]
            sentences_probab_post=[mp + math.log(likelihood_bm(sentence,candidate_sentence)) 
                                    for candidate_sentence,mp in zip(cs,modifier_probab)]
        sorted_index = torch.argsort(torch.tensor(sentences_probab_post))
        sentences_probab_post_sorted = sorted(sentences_probab_post,reverse = True)
        
        return [candidate_sentences[int(k)].split() for k in torch.flip(sorted_index,dims=(0,))],sentences_probab_post_sorted

In [14]:
def correctize_with_window_nn(sentence,model,window = 5,p_lambda = 1,prior = 'transformer',trie = False,likelihood = 'default'):
    '''
    
    '''   
    tokens = words(sentence)
    if len(tokens) <= window:
        return [correctize_entire_nn(sentence,model,p_lambda=p_lambda,prior = prior,trie = trie,likelihood = likelihood)]
    else:
        windows = [tokens[n:window+n] for n in range(0,len(tokens),window-1) if window+n <len(tokens)-1]    
        remaining = (window-1)*len(windows)
        windows.append(tokens[remaining:])
        corrects = []
        for _ in windows:
            d = correctize_entire_nn(' '.join(_),model,p_lambda=p_lambda,prior = prior,trie = trie,likelihood = likelihood)
            corrects.append(d)
        return corrects


In [15]:
def return_choices2(sample_sentences,model,p_lambda = 1,trie = False,model_type ='knlm' ,likelihood = 'default'):
    
    if model_type =='knlm':
        d = correctize_with_window_knlm(sample_sentences,model,p_lambda =p_lambda,trie = trie,likelihood = likelihood)
        window_candidates = []
        window_probab = []
        for window in d:
            maxim = min(len(window[0]),10)
            top_candidates = window[0][:maxim]
            window_candidates.append(top_candidates)
            window_probab.append(window[1][:maxim])
        return window_candidates,window_probab
    
    if model_type == 'transformer':
        d = correctize_with_window_nn(sample_sentences,model,p_lambda =p_lambda,trie = trie,likelihood = likelihood)
        window_candidates = []
        window_probab = []
        for window in d:
            maxim = min(len(window[0]),10)
            top_candidates = window[0][:maxim]
            window_candidates.append(top_candidates)
            window_probab.append(window[1][:maxim])
        return window_candidates,window_probab
        
def extract_choices(sample_sentences,model,p_lambda = 1,trie = False,likelihood = 'default',model_type = 'knlm'):
      
    wc,wp = return_choices2(sample_sentences,model,p_lambda = p_lambda,trie = trie ,model_type = model_type,likelihood = likelihood)
    choices_list=[[] for i in range(len(sample_sentences.split())+1)]
    const = 0
    for _ in wc:
        for sens in _:
            for i,w in enumerate(sens):
                index = i + const
                if w not in choices_list[index]:
                    choices_list[index].append(w)
        const += len(wc[0][0])-1
    if len(choices_list[len(choices_list)-1]) == 0:
        return choices_list[:len(choices_list)-1]
    return choices_list

In [30]:
extract_choices(sample_sentences[1],model=model,p_lambda = 0.8,trie = True,likelihood = 'bm',model_type = 'transformer')

20
2


[['म', 'मा'],
 ['पुस्तकालयबाट'],
 ['थुलो', 'थलो', 'ठूलो', 'ठुला', 'ठुलो'],
 ['किताब'],
 ['पढ्न', 'पढ्ने', 'बढ्न', 'चढ्न'],
 ['चाहन्छ', 'चाहन्छन्', 'चाहन्छु', 'चाहिन्छ'],
 ['।']]

In [22]:
extract_choices(sample_sentences[1],model=model,p_lambda = 0.8,trie = True,likelihood = 'bm',model_type = 'transformer')

Length of candidates: 1250
In candidate_probabilities...
20
Done...
Length of candidates: 125
In candidate_probabilities...
2
Done...


[['म', 'मा'],
 ['पुस्तकालयबाट'],
 ['थुलो', 'थलो', 'ठूलो', 'ठुला', 'ठुलो'],
 ['किताब'],
 ['पढ्न', 'पढ्ने', 'बढ्न', 'चढ्न'],
 ['चाहन्छ', 'चाहन्छन्', 'चाहन्छु', 'चाहिन्छ'],
 ['।']]

In [23]:
extract_choices(sample_sentences[2],model=model,p_lambda =0.135 ,trie = True,likelihood = 'bm',model_type = 'transformer')

Length of candidates: 3125
In candidate_probabilities...
49
Done...
Length of candidates: 1250
In candidate_probabilities...
20
Done...
Length of candidates: 625
In candidate_probabilities...
10
Done...
Length of candidates: 625
In candidate_probabilities...
10
Done...


[['तर', 'तार', 'कर'],
 ['उस', 'उसो', 'उप'],
 ['समयमा', 'समयका', 'समयमै'],
 ['पनि', 'नि'],
 ['स्वस्थ', 'अस्वस्थ'],
 ['राजनैतिक', 'राजनीतिक', 'राजनीतिका'],
 ['वातावरनको', 'वातावरणको'],
 ['अभावले', 'अभावमा'],
 ['गर्दा', 'गर्दै'],
 ['देश', 'दश'],
 ['विकासतर्फ'],
 ['विशेष', 'बिशेष', 'विशेषत'],
 ['प्रगति', 'प्रति', 'प्रालि'],
 ['हुन', 'छुन', 'कुन', 'धुन', 'जुन'],
 ['सकेन', 'सकिन', 'सकेनन्'],
 ['।']]

In [28]:
sample_sentences[2]

'तर उस समयमा पनि स्वस्थ राजनैतिक वातावरनको अभावले गर्दा देश विकासतर्फ विशेष प्रगति हुन  सकेन।'

In [30]:
extract_choices(sample_sentences[2],model=model,p_lambda =0.8 ,trie = True,likelihood = 'bm',model_type = 'transformer')

Length of candidates: 3125
In candidate_probabilities...
49
Done...
Length of candidates: 1250
In candidate_probabilities...
20
Done...
Length of candidates: 625
In candidate_probabilities...
10
Done...
Length of candidates: 625
In candidate_probabilities...
10
Done...


[['तर', 'तार', 'कर'],
 ['उस', 'उसो', 'उप', 'आस'],
 ['समयमा', 'समयमै'],
 ['पनि', 'उनि'],
 ['स्वस्थ', 'अस्वस्थ'],
 ['राजनैतिक', 'राजनीतिक', 'राजनीतिको', 'राजनीति', 'राजनीतिका'],
 ['वातावरणको'],
 ['अभावले', 'अभावमा'],
 ['गर्दा', 'गर्दै', 'गर्ला'],
 ['देश', 'दश'],
 ['विकासतर्फ'],
 ['विशेष', 'बिशेष', 'विदेश'],
 ['प्रगति', 'प्रति', 'प्रगतिको', 'प्रगाढ'],
 ['हुन', 'छुन'],
 ['सकेन', 'सकेनन्', 'सकिन', 'सकिएन'],
 ['।']]

# Evaluation

In [15]:
# import csv

# def read_csv_file(file_path):
#     data = []
#     with open(file_path, 'r',encoding='utf-8') as csv_file:
#         csv_reader = csv.reader(csv_file)
#         for row in csv_reader:
#             data.append(row)
#     return data

# # Example usage:
# file_path = 'data/raw_sentences_np_65k.csv'
# csv_data = read_csv_file(file_path)

# for row in csv_data:
#     print(row)


In [11]:
type(csv_data)

list

In [16]:
from eval import gather_dataset,WER,word_accuracy,char_accuracy

dataset_file = os.path.join(notebook_dir, '..','data','eval_data2.pic')

In [17]:
dataset = gather_dataset(dataset_file)

In [18]:
correct_tokens = [words(t[0]) for t in dataset]
error_tokens = [words(t[1]) for t in dataset]
error_sentences = [t[1] for t in dataset] 

### Bm

In [17]:
Gap = 0.2
a = 10
def find_p_lambda(error_sentences):
    
    for j in range(8):
        predicted_tokens = []
        for i,s in enumerate(error_sentences[:a]):
            c = extract_choices(s,model=model,p_lambda =Gap*j,trie = True,likelihood = 'bm',model_type = 'transformer')
        #     print(c)
            c = [t[0] for t in c]
            predicted_tokens.append(c)
            if i%2 == 0:
                print(i)
#         word_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a])
        print(Gap*j,' : ',word_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]),'WER1: ',WER(correct_tokens[:a],error_tokens[:a] ),'WER2: ',WER(correct_tokens[:a],predicted_tokens[:a] ))
find_p_lambda(error_sentences)

0
2
4
6
8
0.0  :  (0.0, 0, 42) WER1:  0.2727272727272727 WER2:  0.2727272727272727
0
2
4
6
8
0.2  :  (0.5714285714285714, 24, 42) WER1:  0.2727272727272727 WER2:  0.12337662337662338
0
2
4
6
8
0.4  :  (0.6428571428571429, 27, 42) WER1:  0.2727272727272727 WER2:  0.11038961038961038
0
2
4
6
8
0.6000000000000001  :  (0.7142857142857143, 30, 42) WER1:  0.2727272727272727 WER2:  0.09740259740259741
0
2
4
6
8
0.8  :  (0.7142857142857143, 30, 42) WER1:  0.2727272727272727 WER2:  0.09740259740259741
0
2
4
6
8
1.0  :  (0.7142857142857143, 30, 42) WER1:  0.2727272727272727 WER2:  0.11688311688311688
0
2
4
6
8
1.2000000000000002  :  (0.7142857142857143, 30, 42) WER1:  0.2727272727272727 WER2:  0.12987012987012986
0
2
4
6
8
1.4000000000000001  :  (0.6904761904761905, 29, 42) WER1:  0.2727272727272727 WER2:  0.14285714285714285


In [18]:
predicted_tokens = []
a = 324
for i,s in enumerate(error_sentences):
    c = extract_choices(s,model=model,p_lambda =0.7 ,trie = True,likelihood = 'bm',model_type = 'transformer')
#     print(c)
    c = [t[0] for t in c]
    predicted_tokens.append(c)
    if i%2 == 0:
        print(i,WER(correct_tokens[:i+1],predicted_tokens[:i+1] ),word_accuracy(correct_tokens[:i+1],predicted_tokens[:i+1],error_tokens[:i+1]),char_accuracy(correct_tokens[:i+1],predicted_tokens[:i+1],error_tokens[:i+1]))
    if i==323:
        print(WER(correct_tokens[:a],error_tokens[:a] ),WER(correct_tokens[:a],predicted_tokens[:a] ),word_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]),char_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]))
        break  

0 0.07692307692307693 (1.0, 4, 4) (1.0, 4, 4)
2 0.1282051282051282 (0.6923076923076923, 9, 13) (0.7142857142857143, 10, 14)
4 0.11428571428571428 (0.7083333333333334, 17, 24) (0.7037037037037037, 19, 27)
6 0.12371134020618557 (0.6428571428571429, 18, 28) (0.6875, 22, 32)
8 0.11278195488721804 (0.6756756756756757, 25, 37) (0.7209302325581395, 31, 43)
10 0.1 (0.7045454545454546, 31, 44) (0.7115384615384616, 37, 52)
12 0.109375 (0.68, 34, 50) (0.6666666666666666, 40, 60)
14 0.1036036036036036 (0.7213114754098361, 44, 61) (0.6944444444444444, 50, 72)
16 0.11507936507936507 (0.6714285714285714, 47, 70) (0.6626506024096386, 55, 83)
18 0.11397058823529412 (0.6666666666666666, 50, 75) (0.6444444444444445, 58, 90)
20 0.11486486486486487 (0.6705882352941176, 57, 85) (0.6534653465346535, 66, 101)
22 0.1276595744680851 (0.6382978723404256, 60, 94) (0.625, 70, 112)
24 0.1318051575931232 (0.6336633663366337, 64, 101) (0.6166666666666667, 74, 120)
26 0.1276595744680851 (0.6481481481481481, 70, 108) (

198 0.12337867763366023 (0.6699266503667481, 548, 818) (0.6666666666666666, 640, 960)
200 0.1222707423580786 (0.6715151515151515, 554, 825) (0.6683884297520661, 647, 968)
202 0.12172573189522343 (0.6730538922155689, 562, 835) (0.669050051072523, 655, 979)
204 0.12206286237412267 (0.6729857819905213, 568, 844) (0.6693629929221436, 662, 989)
206 0.12258454106280194 (0.671746776084408, 573, 853) (0.6686686686686687, 668, 999)
208 0.1217910447761194 (0.674364896073903, 584, 866) (0.671583087512291, 683, 1017)
210 0.12138216184288246 (0.6750285062713797, 592, 877) (0.6741245136186771, 693, 1028)
212 0.12083089526038619 (0.6768361581920904, 599, 885) (0.6763005780346821, 702, 1038)
214 0.12044006948465547 (0.6778523489932886, 606, 894) (0.6761904761904762, 710, 1050)
216 0.12195121951219512 (0.6766004415011038, 613, 906) (0.6719775070290535, 717, 1067)
218 0.12109152927799886 (0.6797814207650273, 622, 915) (0.6747211895910781, 726, 1076)
220 0.12082045518404046 (0.6778741865509761, 625, 922)

In [19]:
a = len(predicted_tokens)
WER(correct_tokens[:a],error_tokens[:a] ),WER(correct_tokens[:a],predicted_tokens[:a] )

(0.25042111173498033, 0.11903425042111174)

In [20]:
word_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a])

(0.6913303437967115, 925, 1338)

In [None]:
char_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a])

#### (0.6903553299492385, 1088, 1576)

### constant distributive likelihood

In [22]:
Gap = 0.2
a = 10
def find_p_lambda(error_sentences):
    
    for j in range(8):
        predicted_tokens = []
        for i,s in enumerate(error_sentences[:a]):
            c = extract_choices(s,model=model,p_lambda =Gap*j,trie = True,likelihood = 'default',model_type = 'transformer')
        #     print(c)
            c = [t[0] for t in c]
            predicted_tokens.append(c)
            if i%2 == 0:
                print(i)
#         word_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a])
        print(Gap*j,' : ',word_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]),'WER1: ',WER(correct_tokens[:a],error_tokens[:a] ),'WER2: ',WER(correct_tokens[:a],predicted_tokens[:a] ))
find_p_lambda(error_sentences)

0
2
4
6
8
0.0  :  (0.0, 0, 42) WER1:  0.2727272727272727 WER2:  0.2727272727272727
0
2
4
6
8
0.2  :  (0.5714285714285714, 24, 42) WER1:  0.2727272727272727 WER2:  0.11688311688311688
0
2
4
6
8
0.4  :  (0.6666666666666666, 28, 42) WER1:  0.2727272727272727 WER2:  0.1038961038961039
0
2
4
6
8
0.6000000000000001  :  (0.6904761904761905, 29, 42) WER1:  0.2727272727272727 WER2:  0.11038961038961038
0
2
4
6
8
0.8  :  (0.6904761904761905, 29, 42) WER1:  0.2727272727272727 WER2:  0.14285714285714285
0
2
4
6
8
1.0  :  (0.6904761904761905, 29, 42) WER1:  0.2727272727272727 WER2:  0.15584415584415584
0
2
4
6
8
1.2000000000000002  :  (0.6904761904761905, 29, 42) WER1:  0.2727272727272727 WER2:  0.16233766233766234
0
2
4
6
8
1.4000000000000001  :  (0.6666666666666666, 28, 42) WER1:  0.2727272727272727 WER2:  0.19480519480519481


In [23]:
predicted_tokens = []
a = 324
for i,s in enumerate(error_sentences):
    c = extract_choices(s,model=model,p_lambda =0.8 ,trie = True,likelihood = 'default',model_type = 'transformer')
#     print(c)
    c = [t[0] for t in c]
    predicted_tokens.append(c)
    if i%2 == 0:
        print(i,WER(correct_tokens[:i+1],predicted_tokens[:i+1] ),word_accuracy(correct_tokens[:i+1],predicted_tokens[:i+1],error_tokens[:i+1]),char_accuracy(correct_tokens[:i+1],predicted_tokens[:i+1],error_tokens[:i+1]))
    if i==323:
        print(WER(correct_tokens[:a],error_tokens[:a] ),WER(correct_tokens[:a],predicted_tokens[:a] ),word_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]),char_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]))
        break  

0 0.07692307692307693 (1.0, 4, 4) (1.0, 4, 4)
2 0.20512820512820512 (0.6923076923076923, 9, 13) (0.7857142857142857, 11, 14)
4 0.17142857142857143 (0.7083333333333334, 17, 24) (0.7407407407407407, 20, 27)
6 0.17525773195876287 (0.6428571428571429, 18, 28) (0.71875, 23, 32)
8 0.15789473684210525 (0.6486486486486487, 24, 37) (0.7441860465116279, 32, 43)
10 0.13529411764705881 (0.6818181818181818, 30, 44) (0.7307692307692307, 38, 52)
12 0.15104166666666666 (0.66, 33, 50) (0.6833333333333333, 41, 60)
14 0.13513513513513514 (0.7049180327868853, 43, 61) (0.7083333333333334, 51, 72)
16 0.1388888888888889 (0.6714285714285714, 47, 70) (0.6746987951807228, 56, 83)
18 0.13602941176470587 (0.6666666666666666, 50, 75) (0.6555555555555556, 59, 90)
20 0.13513513513513514 (0.6705882352941176, 57, 85) (0.6633663366336634, 67, 101)
22 0.14285714285714285 (0.648936170212766, 61, 94) (0.6517857142857143, 73, 112)
24 0.14326647564469913 (0.6534653465346535, 66, 101) (0.65, 78, 120)
26 0.13829787234042554 (

198 0.14710534640936412 (0.6589242053789731, 539, 818) (0.7302083333333333, 701, 960)
200 0.14660012476606363 (0.6593939393939394, 544, 825) (0.731404958677686, 708, 968)
202 0.1460708782742681 (0.6610778443113773, 552, 835) (0.7334014300306435, 718, 979)
204 0.1464754348489472 (0.659952606635071, 557, 844) (0.7350859453993933, 727, 989)
206 0.14583333333333334 (0.6600234466588512, 563, 853) (0.7347347347347347, 734, 999)
208 0.14507462686567163 (0.663972286374134, 575, 866) (0.7384464110127827, 751, 1017)
210 0.14412285883047843 (0.6659064994298746, 584, 877) (0.7402723735408561, 761, 1028)
212 0.14335868929198362 (0.6677966101694915, 591, 885) (0.7418111753371869, 770, 1038)
214 0.14302258251302838 (0.668903803131991, 598, 894) (0.7409523809523809, 778, 1050)
216 0.1443328550932568 (0.6688741721854304, 606, 906) (0.7366447985004686, 786, 1067)
218 0.14326321773735076 (0.6721311475409836, 615, 915) (0.7388475836431226, 795, 1076)
220 0.14329867940432706 (0.6702819956616052, 618, 922) 

In [None]:
a = len(predicted_tokens)
WER(correct_tokens[:a],error_tokens[:a] ),WER(correct_tokens[:a],predicted_tokens[:a] )

In [40]:
predicted_tokens = []
a = 324
for i,s in enumerate(error_sentences):
    c = extract_choices(s,model=model,p_lambda =0 ,trie = True,likelihood = 'bmWithNoLM',model_type = 'transformer')
#     print(c)
    c = [t[0] for t in c]
    predicted_tokens.append(c)
    if i%2 == 0:
        print(i,WER(correct_tokens[:i+1],predicted_tokens[:i+1] ),word_accuracy(correct_tokens[:i+1],predicted_tokens[:i+1],error_tokens[:i+1]),char_accuracy(correct_tokens[:i+1],predicted_tokens[:i+1],error_tokens[:i+1]))
    if i==323:
        print(WER(correct_tokens[:a],error_tokens[:a] ),WER(correct_tokens[:a],predicted_tokens[:a] ),word_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]),char_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]))
        break  

0 0.15384615384615385 (0.75, 3, 4) (0.75, 3, 4)
2 0.28205128205128205 (0.46153846153846156, 6, 13) (0.42857142857142855, 6, 14)
4 0.2857142857142857 (0.3333333333333333, 8, 24) (0.2962962962962963, 8, 27)
6 0.23711340206185566 (0.32142857142857145, 9, 28) (0.34375, 11, 32)
8 0.21052631578947367 (0.40540540540540543, 15, 37) (0.4418604651162791, 19, 43)
10 0.18235294117647058 (0.45454545454545453, 20, 44) (0.46153846153846156, 24, 52)
12 0.18229166666666666 (0.44, 22, 50) (0.43333333333333335, 26, 60)
14 0.1981981981981982 (0.39344262295081966, 24, 61) (0.3888888888888889, 28, 72)
16 0.20238095238095238 (0.38571428571428573, 27, 70) (0.39759036144578314, 33, 83)
18 0.20588235294117646 (0.36, 27, 75) (0.36666666666666664, 33, 90)
20 0.21283783783783783 (0.35294117647058826, 30, 85) (0.3564356435643564, 36, 101)
22 0.21580547112462006 (0.3404255319148936, 32, 94) (0.3392857142857143, 38, 112)
24 0.2148997134670487 (0.3465346534653465, 35, 101) (0.3416666666666667, 41, 120)
26 0.2101063829

196 0.19218449711723254 (0.3560794044665012, 287, 806) (0.34285714285714286, 324, 945)
198 0.1920278392913635 (0.3581907090464548, 293, 818) (0.34479166666666666, 331, 960)
200 0.19058016219588272 (0.35878787878787877, 296, 825) (0.3450413223140496, 334, 968)
202 0.18952234206471494 (0.36167664670658684, 302, 835) (0.3472931562819203, 340, 979)
204 0.19011290814769607 (0.36018957345971564, 304, 844) (0.34580384226491406, 342, 989)
206 0.19021739130434784 (0.35873388042203985, 306, 853) (0.34534534534534533, 345, 999)
208 0.191044776119403 (0.3568129330254042, 309, 866) (0.3421828908554572, 348, 1017)
210 0.19078558771411694 (0.35917901938426455, 315, 877) (0.3443579766536965, 354, 1028)
212 0.19016968987712113 (0.3615819209039548, 320, 885) (0.348747591522158, 362, 1038)
214 0.18934568616097278 (0.3635346756152125, 325, 894) (0.3495238095238095, 367, 1050)
216 0.18938307030129126 (0.36534216335540837, 331, 906) (0.34957825679475163, 373, 1067)
218 0.18959636156907334 (0.365027322404371

In [52]:
predicted_tokens = []
a = 324
for i,s in enumerate(error_sentences):
    c = extract_choices(s,model=model,p_lambda =0 ,trie = True,likelihood = 'defaultWithNoLM',model_type = 'transformer')
#     print(c)
    c = [t[0] for t in c]
    predicted_tokens.append(c)
    if i%2 == 0:
        print(i,WER(correct_tokens[:i+1],predicted_tokens[:i+1] ),word_accuracy(correct_tokens[:i+1],predicted_tokens[:i+1],error_tokens[:i+1]),char_accuracy(correct_tokens[:i+1],predicted_tokens[:i+1],error_tokens[:i+1]))
    if i==323:
        print(WER(correct_tokens[:a],error_tokens[:a] ),WER(correct_tokens[:a],predicted_tokens[:a] ),word_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]),char_accuracy(correct_tokens[:a],predicted_tokens[:a],error_tokens[:a]))
        break  

0 0.3076923076923077 (0.75, 3, 4) (0.75, 3, 4)
2 0.48717948717948717 (0.38461538461538464, 5, 13) (0.5714285714285714, 8, 14)
4 0.42857142857142855 (0.2916666666666667, 7, 24) (0.4074074074074074, 11, 27)
6 0.3917525773195876 (0.25, 7, 28) (0.46875, 15, 32)
8 0.38345864661654133 (0.2702702702702703, 10, 37) (0.5348837209302325, 23, 43)
10 0.35294117647058826 (0.29545454545454547, 13, 44) (0.5384615384615384, 28, 52)
12 0.3489583333333333 (0.3, 15, 50) (0.5, 30, 60)
14 0.35135135135135137 (0.3114754098360656, 19, 61) (0.4861111111111111, 35, 72)
16 0.35714285714285715 (0.3, 21, 70) (0.4939759036144578, 41, 83)
18 0.35661764705882354 (0.28, 21, 75) (0.4888888888888889, 44, 90)
20 0.36824324324324326 (0.25882352941176473, 22, 85) (0.48514851485148514, 49, 101)
22 0.3708206686930091 (0.26595744680851063, 25, 94) (0.4732142857142857, 53, 112)
24 0.3753581661891118 (0.24752475247524752, 25, 101) (0.4666666666666667, 56, 120)
26 0.3617021276595745 (0.26851851851851855, 29, 108) (0.48818897637

200 0.32096069868995636 (0.3103030303030303, 256, 825) (0.5444214876033058, 527, 968)


KeyboardInterrupt: 

In [19]:
WER(correct_tokens[:201],error_tokens[:201] )

0.25733000623830316