<a href="https://colab.research.google.com/github/Shivansh1205/SarvMAI/blob/main/Auto_Spell_Checker.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import re, random, time, csv
from collections import defaultdict, Counter

# ---------- normalization & edit distance ----------
def normalize(word: str) -> str:
    w = word.lower().strip()
    w = re.sub(r'(a){2,}', 'a', w)
    w = re.sub(r'(i){2,}', 'i', w)
    w = re.sub(r'(e){2,}', 'e', w)
    w = re.sub(r'(o){2,}', 'o', w)
    w = re.sub(r'(u){2,}', 'u', w)
    w = w.replace('w','v').replace('ph','f').replace('bh','b')
    w = re.sub(r'[^a-z0-9]', '', w)
    return w

def levenshtein(a: str, b: str) -> int:
    if a == b: return 0
    if len(a) < len(b): a, b = b, a
    prev = list(range(len(b)+1))
    for i, ca in enumerate(a,1):
        curr = [i]
        for j, cb in enumerate(b,1):
            insert = curr[j-1]+1
            delete = prev[j]+1
            replace = prev[j-1] + (0 if ca==cb else 1)
            curr.append(min(insert, delete, replace))
        prev = curr
    return prev[-1]

# ---------- BK-tree ----------
class BKNode:
    def __init__(self, word):
        self.word = word
        self.children = {}

class BKTree:
    def __init__(self, dist_fn=levenshtein):
        self.root = None
        self.dist = dist_fn

    def add(self, word):
        if self.root is None:
            self.root = BKNode(word)
            return
        node = self.root
        while True:
            d = self.dist(word, node.word)
            if d in node.children:
                node = node.children[d]
            else:
                node.children[d] = BKNode(word)
                break

    def query(self, word, max_distance):
        if self.root is None: return []
        res = []
        stack = [self.root]
        while stack:
            node = stack.pop()
            d = self.dist(word, node.word)
            if d <= max_distance:
                res.append((node.word, d))
            low, high = d - max_distance, d + max_distance
            for cd, child in node.children.items():
                if low <= cd <= high:
                    stack.append(child)
        return res

# ---------- k-gram fallback ----------
def build_kgram_index(words, k=3):
    idx = defaultdict(set)
    for w in words:
        norm = normalize(w)
        grams = set([norm[i:i+k] for i in range(len(norm)-k+1)]) if len(norm) >= k else {norm}
        for g in grams:
            idx[g].add(w)
    return idx

def top_k_by_kgram_overlap(query, kgram_index, k=3, top_n=20):
    norm_q = normalize(query)
    grams = set([norm_q[i:i+k] for i in range(len(norm_q)-k+1)]) if len(norm_q) >= k else {norm_q}
    counter = Counter()
    for g in grams:
        for w in kgram_index.get(g, []):
            counter[w] += 1
    return [w for w,_ in counter.most_common(top_n)]

# ---------- build structures ----------
with open("reference.txt", encoding="utf-8") as f:
    reference = [line.strip() for line in f if line.strip()]

# norm_map: normalized -> canonical
norm_map = {}
for w in reference:
    norm_map[normalize(w)] = w

bk = BKTree()
for norm_w in set(normalize(w) for w in reference):
    bk.add(norm_w)
kgram_index = build_kgram_index(reference, k=3)

def match_best(input_word, bk_tree, kgram_index, norm_map, bk_radius=2):
    norm_in = normalize(input_word)
    # BK-tree search
    bk_res = bk_tree.query(norm_in, max_distance=bk_radius)
    candidates = []
    if bk_res:
        candidates.extend(bk_res)
    else:
        # fallback via k-gram overlap
        ks = top_k_by_kgram_overlap(input_word, kgram_index, k=3, top_n=20)
        for w in ks:
            norm_w = normalize(w)
            d = levenshtein(norm_in, norm_w)
            candidates.append((norm_w, d))
    if not candidates:
        # brute force
        best = None; min_d = None
        for norm_ref in norm_map:
            d = levenshtein(norm_in, norm_ref)
            if min_d is None or d < min_d:
                min_d = d; best = norm_ref
        return norm_map.get(best,best), min_d
    candidates.sort(key=lambda x: x[1])
    best_norm, best_dist = candidates[0]
    return norm_map.get(best_norm, best_norm), best_dist

