In [None]:
import os
import re
from collections import Counter
import copy
from tamil import utf8
import time

start = time.perf_counter()

with open('/kaggle/input/clean-test-data/data.txt', 'r', encoding='utf-8') as file:
    text = file.readlines()

def remove_parenthesis_text(lines):
    pattern = r'\(.?\)|<\/?doc.?>|[^ \u0B80-\u0BFF]'
    cleaned_lines = [re.sub(pattern, '', line).strip() for line in lines]
    return cleaned_lines

def clean_line(line):
    return re.sub(r'[^஀-௿\s]', '', line)

text = remove_parenthesis_text(text)

text = [clean_line(line) for line in text]

def generate_ngrams(words, n):
    """Generate n-grams from a list of words."""
    return [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]

words=[]
unigrams = Counter()
bigrams = Counter()
trigrams = Counter()
for line in text:
    words_in_line = line.strip().split()
    unigrams.update(words_in_line)  # Unigrams
    bigrams.update(generate_ngrams(words_in_line, 2))  # Bigrams
    trigrams.update(generate_ngrams(words_in_line, 3))  # Trigrams
    words.extend(line.split(" "))

words=[word for word in words if word!='']

vocab=set(words)
print("unique words in vocab = ",len(vocab))
word_count_dict = {}  
word_count_dict = Counter(words)
probs = {} 
m = sum(word_count_dict.values())
for k, v in word_count_dict.items():
    probs[k] = v / m

def delete_letter(word):
    letters = utf8.get_letters(word)
    
    delete_l = []
    if len(letters)<3:
        return word
    
    for i in range(1,len(letters)):
        new_word=''.join(letters[:i])
        if len(letters)>i+1:
            new_word+=''.join(letters[i+1:])
        delete_l.append(new_word)

    return delete_l

def insert_letter(word):
    letters = utf8.get_letters(word)
    insert_l=[]
        
    for l in ['க', 'ச', 'ட', 'த', 'ப', 'ற',  
                'ம', 'ய', 'ர', 'வ', 'ஜ', 'ஸ']:
        x=copy.deepcopy(letters)
        x.append(l+'்')
        insert_l.append(''.join(x))
    
    return insert_l

def transpose_letter(word):
    letters = utf8.get_letters(word)
    transpose_l=[]
    for i in range(1,len(letters)-1):
        x=copy.deepcopy(letters)
        x[i],x[i+1]=x[i+1],x[i]
        transpose_l.append(''.join(x))
    
    return transpose_l    

def substitute_letter(word):
    letters = utf8.get_letters(word)
    target_chars = [['ர','ற'],['ல','ள','ழ'],['ன','ண','ந'],['ங','ஞ']]
    substitute_l=[]
    for i in range(0,len(letters)):
        
        for l in target_chars:
            if letters[i][0] in l:
                for char in [char for char in l if char != letters[i][0]]:
                    x=copy.deepcopy(letters)
                    if len(x[i])==2:
                        x[i]=char+x[i][1]
                    else:
                        x[i]=char
                    substitute_l.append(''.join(x))
    
    target_chars = [[ 'ி', 'ீ'],['','ா','ை'],['ு', 'ூ'],['ெ', 'ே'],['ொ', 'ோ']]
    for i in range(0,len(letters)):
        
        for l in target_chars:
            if len(letters[i])==1 and letters[i] in ['க', 'ச', 'ட', 'த', 'ப', 'ற', 'ஞ', 'ங', 'ண', 'ந', 
                   'ம','ன', 'ய', 'ர', 'ல', 'வ', 'ழ', 'ள', 'ஜ', 'ஷ', 'ஸ', 'ஹ']:
                for char in ['ா','ை']:
                    x=copy.deepcopy(letters)
                    x[i]+=char
                    substitute_l.append(''.join(x))
            elif len(letters[i])==2 and letters[i][1] in l:
                for char in [char for char in l if char != letters[i][1]]:
                    x=copy.deepcopy(letters)
                    x[i]=x[i][0]+char
                    substitute_l.append(''.join(x))
    
    return substitute_l

