In [2]:
!pip install pyspellchecker
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, pipeline
from spellchecker import SpellChecker
import re
from difflib import SequenceMatcher

Collecting pyspellchecker
  Downloading pyspellchecker-0.8.4-py3-none-any.whl.metadata (9.4 kB)
Downloading pyspellchecker-0.8.4-py3-none-any.whl (7.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m43.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pyspellchecker
Successfully installed pyspellchecker-0.8.4


In [3]:
def normalize_elongated(word):
    return re.sub(r"(.)\1{2,}", r"\1\1", word)

class AdvancedSpellChecker:
    def __init__(self):
        self.spell = SpellChecker()
        self.custom_words = set()

        print("Loading language model...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Device set to use {self.device}")
        self.model_name = "bert-base-uncased"
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = AutoModelForMaskedLM.from_pretrained(self.model_name).to(self.device)
        self.model.eval()

        try:
            self.grammar_checker = pipeline(
                "text2text-generation",
                model="pszemraj/flan-t5-large-grammar-synthesis",
                device=0 if self.device == "cuda" else -1
            )
        except:
            self.grammar_checker = None
            print("Grammar checker not available")

    def add_to_custom_dict(self, word):
        self.custom_words.add(word.lower())
        print(f"Added '{word}' to custom dictionary.")

    def similarity(self, a, b):
        return SequenceMatcher(None, a.lower(), b.lower()).ratio()

    # ---- BERT subword-safe contextual suggestions ----
    def get_contextual_suggestions(self, sentence, word_idx, original_word):
        words = sentence.split()
        masked_words = words.copy()
        masked_words[word_idx] = "[MASK]"
        masked_sentence = " ".join(masked_words)

        try:
            inputs = self.tokenizer(masked_sentence, return_tensors="pt").to(self.device)
            with torch.no_grad():
                outputs = self.model(**inputs)

            mask_token_index = (inputs.input_ids == self.tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
            if len(mask_token_index) == 0:
                return []

            logits = outputs.logits[0, mask_token_index, :]
            top_tokens = torch.topk(logits, 20, dim=1).indices[0].tolist()

            suggestions = []
            for tok in top_tokens:
                word_piece = self.tokenizer.decode([tok]).strip()

                # Reject BERT garbage
                if (not word_piece or
                    word_piece.startswith("[") or
                    word_piece.startswith("##") or
                    not word_piece.isalpha()):
                    continue

                # If BERT returned a subword joinable to original (rare)
                merged = self.tokenizer.convert_tokens_to_string([self.tokenizer.convert_ids_to_tokens(tok)])
                merged = merged.replace(" ", "")
                if merged.isalpha():
                    suggestions.append(merged)
                else:
                    suggestions.append(word_piece)

            return list(dict.fromkeys(suggestions))[:10]
        except:
            return []

    def check_word_in_context(self, sentence, word_idx):
        words = sentence.split()
        current_word = words[word_idx].lower()
        contextual_sugs = self.get_contextual_suggestions(sentence, word_idx, current_word)

        if not contextual_sugs:
            return None, []

        if current_word not in contextual_sugs[:5]:
            best_match = None
            best_similarity = 0
            for sug in contextual_sugs[:5]:
                sim = self.similarity(current_word, sug)
                if sim > best_similarity and sim > 0.5:
                    best_similarity = sim
                    best_match = sug
            return best_match, contextual_sugs

        return None, contextual_sugs

    def spell_check_basic(self, text):
        words = text.split()
        issues = []
        for idx, word in enumerate(words):
            clean_word = re.sub(r"[^\w]", "", word)
            lower = clean_word.lower()
            if not lower or lower in self.custom_words or not any(c.isalpha() for c in lower):
                continue
            if lower not in self.spell:
                correct = self.spell.correction(lower)
                candidates = list(self.spell.candidates(lower)) if self.spell.candidates(lower) else []
                issues.append({
                    "index": idx,
                    "original": word,
                    "clean_original": clean_word,
                    "basic_correction": correct if correct else clean_word,
                    "basic_suggestions": candidates
                })
        return issues

    def spell_check_advanced(self, text, use_context=True, check_valid_words=True):
        words = text.split()
        words = [normalize_elongated(w) for w in words]
        text = " ".join(words)


        issues = self.spell_check_basic(text)
        if not use_context:
            return self.format_results(text, issues)

        for issue in issues:
            contextual_sugs = self.get_contextual_suggestions(text, issue['index'], issue['clean_original'])
            issue['contextual_suggestions'] = contextual_sugs

            original_lower = issue['clean_original'].lower()
            best_contextual = None
            best_score = 0
            for sug in contextual_sugs[:5]:
                score = self.similarity(original_lower, sug)
                if score > best_score:
                    best_score = score
                    best_contextual = sug

            if best_contextual and best_score > 0.4:
                issue['recommended'] = best_contextual
                issue['confidence'] = 'high' if best_score > 0.6 else 'medium'
            else:
                issue['recommended'] = issue['basic_correction']
                issue['confidence'] = 'low'

        if check_valid_words:
            words = text.split()
            for idx, word in enumerate(words):
                clean_word = re.sub(r"[^\w]", "", word)
                lower = clean_word.lower()
                if not lower or lower in self.custom_words or not any(c.isalpha() for c in lower):
                    continue
                if any(issue['index'] == idx for issue in issues):
                    continue
                if lower in self.spell:
                    contextual_replacement, contextual_sugs = self.check_word_in_context(text, idx)
                    if contextual_replacement:
                        issues.append({
                            'index': idx,
                            'original': word,
                            'clean_original': clean_word,
                            'basic_correction': clean_word,
                            'basic_suggestions': [clean_word],
                            'contextual_suggestions': contextual_sugs,
                            'recommended': contextual_replacement,
                            'confidence': 'medium',
                            'type': 'context_error'
                        })

        issues.sort(key=lambda x: x['index'])
        return self.format_results(text, issues)

    def check_grammar(self, text):
        if not self.grammar_checker:
            return "Grammar checking not available"
        try:
            result = self.grammar_checker(text, max_length=512, num_return_sequences=1)
            return result[0]["generated_text"]
        except Exception as e:
            return f"Grammar check failed: {str(e)}"

    # ---- Output post-processing (capitalization + punctuation preserved) ----
    def apply_capitalization(self, original, replacement):
        if original.isupper():
            return replacement.upper()
        if original.istitle():
            return replacement.capitalize()
        return replacement

    def format_results(self, text, issues):
        words = text.split()
        corrected_words = words.copy()

        for issue in issues:
            idx = issue['index']
            original = words[idx]
            replacement = issue.get("recommended", issue["basic_correction"])

            prefix = re.match(r"^\W*", original).group()
            suffix = re.search(r"\W*$", original).group()

            replacement = self.apply_capitalization(
                original.strip("".join(set(prefix + suffix))),
                replacement
            )

            corrected_words[idx] = f"{prefix}{replacement}{suffix}"

        return {
            "original": text,
            "corrected": " ".join(corrected_words),
            "corrections": [(i['original'], i.get('recommended', i['basic_correction'])) for i in issues],
            "details": issues
        }


In [4]:
#Mini Demo Input UI for Colab
from IPython.display import display
import ipywidgets as widgets

text_input = widgets.Textarea(
    value="",
    placeholder="Enter text to spell-check...",
    description="Input:",
    layout=widgets.Layout(width="100%", height="120px")
)

output = widgets.Output()

run_button = widgets.Button(
    description="Run Spell Check",
    button_style="success",
    icon="check"
)

def run_demo(b):
    output.clear_output()
    with output:
        text = text_input.value.strip()
        if not text:
            print("Please enter some text.")
            return

        print("...Running spell + context + grammar models...")

        checker = AdvancedSpellChecker()
        result = checker.spell_check_advanced(text)

        print("\nORIGINAL:")
        print(text)
        print("\nCORRECTED:")
        print(result['corrected'])

        if result["details"]:
            print("\nCORRECTION DETAILS:")
            for detail in result["details"]:
                print(f" • '{detail['original']}' → '{detail['recommended']}'  [{detail.get('confidence','?')}]")

run_button.on_click(run_demo)
display(text_input, run_button, output)

Textarea(value='', description='Input:', layout=Layout(height='120px', width='100%'), placeholder='Enter text …

Button(button_style='success', description='Run Spell Check', icon='check', style=ButtonStyle())

Output()