EXERCISE 3

In [1]:
!pip install nltk



In [2]:
import nltk
# nltk.download('reuters')
# nltk.download('punkt')
from nltk.corpus import reuters
from nltk.probability import FreqDist
from sklearn.model_selection import train_test_split
from collections import Counter
from nltk.util import ngrams
from nltk.tokenize import sent_tokenize
from nltk.tokenize import word_tokenize
import math
from more_itertools import windowed




In [3]:
# Load the 'reuters' corpus
sentences = reuters.sents()

In [4]:
# Splitting data into Training, Development and Test set
train_sents, test_sents = train_test_split(reuters.sents(), test_size=0.3, random_state=42)
dev_sents, test_sents = train_test_split(test_sents, test_size=0.5, random_state=42)


In [5]:
print(f'Number of sentences in train set: {len(train_sents)}')

Number of sentences in train set: 38301


In [6]:
# Transform the train sentences into words
train_words = [word for sentence in train_sents for word in sentence]
freq_dist_train = FreqDist(train_words)

In [7]:
cleaned_train_sentences = []
for sentence in train_sents:
    cleaned_train_sentence = [word if freq_dist_train[word] > 10 else '<UNK>' for word in sentence]
    cleaned_train_sentences.append(cleaned_train_sentence)

Now lets Build our model

In [8]:
from collections import Counter
from nltk.util import ngrams
from pprint import pprint

unigram_counter = Counter()
bigram_counter = Counter()
trigram_counter = Counter()

for sent in cleaned_train_sentences:
    unigram_counter.update([gram for gram in ngrams(sent, 1, pad_left=True, pad_right=True,
                                                   left_pad_symbol='<s>',right_pad_symbol='<e>') ])
    bigram_counter.update([gram for gram in ngrams(sent, 2, pad_left=True, pad_right=True,
                                                   left_pad_symbol='<s>',right_pad_symbol='<e>') ])
    trigram_counter.update([gram for gram in ngrams(sent, 3, pad_left=True, pad_right=True,
                                                   left_pad_symbol='<s>',right_pad_symbol='<e>') ])
# pprint(unigram_counter.most_common(10))
pprint(bigram_counter.most_common(10))
print('Most common trigrams')
pprint(trigram_counter.most_common(10))

[(('.', '<e>'), 34142),
 (('<s>', '<UNK>'), 8218),
 (('<UNK>', '<UNK>'), 7971),
 ((',', '000'), 7220),
 (("'", 's'), 6427),
 (('<s>', 'The'), 6167),
 (('lt', ';'), 6057),
 (('&', 'lt'), 6055),
 (('said', '.'), 5581),
 (('<UNK>', ','), 5060)]
Most common trigrams
[(('.', '<e>', '<e>'), 34142),
 (('<s>', '<s>', '<UNK>'), 8218),
 (('<s>', '<s>', 'The'), 6167),
 (('&', 'lt', ';'), 6054),
 (('said', '.', '<e>'), 5580),
 (('lt', ';', '<UNK>'), 4843),
 (('U', '.', 'S'), 3977),
 (('.', 'S', '.'), 3726),
 ((';', '<UNK>', '>'), 3027),
 (('<s>', '<s>', '"'), 2528)]


In [9]:
# Build the vocab
vocab = [word[0] for word in unigram_counter]
print(f'Number of tokens in train set: {len(vocab)}')

Number of tokens in train set: 7109


In [22]:
# Transform the test sentences into words
test_words = [word for sentence in test_sents for word in sentence]
freq_dist_test = FreqDist(test_words)

# Replace rare words or Out-of-Vocabulary words in test set
cleaned_test_sentences = []
for sentence in test_sents:
    # cleaned_test_sentence = [word if word in vocab else '<UNK>' for word in sentence]
    cleaned_test_sentence = ['<UNK>' if freq_dist_test[word] <= 10 or word not in train_words else word for word in sentence]
    cleaned_test_sentences.append(cleaned_test_sentence)

print(cleaned_test_sentences[:3])

[['The', 'commission', 'is', 'expected', 'to', '<UNK>', 'the', '<UNK>', 'at', 'a', 'meeting', 'tomorrow', '.'], ['"', 'The', 'United', 'States', 'and', 'the', 'six', 'major', 'industrial', 'countries', 'are', 'fully', 'committed', 'to', '<UNK>', 'our', '<UNK>', 'in', 'these', 'agreements', ',"', 'Baker', 'told', 'the', 'meetings', '.'], ['<UNK>', '<UNK>', 'said', 'in', 'January', 'it', 'was', 'seeking', 'bids', 'for', 'the', 'property', '.']]