def edit_one_letter(word):
    edit_one_set = set()
    
    edit_one_set.update(delete_letter(word))
    edit_one_set.update(transpose_letter(word))
    edit_one_set.update(substitute_letter(word))
    edit_one_set.update(insert_letter(word))

    return edit_one_set

words_outof_context=0

# Compute Context Probability (Using Bigrams and Trigrams)
def ngram_probability(word, prev_word, prev_prev_word):
    """Compute probability using unigrams, bigrams, and trigrams."""
    unigram_prob = unigrams.get(word, 0) / sum(unigrams.values())
    bigram_prob = bigrams.get((prev_word, word), 0) / unigrams.get(prev_word, 1)
    trigram_prob = trigrams.get((prev_prev_word, prev_word, word), 0) / bigrams.get((prev_prev_word, prev_word), 1)

    # Weighted probability combination
    return 0.5 * trigram_prob + 0.4 * bigram_prob + 0.1 * unigram_prob
 

def get_corrections(word, probs, vocab,  prev_word, prev_prev_word,n=3):
    
    suggestions = []
    n_best = []
    
    if word in vocab and ngram_probability(word,prev_word,prev_prev_word)>0.00001:
        return 'Nil',{word}
        
    one_error_set=edit_one_letter(word)
    suggestions=one_error_set.intersection(vocab)
    one_error=list(one_error_set)
    letters=utf8.get_letters(word)
    
    for i in one_error:
        suggestions=suggestions.union(edit_one_letter(i).intersection(vocab))
    
    suggestions=list(suggestions)
    suggestions.append(word)
    n_best_list_tuple = sorted(
    [(word, probs[word]) if word in probs else (word, 0) for word in suggestions]
,
    key=lambda x: x[1],
    reverse=True
)
    n_best_list=[word for (word,probs) in n_best_list_tuple]
    n_best_list = list(dict.fromkeys(n_best_list))[:10]
    n_best_list = sorted(n_best_list, key=lambda w: ngram_probability(w, prev_word, prev_prev_word), reverse=True) 
    n_best_list = n_best_list[:3]
    n_best=set(n_best_list)
    return n_best_list, n_best 

corrected=0
missed=0
file_no=1
with open('/kaggle/input/clean-test-data/error_details_5.txt', 'r',encoding='utf-8') as file:
    values = file.readlines()

no_error_list = [int(value.strip()) for value in values[0].split(',')]
no_error_index = set(no_error_list)

single_error_list = [int(value.strip()) for value in values[1].split(',')]
single_error_index = set(single_error_list)

double_error_list = [int(value.strip()) for value in values[2].split(',')]
double_error_index = set(double_error_list)

with open('/kaggle/input/clean-test-data/error_file_5.txt', 'r', encoding='utf-8') as file:
    test_data = file.readlines()
    

test_data = remove_parenthesis_text(test_data)
test_data = [clean_line(line) for line in test_data]
for i in range(len(test_data)):
    test_data[i]=test_data[i].split(" ")

test_data = [[word for word in line if word] for line in test_data]
word_cnt=0

with open("/kaggle/input/clean-test-data/clean_test_data.txt", 'r',encoding='utf-8') as file:
    test_data_correct=file.readlines()
    test_data_correct = remove_parenthesis_text(test_data_correct)
    test_data_correct=test_data_correct[1200:1500]

    test_data_correct = [clean_line(line) for line in test_data_correct]
x=0
for i in range(len(test_data_correct)):
    test_data_correct[i]=test_data_correct[i].split(" ")
    if (len(test_data[i])==len(test_data_correct[i])):
        x+=1

test_data_correct = [[word for word in line if word] for line in test_data_correct]

not_in_vocab=0    
correctly_predicted_no_error=0
correctly_predicted_single_error=0
correctly_predicted_top_suggestion=0
correctly_predicted_double_error=0
correctly_given_firstoption_no_error=0

punctuations = {'.',',','"',"'"}