# ---------- synthetic test set ----------
random.seed(42)
def stretch_vowels(word):
    return "".join(c*random.choice([1,2]) if c.lower() in "aeiou" else c for c in word)
def typo_delete(word):
    if len(word) <= 1: return word
    i = random.randrange(len(word))
    return word[:i] + word[i+1:]
def typo_swap(word):
    if len(word) < 2: return word
    i = random.randrange(len(word)-1)
    lst = list(word); lst[i], lst[i+1] = lst[i+1], lst[i]
    return "".join(lst)
def case_variant(word):
    return word.lower() if word.isupper() else word.upper()

test_pairs = []
for w in random.sample(reference, 30):
    variants = set([
        stretch_vowels(w),
        typo_swap(w),
        typo_delete(w),
        case_variant(w),
        typo_swap(stretch_vowels(w))
    ])
    for v in variants:
        test_pairs.append((v, w))

# ---------- evaluation ----------
total = len(test_pairs)
exact = 0
top3 = 0
edit_sum = 0.0
mismatches = []

start = time.perf_counter()
for err, gold in test_pairs:
    best, dist = match_best(err, bk, kgram_index, norm_map, bk_radius=2)
    # get top3 for recall
    # gather candidates similarly
    candidates = []
    bk_res = bk.query(normalize(err), max_distance=2)
    if bk_res:
        candidates.extend(bk_res)
    else:
        ks = top_k_by_kgram_overlap(err, kgram_index, k=3, top_n=20)
        for w in ks:
            norm_w = normalize(w)
            d = levenshtein(normalize(err), norm_w)
            candidates.append((norm_w, d))
    if not candidates:
        for nr in norm_map:
            candidates.append((nr, levenshtein(normalize(err), nr)))
    candidates.sort(key=lambda x: x[1])
    top3_list = [norm_map.get(n, n) for n,_ in candidates[:3]]

    if best == gold:
        exact += 1
    if gold in top3_list:
        top3 += 1
    edit_sum += levenshtein(best.lower(), gold.lower()) / max(len(best), len(gold), 1)
    if best != gold:
        mismatches.append((err, gold, top3_list))
end = time.perf_counter()

exact_acc = exact / total
top3_recall = top3 / total
avg_edit = edit_sum / total
time_taken = end - start
per_word = time_taken / total if total else 0

print("=== Evaluation ===")
print(f"Total cases: {total}")
print(f"Exact-match accuracy: {exact}/{total} = {exact_acc:.2%}")
print(f"Top-3 recall: {top3}/{total} = {top3_recall:.2%}")
print(f"Avg normalized edit distance: {avg_edit:.3f}")
print(f"Total time: {time_taken:.4f}s, per-word: {per_word*1000:.2f}ms")
print("Sample mismatches:", mismatches[:5])

# save proof
with open("gold_bktest.tsv","w",encoding="utf-8",newline="") as f:
    w=csv.writer(f,delimiter="\t")
    w.writerow(["error","gold"])
    for err,gold in test_pairs:
        w.writerow([err,gold])
with open("predicted_bktest.tsv","w",encoding="utf-8",newline="") as f:
    w=csv.writer(f,delimiter="\t")
    w.writerow(["error","predicted"])
    for err,gold in test_pairs:
        best,_=match_best(err,bk,kgram_index,norm_map,bk_radius=2)
        w.writerow([err,best])
print("Saved gold and predicted for proof.")


=== Evaluation ===
Total cases: 147
Exact-match accuracy: 136/147 = 92.52%
Top-3 recall: 147/147 = 100.00%
Avg normalized edit distance: 0.032
Total time: 0.1160s, per-word: 0.79ms
Sample mismatches: [('Shail', 'Sahil', ['Sunil', 'Sahil']), ('aKviitaa', 'Kavita', ['Savita', 'Kavita']), ('Rkeha', 'Rekha', ['Sneha', 'Rekha', 'Neha']), ('Kaal', 'Kajal', ['Ram', 'Kajal', 'Aam']), ('rPitii', 'Priti', ['Ritu', 'Priti'])]
Saved gold and predicted for proof.