In [23]:
def calculate_ngram_probability(ngram_counter, ngram_minus_one_counter, ngram, alpha, vocab_size):
    """
    Calculate bigram probability with Laplace smoothing
    :param ngram_counter: Counter which the key is a tuple of ngram and value its frequency
    :param ngram_minus_one_counter: Counter which the key is a tuple of n-1gram and value its frequency
    :param ngram: tuple
    :param alpha: float hyperparameter for Laplace smoothing
    :param vocab_size: int value which defines the whole size of the corpus
    :return: float probability of the ngram inside the corpus
    """
    ngram_count = ngram_counter[ngram]
    context = ngram[:-1]
    ngram_minus_one_count = ngram_minus_one_counter[context]
    ngram_prob = (ngram_count + alpha) / (ngram_minus_one_count + (alpha * vocab_size))
    # Convert to log probability
    ngram_prob = math.log2(ngram_prob)
    return ngram_prob

In [52]:
def calc_kneser_ney_proba(ngram_counter, ngram_minus_one_counter, continuation_counts, total_continuations, ngram, delta):
    """
    Calculate ngram probability with simplified Kneser-Ney smoothing for bigrams or trigrams
    :param ngram_counter: Counter for ngrams (bigrams or trigrams)
    :param ngram_minus_one_counter: Counter for n-1 grams
    :param continuation_counts: Counter for continuation counts
    :param total_continuations: Total number of unique continuations
    :param ngram: tuple representing the ngram (bigram or trigram)
    :param delta: discount value
    :return: float probability of the ngram
    """
    ngram_count = ngram_counter[ngram]
    context = ngram[:-1]
    ngram_minus_one_count = ngram_minus_one_counter[context]

    adjusted_count = max(ngram_count - delta, 0)
    epsilon = 1e-10

    # For bigrams, use the second token for continuation, for trigrams use the third token
    continuation_token = ngram[-1]

    # Calculate our interpolation weight
    continuation_prob = continuation_counts[continuation_token] / total_continuations

    alpha_weight = delta * prefixes_counter[(context)] / (ngram_minus_one_count + epsilon)
    # Kneser-Ney probability
    kn_probability = (adjusted_count+epsilon) / (ngram_minus_one_count + epsilon) + alpha_weight * continuation_prob
    return math.log2(kn_probability)


In [44]:
prefixes_counter[('to','be')]
# len(set([ng for ng in trigram_counter if ng[:-1] == ('to','be')]))

0

In [41]:
to be continued
to be not

SyntaxError: invalid syntax (<ipython-input-41-dd2745a0b168>, line 1)

PERPLEXITY

In [53]:
continuation_counts = Counter([bigram[1] for bigram in bigram_counter])
total_continuations = len(continuation_counts)

# Convert list of n-grams to a list of tuples
ngram_tuples = [tuple(ng) for ng in bigram_counter]

# Create a Counter for the prefixes
prefixes_counter = Counter(ng[:-1] for ng in set(ngram_tuples))

num_tokens = sum(len(sent) + 1 for sent in cleaned_test_sentences)

In [54]:
# print(len(trigram_counter))

In [55]:
from tqdm import tqdm

# Calculate continuation counts
continuation_counts = Counter([bigram[1] for bigram in bigram_counter])
total_continuations = len(continuation_counts)
print(len(continuation_counts))
total_log_proba_bigram_kn = 0.0
delta = 0.5
with tqdm(total=256000) as pbar:  # Check our time and iters remaining!
    for sent in cleaned_test_sentences:
        padded_sent = ['<s>'] + sent + ['<e>']

        for first_token, second_token in windowed(padded_sent, 2):
            if first_token == '<s>': # Avoid calculating that because unigram counter does not does not have counts for <s>
                pass
            else:
                bigram = (first_token, second_token)
                bigram_prob = calc_kneser_ney_proba(bigram_counter, unigram_counter, continuation_counts, total_continuations, bigram, delta)
                total_log_proba_bigram_kn += bigram_prob
                if bigram_prob>0:
                    print(bigram_prob)
                pbar.update(1)  # Update the progress bar

cross_entropy_bigram_kn = - total_log_proba_bigram_kn / num_tokens
print(f"The total Cross-Entropy of bigram model via Kneser-Ney smoothing for our Test set is: {cross_entropy_bigram_kn: .3f}")

# Calculation of the perplexity of bigram model for the test set via Kneser-Ney smoothing
bigram_perplexity_kn = 2 ** (cross_entropy_bigram_kn)
print(f"Perplexity of bigram model for Test Set: {bigram_perplexity_kn:.3f}")

