# Context-sensitive Spelling Correction

The goal of the assignment is to implement context-sensitive spelling correction. The input of the code will be a set of text lines and the output will be the same lines with spelling mistakes fixed.

Submit the solution of the assignment to Moodle as a link to your GitHub repository containing this notebook.

Useful links:
- [Norvig's solution](https://norvig.com/spell-correct.html)
- [Norvig's dataset](https://norvig.com/big.txt)
- [Ngrams data](https://www.ngrams.info/download_coca.asp)

Grading:
- 60 points - Implement spelling correction
- 20 points - Justify your decisions
- 20 points - Evaluate on a test set


## Implement context-sensitive spelling correction

Your task is to implement context-sensitive spelling corrector using N-gram language model. The idea is to compute conditional probabilities of possible correction options. For example, the phrase "dking sport" should be fixed as "doing sport" not "dying sport", while "dking species" -- as "dying species".

The best way to start is to analyze [Norvig's solution](https://norvig.com/spell-correct.html) and [N-gram Language Models](https://web.stanford.edu/~jurafsky/slp3/3.pdf).

When solving this task, we expect you'll face (and successfully deal with) some problems or make up the ideas of the model improvement. Some of them are: 

- 

- taking into account keyboard layout and associated misspellings;
- efficiency improvement to make the solution faster;
- ...

Please don't forget to describe such cases, and what you decided to do with them, in the Justification section.

##### IMPORTANT:  
Your project should not be a mere code copy-paste from somewhere. You must provide:
- Your implementation
- Analysis of why the implemented approach is suggested
- Improvements of the original approach that you have chosen to implement

In [2]:
import re
import math
import pandas as pd
from tqdm import tqdm
from collections import Counter
from nltk.tokenize import word_tokenize
import random
import time
from typing import List


In [3]:
sentences = []

with open("/kaggle/input/nlp-assignment-1-training-dataset/eng_news_2024_300K-sentences.txt", "r", encoding="utf-8") as f:
    for line in tqdm(f):
        parts = line.strip().split("\t")
        if len(parts) == 2:
            _, sentence = parts

            # Remove citations and references like [1], (2024), etc.
            sentence = re.sub(r"\[\d+\]", "", sentence)
            sentence = re.sub(r"\(\d{4}\)", "", sentence)
        
            # Remove special characters, numbers, and punctuation
            sentence = re.sub(r"[^a-zA-Z\s]", "", sentence)
        
            # Lowercase and remove extra spaces
            sentence = sentence.lower().strip()
            sentence = re.sub(r"\s+", " ", sentence)

            tokens = word_tokenize(sentence)
            sentences.append(tokens)

with open("/kaggle/input/nlp-assignment-1-training-dataset/eng-simple_wikipedia_2021_300K-sentences.txt", "r", encoding="utf-8") as f:
    for line in tqdm(f):
        parts = line.strip().split("\t")
        if len(parts) == 2:
            _, sentence = parts

            # Remove citations and references like [1], (2024), etc.
            sentence = re.sub(r"\[\d+\]", "", sentence)
            sentence = re.sub(r"\(\d{4}\)", "", sentence)
        
            # Remove special characters, numbers, and punctuation
            sentence = re.sub(r"[^a-zA-Z\s]", "", sentence)
        
            # Lowercase and remove extra spaces
            sentence = sentence.lower().strip()
            sentence = re.sub(r"\s+", " ", sentence)

            tokens = word_tokenize(sentence)
            sentences.append(tokens)



300000it [00:50, 5957.30it/s]
300000it [00:42, 7104.21it/s]


In [4]:
large_dataset = True
if large_dataset:
    text = ". ".join([" ".join(tokens) for tokens in sentences])
else:
    training_file = "/kaggle/input/bigtxt/big.txt"
    with open(training_file, encoding='utf-8') as f:
        text = f.read()


In [5]:
class LangModel:
    def __init__(self, k=0.01):
        # Smoothing parameter.
        self.k = k
        self.word_to_id = {}
        self.id_to_word = []
        self.unigrams = Counter()
        self.bigrams = Counter()
        self.trigrams = Counter()
        self.total_words = 0
        self.vocab_size = 0
        self.unknown_word_id = -1
        
    def tokenize_text(self, text):
        """Tokenizes the input text into sentences and words."""
        sentences = [s.strip() for s in re.split(r'[.!?]+', text) if s.strip()]
        return [re.findall(r'\b\w+\b', sentence) for sentence in sentences]

    def get_word_id(self, word):
        """Returns the numeric ID for a given word, creating a new ID if unseen."""
        if word not in self.word_to_id:
            word_id = len(self.word_to_id)
            self.word_to_id[word] = word_id
            self.id_to_word.append(word)
        return self.word_to_id[word]

    def train(self, text):
        """Train the language model from a given text file."""
        sentences = self.tokenize_text(text.lower())
        for sentence in sentences:
            word_ids = [self.get_word_id(word) for word in sentence]
            self.total_words += len(word_ids)
            
            for i, wid in enumerate(word_ids):
                self.unigrams[wid] += 1
                if i < len(word_ids) - 1:
                    self.bigrams[(word_ids[i], word_ids[i+1])] += 1
                if i < len(word_ids) - 2:
                    self.trigrams[(word_ids[i], word_ids[i+1], word_ids[i+2])] += 1
        
        self.vocab_size = len(self.word_to_id)
        print(f"Total sentences: {len(sentences)}")
        print(f"Vocabulary size: {self.vocab_size}")
        print(f"Total words: {self.total_words}")
        return True

    def get_smoothed_probability(self, numerator, denominator):
        """Applies Laplace smoothing and returns the probability."""
        return (numerator + self.k) / (denominator + self.k * self.vocab_size)
    
    def get_gram1_prob(self, word_id):
        """Returns the unigram probability with smoothing. If word_id is unknown, we assume it has a count of 0."""
        count = self.unigrams.get(word_id, 0)
        return self.get_smoothed_probability(count, self.total_words)
    
    def get_gram2_prob(self, word1_id, word2_id):
        """Returns the bigram probability with smoothing."""
        if word1_id == self.unknown_word_id or word2_id == self.unknown_word_id:
            return self.k / (self.total_words + self.k)
    
        unigram_count = self.unigrams.get(word1_id, 0)
        bigram_count = self.bigrams.get((word1_id, word2_id), 0)
    
        # Ensure bigram count does not exceed unigram count
        bigram_count = min(bigram_count, unigram_count)
    
        return self.get_smoothed_probability(bigram_count, unigram_count + self.total_words)
    
    def get_gram3_prob(self, word1_id, word2_id, word3_id):
        """Returns the trigram probability with smoothing."""
        if any(word_id == self.unknown_word_id for word_id in (word1_id, word2_id, word3_id)):
            return self.k / (self.total_words + self.k)
    
        bigram_count = self.bigrams.get((word1_id, word2_id), 0)
        trigram_count = self.trigrams.get((word1_id, word2_id, word3_id), 0)
    
        # Ensure trigram count does not exceed bigram count
        trigram_count = min(trigram_count, bigram_count)
    
        return self.get_smoothed_probability(trigram_count, bigram_count + self.total_words)

    def score(self, words):
        """Computes the log probability score for a sequence of words."""
        # Convert words to IDs (use unknown_word_id if the word is not known)
        sentence = [self.word_to_id.get(w, self.unknown_word_id) for w in words]
        # Append two unknown word IDs.
        sentence.extend([self.unknown_word_id] * 2)
        
        if not sentence:
            return float('-inf')
        
        result = 0.0
        # For each trigram window, add log probabilities.
        for i in range(len(sentence) - 2):
            prob1 = self.get_gram1_prob(sentence[i])
            prob2 = self.get_gram2_prob(sentence[i], sentence[i + 1])
            prob3 = self.get_gram3_prob(sentence[i], sentence[i + 1], sentence[i + 2])
            # Avoid taking log(0) by using a very small number.
            result += math.log(prob1 if prob1 > 0 else 1e-10)
            result += math.log(prob2 if prob2 > 0 else 1e-10)
            result += math.log(prob3 if prob3 > 0 else 1e-10)
        return result


In [6]:
def get_deletes1(word):
    """Return all strings formed by deleting one character from 'word'."""
    return [word[:i] + word[i+1:] for i in range(len(word)) if (word[:i] + word[i+1:])]

def get_deletes2(word):
    """Return a list of lists; each inner list contains candidates from a two‐level deletion."""
    results = []
    for i in range(len(word)):
        nw = word[:i] + word[i+1:]
        if nw:
            group = get_deletes1(nw)
            group.append(nw)
            results.append(group)
    return results


class SpellCorrector:
    def __init__(self, lang_model):
        self.lang_model = lang_model  

        # Penalty parameters for candidate scoring
        self.known_words_penalty = 20.0
        self.unknown_words_penalty = 5.0
        self.max_candidates_to_check = 14


    def word_is_known(self, word):
        """Return True if the word is known by the language model."""
        return word in self.lang_model.word_to_id

    def edits(self, word):
        """Generate candidate corrections using two-level deletions."""
        result = set()
        dels = get_deletes2(word)
        for group in dels:
            for cand in group:
                if self.word_is_known(cand):
                    result.add(cand)
        return list(result)

    def edits2(self, word):
        """Generate candidate corrections using one-level edits: deletion, transposition, replacement, and insertion."""
        alphabet = "abcdefghijklmnopqrstuvwxyz"
        splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
        
        deletes = {L + R[1:] for L, R in splits if R}
        transposes = {L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1}
        replaces = {L + c + R[1:] for L, R in splits if R for c in alphabet}
        inserts = {L + c + R for L, R in splits for c in alphabet}
        
        return list(deletes | transposes | replaces | inserts)

    def filter_candidates_by_frequency(self, unique_candidates, orig_word):
        """If too many candidate corrections exist, only keep the ones with the highest frequency."""
        if len(unique_candidates) <= self.max_candidates_to_check:
            return
        
        candidate_counts = [(self.lang_model.unigrams.get(self.lang_model.word_to_id.get(cand), 0), cand)
                            for cand in unique_candidates]
        
        candidate_counts.sort(key=lambda x: x[0], reverse=True)
        
        unique_candidates.clear()
        for i in range(min(self.max_candidates_to_check, len(candidate_counts))):
            unique_candidates.add(candidate_counts[i][1])
        unique_candidates.add(orig_word)

    def get_candidates(self, sentence, position):
        """Return a list of candidate words for the target word in the sentence."""
        if position >= len(sentence):
            return []

        orig_word = sentence[position]
        
        # Try first the edits2 candidates.
        candidates = self.edits2(orig_word)
        first_level = True
        if not candidates:
            candidates = self.edits(orig_word)
            first_level = False

        if not candidates:
            return []

        # If the original word is known, add it.
        known_word = self.word_is_known(orig_word)
        candidates.append(orig_word)

        # Remove duplicates and filter by frequency.
        unique_candidates = set(candidates)
        self.filter_candidates_by_frequency(unique_candidates, orig_word)

        scored_candidates = []
        for cand in unique_candidates:
            context = self._build_context(sentence, position, cand)
            score = self.lang_model.score(context)
            score = self._apply_scoring_penalties(score, orig_word, known_word, first_level, cand)
            scored_candidates.append((cand, score))

        # Sort candidates by descending score.
        scored_candidates.sort(key=lambda x: x[1], reverse=True)
        return [cand for cand, _ in scored_candidates]

    def _build_context(self, sentence, position, candidate):
        """Build a small context window around the candidate word."""
        context = []
        for i, w in enumerate(sentence):
            if i == position:
                context.append(candidate)
            elif (i < position and i + 2 >= position) or (i > position and i <= position + 2):
                context.append(w)
        return context

    def _apply_scoring_penalties(self, score, orig_word, known_word, first_level, candidate):
        """Apply penalties based on the known/unknown word status and correction level."""
        if candidate != orig_word:
            if known_word:
                if first_level:
                    score -= self.known_words_penalty
                else:
                    score *= 50.0
            else:
                score -= self.unknown_words_penalty
        return score

    def fix_fragment(self, text):
        """Correct a text fragment while preserving some original formatting."""
        # Split into tokens (words + punctuation)
        tokens = re.findall(r"\w+[’']*\w*|[.!?,;:]", text)
        
        # Correct words while leaving punctuation untouched
        corrected_tokens = []
        for token in tokens:
            if token.isalpha():
                # Process only alphabetic tokens
                corrected = self._correct_word(token, tokens, len(corrected_tokens))
                corrected_tokens.append(corrected)
            else:
                # Keep punctuation as-is
                corrected_tokens.append(token)
        
        # Reconstruct the text with original spacing (simplified)
        return ' '.join(corrected_tokens).replace(' ,', ',').replace(' .', '.')

    def _correct_word(self, word, words, index):
        """Correct a single word in the sentence."""
        candidates = self.get_candidates(words, index)
        if candidates:
            corrected = candidates[0]
            if word.istitle():
                corrected = corrected.capitalize()
            return corrected
        return word


In [7]:
model = LangModel()
if model.train(text):
    print("Model trained successfully!")

corrector = SpellCorrector(model)
print("Spell corrector created!")

Total sentences: 600000
Vocabulary size: 236442
Total words: 9924458
Model trained successfully!
Spell corrector created!


# As we can see here, our model is a little bit sensitive for punctuation (spaces btw them), the same for Norvig. Therefore, I decided to use dataset without punctuation for validation part

In [8]:
print(corrector.fix_fragment("He likes to eat applle.") == "He likes to eat apple.", corrector.fix_fragment("He likes to eat applle."))
print(corrector.fix_fragment("She doesn't know how to fix teh problem.") == "She doesn't know how to fix the problem.", corrector.fix_fragment("She doesn't know how to fix teh problem."))
print(corrector.fix_fragment("iPHone is expensive."))
print(corrector.fix_fragment("Mr. Smth went to paris."))
print(corrector.fix_fragment("Wait... did you mean that??"))
print(corrector.fix_fragment("state-of-the-arte technology."))
print(corrector.fix_fragment("“Hello, it’s nice here!” she said."))


True He likes to eat apple.
False The doesn't know how to fix the problem.
iPHone is expensive.
Or. Smth went to paris.
Wait... did you mean that ? ?
state of the art technology.
Hello, it’s nice here ! she said.


## Justify your decisions

Write down justificaitons for your implementation choices. For example, these choices could be:
- Which ngram dataset to use
- Which weights to assign for edit1, edit2 or absent words probabilities
- Beam search parameters
- etc.

# Dataset
I decided to collect [300k Wikipedia + 300k News](https://wortschatz.uni-leipzig.de/en/download/English) sentences of the last year, and preprocess it removing citations and references (like [1], (2024), etc.), special characters, numbers, and punctuation. And after that lowercasing and removing extra spaces

For evaluating I take  [2014 European Union WEB 10k](https://wortschatz.uni-leipzig.de/en/download/English) dataset. Preprocessing was the same. After that I go through the sentences and generate typos with some probability
# Choosing n in ngram model
The choice of n is cruicial in ngram model, since it strictly leads to accuracy. While searching in articles, I found that mostly n = 2-3 showed best performance. I took n = 3 to have more context for training.
# Smoothing
To avoid recieving zero probability for n-gram that was not found in the training text, I applied Laplace smoothing. The model generalizes better with it when faced with new or rare word combinations.
# Generate correction candidates
1. Primary I generate correction using Insertions, Replacements, Transpositions, and Deletions. This broad search ensures most plausible corrections are considered.
2. Since the depth = 2 recursive algorithm of 1st metod is slow and generates many candidates, I decided to save only 2 depth deletion. So, if first method didn't show valid candidate I use two-level deletions (e.g., "test" → "tst" → "st"). 
3. If upper method didn't worked give original word
# Candidate selection
Due to large amount of correction candidates generation, we need somehow to bound them. So I filter candidates based on their frequency of occurrence and take only first 14. This frequency-based filtering ensures that only the most probable candidates are considered, improving efficiency and reducing the computational load.
# Penalties
To avoid over-correcting valid known words I added penalties to prefer existing, more probable words. Additional penalty is applied to make corrections allow for likely typos but still slightly penalize them to prefer the original word if it’s valid (e.g., rare words, names). Moreover, there exist multiplication pinalty. It reflects that two deletions are less likely to be valid corrections unless absolutely necessary.
# Context
Each candidate word's score is computed by considering its probability in the given context, which is determined by the trigram model. The final score combines the likelihood of the candidate in the context with penalties for known and unknown words.
# Efficiency
A key challenge in this task is the need to handle large datasets and quickly generate candidate corrections. To improve efficiency, I limited the number of candidates considered and used a simple heuristic to filter candidates by frequency. This reduces the number of computations and speeds up the spelling correction process.

Additionally, the model uses a pre-built word_to_id dictionary to quickly look up word IDs, which further speeds up both training and inference phases.
# Keyboard Layout and Misspellings
I didn't implement it in my approach

# My approach
My approach was based on N-grams method with slightly modifications: reducing number of generation correction candidates, smoothing, filtering and bounding candidates. Moreover adding penalties to make context better and reduce misspelling correct words. 

## Evaluate on a test set

Your task is to generate a test set and evaluate your work. You may vary the noise probability to generate different datasets with varying compexity (or just take another dataset). Compare your solution to the Norvig's corrector, and report the accuracies.

In [9]:
evaluate_sentences = []

with open("/kaggle/input/nlp-assignment-1-training-dataset/eng-eu_web_2014_10K-sentences.txt", "r", encoding="utf-8") as f:
    for line in tqdm(f):
        parts = line.strip().split("\t")
        if len(parts) == 2:
            _, sentence = parts

            # Remove citations and references like [1], (2024), etc.
            sentence = re.sub(r"\[\d+\]", "", sentence)
            sentence = re.sub(r"\(\d{4}\)", "", sentence)
        
            # Remove special characters, numbers, and punctuation
            sentence = re.sub(r"[^a-zA-Z\s]", "", sentence)
        
            # Lowercase and remove extra spaces
            sentence = sentence.lower().strip()
            sentence = re.sub(r"\s+", " ", sentence)

            tokens = word_tokenize(sentence)
            evaluate_sentences.append(tokens)


10000it [00:01, 5889.64it/s]


In [10]:
number_sentences = 1000
evaluate_sentences = evaluate_sentences[:number_sentences]

In [11]:
def words(text): return re.findall(r'\w+', text.lower())

WORDS = Counter(words(open('/kaggle/input/bigtxt/big.txt').read()))

def P(word, N=sum(WORDS.values())): 
    "Probability of `word`."
    return WORDS[word] / N

def correction(word): 
    "Most probable spelling correction for word."
    return max(candidates(word), key=P)

def candidates(word): 
    "Generate possible spelling corrections for word."
    return (known([word]) or known(edits1(word)) or known(edits2(word)) or [word])

def known(words): 
    "The subset of `words` that appear in the dictionary of WORDS."
    return set(w for w in words if w in WORDS)

def edits1(word):
    "All edits that are one edit away from `word`."
    letters    = 'abcdefghijklmnopqrstuvwxyz'
    splits     = [(word[:i], word[i:])    for i in range(len(word) + 1)]
    deletes    = [L + R[1:]               for L, R in splits if R]
    transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
    replaces   = [L + c + R[1:]           for L, R in splits if R for c in letters]
    inserts    = [L + c + R               for L, R in splits for c in letters]
    return set(deletes + transposes + replaces + inserts)

def edits2(word): 
    "All edits that are two edits away from `word`."
    return (e2 for e1 in edits1(word) for e2 in edits1(e1))

def correct_sentence_function(sentence):
    """Corrects a sentence by applying word-level correction."""
    words = sentence.split()
    return ' '.join([correction(word) for word in words])


In [12]:
TYPO_PROB = 0.03  # Chance of making a typo for a single letter
SECOND_TYPO_CF = 0.2  # Chance of making two typos, relative to TYPO_PROB
REPLACE_PROB = 0.7
INSERT_PROB = 0.1
REMOVE_PROB = 0.1
TRANSPOSE_PROB = 0.1

def add_typos(word):
    if not word:
        return word

    typo_word = list(word)
    typo_applied = False

    i = 0
    while i < len(typo_word):
        if random.random() < TYPO_PROB:
            typo_applied = True
            typo_type = random.choices(
                ['replace', 'insert', 'remove', 'transpose'],
                [REPLACE_PROB, INSERT_PROB, REMOVE_PROB, TRANSPOSE_PROB]
            )[0]

            if typo_type == 'replace' and len(typo_word) > 1:
                typo_word[i] = random.choice('abcdefghijklmnopqrstuvwxyz')
            elif typo_type == 'insert':
                typo_word.insert(i, random.choice('abcdefghijklmnopqrstuvwxyz'))
                i += 1
            elif typo_type == 'remove' and len(typo_word) > 1:
                typo_word.pop(i)
                i -= 1
            elif typo_type == 'transpose' and i + 1 < len(typo_word):
                typo_word[i], typo_word[i + 1] = typo_word[i + 1], typo_word[i]
                i += 1

        i += 1

    return ''.join(typo_word) if typo_applied else word

def generate_typo_sentences(sentences):
    typo_sentences = []
    typo_flags = []

    for sentence in sentences:
        typo_sentence = []
        typo_flag = []
        for word in sentence:
            typo_word = add_typos(word)
            typo_sentence.append(typo_word)
            typo_flag.append(typo_word != word)
        typo_sentences.append(typo_sentence)
        typo_flags.append(typo_flag)

    return typo_sentences, typo_flags

def assess_spell_checker(original, typo, processed, flags):
    errored, fixed, broken = 0, 0, 0
    total_words = sum(len(s) for s in original)

    for orig, typo, proc, flag in zip(original, typo, processed, flags):
        for o, t, p, f in zip(orig, typo, proc, flag):
            if f:
                errored += 1
                if p == o:
                    fixed += 1
            elif p != o:
                broken += 1

    error_rate = (errored / total_words) * 100
    fix_rate = (fixed / errored * 100) if errored else 0
    broken_rate = (broken / total_words) * 100

    return {
        "errors": error_rate,
        "fix_rate": fix_rate,
        "broken": broken_rate
    }

def evaluate_spell_checkers(sentences):
    typo_sentences, typo_flags = generate_typo_sentences(sentences)
    
    start_time = time.time()
    processed_1 = [corrector.fix_fragment(" ".join(s)).split() for s in typo_sentences]
    assessment_1 = assess_spell_checker(sentences, typo_sentences, processed_1, typo_flags)
    processing_time_1 = time.time() - start_time
  
    start_time = time.time()
    processed_2 = [correct_sentence_function(" ".join(s)).split() for s in typo_sentences]
    assessment_2 = assess_spell_checker(sentences, typo_sentences, processed_2, typo_flags)
    processing_time_2 = time.time() - start_time

    total_words = sum(len(s) for s in sentences)
    words_per_second_1 = total_words / processing_time_1 if processing_time_1 > 0 else 0
    words_per_second_2 = total_words / processing_time_2 if processing_time_2 > 0 else 0

    return {
        "model_1": assessment_1,
        "model_2": assessment_2,
        "words_per_second_1": words_per_second_1,
        "words_per_second_2": words_per_second_2
    }

result = evaluate_spell_checkers(evaluate_sentences)
print("Total percent of errored words:", result['model_1']['errors'])
print("N-gram model: fix-rate =", result['model_1']['fix_rate'], "; broken-rate =", result['model_1']['broken'], "; words-per-second =", result["words_per_second_1"])
print("Norvig model: fix-rate =", result['model_2']['fix_rate'], "; broken-rate =", result['model_2']['broken'], "; words-per-second =", result["words_per_second_2"])

Total percent of errored words: 13.58706273206957
N-gram model: fix-rate = 76.77094570298453 ; broken-rate = 1.2556185264803597 ; words-per-second = 1583.5481727525726
Norvig model: fix-rate = 67.92520676015822 ; broken-rate = 5.804182137971468 ; words-per-second = 125.71400391755107


# As we can see my model is 10x faster than classic norvig model, fix rate better in average 9-10%, and broken-rate doesn't exeed 2% compared to 5-7% from Norvig.

# Overall my solution is much more better than simple Norvig approach

#### Useful resources (also included in the archive in moodle):

1. [Possible dataset with N-grams](https://www.ngrams.info/download_coca.asp)
2. [Damerau–Levenshtein distance](https://en.wikipedia.org/wiki/Damerau–Levenshtein_distance#:~:text=Informally%2C%20the%20Damerau–Levenshtein%20distance,one%20word%20into%20the%20other.)