In [1]:
import spacy
import torch
import editdistance

from spacy.tokens import Token
from spacy.vocab import Vocab

from transformers import AutoModelWithLMHead, AutoTokenizer
from transformers import AutoModelWithLMHead, AutoTokenizer



## Identifying OOV

In [2]:
class oovChecker():
    
    def __init__(self):
        self.nlp = spacy.load("en_core_web_sm", disable=["tagger", "parser"]) # using default tokeniser with NER
        with open('./uncased_L-4_H-512_A-8/vocab.txt') as f:
            # if want to remove '[unusedXX]' from vocab
            # words = [line.rstrip() for line in f if not line.startswith('[unused')]
            words = [line.rstrip() for line in f]
        self.vocab = Vocab(strings=words)
        self.BertTokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
        self.BertModel = AutoModelWithLMHead.from_pretrained("bert-base-cased")
        self.mask = self.BertTokenizer.mask_token
            
        # This is temp
        self.query = "this is my India. My name is Rajat Goel. I use tru. income was $9.4 million compared to the prior year of $2.7 milion. number is 62722.2"
        self.debug = False
        
  
    ## query --> "aa bb cc..."
    def misspellIdentify(self, query=''):
        """
        At present, All the following criteria should be met for word to be misspelled
        1. Should not in our vocab
        2. should not be a Person
        3. Should not be a number
        
        @params query: sequence on which to perform 
        @return Dictonary: {'misspell-1':['candidate-1','candidate-2', ...],
                            'misspell-2':['candidate-1','candidate-2'. ...]}
        """
        if query == '':
            query = self.query
            
        doc = self.nlp(query)
        misspell = []
        for token in doc:
            if((token.text.lower() not in self.vocab) and 
               (token.ent_type_ != 'PERSON') and 
               (not token.like_num)):
                
                misspell.append(token)
        
        print(misspell)
        return misspell
    
    def candidateGenerator(self, misspellings, query='',debug=False):
        response = {}
            
        for token in misspellings:
            if query == '':
                updatedQuery = self.query
            updatedQuery = updatedQuery.replace(token.text, self.mask)
            if debug: print(updatedQuery)

            model_input = self.BertTokenizer.encode(updatedQuery, return_tensors="pt")
            mask_token_index = torch.where(model_input == self.BertTokenizer.mask_token_id)[1]
            token_logits = self.BertModel(model_input)[0]
            mask_token_logits = token_logits[0, mask_token_index, :]

            top_5_tokens = torch.topk(mask_token_logits, 5, dim=1).indices[0].tolist()  
            
            if token not in response:
                response[token.i] = []
                
            for candidate in top_5_tokens:
                response[token.i].append(self.BertTokenizer.decode([candidate]))
                # print(updatedQuery.replace(self.mask, self.BertTokenizer.decode([candidate])))
            
            if debug: print(response)
                

        return response
    
    def candidateRanking(self, misspellingsDict, query='', debug=False):
        if query == '':
                query = self.query
                
        response={}
        doc = self.nlp(query)
        for misspell in misspellingsDict:
            ## Init least_edit distance
            least_edit_dist = 100
            
            if debug: print('misspellingsDict[misspell]', misspellingsDict[misspell])
            for candidate in misspellingsDict[misspell]:
                edit_dist = editdistance.eval(doc[misspell].text,candidate)
                if edit_dist < least_edit_dist:
                    least_edit_dist = edit_dist
                    response[misspell] = candidate
                    
            if debug: print(response)
        return response
        
        

In [3]:
checker = oovChecker()
misspellTokens = checker.misspellIdentify()
print('misspellTokens',misspellTokens)
candidate = checker.candidateGenerator(misspellTokens)
answer = checker.candidateRanking(candidate)
print(answer)


[tru, milion]
misspellTokens [tru, milion]
{13: 'you', 28: 'million'}


In [None]:
editdistance.distance('milion','billion')

## OLD CODE
***