7110


256664it [00:01, 237519.52it/s]                            

The total Cross-Entropy of bigram model via Kneser-Ney smoothing for our Test set is:  4.841
Perplexity of bigram model for Test Set: 28.654





In [56]:
# continuation_counts = Counter([bigram[1] for bigram in bigram_counter])
# total_continuations = len(continuation_counts)

# Convert list of n-grams to a list of tuples
ngram_tuples = [tuple(ng) for ng in trigram_counter]

# Create a Counter for the prefixes
prefixes_counter = Counter(ng[:-1] for ng in ngram_tuples)

In [50]:
print(len(ngram_tuples))

578370


In [51]:
print(len(set(ngram_tuples)))

578370


In [None]:
[ng[:-1] for ng in ngram_tuples]

In [57]:
# Calculation of total tokens for test set, including only 'end' token for each sentence
num_tokens = sum(len(sent) + 1 for sent in cleaned_test_sentences)
# Calculate continuation counts
continuation_counts_tri = Counter([trigram[2] for trigram in trigram_counter])
total_continuations_tri = len(continuation_counts_tri)
print(continuation_counts_tri)
total_log_proba_trigram_kn = 0.0
delta = 0.75

with tqdm(total=200000) as pbar:  # Check our time and iters remaining!
    for sent in cleaned_test_sentences:
        padded_sent = ['<s>'] + ['<s>'] + sent + ['<e>']

        for first_token, second_token, third_token in windowed(padded_sent, 3):
            if first_token == '<s>' and second_token == '<s>': # Avoid calculating that because bigram counter does not have counts for <s>, <s>
                pass
            else:
                trigram = (first_token, second_token, third_token)
                trigram_prob = calc_kneser_ney_proba(trigram_counter, bigram_counter, continuation_counts_tri, total_continuations_tri,
                                                     trigram, delta)
                total_log_proba_trigram_kn += trigram_prob
                pbar.update(1)  # Update the progress bar

cross_entropy_trigram_kn = - total_log_proba_trigram_kn / num_tokens
print(f"The total Cross-Entropy of trigram model for our Test set is: {cross_entropy_trigram_kn: .3f}")

# Calculation of the perplexity of bigram model for the test set
trigram_perplexity_kn = 2 ** (cross_entropy_trigram_kn)
print(f"Perplexity of trigram model for Test Set: {trigram_perplexity_kn:.3f}")



256664it [00:01, 200216.46it/s]

The total Cross-Entropy of trigram model for our Test set is:  2.473
Perplexity of trigram model for Test Set: 5.553





In [101]:
bigram_counter[('application', 'at')]

0

# BIGRAM GENERATE NEXT WORD GIVEN SEQUENCE

In [58]:
def generate_candidates(state, ngram_counter, model):
    """
    Given the state calculate the next possible words
    - state: The current word sequence
    - ngram_counter: Counter which the key is a tuple of n-1gram and value its frequency

    Returns:
    - Next state
    """
    # if ngram_counter = trigram_counter
    ngram_width = 1
    if model == 'trigram':
      ngram_width = 2
    prev_words = tuple(state[-ngram_width:])

    # Find candidates words
    next_words = [prev_words_tuple[-1] for (prev_words_tuple) in ngram_counter if prev_words == tuple(prev_words_tuple[:-1])]

    return [state + [next_word] for next_word in next_words]



In [59]:
def score(state, vocab_size, alpha, ngram_counter, ngram_minus_one_counter,length, model='trigram', dist=0, l1=1, l2=0, calculate_ngram_probability_fn=calculate_ngram_probability):
    """
    Calculate the log probability  of the word sequence

    Parameters:
    - state: The current word sequence.
    - vocab_size: The size of the vocabulary
    - alpha: float hyperparameter for Laplace smoothing
    - ngram_counter:
    - ngram_minus_one_counter
    - dist: int distance between words. Only for spell correcting.
    - l1: float hyperparameter for weighting the model. Deffault=1
    - l2: float hyperparameter for weigthing the distance. Deffault=0
    - calculate_ngram_probability_fn:

    Returns:
    - Log Probability
    """
    ngram_width = 2
    if model == 'trigram':
      ngram_width = 3
    probability = 0
    #for i in range(ngram_width, len(state)):
    prev_words = tuple(state[-ngram_width:])

    probability += calc_kneser_ney_proba(ngram_counter, ngram_minus_one_counter, continuation_counts, total_continuations, (prev_words), 0.5) + l2 * math.log2(1 / (dist + 1))
    # probability += l1 * calculate_ngram_probability_fn(ngram_counter, ngram_minus_one_counter,(prev_words),alpha, vocab_size) + l2 * math.log2(1 / (dist + 1))
    return probability