for i in range(len(test_data)):
    prev_word="<s>"
    prev_prev_word="<s>"
    for j in range(len(test_data[i])):
        if test_data[i][j] in punctuations:
            continue
        
        chk_list,chk = get_corrections(test_data[i][j],probs,vocab,prev_word,prev_prev_word)
        if len(chk)>0 and (j+1)!=len(test_data[i]) and test_data[i][j+1][0] in ["க", "ச", "ட", "த","ப","ற"]:
            if(chk_list=='Nil'):
                chk_list = list(chk)
            expected_sandhi=test_data[i][j+1][0]+'்'
            length=len(chk)
            for k in range(length):
                if chk_list[k][-1]!='்':
                    chk_list.append(chk_list[k]+expected_sandhi)
                    chk.add(chk_list[k]+expected_sandhi)
        if (word_cnt in no_error_index):
            if chk_list=="Nil" or test_data_correct[i][j] in chk:
                correctly_predicted_no_error+=1
            if chk_list=="Nil" or (len(chk_list)>0 and chk_list[0]==test_data_correct[i][j]):
                correctly_given_firstoption_no_error+=1

        elif (word_cnt in single_error_index):
            
                if test_data_correct[i][j] in chk:
                    correctly_predicted_single_error+=1
                if chk_list[0]==test_data_correct[i][j]:
                    correctly_predicted_top_suggestion+=1
            
        elif (word_cnt in double_error_index):
            
                if test_data_correct[i][j] in chk:
                    correctly_predicted_double_error+=1
                if chk_list[0]==test_data_correct[i][j]:
                    correctly_predicted_top_suggestion+=1
        word_cnt+=1

        if(test_data_correct[i][j] not in vocab):
            not_in_vocab+=1

        if(word_cnt%1000 == 0):
            print(word_cnt)
        
        prev_prev_word=prev_word
        if(chk_list=="Nil" or chk_list==[]):
            prev_word=test_data[i][j]
        else:
            prev_word=chk_list[0]

print("Correct No error prediction: ", str(correctly_predicted_no_error))
print("Correct Single error prediction: ", str(correctly_predicted_single_error))
print("Correct Double error prediction: ", str(correctly_predicted_double_error))
print("Total no error: ", str(len(no_error_list)))
print("Total single error: ", str(len(single_error_list)))
print("Total double error: ", str(len(double_error_list)))
print("Total Accuracy: ", str((correctly_predicted_no_error+correctly_predicted_single_error+correctly_predicted_double_error)/(len(no_error_list)+len(single_error_list)+len(double_error_list))))
print("Accuracy among errors: ", str((correctly_predicted_single_error+correctly_predicted_double_error)/(len(single_error_list)+len(double_error_list))))
print("Total top suggestion accuracy: ", str((correctly_predicted_top_suggestion+correctly_given_firstoption_no_error)/(len(no_error_list)+len(single_error_list)+len(double_error_list))))
print("Top suggestion accuracy among errors: ", str(correctly_predicted_top_suggestion/(len(single_error_list)+len(double_error_list))))

finish = time.perf_counter()

print("Time =",round(finish-start,2),"sec")
# print("Words out of correction = ",words_outof_context)
print("Not in Vocab =",not_in_vocab, not_in_vocab*100/(len(single_error_list)+len(double_error_list)+len(no_error_list)),"%")


In [None]:
from transformers import MT5Tokenizer,MT5ForConditionalGeneration, logging
import torch
from scipy.spatial.distance import cosine
from huggingface_hub import hf_hub_download
import os, math, string
import re, Levenshtein
from collections import Counter
import copy
from tamil import utf8
import time
logging.set_verbosity_error()
import fasttext.util
os.environ['FASTTEXT_VERBOSE'] = '0'
fasttext.util.download_model('ta', if_exists='ignore')
ft = fasttext.load_model('cc.ta.300.bin')

def get_fasttext_suggestions(word, top_n=15):
    try:
        suggestions = {}  
        
        for suggestion in ft.get_nearest_neighbors(word, k=top_n):  
            suggested_word = suggestion[1]    
            dist = Levenshtein.distance(word, suggested_word)  
    
            if dist <= 4:  # Only keep words within max distance
                suggestions[suggested_word] = dist 
                
        return suggestions
    except Exception as e:
        return []

