# Context-sensitive Spelling Correction

The goal of the assignment is to implement context-sensitive spelling correction. The input of the code will be a set of text lines and the output will be the same lines with spelling mistakes fixed.

Submit the solution of the assignment to Moodle as a link to your GitHub repository containing this notebook.

Useful links:
- [Norvig's solution](https://norvig.com/spell-correct.html)
- [Norvig's dataset](https://norvig.com/big.txt)
- [Ngrams data](https://www.ngrams.info/download_coca.asp)

Grading:
- 60 points - Implement spelling correction
- 20 points - Justify your decisions
- 20 points - Evaluate on a test set


## Implement context-sensitive spelling correction

Your task is to implement context-sensitive spelling corrector using N-gram language model. The idea is to compute conditional probabilities of possible correction options. For example, the phrase "dking sport" should be fixed as "doing sport" not "dying sport", while "dking species" -- as "dying species".

The best way to start is to analyze [Norvig's solution](https://norvig.com/spell-correct.html) and [N-gram Language Models](https://web.stanford.edu/~jurafsky/slp3/3.pdf).

You may also want to implement:
- spell-checking for a concrete language - Russian, Tatar, etc. - any one you know, such that the solution accounts for language specifics,
- some recent (or not very recent) paper on this topic,
- solution which takes into account keyboard layout and associated misspellings,
- efficiency improvement to make the solution faster,
- any other idea of yours to improve the Norvig’s solution.

IMPORTANT:  
Your project should not be a mere code copy-paste from somewhere. You must provide:
- Your implementation
- Analysis of why the implemented approach is suggested
- Improvements of the original approach that you have chosen to implement

In [1]:
from nltk.tokenize import word_tokenize
from nltk.tokenize.treebank import TreebankWordDetokenizer
from tqdm.notebook import tqdm
from collections import Counter, defaultdict
import pandas as pd
import numpy as np
import string
import random

punct = string.punctuation + '... '

N = 5

data = pd.read_csv('fivegrams.txt', sep='\t', names=['freq', '1', '2', '3', '4', '5'])
data.head()

Unnamed: 0,freq,1,2,3,4,5
0,16,a,babe,in,the,woods
1,6,a,baby,at,her,breast
2,9,a,baby,brother,or,sister
3,6,a,baby,crying,in,the
4,6,a,baby,girl,was,born


In [2]:
n_grams = [Counter() for _ in range(N + 1)]

# saving all 1-5 grams

for _, row in tqdm(data.iterrows(), total=len(data)):
    five_gram = (row['1'], row['2'], row['3'], row['4'], row['5'])

    for n in range(N + 1):
        for j in range(N - n + 1):
            n_grams[n][five_gram[j:j+n]] += row['freq']

  0%|          | 0/1044268 [00:00<?, ?it/s]

In [3]:
n_grams[5].most_common(10)

[(('at', 'the', 'end', 'of', 'the'), 13588),
 (('i', 'do', "n't", 'want', 'to'), 12744),
 (('in', 'the', 'middle', 'of', 'the'), 9124),
 (('i', 'do', "n't", 'know', 'what'), 8076),
 (('you', 'do', "n't", 'have', 'to'), 7186),
 (('i', 'do', "n't", 'know', 'if'), 6455),
 (('for', 'the', 'first', 'time', 'in'), 6006),
 (('i', 'do', "n't", 'think', 'it'), 5559),
 (('there', 'are', 'a', 'lot', 'of'), 5523),
 (('i', 'do', "n't", 'think', 'that'), 5466)]

In [4]:
# getting all neigbouring keys for each key

keyboard_rows = ['qwertyuiop', 'asdfghjkl', 'zxcvbnm']

keyboard_neighbours = defaultdict(set)

for row_idx, row in enumerate(keyboard_rows):
    for col_idx, key in enumerate(row):
        neighbours = set()
        for i in range(max(0, row_idx - 1), min(len(keyboard_rows), row_idx + 2)):
            for j in range(max(0, col_idx - 1), min(len(keyboard_rows[i]), col_idx + 2)):
                if keyboard_rows[i][j] == key:
                    continue
                neighbours.add(keyboard_rows[i][j])
        keyboard_neighbours[key] = neighbours

keyboard_neighbours

defaultdict(set,
            {'q': {'a', 's', 'w'},
             'w': {'a', 'd', 'e', 'q', 's'},
             'e': {'d', 'f', 'r', 's', 'w'},
             'r': {'d', 'e', 'f', 'g', 't'},
             't': {'f', 'g', 'h', 'r', 'y'},
             'y': {'g', 'h', 'j', 't', 'u'},
             'u': {'h', 'i', 'j', 'k', 'y'},
             'i': {'j', 'k', 'l', 'o', 'u'},
             'o': {'i', 'k', 'l', 'p'},
             'p': {'l', 'o'},
             'a': {'q', 's', 'w', 'x', 'z'},
             's': {'a', 'c', 'd', 'e', 'q', 'w', 'x', 'z'},
             'd': {'c', 'e', 'f', 'r', 's', 'v', 'w', 'x'},
             'f': {'b', 'c', 'd', 'e', 'g', 'r', 't', 'v'},
             'g': {'b', 'f', 'h', 'n', 'r', 't', 'v', 'y'},
             'h': {'b', 'g', 'j', 'm', 'n', 't', 'u', 'y'},
             'j': {'h', 'i', 'k', 'm', 'n', 'u', 'y'},
             'k': {'i', 'j', 'l', 'm', 'o', 'u'},
             'l': {'i', 'k', 'o', 'p'},
             'z': {'a', 's', 'x'},
             'x': {'a', 'c', 'd', 's',

In [5]:
def edit_distance_1(word, weights):
    ''' Generate all possible combinations of edit distance 1 for the given word and with given weights '''
    
    letters = "abcdefghijklmnopqrstuvwxyz'"
    splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
    deletes = [(L + R[1:], weights[0]) for L, R in splits if R]
    transposes = [(L + R[1] + R[0] + R[2:], weights[0]) for L, R in splits if len(R) > 1]
    replaces = [(L + c + R[1:], weights[1] if c in keyboard_neighbours[R[0]] else weights[0]) for L, R in splits if R for c in letters]
    inserts = [(L + c + R, weights[1] if (L and c in keyboard_neighbours[L[-1]]) or (R and c in keyboard_neighbours[R[0]]) else weights[0]) for L, R in splits for c in letters]
    return set(deletes + transposes + replaces + inserts + [(word, weights[2])])

In [6]:
def interpolated_prob(word, context, lambdas):
    ''' Calculate linear interpolation for 1-5 grams of context and word with given lambdas '''

    total_prob = 0
    
    for n in range(min(N, len(context) + 1)):
        cur_context = context[-n:] if n > 0 else ()
        context_freq = n_grams[n][cur_context]
        prob = n_grams[n+1][(*cur_context, word)] / context_freq if context_freq > 0 else 0
        total_prob += lambdas[n] * prob
    
    return total_prob

In [7]:
def word_candidates(word, context, weights, lambdas):
    ''' For given word, context and hyperparameters get candidate words with their unnormalized probabilities '''

    candidates = {}
    for x in edit_distance_1(word, weights):
        if (x[0],) in n_grams[1].keys():
            if x[0] in candidates.keys():
                candidates[x[0]] = max(candidates[x[0]], x[1])
            else:
                candidates[x[0]] = x[1]
    return [(x[0], x[1] * interpolated_prob(x[0], context, lambdas)) for x in candidates.items()]

In [9]:
def beam_search(tokens, weights, lambdas, beam_width=5):
    ''' Beam search implementation for given hyperparameters '''

    sequences = [[[], 1.0]]

    for i in range(len(tokens)):
        if tokens[i] in punct:
            continue

        next_sequences = []
        for seq, score in sequences:
            context = tuple(seq[-(N - 1):])
            next_word_candidates = word_candidates(tokens[i], context, weights, lambdas)
            if len(next_word_candidates) == 0:
                next_word_candidates = [(tokens[i], 1)]
            for next_word, next_word_prob in next_word_candidates:
                new_seq = seq + [next_word]
                new_score = score * next_word_prob
                next_sequences.append([new_seq, new_score])
        next_sequences.sort(key=lambda x: x[1], reverse=True)
        sequences = next_sequences[:beam_width]
    return sequences

In [11]:
twd = TreebankWordDetokenizer()

def correct_spelling(text, weights, lambdas):
    ''' Correct spelling in the given text '''
    
    tokens = word_tokenize(text.lower())
    best_result = beam_search(tokens, weights, lambdas)[0][0]
    return twd.detokenize(best_result)

In [37]:
# example of correction
correct_spelling("Teh quirk brawn fix jumsp ocwr hte laizy doge", [1, 1.5, 2], [0.1, 0.15, 0.2, 0.25, 0.3])

'the quick brown fox jump ocwr the lazy dog'

In [32]:
# open data for dev set and test set 
test_data = open('test_sentences.txt', 'r', encoding='utf-8').read().replace('’', '\'').replace('…', '').split('\n')
test_data[:5]

['The quick brown fox jumps over the lazy dog.',
 'My Mum tries to be cool by saying that she likes all the same things that I do.',
 'A purple pig and a green donkey flew a kite in the middle of the night and ended up sunburnt.',
 'Last Friday I saw a spotted striped blue worm shake hands with a legless lizard.',
 "A song can make or ruin a person's day if they let it get to them."]

In [14]:
# split into dev set and test set
random.seed(1337)
dev_set_len = int(0.2 * len(test_data))
random.shuffle(test_data)
dev_set = test_data[:dev_set_len]
test_set = test_data[dev_set_len:]

In [15]:
cum_prob_swap = 0.15
cum_prob_neighbour = cum_prob_swap + 0.35
cum_prob_skipped = cum_prob_neighbour + 0.2
cum_prob_inserted_neighbour = cum_prob_skipped + 0.2
cum_prob_inserted_random = cum_prob_inserted_neighbour + 0.05

letters = "abcdefghijklmnopqrstuvwxyz"

def misspell_word(word, misspell_prob=0.1):
    ''' Misspell every letter of the given word with the given probability '''

    new_word = ''
    skip_next = False
    for i, letter in enumerate(word):
        if skip_next:
            skip_next = False
            continue

        if letter.isalpha() and random.random() < misspell_prob:
            misspell_type_roll = random.random()
            if misspell_type_roll < cum_prob_swap and i < len(word) - 1:
                new_word += word[i + 1] + word[i]
                skip_next = True
            elif misspell_type_roll < cum_prob_neighbour:
                new_word += random.choice(list(keyboard_neighbours[letter]))
            elif misspell_type_roll < cum_prob_skipped:
                continue
            elif misspell_type_roll < cum_prob_inserted_neighbour:
                new_word += letter + random.choice(list(keyboard_neighbours[letter]))
            elif misspell_type_roll < cum_prob_inserted_random:
                new_word += letter + random.choice(letters)
            else:
                new_word += random.choice(letters)
        else:
            new_word += letter
    return new_word

def misspell_sentence(sent, misspell_prob=0.1):
    ''' Misspell the whole sentence '''

    tokens = [x for x in word_tokenize(sent.lower()) if x not in punct]
    while True:
        new_tokens = []
        for token in tokens:
            new_tokens.append(misspell_word(token, misspell_prob))
        if new_tokens != tokens:
            break
    return twd.detokenize(new_tokens)
    

In [16]:
def calculate_metrics(true_sent, miss_sent, pred_sent):
    ''' Calculate accuracy, precision, recall and f1 score for the given correction of the sentence '''

    true_tokens = [x for x in word_tokenize(true_sent.lower()) if x not in punct]
    miss_tokens = word_tokenize(miss_sent.lower())
    pred_tokens = word_tokenize(pred_sent.lower())

    tp = sum([1 for true_token, miss_token, pred_token in zip(true_tokens, miss_tokens, pred_tokens) if true_token != miss_token and true_token == pred_token])
    fn = sum([1 for true_token, miss_token, pred_token in zip(true_tokens, miss_tokens, pred_tokens) if true_token != miss_token and true_token != pred_token])
    fp = sum([1 for true_token, miss_token, pred_token in zip(true_tokens, miss_tokens, pred_tokens) if true_token == miss_token and true_token != pred_token])

    accuracy = sum([1 for true_token, pred_token in zip(true_tokens, pred_tokens) if true_token == pred_token]) / len(true_tokens)
    precision = tp / (tp + fp) if tp > 0 else 0
    recall = tp / (tp + fn) if tp > 0 else 0
    f1_score = tp / (tp + 0.5 * (fp + fn))

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1_score,
        }

Genetic Algorithm for parameter tuning

In [17]:
w_range = np.arange(1, 2, 0.05)
lambda_range = np.arange(0.1, 1, 0.05)

def generate_gene():
    ''' Generate random gene of hyperparameters in form [default_weight, neigbouring_key_weight, self_weight (don't change the word), 5x lambdas] '''

    ws =  [round(random.choice(w_range), 2) for _ in range(3)]
    ls = [round(random.choice(lambda_range), 2) for _ in range(5)]
    return ws + ls

def mutate(gene):
    ''' Slightly modify given gene '''

    new_gene = []
    for  x in gene:
        prob = random.random()
        mult = 0
        if prob < 0.2:
            mult = -1
        elif prob > 0.8:
            mult = 1
        new_gene.append(round(x + mult * 0.05, 2))
    return new_gene

def merge(gene1, gene2):
    ''' Cross 2 given genes and return created children '''
    
    x = random.randint(1, len(gene1) - 1)
    child1 = gene1[:x] + gene2[x:]
    child2 = gene2[:x] + gene1[x:]
    return [child1, child2]

In [18]:
def objective_function(gene):
    ''' Function to be optimized by genetic algorithm (f1 score) '''
    
    cur_f1 = 0
    random.seed(1337)
    for test_sent in dev_set:
        misspelled = misspell_sentence(test_sent)
        corrected = correct_spelling(misspelled, gene[:3], gene[3:])
        cur_f1 += calculate_metrics(test_sent, misspelled, corrected)['f1']

    return cur_f1 / len(dev_set)

In [None]:
NUM_ITERATIONS = 20
POPULATION_SIZE = 16
BEST_SIZE = 8

# initial population
population = sorted([generate_gene() for _ in range(POPULATION_SIZE)], key=objective_function, reverse=True)

progress_bar = tqdm(range(NUM_ITERATIONS))

for i in progress_bar:
    # each iteration take best genes from the population and merge and mutate them, also add some new random genes for divesity
    for j in range(0, BEST_SIZE, 2):
        child1, child2 = merge(population[j], population[j + 1])
        population.append(mutate(child1))
        population.append(mutate(child2))
        population.append(generate_gene())

    population = sorted(population, key=objective_function, reverse=True)[:POPULATION_SIZE]

    progress_bar.set_description(f'Best F1: {objective_function(population[0])}')


In [None]:
best_weights = population[0][:3]
best_lambdas = population[0][3:]

In [20]:
# saved best gene, so I don't have to run genetic algorithm every time
best_gene = [0.05, 0.1, 5.5, 0.2, 0.35, 0.4, 0.25, 0.6]
best_weights = best_gene[:3]
best_lambdas = best_gene[3:]

## Justify your decisions

Write down justificaitons for your implementation choices. For example, these choices could be:
- Which ngram dataset to use
- Which weights to assign for edit1, edit2 or absent words probabilities
- Beam search parameters
- etc.

- For my Language Model I used [five-grams dataset](https://www.ngrams.info/download_coca.asp), however not only five-grams were used. 
- I used linear interpolation, so summing probabilities of unigram, bigram, ..., five-gram each weighted by some hyperparameter λ. This is useful, because sometimes less context is better.
- As for edit distances, only edit distance of 1 was considered, because using edit distance 2 requires much more time to process.
- For edit distance 1 keyboard layout was taken into account. So for example word "tewt" is more likly to be "test" than "text", because the letter "s" is close to the "w".
- If the original unedited word is not in the vocabulary, a weight is assigned to it. For example, if we get a word "google" it should not be corrected, even though it is not present in the vocabulary. 
- For sampling method beam search was used, with beam width of 5, beacuse five-grams were used.
- All hyperparameters (weights for each edit and lambdas for liner interpolation) were tuned on dev set using genetic algorithm, using f1 score as an objective function.
- [This dataset](https://www.kaggle.com/datasets/nikitricky/random-english-sentences/data) was misspelled and used for dev set and test set.

## Evaluate on a test set

Your task is to generate a test set and evaluate your work. You may vary the noise probability to generate different datasets with varying compexity. Compare your solution to the Norvig's corrector, and report the accuracies.

In [29]:
# copy pasted Norvig's solution for performance comparison

import re

def words(text): return re.findall(r'\w+', text.lower())

WORDS = Counter(words(open('big.txt').read()))

def P(word, N=sum(WORDS.values())): 
    "Probability of `word`."
    return WORDS[word] / N

def norvig_correction(word): 
    "Most probable spelling correction for word."
    return max(candidates(word), key=P)

def candidates(word): 
    "Generate possible spelling corrections for word."
    return (known([word]) or known(edits1(word)) or known(edits2(word)) or [word])

def known(words): 
    "The subset of `words` that appear in the dictionary of WORDS."
    return set(w for w in words if w in WORDS)

def edits1(word):
    "All edits that are one edit away from `word`."
    letters    = 'abcdefghijklmnopqrstuvwxyz'
    splits     = [(word[:i], word[i:])    for i in range(len(word) + 1)]
    deletes    = [L + R[1:]               for L, R in splits if R]
    transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
    replaces   = [L + c + R[1:]           for L, R in splits if R for c in letters]
    inserts    = [L + c + R               for L, R in splits for c in letters]
    return set(deletes + transposes + replaces + inserts)

def edits2(word): 
    "All edits that are two edits away from `word`."
    return (e2 for e1 in edits1(word) for e2 in edits1(e1))

def norvig_correct_sentence(text):
    tokens = word_tokenize(text)
    corrected_tokens = [norvig_correction(x) for x in tokens]
    return twd.detokenize(corrected_tokens)

In [42]:
# testing language model performance on text with different misspell density
misspell_probabilities = [0.05, 0.1, 0.2, 0.5, 0.9]

for misspell_prob in misspell_probabilities:
    my_metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0}
    norvig_metrics = {'accuracy': 0.0, 'precision': 0.0, 'recall': 0.0, 'f1': 0.0}

    random.seed(1337)

    for test_sent in tqdm(test_set):
        misspelled = misspell_sentence(test_sent, misspell_prob=misspell_prob)

        corrected = correct_spelling(misspelled, best_weights, best_lambdas)
        my_metrics = {k: my_metrics[k] + v for k, v in calculate_metrics(test_sent, misspelled, corrected).items()}

        norvig_corrected = norvig_correct_sentence(misspelled)
        norvig_metrics = {k: norvig_metrics[k] + v for k, v in calculate_metrics(test_sent, misspelled, norvig_corrected).items()}

    my_metrics = {k: v / len(test_set) for k, v in my_metrics.items()}
    norvig_metrics = {k: v / len(test_set) for k, v in norvig_metrics.items()}

    print(f'Misspelling Probability: {misspell_prob}\nMy metrics:       {[f"{k}: {v:.5f}" for k, v in my_metrics.items()]}\nNorvig\'s metrics: {[f"{k}: {v:.5f}" for k, v in norvig_metrics.items()]}')

  0%|          | 0/580 [00:00<?, ?it/s]

Misspelling Probability: 0.05
My metrics:       ['accuracy: 0.88609', 'precision: 0.74256', 'recall: 0.73940', 'f1: 0.71013']
Norvig's metrics: ['accuracy: 0.87026', 'precision: 0.70586', 'recall: 0.63633', 'f1: 0.63557']


  0%|          | 0/580 [00:00<?, ?it/s]

Misspelling Probability: 0.1
My metrics:       ['accuracy: 0.82201', 'precision: 0.80125', 'recall: 0.64749', 'f1: 0.68672']
Norvig's metrics: ['accuracy: 0.81258', 'precision: 0.79899', 'recall: 0.59108', 'f1: 0.64914']


  0%|          | 0/580 [00:00<?, ?it/s]

Misspelling Probability: 0.2
My metrics:       ['accuracy: 0.67882', 'precision: 0.88169', 'recall: 0.51326', 'f1: 0.62401']
Norvig's metrics: ['accuracy: 0.67784', 'precision: 0.89506', 'recall: 0.49940', 'f1: 0.61349']


  0%|          | 0/580 [00:00<?, ?it/s]

Misspelling Probability: 0.5
My metrics:       ['accuracy: 0.30551', 'precision: 0.86347', 'recall: 0.23515', 'f1: 0.35514']
Norvig's metrics: ['accuracy: 0.31745', 'precision: 0.86894', 'recall: 0.24615', 'f1: 0.36802']


  0%|          | 0/580 [00:00<?, ?it/s]

Misspelling Probability: 0.9
My metrics:       ['accuracy: 0.06589', 'precision: 0.48534', 'recall: 0.05993', 'f1: 0.10443']
Norvig's metrics: ['accuracy: 0.06111', 'precision: 0.47198', 'recall: 0.05396', 'f1: 0.09478']


It can be seen that my solution performs nearly identical to Norvig's solution, even though I only use words of edit distance 1. I believe that means with more computational power and edit distance of 2, my solution would perform much better than simplistic solution of Norvig.