In [60]:
def beam_search_sequence(initial_state, max_depth, beam_width, vocab_size, alpha, ngram_counter, ngram_minus_one_counter, generate_candidates_fn, score_fn):
    candidates = [(initial_state, 0)]

    for depth in tqdm(range(max_depth)):
        new_candidates = []
        for candidate, prob in candidates:
            for next_state in generate_candidates_fn(candidate, bigram_counter, 'bigram'):
                # print(prefixes_counter[(next_state[:-1],])
                length = prefixes_counter[tuple(next_state[:-1])]
                new_prob = prob + score_fn(next_state, vocab_size, alpha, ngram_counter, ngram_minus_one_counter,length,'bigram')
                new_candidates.append((next_state, new_prob))



        new_candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)
        candidates = new_candidates[:beam_width]
        print(candidates)

    best_sequence, best_prob = max(candidates, key=lambda x: x[1])
    return best_sequence


In [61]:
test_sentence = "The problem is"
initial_state = test_sentence.split(' ')[-1:]
max_depth = 21
beam_width = 3
best_sequence = beam_search_sequence(initial_state, max_depth, beam_width,len(vocab),0.01,bigram_counter,unigram_counter, generate_candidates, score)

print(test_sentence, ' '.join(best_sequence[1:]))  # Excluding the "<start>" token

  5%|▍         | 1/21 [00:00<00:02,  7.42it/s]

[(['is', '<UNK>'], -3.607799573047905), (['is', 'expected'], -3.9393275835653627), (['is', 'a'], -4.039012828941677)]


 10%|▉         | 2/21 [00:00<00:04,  3.90it/s]

[(['is', 'expected', 'to'], -4.684152532247547), (['is', '<UNK>', '<UNK>'], -6.815213976767074), (['is', '<UNK>', ','], -7.47088938691119)]


 14%|█▍        | 3/21 [00:01<00:08,  2.04it/s]

[(['is', 'expected', 'to', 'the'], -8.434240049611482), (['is', 'expected', 'to', '<UNK>'], -8.489639731663491), (['is', 'expected', 'to', 'be'], -9.099471403587014)]


 19%|█▉        | 4/21 [00:01<00:07,  2.41it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>'], -11.697054135382661), (['is', 'expected', 'to', 'be', '<UNK>'], -12.085219248659584), (['is', 'expected', 'to', 'the', '<UNK>'], -12.301865666536564)]


 24%|██▍       | 5/21 [00:01<00:06,  2.62it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>'], -14.904468539101831), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>'], -15.292633652378754), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>'], -15.509280070255734)]


 29%|██▊       | 6/21 [00:02<00:05,  2.73it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -18.111882942821), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>'], -18.500048056097924), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>'], -18.716694473974904)]


 33%|███▎      | 7/21 [00:02<00:04,  2.85it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -21.31929734654017), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -21.707462459817094), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -21.924108877694074)]


 38%|███▊      | 8/21 [00:02<00:04,  2.91it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -24.52671175025934), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -24.914876863536264), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -25.131523281413244)]


 43%|████▎     | 9/21 [00:03<00:05,  2.08it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -27.73412615397851), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -28.122291267255434), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -28.338937685132414)]


 48%|████▊     | 10/21 [00:03<00:04,  2.32it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -30.94154055769768), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -31.329705670974604), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -31.546352088851584)]


 52%|█████▏    | 11/21 [00:04<00:04,  2.47it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -34.14895496141685), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -34.537120074693775), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -34.753766492570755)]


 57%|█████▋    | 12/21 [00:04<00:03,  2.33it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -37.35636936513602), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -37.744534478412945), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -37.961180896289925)]


 62%|██████▏   | 13/21 [00:05<00:03,  2.24it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -40.56378376885519), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -40.951948882132115), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -41.168595300009095)]


 67%|██████▋   | 14/21 [00:05<00:03,  2.16it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -43.77119817257436), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -44.159363285851285), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -44.376009703728265)]


 71%|███████▏  | 15/21 [00:06<00:03,  1.65it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -46.97861257629353), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -47.366777689570455), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -47.583424107447435)]


 76%|███████▌  | 16/21 [00:07<00:02,  1.91it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -50.1860269800127), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -50.574192093289625), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -50.790838511166605)]


 81%|████████  | 17/21 [00:07<00:01,  2.14it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -53.39344138373187), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -53.781606497008795), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -53.998252914885775)]


 86%|████████▌ | 18/21 [00:07<00:01,  2.34it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -56.60085578745104), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -56.989020900727965), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -57.205667318604945)]


 90%|█████████ | 19/21 [00:08<00:00,  2.51it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -59.80827019117021), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -60.196435304447135), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -60.413081722324115)]


 95%|█████████▌| 20/21 [00:08<00:00,  1.96it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -63.01568459488938), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -63.403849708166305), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -63.620496126043285)]