model_name = "google/mt5-base"  
qa_model = MT5ForConditionalGeneration.from_pretrained(model_name)
tokenizer = MT5Tokenizer.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
qa_model.to(device)


start = time.perf_counter()


with open('/kaggle/input/clean-test-data/data.txt', 'r', encoding='utf-8') as file:
    text = file.readlines()

def remove_parenthesis_text(lines):
    pattern = r'\(.?\)|<\/?doc.?>|[^ \u0B80-\u0BFF]'
    cleaned_lines = [re.sub(pattern, '', line).strip() for line in lines]
    return cleaned_lines

def clean_line(line):
    return re.sub(r'[^஀-௿\s]', '', line)

text = remove_parenthesis_text(text)

text = [clean_line(line) for line in text]

def generate_ngrams(words, n):
    """Generate n-grams from a list of words."""
    return [tuple(words[i:i+n]) for i in range(len(words) - n + 1)]

words=[]
unigrams = Counter()
bigrams = Counter()
trigrams = Counter()
for line in text:
    words_in_line = line.strip().split()
    unigrams.update(words_in_line)  
    bigrams.update(generate_ngrams(words_in_line, 2))  
    trigrams.update(generate_ngrams(words_in_line, 3))  
    words.extend(line.split(" "))

words=[word for word in words if word!='']

vocab=set(words)
print("unique words in vocab = ",len(vocab))
word_count_dict = {}
word_count_dict = Counter(words)
probs = {}
m = sum(word_count_dict.values())
for k, v in word_count_dict.items():
    probs[k] = v / m

def delete_letter(word):
    letters = utf8.get_letters(word)

    delete_l = []
    if len(letters)<3:
        return word

    for i in range(1,len(letters)):
        new_word=''.join(letters[:i])
        if len(letters)>i+1:
            new_word+=''.join(letters[i+1:])
        delete_l.append(new_word)

    return delete_l

def insert_letter(word):
    letters = utf8.get_letters(word)
    insert_l=[]

    for l in ['க', 'ச', 'ட', 'த', 'ப', 'ற',
                'ம', 'ய', 'ர', 'வ', 'ஜ', 'ஸ']:
        x=copy.deepcopy(letters)
        x.append(l+'்')
        insert_l.append(''.join(x))

    return insert_l

def transpose_letter(word):
    letters = utf8.get_letters(word)
    transpose_l=[]
    for i in range(1,len(letters)-1):
        x=copy.deepcopy(letters)
        x[i],x[i+1]=x[i+1],x[i]
        transpose_l.append(''.join(x))

    return transpose_l

def substitute_letter(word):
    letters = utf8.get_letters(word)
    target_chars = [['ர','ற'],['ல','ள','ழ'],['ன','ண','ந'],['ங','ஞ']]
    substitute_l=[]
    for i in range(0,len(letters)):

        for l in target_chars:
            if letters[i][0] in l:
                for char in [char for char in l if char != letters[i][0]]:
                    x=copy.deepcopy(letters)
                    if len(x[i])==2:
                        x[i]=char+x[i][1]
                    else:
                        x[i]=char
                    substitute_l.append(''.join(x))

    target_chars = [[ 'ி', 'ீ'],['','ா','ை'],['ு', 'ூ'],['ெ', 'ே'],['ொ', 'ோ']]
    for i in range(0,len(letters)):

        for l in target_chars:
            if len(letters[i])==1 and letters[i] in ['க', 'ச', 'ட', 'த', 'ப', 'ற', 'ஞ', 'ங', 'ண', 'ந',
                   'ம','ன', 'ய', 'ர', 'ல', 'வ', 'ழ', 'ள', 'ஜ', 'ஷ', 'ஸ', 'ஹ']:
                for char in ['ா','ை']:
                    x=copy.deepcopy(letters)
                    x[i]+=char
                    substitute_l.append(''.join(x))
            elif len(letters[i])==2 and letters[i][1] in l:
                for char in [char for char in l if char != letters[i][1]]:
                    x=copy.deepcopy(letters)
                    x[i]=x[i][0]+char
                    substitute_l.append(''.join(x))

    return substitute_l

