In [1]:
import nltk
import string
import random
import math

import pandas as pd
import Levenshtein as lev

from nltk import sent_tokenize
from nltk import word_tokenize
from nltk.corpus import reuters
from collections import Counter
from datasets import load_metric
from tqdm import tqdm

_ = nltk.download("reuters")

[nltk_data] Downloading package reuters to /Users/antonal/nltk_data...
[nltk_data]   Package reuters is already up-to-date!


## Data Preparation

In [2]:
# Function to create vocabulary from text tokens
def create_vocab(text_tokens, min_count=10):
    # compute the count of each token and create the vocabulary from those whose count is 10 or more
    filtered_tokens_freq = nltk.FreqDist(text_tokens)
    created_vocabulary = [k for k, v in filtered_tokens_freq.items() if v >= min_count]
    return created_vocabulary

In [3]:
# Get the file ids of the documents in the training subset
training_file_ids = [
    file_id for file_id in reuters.fileids() if file_id.startswith("training/")
]
training_text = reuters.raw(training_file_ids)

In [4]:
# get the tokens for the vocabulary creation
tokens = word_tokenize(training_text)

# Create the vocabulary
vocabulary = create_vocab(tokens)

In [5]:
# Replace the tokens in sentences with *UNK*
sentences = sent_tokenize(training_text)

# Tokenize the sentences
tokenized_sentences = [word_tokenize(sentence) for sentence in sentences]

# Replace token which are not in the vocabulary with "*UNK*"
processed_sentences = [
    [token if token in vocabulary else "*UNK*" for token in tokenized_sentence]
    for tokenized_sentence in tokenized_sentences
]

## N-gram Language Model

In [6]:
# Initialize the unigram, bigram, and trigram counters
unigram_counter = Counter()
bigram_counter = Counter()
trigram_counter = Counter()

In [7]:
for sent in processed_sentences:
    # Update the unigram counter
    unigram_counter.update([(gram,) for gram in ["<s>"] + sent])

    # Update the bigram counter
    bigram_pad_sent = ["<s>"] + sent + ["<e>"]
    bigram_counter.update(
        [(gram1, gram2) for gram1, gram2 in zip(bigram_pad_sent, bigram_pad_sent[1:])]
    )

    # Update the trigram counter
    trigram_pad_sent = ["<s>"] * 2 + sent + ["<e>"] * 2
    trigram_counter.update(
        [
            (gram1, gram2, gram3)
            for gram1, gram2, gram3 in zip(
                trigram_pad_sent, trigram_pad_sent[1:], trigram_pad_sent[2:]
            )
        ]
    )

In [8]:
# The 5 most common unigrams
unigram_counter.most_common(5)

[(('*UNK*',), 93422),
 (('the',), 43182),
 ((',',), 39586),
 (('<s>',), 37700),
 (('.',), 37651)]

In [9]:
# The 5 most common bigrams
bigram_counter.most_common(5)

[(('.', '<e>'), 36457),
 (('*UNK*', '*UNK*'), 9364),
 (('<s>', '*UNK*'), 7194),
 (('<s>', 'The'), 6600),
 (('&', 'lt'), 6300)]

In [10]:
# The 5 most common trigrams
trigram_counter.most_common(5)

[(('.', '<e>', '<e>'), 36457),
 (('<s>', '<s>', '*UNK*'), 7194),
 (('<s>', '<s>', 'The'), 6600),
 (('&', 'lt', ';'), 6300),
 (('said', '.', '<e>'), 5924)]

In [11]:
def calculate_bigram_prob(
    bigram_vocabulary, bigram_counter, unigram_counter, alpha, first_word, second_word
):
    # Calculate vocab size
    bigram_vocab_size = len(bigram_vocabulary)

    # Bigram prob + laplace smoothing
    bigram_prob = (bigram_counter[(first_word, second_word)] + alpha) / (
        unigram_counter[(first_word,)] + alpha * bigram_vocab_size
    )

    # Calculate log probability
    bigram_log_prob = math.log2(bigram_prob)

    return bigram_prob, bigram_log_prob