100%|██████████| 21/21 [00:09<00:00,  2.30it/s]

[(['is', 'expected', 'to', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -66.22309899860855), (['is', 'expected', 'to', 'be', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -66.61126411188548), (['is', 'expected', 'to', 'the', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>', '<UNK>'], -66.82791052976245)]
The problem is expected to <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK> <UNK>





In [None]:
math.log2((bigram_counter[('to','be',)]+0.01) / (unigram_counter[('to',)]+ (0.01 * len(vocab))))

-4.424585941845405

In [None]:
(bigram_counter[('to','be',)]+0.01)

1122.01

In [None]:
(unigram_counter[('to',)])+ (0.01 * len(vocab))

24095.16

TRIGRAM GENERATE NEXT WORD GIVEN SEQUENCE

In [None]:
#TRIGRAM MODEL

In [17]:
def beam_search_decode(initial_state, max_depth, beam_width, generate_candidates_fn, score_fn):
    """
    Generate candidate words for a misspelled word, using between words distance.

    Parameters:
    - state: The current state.
    - word: The misspelled word.
    - word_list: List of words to search for candidates.
    - max_candidates: Maximum number of candidates
    - distance_fn: Distance function. Deffault damerau_levenshtein_distance

    Returns:
    - A list of candidate words.
    """
    candidates = [(initial_state, 0)]

    for depth in range(max_depth):
        new_candidates = []
        for candidate, prob in candidates:
            for next_state in generate_candidates_fn(candidate, trigram_counter, 'trigram'):
                length = prefixes_counter[tuple(next_state[:-1])]
                new_prob = prob + score_fn(next_state, len(vocab), 0.01, trigram_counter, bigram_counter,length, 'trigram')
                new_candidates.append((next_state, new_prob))



        new_candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)

        candidates = new_candidates[:beam_width]

    best_sequence, best_prob = max(candidates, key=lambda x: x[1])
    print(best_sequence[-1],end=" ")
    return best_sequence


test_sentence = "I am coming to"
initial_state = test_sentence.split(' ')[-2:]
max_depth = 30
beam_width = 3
print(test_sentence, end=" ")
best_sequence = beam_search_decode(initial_state, max_depth, beam_width, generate_candidates, score)
print(' '.join(best_sequence))

I am coming to <e> coming to the U . S . Agriculture Department ' s , <UNK> and <UNK> . O > 3RD QTR FEB 28 NET Shr 1 . 5 mln dlrs . <e> <e>


In [None]:
# import time

# # test_sentence = "I would like to"
# # initial_state = test_sentence.split(' ')[-2:]
# # max_depth = 20
# # beam_width = 3
# # print(test_sentence, end=" ")
# # best_sequence = beam_search_decode(initial_state, max_depth, beam_width, generate_candidates, score)

# for i in range(2,len(best_sequence)):
#   print(best_sequence[i], end=" ")  # Excluding the 2 first <start>" tokens
#   time.sleep(0.2)

# Spelling Corrector

In [50]:
# Leveinstein Destance with transposition.
def damerau_levenshtein_distance(s1, s2):
    """
    Calculate the Damerau–Levenshtein distance between two strings.

    Parameters:
    - s1: first string
    - s2: second string

    Returns:
    - Damerau Levenshtein distance
    """
    len_s1 = len(s1)
    len_s2 = len(s2)
    if abs(len_s1 - len_s2) >= 3: ##min(len_s1,len_s2):
        return 10
    d = [[0] * (len_s2 + 1) for _ in range(len_s1 + 1)]

    for i in range(len_s1 + 1):
        d[i][0] = i
    for j in range(len_s2 + 1):
        d[0][j] = j

    for i in range(1, len_s1 + 1):
        for j in range(1, len_s2 + 1):
            cost = 0 if s1[i - 1] == s2[j - 1] else 1
            d[i][j] = min(
                d[i - 1][j] + 1,  # deletion
                d[i][j - 1] + 1,  # insertion
                d[i - 1][j - 1] + cost,  # substitution
            )
            if i > 1 and j > 1 and s1[i - 1] == s2[j - 2] and s1[i - 2] == s2[j - 1]:
                d[i][j] = min(d[i][j], d[i - 2][j - 2] + cost)  # transposition

    return d[len_s1][len_s2]


In [51]:
# Take a word and the vocab and produce candidates/
def generate_candidate_with_distance(state, word, word_list, max_candidates=5, distance_fn=damerau_levenshtein_distance):
    """
    Generate candidate words for a misspelled word, using between words distance.

    Parameters:
    - state: The current state.
    - word: The misspelled word.
    - word_list: List of words to search for candidates.
    - max_candidates: Maximum number of candidates
    - distance_fn: Distance function. Deffault damerau_levenshtein_distance

    Returns:
    - A list of candidate words.
    """
    candidates = []

    for candidate in word_list:
        distance = distance_fn(word, candidate)
        if distance <=5:
          candidates.append((candidate, distance))

    # Sort candidates by Distance distance in ascending order
    candidates.sort(key=lambda x: x[1])
    next_words = candidates[:max_candidates]

    # Return next word and distance
    return [(state + [next_word[0]], next_word[1]) for next_word in next_words]

# Example usage
misspelled_word = "candidat"
initial_state = ['<s>','<s>']
candidates = generate_candidate_with_distance(initial_state,misspelled_word, vocab, 10)
print(f"Candidate words for '{misspelled_word}': {candidates}")

Candidate words for 'candidat': [(['<s>', '<s>', 'candidate'], 1), (['<s>', '<s>', 'candidates'], 2), (['<s>', '<s>', 'canadian'], 3), (['<s>', '<s>', 'credit'], 4), (['<s>', '<s>', 'scandal'], 4), (['<s>', '<s>', 'confident'], 4), (['<s>', '<s>', 'capital'], 4), (['<s>', '<s>', 'Canada'], 4), (['<s>', '<s>', 'Canadian'], 4), (['<s>', '<s>', 'consider'], 4)]


In [52]:
# def score(state, vocab_size, alpha, ngram_counter, ngram_minus_one_counter, dist=0, l1=1, l2=0, calculate_ngram_probability_fn=calculate_ngram_probability):
#     """
#     Calculate the log probability  of the word sequence

#     Parameters:
#     - state: The current word sequence.
#     - vocab_size: The size of the vocabulary
#     - alpha: float hyperparameter for Laplace smoothing
#     - ngram_counter:
#     - ngram_minus_one_counter
#     - dist: int distance between words. Only for spell correcting.
#     - l1: float hyperparameter for weighting the model. Deffault=1
#     - l2: float hyperparameter for weigthing the distance. Deffault=0
#     - calculate_ngram_probability_fn:

#     Returns:
#     - Log Probability
#     """
#     probability = 0
#     for i in range(2, len(state)):

#         prev_words = tuple(state[i-len(state):])
#         probability += l1 * calculate_ngram_probability_fn(ngram_counter, ngram_minus_one_counter,(prev_words),alpha, vocab_size) + l2 * math.log2(1 / (dist + 1))
#     return probability

In [53]:
def beam_search_spelling(sentence, beam_width, l1, l2, generate_candidates_fn, score_fn):
    """
    Spelling correction with contect awereness using beam search

    Parameters:
    - sentence: The sentence we try to correct
    - beam_width: The width of beam search algorithm.
    - generate_candidates_fn: function that generates candidate words
    - score_fn: Function that calculates the log probability

    Returns:
    - The most probable sequence corrected.
    """

    initial_state = ['<s>','<s>']
    candidates = [(initial_state, 0)]
    # sentence = word_tokenize(sentence)
    max_depth = len(sentence)
    for depth in range(max_depth):
        new_candidates = []
        for candidate, prob in candidates:
            for next_state, dist in generate_candidates_fn(candidate, sentence[depth],vocab):

                # Prob we add the previous prob, the prob of the next state and the inverse of the distance
                new_prob = prob + score_fn(next_state,len(vocab),0.01, trigram_counter, bigram_counter, dist, l1=0.2, l2=0.8)

                new_candidates.append((next_state, new_prob))


        new_candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)

        candidates = new_candidates[:beam_width]
        print(candidates)
        # print(candidates)
    best_sequence, best_prob = max(candidates, key=lambda x: x[1])
    return best_sequence[2:]


test_sentence = word_tokenize("I ae coming to down")
beam_width = 5
best_sequence = beam_search_spelling(test_sentence, beam_width, 0.2, 0.8, generate_candidate_with_distance, score)
print(' '.join(best_sequence))  # Excluding the "<start>" token

[(['<s>', '<s>', ','], 7459890415935.098), (['<s>', '<s>', '.'], 5969166918932.01), (['<s>', '<s>', 'a'], 1611882000363.2622), (['<s>', '<s>', 'I'], 415175273960.16223), (['<s>', '<s>', 'DI'], 18589937640.007267)]
[(['<s>', '<s>', ',', 'be'], 7459890415935.098), (['<s>', '<s>', ',', 'a'], 7459890415935.098), (['<s>', '<s>', ',', 'an'], 7459890415935.098), (['<s>', '<s>', ',', 'as'], 7459890415935.098), (['<s>', '<s>', ',', 'are'], 7459890415935.098)]
[(['<s>', '<s>', ',', 'are', 'coming'], 7459890415935.1045), (['<s>', '<s>', ',', 'be', 'coming'], 7459890415935.099), (['<s>', '<s>', ',', 'be', 'closing'], 7459890415935.099), (['<s>', '<s>', ',', 'be', 'voting'], 7459890415935.099), (['<s>', '<s>', ',', 'a', 'closing'], 7459890415935.099)]
[(['<s>', '<s>', ',', 'are', 'coming', 'to'], 7459890415935.368), (['<s>', '<s>', ',', 'a', 'closing', 'two'], 7459890415935.356), (['<s>', '<s>', ',', 'be', 'coming', 'to'], 7459890415935.186), (['<s>', '<s>', ',', 'be', 'closing', 'to'], 74598904159

In [None]:
trigram_counter.get

['I', 'are', 'coming', 'to', 'down']


In [None]:
ngram_count = trigram_counter[('<s>','<s>','I')]
context = ('<s>','<s>',)
ngram_minus_one_count = bigram_counter[context]
ngram_prob = (ngram_count + 0.01) / (ngram_minus_one_count + (0.01 * len(vocab)))
print(ngram_prob)
# Convert to log probability
# ngram_prob = math.log2(ngram_prob)

0.2059215306369581


In [None]:
math.log2(0.2)

-2.321928094887362

0

# EVALUATE

In [54]:
import random

def replace_characters(sentence, probability):
    modified_sentence = []
    for word in sentence:
        modified_word = ''
        for char in word:
            if char != ' ' and random.random() < probability:
                # Replace non-space character with a visually or acoustically similar character
                # You can customize this part based on your preference or use external libraries for similarity
                modified_char = get_similar_char(char)
                modified_word += modified_char
            else:
                modified_word += char
        modified_sentence.append(modified_word)
    return modified_sentence

def get_similar_char(char):
    # Replace this with your logic to get a visually or acoustically similar character
    # For simplicity, using a basic example here (you can expand this based on your requirements)
    similar_chars = {'a': 'e', 'b': 'd', 'c': 'e', 'd': 'b', 'e': 'a', 'f': 'g',
                     'g': 'f', 'h': 'i', 'i': 'h', 'j': 'k', 'k': 'j', 'l': 'm',
                     'm': 'l', 'n': 'o', 'o': 'n', 'p': 'q', 'q': 'p', 'r': 's',
                     's': 'r', 't': 'u', 'u': 't', 'v': 'w', 'w': 'v', 'x': 'y',
                     'y': 'x', 'z': 'z'}
    return similar_chars.get(char, char)

def modify_corpus(corpus, probability):
    modified_corpus = []
    for sentence in corpus:
        modified_sentence = replace_characters(sentence, probability)
        modified_corpus.append(modified_sentence)
    return modified_corpus

# Example usage with a probability of 0.1 (10% chance of replacing each non-space character)
modified_test_corpus = modify_corpus(test_sents, 0.1)

print_cnt = 0
for sent in modified_test_corpus:
  print_cnt +=1
  print(sent)
  print("______________________")
  if print_cnt == 20:
    break;

['The', 'commirsion', 'is', 'expected', 'to', 'apprnve', 'the', 'applicathon', 'at', 'a', 'meeting', 'uomorrow', '.']
______________________
['"', 'The', 'Unitad', 'States', 'end', 'the', 'six', 'major', 'hndusurial', 'countrier', 'are', 'fully', 'commiuted', 'to', 'hmplemanting', 'our', 'undertajings', 'io', 'there', 'agreementr', ',"', 'Bejer', 'told', 'the', 'maetings', '.']
______________________
['Son', 'Line', 'said', 'in', 'Jaouary', 'iu', 'was', 'seekinf', 'bids', 'for', 'the', 'psoperty', '.']
______________________
['Cain', '-', 'Sloan', 'has', 'four', 'stnres', 'hn', 'Nashville', '.']
______________________
['Romero', 'seid', 'he', 'would', 'tell', 'big', 'buyers', 'of', 'copra', 'leal', 'io', 'Loodon', 'that', 'the', 'Philhppioes', 'was', 'doiog', 'its', 'best', 'to', 'meet', 'EC', 'standarbs', '.']
______________________
['Pretex', 'qsnfius', 'also', 'dhpped', 'uo', '601', '.', '7', 'mln', 'stg', 'afuer', '614', '.', '4', 'mln', '.']
______________________
['Asked', 'in', 

In [55]:
import numpy as np

# Take a portion of the test_corpus and modified_test_corpus
org_sent = test_sents[:5]
wrg_sent = modified_test_corpus[:5]

def correct_corpus_np(corpus, vocab, max_candidates=5):
    corrected_corpus = []
    for sentence in corpus:
        corrected_sentence = beam_search_spelling(sentence, 3,0.2,0.8, generate_candidate_with_distance, score)
        corrected_corpus.append(corrected_sentence)
    return corrected_corpus

corrected_test_corpus = correct_corpus_np(wrg_sent, vocab, 5)

print("Original Test Corpus:")
print(org_sent)

print("\nModified Test Corpus:")
print(wrg_sent)

print("\nCorrected Test Corpus:")
print(corrected_test_corpus)


[(['<s>', '<s>', 'The'], 62131691893200.945), (['<s>', '<s>', 'They'], 3142179875280.0146), (['<s>', '<s>', 'the'], 2444701822364.836)]
[(['<s>', '<s>', 'The', 'commission'], 62131691893200.945), (['<s>', '<s>', 'The', 'Commission'], 62131691893200.945), (['<s>', '<s>', 'The', 'commissions'], 62131691893200.945)]
[(['<s>', '<s>', 'The', 'commission', 'is'], 62131691893201.04), (['<s>', '<s>', 'The', 'commission', 'in'], 62131691893201.02), (['<s>', '<s>', 'The', 'Commission', 'in'], 62131691893201.0)]
[(['<s>', '<s>', 'The', 'commission', 'is', 'expected'], 62131691893201.04), (['<s>', '<s>', 'The', 'commission', 'is', 'expect'], 62131691893201.04), (['<s>', '<s>', 'The', 'commission', 'is', 'expects'], 62131691893201.04)]
[(['<s>', '<s>', 'The', 'commission', 'is', 'expected', 'to'], 62131691893201.914), (['<s>', '<s>', 'The', 'commission', 'is', 'expected', 'Co'], 62131691893201.04), (['<s>', '<s>', 'The', 'commission', 'is', 'expected', 'no'], 62131691893201.04)]
[(['<s>', '<s>', 'T

In [56]:
!pip install evaluate
!pip install jiwer
from evaluate import load

# Flatten the list of lists
flattened_corrected_test_corpus = [' '.join(sentence) for sentence in corrected_test_corpus]
flattened_org_sent = [' '.join(sentence) for sentence in org_sent]

# Transform predictions
predictions = [' '.join(flattened_corrected_test_corpus)]
references = [' '.join(flattened_org_sent)]

wer = load("wer")  # Load Word-Error-Rate metric
wer_score = wer.compute(predictions=predictions, references=references)
print(f"WER score is: {wer_score}")

cer = load("cer")
cer_score = cer.compute(predictions=predictions, references=references)
print(f"CER score is: {cer_score}")


# Flatten the list of lists
flattened_corrected_test_corpus = [' '.join(sentence) for sentence in wrg_sent]
flattened_org_sent = [' '.join(sentence) for sentence in org_sent]

# Transform predictions
predictions = [' '.join(flattened_corrected_test_corpus)]
references = [' '.join(flattened_org_sent)]

wer = load("wer")  # Load Word-Error-Rate metric
wer_score = wer.compute(predictions=predictions, references=references)
print(f"WER score is: {wer_score}")

cer = load("cer")
cer_score = cer.compute(predictions=predictions, references=references)
print(f"CER score is: {cer_score}")


Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets>=2.0.0 (from evaluate)
  Downloading datasets-2.16.1-py3-none-any.whl (507 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m507.1/507.1 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
Collecting dill (from evaluate)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from evaluate)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19 (from evaluate)
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Installing collected packages: dill, responses, mul

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

WER score is: 0.36470588235294116


Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

CER score is: 0.10425531914893617
WER score is: 0.35294117647058826
CER score is: 0.07234042553191489


ZeroDivisionError: division by zero