def edit_one_letter(word):
    edit_one_set = set()
    edit_one_set.update(delete_letter(word))
    edit_one_set.update(transpose_letter(word))
    edit_one_set.update(substitute_letter(word))
    edit_one_set.update(insert_letter(word))

    return edit_one_set

words_outof_context=0

# Compute Context Probability (Using Bigrams and Trigrams)
def ngram_probability(word, prev_word, prev_prev_word):
    
    unigram_prob = unigrams.get(word, 0) / sum(unigrams.values())
    bigram_prob = bigrams.get((prev_word, word), 0) / unigrams.get(prev_word, 1)
    trigram_prob = trigrams.get((prev_prev_word, prev_word, word), 0) / bigrams.get((prev_prev_word, prev_word), 1)

    # Weighted probability combination 
    return 0.5 * trigram_prob + 0.4 * bigram_prob + 0.1 * unigram_prob

def normalize_scores(scores):
    min_score = min(scores.values())
    max_score = max(scores.values())

    # Prevent division by zero if min == max
    if min_score == max_score:
        return {key: 1.0 for key in scores}  # Return all normalized to 1 if all scores are the same

    return {key: (value - min_score) / (max_score - min_score) for key, value in scores.items()}

def log_transform_scores(scores):
    """Apply log transformation to ngram scores."""
    return {key: math.log(value + 1e-8) for key, value in scores.items()}

fasttext_proj = torch.nn.Linear(300, 768).to(device)
attention_layer = torch.nn.MultiheadAttention(embed_dim=768, num_heads=4, batch_first=True).to(device)
projection_layer = torch.nn.Linear(1536, 768).to(device)
contrastive_loss_fn = torch.nn.CosineEmbeddingLoss()

def get_mt5_ranking(sentence, word, suggestions):
    """Generate suggestions using mBERT if the word is OOV and rank all suggestions using MLM-style ranking."""
    mt5suggestions = copy.deepcopy(suggestions)
    ln = len(mt5suggestions)
    qa_scores = {}
    
    # Generate new suggestions if the list is empty
    cleaned_suggestions = set(mt5suggestions)
    if len(mt5suggestions)<=1:
        fasttext_suggestions = get_fasttext_suggestions(word)
        mt5suggestions.extend(fasttext_suggestions.keys())
        
    
    for suggestion in mt5suggestions:
        suggestion = re.sub(r"[a-zA-Z0-9]", "", suggestion)
        suggestion = re.sub(r"[^\u0B80-\u0BFF\s]", "", suggestion).strip()
        if suggestion:  # Ensure it's not empty
            cleaned_suggestions.add(suggestion)

    fasttext_word_embedding = fasttext_proj(torch.tensor(ft.get_word_vector(word)).unsqueeze(0).to(device))


    fasttext_sentence_embedding = fasttext_proj(torch.tensor(ft.get_sentence_vector(sentence)).unsqueeze(0).to(device))
    
    for suggestion in cleaned_suggestions:
        input_text = f"Is '{suggestion}' the correct spelling for '{word}' in this sentence: {sentence}?"
        inputs = tokenizer(input_text, return_tensors="pt").to(device)
        
        fasttext_suggestion_embedding = fasttext_proj(torch.tensor(ft.get_word_vector(suggestion)).unsqueeze(0).to(device))
        
        with torch.no_grad():
           
            inputs_embeds = qa_model.shared(inputs.input_ids)  
            
            fasttext_combined = torch.cat([fasttext_sentence_embedding, fasttext_suggestion_embedding], dim=0).unsqueeze(0)  
            fasttext_combined = fasttext_combined.repeat(1, inputs_embeds.shape[1], 1)

            # Apply Multihead Attention
            attn_output, _ = attention_layer(inputs_embeds, fasttext_combined, fasttext_combined)

            # Concatenate and pass through model
            combined_embeds = torch.cat([inputs_embeds, attn_output], dim=-1)
            combined_embeds = projection_layer(combined_embeds)
            outputs = qa_model(inputs_embeds=combined_embeds, labels=inputs["input_ids"])
            loss = outputs.loss.item()    
            
        target = torch.tensor([1.0]).to(device)  # Positive pair
        contrastive_loss = contrastive_loss_fn(fasttext_word_embedding, fasttext_suggestion_embedding, target)
        loss = loss + contrastive_loss.item()
        qa_scores[suggestion] = 1 / (1 + loss)  # Convert loss to a positive score

        if(ln==1 and len(cleaned_suggestions)>3):
            fasttext_distance = fasttext_suggestions.get(suggestion)  
            if fasttext_distance is not None:
                distance_penalty = 1 / (1 + fasttext_distance)  # Higher penalty for larger distance
                qa_scores[suggestion] *= distance_penalty  # Reduce score for distant words
        
        
    # Normalize scores
    total_score = sum(qa_scores.values())
    if total_score > 0:
        qa_scores = {w: score / total_score for w, score in qa_scores.items()}
    
    return qa_scores


