# Lab 3: Text Generation with an N-Gram Langauge Model

In this lab, you will implement two text generation strategies covered in lecture: Greedy Serch and Sampling.

Here's a revision of the n-gram language model implementation from the previous lab. It now includes a `get_prob_dist()` function, which returns the probabilities of all tokens given the context.

Look over the implementation to be sure that you understand it.

In [2]:
import pickle
BOS = '<BOS>'
EOS = '<EOS>'
OOV = '<OOV>'
class NGramLM:
    def __init__(self, path, smoothing=0.001, verbose=False):
        with open(path, 'rb') as fin:
            data = pickle.load(fin)
        self.n = data['n']
        self.V = set(data['V'])
        self.model = data['model']
        self.smoothing = smoothing
        self.verbose = verbose

    def get_prob_dist(self, context):
        # Take only the n-1 most recent context (Markov Assumption)
        context = tuple(context[-self.n+1:])
        # Add <BOS> tokens if the context is too short, i.e., it's at the start of the sequence
        while len(context) < (self.n-1):
            context = (BOS,) + context
        # Handle words that were not encountered during the training by replacing them with a special <OOV> token
        context = tuple((c if c in self.V else OOV) for c in context)
        if context in self.model:
            # Compute the probability distribution using a Maximum Likelihood Estimation and Laplace Smoothing
            norm = sum(self.model[context].values()) + self.smoothing * len(self.V)
            prob_dist = {k: (c + self.smoothing) / norm for k, c in self.model[context].items()}
            for word in self.V - prob_dist.keys():
                prob_dist[word] = self.smoothing / norm
        else:
            # Simplified formula if we never encountered this context; the probability of all tokens is uniform
            prob = 1 / len(self.V)
            prob_dist = {k: prob for k in self.V}
        prob_dist = dict(sorted(prob_dist.items(), key=lambda x: (-x[1], x[0])))
        return prob_dist

In [3]:
# Load pre-built n-gram languae models
model_unigram = NGramLM('arthur-conan-doyle.tok.train.n1.pkl')
model_bigram = NGramLM('arthur-conan-doyle.tok.train.n2.pkl')
model_trigram = NGramLM('arthur-conan-doyle.tok.train.n3.pkl')
model_4gram = NGramLM('arthur-conan-doyle.tok.train.n4.pkl')
model_5gram = NGramLM('arthur-conan-doyle.tok.train.n5.pkl')

Let's take a look at some of the probability distributions.

Are they reasonable?

In [4]:
model_bigram.get_prob_dist(['my'])