In [12]:
def calculate_trigram_prob(
    vocabulary, trigram_counter, bigram_counter, alpha, word1, word2, word3
):
    # Calculate vocab size
    vocab_size = len(vocabulary)

    # Bigram prob + laplace smoothing
    trigram_prob = (trigram_counter[(word1, word2, word3)] + alpha) / (
        bigram_counter[
            (
                word1,
                word2,
            )
        ]
        + alpha * vocab_size
    )

    # Calculate log probability
    trigram_log_prob = math.log2(trigram_prob)

    return trigram_prob, trigram_log_prob

In [13]:
# Get the file ids of the documents in the testing subset
testing_file_ids = [
    file_id for file_id in reuters.fileids() if file_id.startswith("test/")
]
testing_text = reuters.raw(testing_file_ids[:100])
testing_sentences = sent_tokenize(testing_text)

tokenized_testing_sentences = [
    word_tokenize(sentence) for sentence in testing_sentences
]

# Replace token which are not in the vocabulary with "*UNK*"
processed_testing_sentences = [
    [token if token in vocabulary else "*UNK*" for token in tokenized_sentence]
    for tokenized_sentence in tokenized_testing_sentences
]

In [14]:
sum_log_prob = 0
bigram_cnt = 0
alpha = 1

for sent in processed_testing_sentences:
    sent = ["<s>"] + sent + ["<e>"]

    # Iterate over the bigrams of the sentence
    for idx in range(1, len(sent)):
        bigram_prob, bigram_log_prob = calculate_bigram_prob(
            vocabulary, bigram_counter, unigram_counter, alpha, sent[idx - 1], sent[idx]
        )

        sum_log_prob += bigram_log_prob
        bigram_cnt += 1

HC = -sum_log_prob / bigram_cnt
perpl = math.pow(2, HC)
print("Bigram:")
print("\tCross Entropy: {0:.3f}".format(HC))
print("\tPerplexity: {0:.3f}".format(perpl))

Bigram:
	Cross Entropy: 8.353
	Perplexity: 327.069


In [15]:
sum_log_prob = 0
trigram_cnt = 0
alpha = 1

for sent in processed_testing_sentences:
    sent = ["<s>"] + ["<s>"] + sent + ["<e>"] + ["<e>"]

    for idx in range(2, len(sent) - 1):
        trigram_prob, trigram_log_prob = calculate_trigram_prob(
            vocabulary,
            trigram_counter,
            bigram_counter,
            alpha,
            sent[idx - 2],
            sent[idx - 1],
            sent[idx],
        )
        sum_log_prob += trigram_log_prob
        trigram_cnt += 1

HC = -sum_log_prob / trigram_cnt
perpl = math.pow(2, HC)
print("Trigram:")
print("\tCross Entropy: {0:.3f}".format(HC))
print("\tPerplexity: {0:.3f}".format(perpl))

Trigram:
	Cross Entropy: 10.356
	Perplexity: 1310.347


## Noisy Text Generation

In [16]:
# method to create a new text with random errors
def wrong_text_creator(text, error_probability=0.05):
    result = []
    for sentence in text:
        wrong_sent = []
        for char in sentence:
            if char.isspace():
                wrong_sent.append(char)
            elif char.isalpha():
                if random.random() < error_probability:
                    wrong_sent.append(random.choice(string.ascii_letters))
                else:
                    wrong_sent.append(char)
            elif char.isnumeric():
                if random.random() < error_probability:
                    wrong_sent.append(random.choice(string.digits))
                else:
                    wrong_sent.append(char)
            elif char in string.punctuation:
                if random.random() < error_probability:
                    wrong_sent.append(random.choice(string.punctuation))
                else:
                    wrong_sent.append(char)
            else:
                wrong_sent.append(char)
        result.append("".join(wrong_sent))

    return result

In [17]:
# Store the train bigram probabilities for beam search decoder function
bigram_probs_dict = {}
alpha = 1
vocab_size = len(set(vocabulary))
for set_of_2 in list(bigram_counter.items()):
    bigram_prob = (set_of_2[1] + alpha) / (
        unigram_counter[(set_of_2[0][0],)] + alpha * vocab_size
    )
    bigram_probs_dict[set_of_2[0]] = bigram_prob

## Beam Search Decoder