def get_corrections(word, sentence, probs, vocab, prev_word, prev_prev_word, n=3):

    ngram = {}
    bert = {}
    final_scores = {}
    suggestions = []
    n_best = []
    if word in vocab and ngram_probability(word, prev_word, prev_prev_word) > 0.00001:
        return 'Nil', {word}, []

    one_error_set = edit_one_letter(word)
    suggestions = one_error_set.intersection(vocab)
    one_error = list(one_error_set)
    letters = utf8.get_letters(word)
    two_error_set = set()
    for i in one_error:
        two_error_set = two_error_set.union(edit_one_letter(i))
        
        suggestions = suggestions.union(two_error_set.intersection(vocab))
    
    suggestions = list(suggestions)
    
    suggestions.append(word)

    n_best_list_tuple = sorted(
        [(word, probs[word]) if word in probs else (word, 0) for word in suggestions]
,
        key=lambda x: x[1],
        reverse=True
    )

    n_best_list = [word for word, _ in n_best_list_tuple]
    n_best_list = list(dict.fromkeys(n_best_list))
    l = n_best_list[6:10]
    n_best_list = n_best_list[:6]
    sentence_text = " ".join(sentence)
    ln=len(n_best_list)        
    
    for wrd in n_best_list:
        ngram[wrd] = ngram_probability(wrd, prev_word, prev_prev_word)
     
    ngram = log_transform_scores(ngram)
    ngram = normalize_scores(ngram)
    
    bert = get_mt5_ranking(sentence_text, word, n_best_list)
    sorted_bert = sorted(bert.items(), key=lambda x: x[1], reverse=True)
    for i, (suggestion, score) in enumerate(sorted_bert):
      if suggestion in ngram:
        if suggestion!=word and ngram[suggestion]>0.1:
            final_scores[suggestion] = 0.7 * score + 0.3 * ngram[suggestion]
        else:
            final_scores[suggestion] = ngram[suggestion]
      else:
        final_scores[suggestion] = score
          
    n_best_list = [suggestion for suggestion, _ in sorted_bert]
    n_best_list =sorted(n_best_list, key=lambda w: final_scores[w], reverse=True)
    n_best_list.extend(l)
    n_best_list = n_best_list[:3]
    n_best=set(n_best_list)
    return n_best_list, n_best, final_scores

corrected=0
missed=0
file_no=1
with open('/kaggle/input/clean-test-data/error_details_5.txt', 'r',encoding='utf-8') as file:
    values = file.readlines()

no_error_list = [int(value.strip()) for value in values[0].split(',')]
no_error_index = set(no_error_list)

single_error_list = [int(value.strip()) for value in values[1].split(',')]
single_error_index = set(single_error_list)

double_error_list = [int(value.strip()) for value in values[2].split(',')]
double_error_index = set(double_error_list)

with open('/kaggle/input/clean-test-data/error_file_5.txt', 'r', encoding='utf-8') as file:
    test_data = file.readlines()

test_data = remove_parenthesis_text(test_data)
for i in range(len(test_data)):
    test_data[i]=test_data[i].split(" ")
test_data = [[word for word in line if word] for line in test_data]
word_cnt=0

