# Context-sensitive Spelling Correction

The goal of the assignment is to implement context-sensitive spelling correction. The input of the code will be a set of text lines and the output will be the same lines with spelling mistakes fixed.

Submit the solution of the assignment to Moodle as a link to your GitHub repository containing this notebook.

Useful links:
- [Norvig's solution](https://norvig.com/spell-correct.html)
- [Norvig's dataset](https://norvig.com/big.txt)
- [Ngrams data](https://www.ngrams.info/download_coca.asp)

Grading:
- 60 points - Implement spelling correction
- 20 points - Justify your decisions
- 20 points - Evaluate on a test set


## Implement context-sensitive spelling correction

Your task is to implement context-sensitive spelling corrector using N-gram language model. The idea is to compute conditional probabilities of possible correction options. For example, the phrase "dking sport" should be fixed as "doing sport" not "dying sport", while "dking species" -- as "dying species".

The best way to start is to analyze [Norvig's solution](https://norvig.com/spell-correct.html) and [N-gram Language Models](https://web.stanford.edu/~jurafsky/slp3/3.pdf).

When solving this task, we expect you'll face (and successfully deal with) some problems or make up the ideas of the model improvement. Some of them are: 

- solving a problem of n-grams frequencies storing for a large corpus;
- taking into account keyboard layout and associated misspellings;
- efficiency improvement to make the solution faster;
- ...

Please don't forget to describe such cases, and what you decided to do with them, in the Justification section.

##### IMPORTANT:  
Your project should not be a mere code copy-paste from somewhere. You must provide:
- Your implementation
- Analysis of why the implemented approach is suggested
- Improvements of the original approach that you have chosen to implement

In [1]:
import collections
import re

###############################
# Utility and Data Preparation
###############################

def words(text):
    """Extract words from text, lowercasing all characters."""
    return re.findall(r'\w+', text.lower())

def build_all_ngrams(file_path):
    """
    Build n-gram counts from a file.
    The file is assumed to have at least 6 columns:
      count word1 word2 word3 word4 word5
    and each line gives a five-gram.
    """
    unigrams = collections.Counter()
    bigrams = collections.Counter()
    trigrams = collections.Counter()
    fourgrams = collections.Counter()
    fivegrams = collections.Counter()
    
    with open(file_path, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            parts = line.strip().split()
            if len(parts) < 6:
                continue
            count = int(parts[0])
            tokens = [w.lower() for w in parts[1:6]]
            fivegrams[tuple(tokens)] += count
            # Two distinct fourgrams per fivegram
            fourgrams[tuple(tokens[0:4])] += count
            fourgrams[tuple(tokens[1:5])] += count
            # Three trigrams per fivegram
            trigrams[tuple(tokens[0:3])] += count
            trigrams[tuple(tokens[1:4])] += count
            trigrams[tuple(tokens[2:5])] += count
            # Four bigrams per fivegram
            bigrams[tuple(tokens[0:2])] += count
            bigrams[tuple(tokens[1:3])] += count
            bigrams[tuple(tokens[2:4])] += count
            bigrams[tuple(tokens[3:5])] += count
            # Unigrams from all tokens
            for token in tokens:
                unigrams[token] += count
    total_unigrams = sum(unigrams.values())
    total_bigrams = sum(bigrams.values())
    total_trigrams = sum(trigrams.values())
    total_fourgrams = sum(fourgrams.values())
    total_fivegrams = sum(fivegrams.values())
    return unigrams, bigrams, trigrams, fourgrams, fivegrams, total_unigrams, total_bigrams, total_trigrams, total_fourgrams, total_fivegrams

# Build n-gram counts from 'fivegrams.txt'
# (Make sure you have a file named 'fivegrams.txt' in the working directory)
unigrams, bigrams, trigrams, fourgrams, fivegrams, total_unigrams, total_bigrams, total_trigrams, total_fourgrams, total_fivegrams = build_all_ngrams('fivegrams.txt')

ngrams = {1: unigrams, 2: bigrams, 3: trigrams, 4: fourgrams, 5: fivegrams}
ngrams_total = {1: total_unigrams, 2: total_bigrams, 3: total_trigrams, 4: total_fourgrams, 5: total_fivegrams}

In [2]:
import math
from functools import lru_cache

#############################
# Norvig’s Baseline Corrector
#############################

def known(words_list):
    return set(w for w in words_list if w in unigrams)

@lru_cache(maxsize=None)
def candidates(word):
    """
    Generate candidate corrections for a word.
    Uses one or two edit distances. Returns a set.
    """
    def edits1(word):
        letters = 'abcdefghijklmnopqrstuvwxyz'
        splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
        deletes = [L + R[1:] for L, R in splits if R]
        transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
        replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
        inserts = [L + c + R for L, R in splits for c in letters]
        return set(deletes + transposes + replaces + inserts)

    def edits2(word):
        return (e2 for e1 in edits1(word) for e2 in edits1(e1))

    return known([word]) or known(edits1(word)) or known(edits2(word)) or {word}

@lru_cache(maxsize=None)
def norvig_word_probability(word):
    return unigrams[word] / total_unigrams

def norvig_correct_word(word):
    lower = word.lower()
    if lower in unigrams:
        return word
    return max(candidates(lower), key=lambda w: norvig_word_probability(w))

def norvig_correct_sentence(sentence):
    tokens = sentence.split()
    corrected_tokens = [norvig_correct_word(token) for token in tokens]
    return " ".join(corrected_tokens)


In [3]:
#############################
# SymSpell candidate generator
#############################

MAX_SYMSPELL_DISTANCE = 2

def delete_edit(word):
    """
    Generate all possible strings that result from deleting one character from word.
    """
    splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
    return {L + R[1:] for L, R in splits if R}

def precalculate_symspell(unigrams):
    """
    Pre-calculate a mapping from deletion edits to the set of original words that produced them.
    """
    symspell_mapping = {}
    for word in unigrams:
        edits = delete_edit(word)
        for edit in edits:
            if edit:
                symspell_mapping.setdefault(edit, set()).add(word)
    return symspell_mapping

symspell_mapping = precalculate_symspell(unigrams)

def edit_distance(s, t, max_distance):
    """
    Compute the Levenshtein edit distance between strings s and t.
    Early exit if the distance exceeds max_distance.
    """
    if abs(len(s) - len(t)) > max_distance:
        return max_distance + 1

    previous_row = list(range(len(t) + 1))
    for i, sc in enumerate(s, start=1):
        current_row = [i] + [0] * len(t)
        for j, tc in enumerate(t, start=1):
            cost = 0 if sc == tc else 1
            current_row[j] = min(previous_row[j] + 1,        # deletion
                                 current_row[j - 1] + 1,       # insertion
                                 previous_row[j - 1] + cost)   # substitution
        if min(current_row) > max_distance:
            return max_distance + 1
        previous_row = current_row
    return previous_row[len(t)]

def symspell_candidates(word, max_distance=2):
    """
    Generate candidate corrections for a word using the SymSpell algorithm.
    
    Assumes:
      - `unigrams` is a global set of valid words.
      - `symspell_mapping` is a global mapping from deletion edits to original words.
    
    Returns:
      A dictionary mapping candidate words to their edit distance from the query word.
    """
    suggestions = {}
    
    # If the word is an exact match, add it as a candidate with distance 0.
    suggestions[word] = 0

    visited = {word}
    candidates = {word}

    while candidates:
        candidate = candidates.pop()
        
        # Check if the candidate deletion exists in our mapping.
        if candidate in symspell_mapping:
            for suggestion in symspell_mapping[candidate]:
                if suggestion not in suggestions:
                    distance = edit_distance(word, suggestion, max_distance)
                    if distance <= max_distance:
                        suggestions[suggestion] = distance
        
        # Only expand further deletions if we haven't exceeded the max_distance.
        # The number of deletions already applied is len(word) - len(candidate)
        if len(word) - len(candidate) < max_distance:
            for deletion in delete_edit(candidate):
                if deletion not in visited:
                    visited.add(deletion)
                    candidates.add(deletion)
    
    return suggestions


In [4]:
#############################
# N-gram Spell Corrector
# With Interpolated Kneser-Ney Smoothing
# SymSpell Candidate Generator
# Edit Distance Error Model
#############################

ERROR_PENALTY = 0.001
KNESER_NEY_DISCOUNT = 0.75
EPSILON = 1e-10
MAX_NGRAM = 5
MIN_IMPROVEMENT_MARGIN = 0.5

#############################
# Pre-calculate context counts
#############################

def build_context_counts(ngrams):
    """
    Pre-calculate context counts for n-grams.
    
    Returns two dictionaries:
      - context_counts: mapping from a context (tuple) to the total frequency with which it appears.
      - context_unique_counts: mapping from a context (tuple) to the number of unique words following it.
    """
    context_counts = {}  # Total count of contexts
    context_unique_counts = {}  # Number of unique words that follow each context
    
    # For n-grams where n >= 2, the context is the first n-1 tokens.
    for n in range(2, MAX_NGRAM + 1):
        context_counts[n - 1] = collections.Counter()
        context_unique_counts[n - 1] = collections.Counter()
        for ngram, count in ngrams.get(n, {}).items():
            context = ngram[:-1]
            context_counts[n - 1][context] += count
            context_unique_counts[n - 1][context] += 1
    return context_counts, context_unique_counts

context_counts, context_unique_counts = build_context_counts(ngrams)

#############################
# Interpolated Kneser-Ney Smoothing
#############################
@lru_cache(maxsize=None)
def calculate_backoff(context):
    """
    Calculate the backoff weight for a given context.
    """
    context_len = len(context)
    if context_len == 0:
        return 0.0

    # Get the frequency of the context. If absent, assume 0.
    context_freq = context_counts.get(context_len, {}).get(context, 0)
    if context_freq == 0:
        return 0.0

    # Get the number of unique continuations.
    unique_count = context_unique_counts.get(context_len, {}).get(context, 0)
    return KNESER_NEY_DISCOUNT * unique_count / context_freq

@lru_cache(maxsize=None)
def kneser_ney_language_model(word, context):
    """
    Calculate the interpolated Kneser-Ney probability of a word given a context.
    
    Args:
        word: The target word (string).
        context: A tuple of preceding words.
    
    Returns:
        The probability of 'word' given 'context' under the interpolated Kneser-Ney model.
    """
    context_len = len(context)
    
    # Base case: empty context; use the unigram probability.
    if context_len == 0:
        return norvig_word_probability(word)
    
    # If context is longer than allowed, use the last (MAX_NGRAM-1) words.
    if context_len >= MAX_NGRAM:
        context = context[-(MAX_NGRAM - 1):]
        context_len = MAX_NGRAM - 1

    ngram_count = ngrams.get(context_len + 1, {}).get(context + (word,), 0)
    context_total = context_counts.get(context_len, {}).get(context, 0)

    
    if context_total == 0:
        return kneser_ney_language_model(word, context[1:])
    else:
        discounted_prob = max(ngram_count - KNESER_NEY_DISCOUNT, 0) / context_total

    backoff = calculate_backoff(context)
    
    # Recursive backoff to a lower-order model by removing the first word of the context.
    return discounted_prob + backoff * kneser_ney_language_model(word, context[1:])


#############################
# Error Model
#############################

def error_model(candidate_distance):
    """
    Calculate the probability of an error given the edit distance.
    """
    return ERROR_PENALTY ** candidate_distance


#############################
# N-gram Spell Corrector
#############################

def ngram_correct_word(word, context):
    """
    Correct a word using the N-gram model with interpolated Kneser-Ney smoothing.
    """
    if word in unigrams:
        return word
    # Compute score for the original word
    orig_lm_prob = kneser_ney_language_model(word, context)
    orig_score = orig_lm_prob * error_model(0)
    orig_log = math.log(orig_score + EPSILON)

    # Generate candidate corrections using SymSpell.
    candidates_dict = symspell_candidates(word, MAX_SYMSPELL_DISTANCE)
    candidate_scores = {}
    for candidate, distance in candidates_dict.items():
        lm_prob = kneser_ney_language_model(candidate, context)
        candidate_prob = lm_prob * error_model(distance)
        candidate_scores[candidate] = candidate_prob

    # Find the candidate with the highest probability
    best_candidate = max(candidate_scores, key=candidate_scores.get)
    best_log = math.log(candidate_scores[best_candidate] + EPSILON)

    # Only correct if the candidate is different and its score is better by the threshold margin.
    if best_candidate != word and (best_log - orig_log) > MIN_IMPROVEMENT_MARGIN:
        return best_candidate
    else:
        return word

def ngram_correct_sentence(sentence):
    """
    Correct a sentence using the N-gram model with interpolated Kneser-Ney smoothing.
    """
    tokens = sentence.lower().split()
    corrected_tokens = []
    for i, token in enumerate(tokens):
        context = tuple(tokens[max(0, i - MAX_NGRAM + 1):i])
        corrected_tokens.append(ngram_correct_word(token, context))
    return " ".join(corrected_tokens)



In [5]:
#############################
# Beam Search
#############################

BEAM_WIDTH = 10
LOOKAHEAD_DEPTH  = 2

def lookahead_score(tokens, index, context, depth):
    """
    Recursively estimate an optimistic additional log score for the next 'depth' tokens.
    
    For each token in the lookahead window, if the token exists in our vocabulary,
    we assume the best candidate is the token itself.
    """
    if index >= len(tokens) or depth == 0:
        return 0.0
    
    token = tokens[index]
    lm_prob = kneser_ney_language_model(token, context)
    log_prob = math.log(lm_prob + EPSILON)
    new_context = (context + (token,))[-(MAX_NGRAM - 1):]
    return log_prob + lookahead_score(tokens, index + 1, new_context, depth - 1)


def beam_search_sentence(sentence, beam_width=BEAM_WIDTH, lookahead_depth=LOOKAHEAD_DEPTH):
    """
    Correct a sentence using beam search with lookahead.
    
    At each token position, we generate candidate corrections (using symspell_candidates)
    and score them using our Kneser-Ney language model combined with the error model.
    A lookahead heuristic estimates the future score for remaining tokens.
    
    Returns:
      The corrected sentence (as a string).
    """
    tokens = sentence.lower().split()
    beam = [([], 0.0, tuple())]

    for i, token in enumerate(tokens):
        new_beam = []
        for corrected_tokens, cum_log_score, context in beam:
            # Compute the original token's log score in this context.
            orig_lm_prob = kneser_ney_language_model(token, context)
            orig_score = orig_lm_prob * error_model(0)
            orig_log = math.log(orig_score + EPSILON)

            # Generate candidate corrections.
            candidates = symspell_candidates(token, MAX_SYMSPELL_DISTANCE)

            for candidate, distance in candidates.items():
                # Compute candidate score.
                lm_prob = kneser_ney_language_model(candidate, context)
                candidate_prob = lm_prob * error_model(distance)
                log_candidate_prob = math.log(candidate_prob + EPSILON)

                # If the candidate is not the same as the original,
                # only allow it if its log score improves by the threshold margin.
                if candidate != token and (log_candidate_prob - orig_log) < MIN_IMPROVEMENT_MARGIN:
                    continue  # skip candidate

                new_cum_log = cum_log_score + log_candidate_prob
                new_context = (context + (candidate,))[-(MAX_NGRAM - 1):]

                heuristic = lookahead_score(tokens, i + 1, new_context, lookahead_depth)
                total_estimated_score = new_cum_log + heuristic

                new_beam.append((corrected_tokens + [candidate], new_cum_log, new_context, total_estimated_score))

        new_beam.sort(key=lambda x: x[3], reverse=True)
        beam = [(ct, cl, ctx) for ct, cl, ctx, _ in new_beam[:beam_width]]

    best_candidate = max(beam, key=lambda state: state[1])
    return " ".join(best_candidate[0])

## Justify your decisions

Write down justificaitons for your implementation choices. For example, these choices could be:
- Which ngram dataset to use
- Which weights to assign for edit1, edit2 or absent words probabilities
- Beam search parameters
- etc.

## Data:
---
We extact the ngrams from the provided "fivegrams.txt" dataset. Our function build_all_ngrams reads a file where each line includes a count and five tokens. We compute counts for unigrams through fivegrams and store totals. The resulting ngrams dictionary and total counts are later used by the language models.

## Spell Correction Algorithms:
---
### Norvig’s Baseline Corrector:
A well‐known method that uses one- and two-edit transformations to generate candidate corrections and then chooses the candidate with the highest unigram probability from a large corpus.

### N-gram Model with Interpolated Kneser–Ney Smoothing:
This approach leverages higher-order n‑gram statistics (up to five-grams) to compute the probability of a candidate word given its context. Kneser–Ney smoothing is applied to better handle unseen n‑grams, and a candidate’s likelihood is combined with an error model (based on edit distance) to penalize corrections that require many changes.

### Beam Search with Lookahead:
Building on the n‑gram model, this method performs a global search over the entire sentence. At each token, candidate corrections are scored, and the algorithm maintains a beam of top hypotheses. A lookahead heuristic (which also uses Kneser–Ney probabilities) is optionally incorporated to provide an estimate of future score, aiming to balance immediate corrections with longer-range context.

## Implementation improvements:
---
### SymSpell
A faster candidate generation algorithm, uses only 1 edit operation and precomputed neighbouring words.

### Interpolated Kneser-Ney smoothing
The core of our algorithms is the kneser_ney_language_model, which calculates the probability of a word given its context. It uses recursive backoff and discounting based on pre-calculated context counts.

### Error Model
We incorporate an error penalty based on edit distance (implemented as error_model). This model penalizes corrections that require more changes.

### Thresholding
A minimum improvement margin is used to decide whether to change a word. If a candidate’s (log) score isn’t better than that of the original word by a set threshold, the original is kept.

### Beam Search
A beam of top hypotheses is maintained to control search complexity. An optional lookahead function estimates the best future score by assuming the best candidate for upcoming tokens is the token itself (using Kneser–Ney probabilities). This value is weighted before being added to the cumulative score.

## Obstacles:
---
Although both n-gram approaches showed better in-context accuracy for spell correction, they greatly suffered from over-correction. To prevent our algorithms from replacing correct words we had to introduce distance based penalties and improvement thresholding.

## Evaluation:
---
### Dataset:
We evaluated our models on the Holbrook dataset, which provides sentences with error markup.

## Evaluate on a test set

Your task is to generate a test set and evaluate your work. You may vary the noise probability to generate different datasets with varying compexity (or just take another dataset). Compare your solution to the Norvig's corrector, and report the accuracies.

In [6]:
def parse_holbrook_sentence(line):
    """
    Given a line from the Holbrook dataset, return a tuple (erroneous, correct).
    We remove the error tags for the erroneous version, keeping the inner (erroneous) text.
    For the correct version, we replace the entire <ERR targ=...>...</ERR> with the value of targ.
    
    For example:
      "I have four ... and <ERR targ=sister> siter </ERR> ."
    returns:
      erroneous: "I have four ... and siter ."
      correct:   "I have four ... and sister ."
    """
    # Pattern to capture error tags: <ERR targ=XXX> YYY </ERR>
    pattern = re.compile(r'<ERR\s+targ=([^>]+)>\s*([^<]+)\s*</ERR>')
    
    # Create erroneous version: simply remove the tag markup.
    erroneous = pattern.sub(r'\2', line)
    # Create corrected version: replace entire tag with the target attribute.
    correct = pattern.sub(r'\1', line)
    
    # Remove any extra whitespace.
    erroneous = re.sub(r'\s+', ' ', erroneous).strip().lower()
    correct = re.sub(r'\s+', ' ', correct).strip().lower()
    return erroneous, correct

def evaluate_holbrook(corpus_file, correction_functions, num_samples=None):
    """
    Evaluate the correction functions using the Holbrook dataset.
    For each sentence, compute token-level accuracy: (# tokens matching gold) / (total gold tokens).
    Returns the average accuracy per correction function.
    """
    sentences = []
    with open(corpus_file, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            line = line.strip()
            if line:
                erroneous, correct = parse_holbrook_sentence(line)
                # Only consider sentences where corrections are present.
                if erroneous != correct:
                    sentences.append((erroneous, correct))
            if num_samples and len(sentences) >= num_samples:
                break
    
    metrics = {}  # Maps correction function name to a list of per-sentence metrics.
    for erroneous, gold in sentences:
        gold_tokens = gold.split()
        erroneous_tokens = erroneous.split()
        for func in correction_functions:
            corrected = func(erroneous)
            corr_tokens = corrected.split()
            # Align tokens using zip (assuming tokenization is consistent).
            true_positives = sum(1 for g, er, c in zip(gold_tokens, erroneous_tokens, corr_tokens) if g == c and g != er)
            false_negatives = sum(1 for g, er, c in zip(gold_tokens, erroneous_tokens, corr_tokens) if g != c and g == er)
            false_positives = sum(1 for g, er, c in zip(gold_tokens, erroneous_tokens, corr_tokens) if g != c and g != er)
            
            # Accuracy: fraction of tokens in the corrected sentence that are correct.
            accuracy = sum(1 for g, c in zip(gold_tokens, corr_tokens) if g == c) / len(gold_tokens)
            precision = (true_positives / (true_positives + false_positives)) if (true_positives + false_positives) > 0 else 0.0
            recall = (true_positives / (true_positives + false_negatives)) if (true_positives + false_negatives) > 0 else 0.0
            f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0.0

            # Store per-sentence metrics.
            metrics.setdefault(func.__name__, []).append(
                (accuracy, precision, recall, f1, true_positives, false_positives, false_negatives)
            )
    
    # Average metrics over all sentences.
    avg_metrics = {}
    for func, scores in metrics.items():
        avg_accuracy = sum(score[0] for score in scores) / len(scores)
        avg_precision = sum(score[1] for score in scores) / len(scores)
        avg_recall = sum(score[2] for score in scores) / len(scores)
        avg_f1 = sum(score[3] for score in scores) / len(scores)
        total_tp = sum(score[4] for score in scores)
        total_fp = sum(score[5] for score in scores)
        total_fn = sum(score[6] for score in scores)
        avg_metrics[func] = (avg_accuracy, avg_precision, avg_recall, avg_f1,
                             total_tp, total_fp, total_fn)
    
    return avg_metrics

print("\n=== Evaluation on Holbrook Dataset ===\n")
holbrook_metrics = evaluate_holbrook('holbrook-tagged.txt', 
                                        correction_functions=[
                                            norvig_correct_sentence, 
                                            ngram_correct_sentence,
                                            beam_search_sentence
                                            ],
                                        # num_samples=1000
                                        )
for func, score in holbrook_metrics.items():
    print(f"{func}: {score}")


=== Evaluation on Holbrook Dataset ===

norvig_correct_sentence: (0.7184713911500502, 0.15640287737394507, 0.18021109371262453, 0.13899407014260975, 373, 4840, 1340)
ngram_correct_sentence: (0.7440060209447952, 0.14394864457467987, 0.19674377155650577, 0.14043256551595992, 332, 4881, 818)
beam_search_sentence: (0.7433419808360842, 0.17310096989954568, 0.21874368349087447, 0.16543528273444696, 390, 4823, 904)


#### Useful resources (also included in the archive in moodle):

1. [Possible dataset with N-grams](https://www.ngrams.info/download_coca.asp)
2. [Damerau–Levenshtein distance](https://en.wikipedia.org/wiki/Damerau–Levenshtein_distance#:~:text=Informally%2C%20the%20Damerau–Levenshtein%20distance,one%20word%20into%20the%20other.)