In [18]:
# Beam search decoder
def beam_search_decoder(
    input_sentence, bigram_probabilities, vocabulary, max_depth, beam_size=2
):
    input_tokens = ["<s>"] + word_tokenize(input_sentence) + ["<e>"]
    lambda_1 = 0.5
    lambda_2 = 0.5

    beam = [([], 0)]
    for current_word in input_tokens[1:]:
        candidates = []
        for candidate, candidate_score in beam:
            for vocab_word in vocabulary:
                temp_candidate = "<s>" if not candidate else candidate[-1]
                bigram_prob = bigram_probabilities.get((temp_candidate, vocab_word), 0)
                distance = lev.distance(current_word, vocab_word)
                word_similarity = max(len(current_word), len(vocab_word)) - distance
                new_score = (
                    candidate_score
                    + lambda_1 * math.log2(1 + bigram_prob)
                    + lambda_2 * math.log2(1 + word_similarity)
                )
                new_candidate = candidate + [vocab_word]
                candidates.append((new_candidate, new_score))
        candidates.sort(key=lambda x: x[1], reverse=True)
        beam = candidates[:beam_size]
        if len(beam[0][0]) >= max_depth:
            break

    return " ".join(beam[0][0])

In [19]:
corrected_sentences = []
# create the faulty text
noisy_sentences = wrong_text_creator(testing_sentences)

for wrong_sentence in tqdm(noisy_sentences, desc="Correcting sentences"):
    corrected_sentence = beam_search_decoder(
        wrong_sentence, bigram_probs_dict, vocabulary, len(wrong_sentence)
    )
    corrected_sentences.append(corrected_sentence)

Correcting sentences: 100%|██████████| 619/619 [11:13<00:00,  1.09s/it]


In [20]:
# print the original, noisy and corrected sentences into a dataframe so that we can compare them
pd.DataFrame(
    {
        "Original": testing_sentences,
        "Noisy": noisy_sentences,
        "Corrected": corrected_sentences,
    }
)

Unnamed: 0,Original,Noisy,Corrected
0,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-JAPAN RI...,ASIAN EXPORTERS FEAR DAMAGE FROM U.S.-yAPAN RI...,AUSTRALIAN EXPORTERS FEBRUARY DAMAGE FROM U.S....
1,They told Reuter correspondents in Asian capit...,They told Reuter correspondents in Asian cBQit...,They told Reuters corresponding in Australian ...
2,But some exporters said that while the conflic...,But some exporters Saiw that while the cenflic...,But some exporters Statistics that while the c...
3,The U.S. Has said it will impose 300 mln dlrs ...,Tce U.S. Has said it will impose 300 mln dlrO ...,Thatcher U.S. Has said it will impose 300 mln ...
4,Unofficial Japanese estimates put the impact o...,USofficiEl Japanese estimates put the impaIt o...,official Japanese estimates put the impact on ...
...,...,...,...
614,A dividend of 11 marks would\n be proposed fo...,A Nividend of 11 marks would\n be proposed fo...,A dividend of 1.1 markets would be proposed fo...
615,Share analysts said they saw supervisory board...,Share analysts said they saw superJisory board...,Share analysts said they saw superior board ap...
616,"""Anything else would be more than a surprise,""...","""Anything else would be more than a surpFise,""...",`` anything else would be more than a surprise...
617,Company sources said VW would have to dig into...,lompany sources sOid VW woQld haLe to dig iVto...,Company sources said Venezuela would have to d...


### Evaluation

In [21]:
# Character error rate metric
cer = load_metric("cer")
# Word error rate metric
wer = load_metric("wer")

truth_sentences = testing_sentences

# Compute the WER and CER scores
wer_score = wer.compute(predictions=corrected_sentences, references=truth_sentences)
cer_score = cer.compute(predictions=corrected_sentences, references=truth_sentences)

print("WER score: {0:.3f}".format(wer_score))
print("CER score: {0:.3f}".format(cer_score))

  cer = load_metric("cer")


Downloading builder script:   0%|          | 0.00/2.16k [00:00<?, ?B/s]

ValueError: Loading cer requires you to execute the dataset script in that repo on your local machine. Make sure you have read the code there to avoid malicious use, then set the option `trust_remote_code=True` to remove this error.