with open("/kaggle/input/clean-test-data/clean_test_data.txt", 'r',encoding='utf-8') as file:
    test_data_correct=file.readlines()
    test_data_correct = remove_parenthesis_text(test_data_correct)
    test_data_correct=test_data_correct[1200:1500]
    test_data_correct = [clean_line(line) for line in test_data_correct]

for i in range(len(test_data_correct)):
    test_data_correct[i]=test_data_correct[i].split(" ")

test_data_correct = [[word for word in line if word] for line in test_data_correct]

not_in_vocab=0
correctly_predicted_no_error=0
correctly_predicted_single_error=0
correctly_predicted_top_suggestion=0
correctly_predicted_double_error=0
correctly_given_firstoption_no_error=0

punctuations = {'.',',','"',"'"}
for i in range(len(test_data)):
    prev_word="<s>"
    prev_prev_word="<s>"
    crct_sentence=test_data[i]
    for j in range(len(test_data[i])):
        if test_data[i][j] in punctuations:
            continue

        chk_list,chk, scores = get_corrections(test_data[i][j], crct_sentence, probs,vocab,prev_word,prev_prev_word)
        if len(chk)>0 and (j+1)!=len(test_data[i]) and test_data[i][j+1][0] in ["க", "ச", "ட", "த","ப","ற"]:
            if(chk_list=='Nil'):
                chk_list = list(chk)
            expected_sandhi=test_data[i][j+1][0]+'்'
            length=len(chk)
            for k in range(length):
                if chk_list[k][-1]!='்':
                    chk_list.append(chk_list[k]+expected_sandhi)
                    chk.add(chk_list[k]+expected_sandhi)
        if (word_cnt in no_error_index):
            if chk_list=="Nil" or test_data_correct[i][j] in chk:
                correctly_predicted_no_error+=1
            if chk_list=="Nil" or len(chk_list)>0 and chk_list[0]==test_data_correct[i][j]:
                correctly_given_firstoption_no_error+=1

        elif (word_cnt in single_error_index):

                if test_data_correct[i][j] in chk:
                    correctly_predicted_single_error+=1
                if chk_list[0]==test_data_correct[i][j]:
                    correctly_predicted_top_suggestion+=1

        elif (word_cnt in double_error_index):

                if test_data_correct[i][j] in chk:
                    correctly_predicted_double_error+=1
                if chk_list[0]==test_data_correct[i][j]:
                    correctly_predicted_top_suggestion+=1
        word_cnt+=1

        if(test_data_correct[i][j] not in vocab):
            not_in_vocab+=1

        if(word_cnt%1000 == 0):
            print(word_cnt)

        prev_prev_word=prev_word
        if(chk_list=="Nil" or chk_list==[]):
            prev_word=test_data[i][j]
        else:
            prev_word=chk_list[0]
            
print("Correct No error prediction: ", str(correctly_predicted_no_error))
print("Correct Single error prediction: ", str(correctly_predicted_single_error))
print("Correct Double error prediction: ", str(correctly_predicted_double_error))
print("Total no error: ", str(len(no_error_list)))
print("Total single error: ", str(len(single_error_list)))
print("Total double error: ", str(len(double_error_list)))
print("Total Accuracy: ", str((correctly_predicted_no_error+correctly_predicted_single_error+correctly_predicted_double_error)/(len(no_error_list)+len(single_error_list)+len(double_error_list))))
print("Accuracy among errors: ", str((correctly_predicted_single_error+correctly_predicted_double_error)/(len(single_error_list)+len(double_error_list))))
print("Total top suggestion accuracy: ", str((correctly_predicted_top_suggestion+correctly_given_firstoption_no_error)/(len(no_error_list)+len(single_error_list)+len(double_error_list))))
print("Top suggestion accuracy among errors: ", str(correctly_predicted_top_suggestion/(len(single_error_list)+len(double_error_list))))

finish = time.perf_counter()

print("Time =",round(finish-start,2),"sec")
print("Not in Vocab =",not_in_vocab, not_in_vocab*100/(len(single_error_list)+len(double_error_list)+len(no_error_list)),"%")
