In [1]:
from symspellpy import SymSpell, Verbosity
from transformers import pipeline
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from collections import defaultdict
from wordfreq import top_n_list
english_vocab = set(top_n_list("en", 50000))

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
from tqdm import tqdm

In [4]:
sym_spell = SymSpell()
sym_spell.load_pickle("../../data/dictionary/symspell_dictionary.pkl")

True

In [5]:
# corpora = [
#     "../data/dictionary/dataset_wot_uncased_blanklines/processed_uncased_blanklines/wiki.txt",
#     "../data/dictionary/dataset_wot_uncased_blanklines/processed_uncased_blanklines/kompas.txt",
#     "../data/dictionary/dataset_wot_uncased_blanklines/processed_uncased_blanklines/tempo.txt",
#     "../data/dictionary/dataset_wot_uncased_blanklines/processed_uncased_blanklines/bppt.txt"
# ]

# for corpus in corpora:
#     sym_spell.create_dictionary(corpus)

# if "jumal" in sym_spell.words:
#     del sym_spell.words["jumal"]

In [6]:
# addon_path = "../../data/dictionary/addon.txt"

# with open(addon_path, "r", encoding="utf-8") as f:
#     words = [w.strip() for w in f.readlines() if w.strip()]

# for word in words:
#     entry = sym_spell.words.get(word)
#     current_count = entry if entry is not None else 0
#     if current_count < 10:
#         sym_spell.create_dictionary_entry(word, 10 - current_count)

# sym_spell.save_pickle("../../data/dictionary/symspell_dictionary.pkl")

In [7]:
ner = pipeline("ner", model="cahya/NusaBert-ner-v1.3", grouped_entities=True)
model_name = "indolem/indobert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

Device set to use cuda:0
Some weights of the model checkpoint at indolem/indobert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [15]:
def find_typos(text):
    sentences = re.split(r'(?<=[.!?])\s+', text)
    typos = defaultdict(list)
    
    for sentence in sentences:
        words = re.findall(r"\b[a-zA-Z]+\b", sentence.lower())
        # print(words)
        
        for word in words:
            if word in english_vocab or len(word) <= 2:
                continue
            
            # print(word)
            suggestions = sym_spell.lookup(word, Verbosity.CLOSEST, max_edit_distance=2, include_unknown=True)
            best = suggestions[0]
            
            if (best.term != word and best.distance > 0) or (best.count < 10):
                if sentence.strip() not in typos[word]:  
                    typos[word].append(sentence.strip())

    return typos

In [16]:
def should_correct(entity_label, suggestion):
    if entity_label in {"PERSON", "GPE", "LOC"}:
        return False

    if entity_label == "ORG":
        if suggestion in sym_spell.words:
            return True
        return False
    
    if entity_label is None:
        return suggestion in sym_spell.words

    return False

In [17]:
ocr_confusions = {
    'rn': ['m'],
    'm': ['rn'],
    'l': ['t', 'i'],
    't': ['l'],
    '0': ['o'],
    '1': ['l', 'i'],
    'o': ['0'],
    'n': ['ri', 'ni'],
    'vv': ['w'],
    'w': ['vv'],
    'e':['c']
}

def expand_ocr_variants(word):
    variants = set()
    for pattern, subs in ocr_confusions.items():
        if pattern in word:
            for s in subs:
                variants.add(re.sub(pattern, s, word))
    return variants

In [18]:
def suggest_words(word):
    suggest_symspell = sym_spell.lookup(word, Verbosity.CLOSEST, max_edit_distance=2, include_unknown=True)
    
    valid_suggestions = [s for s in suggest_symspell if s.count > 10]
    if valid_suggestions: best = valid_suggestions[0].term
    else: best = suggest_symspell[0].term

    variants = expand_ocr_variants(word)

    final_s = set()
    final_s.add(best)

    for var in variants:
        if var in sym_spell.words:
            final_s.add(var)
    return final_s

In [29]:
def output_highest(model, tokenizer, sentence, cands, word):
    sentence = sentence.lower()
    window = 500

    if (len(sentence) >= window*2): 
        idx = sentence.find(word)
        start = max(0, idx - window)
        end = min(len(sentence), idx + len(word) + window)
        sentence = sentence[start:end]

    sentence = sentence.replace(word,'[MASK]')

    inputs = tokenizer(sentence, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
    
    mask_token_id = tokenizer.mask_token_id
    mask_token_index = (inputs.input_ids == mask_token_id).nonzero(as_tuple=True)[1]
    mask_logits = logits[0, mask_token_index, :]

    candidate_ids = tokenizer.convert_tokens_to_ids(cands)
    candidate_scores = mask_logits[0, candidate_ids]
    
    best_candidate_index = candidate_scores.argmax().item()
    best_candidate = cands[best_candidate_index]
    return best_candidate

In [30]:
def context_correct(typos):
    corrections = []

    for word, sentences in tqdm(typos.items(), desc='Processing words'):
        for sentence in sentences:
            doc = ner(sentence)
            entity_label = None

            # cari entity label kata typo
            for ent in doc:
                if word.lower() in ent["word"].lower():
                    entity_label = ent["entity_group"]
                    break
            
            suggestions = suggest_words(word)
            if not suggestions:
                continue #kalau gada suggestions

            best_word = output_highest(model, tokenizer, sentence, list(suggestions), word)
            if(not should_correct(entity_label, best_word)): continue
            corrections.append({"word":word, "correction":best_word})
            
    return corrections

In [31]:
start, n = 479, 662

for i in tqdm(range(start, n), desc='Iterating data'):
    with open(f"../../data/raw/ocr_result/ocr_{i}.txt", 'r', encoding="utf-8", errors="ignore") as file:
        text = file.read()
    
    typos = find_typos(text)
    res = context_correct(typos)
    corrected_text = text
    for obj in res:
        wrong = re.escape(obj['word'])
        corr = obj['correction']
        corrected_text = re.sub(rf"\b{wrong}\b", corr, corrected_text, flags=re.IGNORECASE)
    
    with open(f"./symspell_res/res_{i}.txt", "w",encoding="utf-8", errors="ignore") as file:
        file.write(corrected_text)

Processing words: 100%|██████████| 131/131 [01:53<00:00,  1.16it/s]
Processing words: 100%|██████████| 4/4 [00:02<00:00,  1.68it/s]it]
Processing words: 100%|██████████| 10/10 [00:05<00:00,  1.96it/s] 
Processing words: 100%|██████████| 8/8 [00:02<00:00,  2.76it/s]t]
Processing words: 100%|██████████| 12/12 [00:04<00:00,  2.65it/s]
Processing words: 100%|██████████| 42/42 [00:05<00:00,  7.27it/s]
Processing words: 100%|██████████| 15/15 [00:01<00:00,  7.65it/s]
Processing words: 0it [00:00, ?it/s]83 [02:16<23:18,  7.94s/it]
Processing words: 100%|██████████| 22/22 [00:08<00:00,  2.73it/s]
Processing words: 100%|██████████| 18/18 [00:02<00:00,  6.46it/s]
Processing words: 100%|██████████| 25/25 [00:05<00:00,  4.82it/s]
Processing words: 100%|██████████| 18/18 [00:14<00:00,  1.28it/s]
Processing words: 100%|██████████| 30/30 [00:10<00:00,  2.88it/s]
Processing words: 100%|██████████| 32/32 [00:09<00:00,  3.42it/s]
Processing words: 100%|██████████| 45/45 [00:06<00:00,  6.73it/s]
Processi