{'friend': 0.05662633311127694,
 'own': 0.04360886008266746,
 'dear': 0.031893134356918935,
 'mind': 0.02570983466832943,
 'companion': 0.02408265053975325,
 'wife': 0.0201774086311704,
 'hand': 0.016923040374018032,
 'life': 0.013343235291150427,
 'father': 0.013017798465435191,
 'husband': 0.010414303859713295,
 'head': 0.01008886703399806,
 'eyes': 0.00911255655685235,
 'room': 0.00911255655685235,
 'way': 0.00911255655685235,
 'hands': 0.007810809253991401,
 'face': 0.007485372428276164,
 'heart': 0.007485372428276164,
 'house': 0.006834498776845691,
 'old': 0.006834498776845691,
 'sister': 0.006834498776845691,
 'word': 0.006834498776845691,
 'poor': 0.006509061951130454,
 'name': 0.005532751473984743,
 'presence': 0.005532751473984743,
 'attention': 0.005207314648269506,
 'boy': 0.005207314648269506,
 'business': 0.005207314648269506,
 'surprise': 0.005207314648269506,
 'chair': 0.004881877822554268,
 'pocket': 0.004881877822554268,
 'little': 0.004556440996839032,
 'story': 0.00

In [5]:
model_bigram.get_prob_dist(['.'])

{'"': 0.23905809267120004,
 '<EOS>': 0.12003662173077437,
 'I': 0.08395797400078746,
 'He': 0.043522818983357844,
 'It': 0.04326904185981749,
 'The': 0.03899712694688821,
 "'": 0.02368590715995356,
 'But': 0.022036355856941265,
 'You': 0.021063543550036576,
 'There': 0.018991030374457027,
 'We': 0.014634523087014295,
 'Then': 0.012519713724178018,
 'She': 0.012139048038867487,
 'A': 0.01040490436134174,
 'If': 0.01032031198682829,
 'His': 0.009812757739747583,
 'And': 0.009347499679923602,
 'This': 0.00879764924558617,
 'That': 0.008543872122045817,
 'In': 0.008290094998505464,
 'As': 0.008205502623992013,
 'What': 0.007782540751424758,
 'Now': 0.007444171253370953,
 'Holmes': 0.0066828398827498935,
 'When': 0.006640543695493168,
 'My': 0.006217581822925912,
 'They': 0.005921508512128833,
 'At': 0.005710027575845205,
 'No': 0.005117880954251048,
 'On': 0.004779511456197243,
 'For': 0.004271957209116537,
 'Well': 0.004187364834603086,
 'So': 0.0041450686473463606,
 'One': 0.003426033463

In [6]:
model_trigram.get_prob_dist(['my', 'dear'])

{'Watson': 0.4791320028224717,
 'fellow': 0.12196736734818324,
 'sir': 0.07841070448546514,
 'young': 0.043565374195290656,
 'Von': 0.017431376477659785,
 'boy': 0.017431376477659785,
 '.': 0.008720043905116165,
 'Arthur': 0.008720043905116165,
 'Hopkins': 0.008720043905116165,
 'Watson,"—he': 0.008720043905116165,
 'daughter': 0.008720043905116165,
 'girl': 0.008720043905116165,
 'lady': 0.008720043905116165,
 'little': 0.008720043905116165,
 'madam': 0.008720043905116165,
 'son': 0.008720043905116165,
 'wife': 0.008720043905116165,
 '!': 8.711332572543621e-06,
 '"': 8.711332572543621e-06,
 '&': 8.711332572543621e-06,
 "'": 8.711332572543621e-06,
 "'S": 8.711332572543621e-06,
 "'d": 8.711332572543621e-06,
 "'em": 8.711332572543621e-06,
 "'ll": 8.711332572543621e-06,
 "'m": 8.711332572543621e-06,
 "'re": 8.711332572543621e-06,
 "'s": 8.711332572543621e-06,
 "'ve": 8.711332572543621e-06,
 '(': 8.711332572543621e-06,
 ')': 8.711332572543621e-06,
 '):': 8.711332572543621e-06,
 '+': 8.7113

Great, now we have all the tools we need to start generating text!

We'll start with a simple greedy generation approach. Your task is to implement greedy generation below.

Note: we have a `max_length` parameter to be sure that the generation process doesn't go on forever. You can stop when your sequence either reaches an `<EOS>` token or is the maximum length.

In [15]:
from typing import List
def greedy_generation(model: NGramLM, context: List[str], max_length: int = 100) -> List[str]:
#     return model.get_prob_dist(context)
   return sorted(model.get_prob_dist(context).items(), key=lambda x: x[1], reverse=True)[0]

greedy_generation(model_trigram, ['""', 'My', 'dear', 'Watson'])

(',', 0.6391663429122855)

Great! How does the generation look? Feel free to try out a several samples below.

Is it deterministic? Are the generated sequences interesting?

Consider trying different model types. What are the different qualities that you see from the unigram, bigram, trigram, 4-gram, and 5-gram models?

In [17]:
print(greedy_generation(model_unigram, ['""', 'My', 'dear', 'Watson']))
print(greedy_generation(model_bigram, ['""', 'My', 'dear', 'Watson']))
print(greedy_generation(model_trigram, ['""', 'My', 'dear', 'Watson']))
print(greedy_generation(model_4gram, ['""', 'My', 'dear', 'Watson']))
print(greedy_generation(model_4gram, ['""', 'My', 'dear', 'Watson']))

('!', 5.954862144941345e-05)
(',', 0.5712095268053272)
(',', 0.6391663429122855)
(',', 0.32271205582220786)
(',', 0.32271205582220786)


Now it's time to implement sampling.

We now include a `topk` argument. This reduces the candidate set of probabilities down to only the `topk` highest-probability items. This helps reduce the chance of generating highly unlikely sequences.

Note: consider using [`random.choices`](https://docs.python.org/3/library/random.html#random.choices) to help in sampling.

In [23]:
from typing import List
import random
def sampling_generation(model: NGramLM, context: List[str], max_length: int = 100, topk=10) -> List[str]:
    word_prob = sorted(model.get_prob_dist(context).items(), key=lambda x: x[1], reverse=True)[:topk]
    words = [x[0] for x in word_prob]
    probs = [x[1] for x in word_prob]    
    return random.choices(words, weights=probs, k=1)[0]

sampling_generation(model_trigram, ['""', 'My', 'dear', 'Watson'])

','

Now qualitatively compare your sampling generation with the greedy generation.

What do you notice about the generated sequences? How do models of different sizes behave? What is the effect of `topk`?

In [24]:
print(sampling_generation(model_unigram, ['""', 'My', 'dear', 'Watson']))
print(sampling_generation(model_bigram, ['""', 'My', 'dear', 'Watson']))
print(sampling_generation(model_trigram, ['""', 'My', 'dear', 'Watson']))
print(sampling_generation(model_4gram, ['""', 'My', 'dear', 'Watson']))
print(sampling_generation(model_4gram, ['""', 'My', 'dear', 'Watson']))

"
—
,
,
,


We see for bigram model though "," has highest probability, the samping strategy allows other token to appear.

## Optional Extras:
 - Try implementing a beam search strategy. Does it tend to lead to qualitatively better results than the other two approaches?
 - What strategy might you take to efficiently find the most likely sequence for an n-gram language model?
 
### Bream Search Generation

In [42]:
def bream_search_generation(model: NGramLM, context: List[str], max_length: int = 100, topk=2) -> List[str]:
    
    pred_token = []
    for i in range(len(context)):
        word_prob = sorted(model.get_prob_dist(context[i]).items(), key=lambda x: x[1], reverse=True)[:5*topk]
        
        print(word_prob)
        for j in range(len(word_prob)):
            pred_token.append(word_prob[j][0])
    

    return pred_token

topk = 3
pred_tokens = bream_search_generation(model_trigram, [['""', 'My', 'dear']], topk=topk)

print()

new_context = []
for i in range(topk):
    new_context.append(['""', 'My', 'dear'] + [pred_tokens[i]])
    
print(new_context)
print() 
    
_ = bream_search_generation(model_trigram, new_context)
print(_)

[('fellow', 0.18240542306932345), ('Holmes', 0.13680786709832352), ('Watson', 0.12160868177465686), ('sir', 0.09121031112732357), ('young', 0.04561275515632361), ('Mr.', 0.030413569832656966), (',', 0.015214384508990314), ('Gregory', 0.015214384508990314), ('Inspector', 0.015214384508990314), ('Professor', 0.015214384508990314), ('chap', 0.015214384508990314), ('doctor', 0.015214384508990314), ('madam', 0.015214384508990314), ('man', 0.015214384508990314), ('wife', 0.015214384508990314)]

[['""', 'My', 'dear', 'fellow'], ['""', 'My', 'dear', 'Holmes'], ['""', 'My', 'dear', 'Watson']]

[(',', 0.4795515264996689), ('!', 0.0456922339186628), ('.', 0.0456922339186628), ('--', 0.02285753430913616), ('wanted', 0.02285753430913616), ('"', 2.2834699609526637e-05), ('&', 2.2834699609526637e-05), ("'", 2.2834699609526637e-05), ("'S", 2.2834699609526637e-05), ("'d", 2.2834699609526637e-05)]
[('!', 0.19388981506610323), (',', 0.15511960609467687), ('"', 3.8770208971426355e-05), ('&', 3.87702089714