# Lab 3: Text Generation with an N-Gram Language Model

Here, we will implement two text generation strategies covered in lecture: Greedy Search and Sampling.

Here's a revision of the n-gram language model implementation from the previous lab. It now includes a `get_prob_dist()` function, which returns the probabilities of all tokens given the context.

Look over the implementation to be sure that you understand it.

In [2]:
import pickle
BOS = '<BOS>'
EOS = '<EOS>'
OOV = '<OOV>'
class NGramLM:
    def __init__(self, path, smoothing=0.001, verbose=False):
        with open(path, 'rb') as fin:
            data = pickle.load(fin)
        self.n = data['n']
        self.V = set(data['V'])
        self.model = data['model']
        self.smoothing = smoothing
        self.verbose = verbose

    def get_prob_dist(self, context):
        # Take only the n-1 most recent context (Markov Assumption)
        context = tuple(context[-self.n+1:])
        # Add <BOS> tokens if the context is too short, i.e., it's at the start of the sequence
        while len(context) < (self.n-1):
            context = (BOS,) + context
        # Handle words that were not encountered during the training by replacing them with a special <OOV> token
        context = tuple((c if c in self.V else OOV) for c in context)
        if context in self.model:
            # Compute the probability distribution using a Maximum Likelihood Estimation and Laplace Smoothing
            norm = sum(self.model[context].values()) + self.smoothing * len(self.V)
            prob_dist = {k: (c + self.smoothing) / norm for k, c in self.model[context].items()}
            for word in self.V - prob_dist.keys():
                prob_dist[word] = self.smoothing / norm
        else:
            # Simplified formula if we never encountered this context; the probability of all tokens is uniform
            prob = 1 / len(self.V)
            prob_dist = {k: prob for k in self.V}
        prob_dist = dict(sorted(prob_dist.items(), key=lambda x: (-x[1], x[0])))
        return prob_dist

In [3]:
# Load pre-built n-gram languae models
model_unigram = NGramLM('arthur-conan-doyle.tok.train.n1.pkl')
model_bigram = NGramLM('arthur-conan-doyle.tok.train.n2.pkl')
model_trigram = NGramLM('arthur-conan-doyle.tok.train.n3.pkl')
model_4gram = NGramLM('arthur-conan-doyle.tok.train.n4.pkl')
model_5gram = NGramLM('arthur-conan-doyle.tok.train.n5.pkl')

In [4]:
model_bigram.get_prob_dist(['my'])

{'friend': 0.05662633311127694,
 'own': 0.04360886008266746,
 'dear': 0.031893134356918935,
 'mind': 0.02570983466832943,
 'companion': 0.02408265053975325,
 'wife': 0.0201774086311704,
 'hand': 0.016923040374018032,
 'life': 0.013343235291150427,
 'father': 0.013017798465435191,
 'husband': 0.010414303859713295,
 'head': 0.01008886703399806,
 'eyes': 0.00911255655685235,
 'room': 0.00911255655685235,
 'way': 0.00911255655685235,
 'hands': 0.007810809253991401,
 'face': 0.007485372428276164,
 'heart': 0.007485372428276164,
 'house': 0.006834498776845691,
 'old': 0.006834498776845691,
 'sister': 0.006834498776845691,
 'word': 0.006834498776845691,
 'poor': 0.006509061951130454,
 'name': 0.005532751473984743,
 'presence': 0.005532751473984743,
 'attention': 0.005207314648269506,
 'boy': 0.005207314648269506,
 'business': 0.005207314648269506,
 'surprise': 0.005207314648269506,
 'chair': 0.004881877822554268,
 'pocket': 0.004881877822554268,
 'little': 0.004556440996839032,
 'story': 0.00

Great, now we have all the tools we need to start generating text!

We'll start with a simple greedy generation approach. Our task is to implement greedy generation below.

Note: we have a `max_length` parameter to be sure that the generation process doesn't go on forever. We can stop when our sequence either reaches an `<EOS>` token or is the maximum length.

In [76]:
from typing import List

def greedy_generation(model: NGramLM, context: List[str], max_length: int=100) -> List[str]:
    generated_text = context.copy()  # Start with the given context
    while len(generated_text) < max_length:
        current_context = generated_text[-(model.n-1):] if len(generated_text) >= (model.n-1) else generated_text
        prob_dist = model.get_prob_dist(current_context)  # Get the probability distribution for the next word
        next_word = max(prob_dist, key=prob_dist.get)  # Select the word with the highest probability
        if next_word == EOS:  # Stop if EOS is reached
            break
        generated_text.append(next_word)  # Append the selected word to the generated text

    return generated_text


greedy_generation(model_trigram, ['""', 'My', 'dear', 'Watson'])

['""', 'My', 'dear', 'Watson', ',', '"', 'said', 'he', '.', '"']

In [24]:
# Load pre-built n-gram languae models
unigram = 'arthur-conan-doyle.tok.train.n1.pkl'
bigram = 'arthur-conan-doyle.tok.train.n2.pkl'
trigram = 'arthur-conan-doyle.tok.train.n3.pkl'
_4gram = 'arthur-conan-doyle.tok.train.n4.pkl'
_5gram = 'arthur-conan-doyle.tok.train.n5.pkl'

In [25]:
with open (trigram, 'rb') as fin:
    dt = pickle.load(fin)

In [55]:
texts = ['My', 'dear', 'Watson']
n = 3
model = dt['model']
V = dt['V']

context = tuple(texts[-n+1:])
print(texts, context)

['My', 'dear', 'Watson'] ('dear', 'Watson')


In [56]:
while len(context) < (n-1):
    context = (BOS,) + context
context

('dear', 'Watson')

In [57]:
context = tuple((c if c in dt['V'] else OOV) for c in context)
context

('dear', 'Watson')

SyntaxError: invalid syntax (3314419859.py, line 1)