# Download WikiText-2 dataset

In [None]:
! wget https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/data/wikitext-2/train.txt
! wget https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/data/wikitext-2/valid.txt
! wget https://raw.githubusercontent.com/pytorch/examples/master/word_language_model/data/wikitext-2/test.txt

# Helper function

In [2]:
from collections import Counter, defaultdict
import math
import copy
import random
import operator

flatten = lambda l: [item for sublist in l for item in sublist]

# some helper functions
def prepare_data(filename):
    data = [l.strip().split() + ['</s>'] for l in open(filename) if l.strip()]
    corpus = flatten(data)
    vocab = set(corpus)
    return vocab, data

# N-Gram LM
An language model assign a probability to each possible next word given a history of previous words (context). 

$P(w_t|w_{t-1}, w_{t-2}, ... , w_1)$

### Markov Assumption
Since calculating the probability of the whole sentence is not feasible, the Markov Assumption is introduced.

It assumes that each next word only depend on the previous K words (In an NGram language model, K = N-1).
- Unigram: $P(w_t|w_{t-1}, w_{t-2}, ... , w_1) = P(w_t)$
- Bigram: $P(w_t|w_{t-1}, w_{t-2}, ... , w_1) = P(w_t|w_{t-1}) $
- Trigram: $P(w_t|w_{t-1}, w_{t-2}, ... , w_1) = P(w_t|w_{t-1}, w_{t-2})$

For a NGram language model, the probability is calculating by counting the frequency:

$P(w_t|w_{t-1}, w_{t-2}, ... ,w_{t-N+1}) = \frac{C(w_t, w_{t-1}, w_{t-2}, ... ,w_{t-N+1} )}{C(w_{t-1}, w_{t-2}, ... ,w_{t-N+1})}$

In [3]:
class NGramLM():
    def __init__(self, N):
        self.N = N
        self.vocab = set()
        self.data = []
        self.prob = {}
        self.count = defaultdict(Counter)
    
    # For N = 1, the probability is stored in a dict   P = prob[next_word]
    # For N > 1, the probability is in a nested dict   P = prob[context][next_word]
    def train(self, vocab, data, smoothing_k=0):
        self.vocab = vocab
        self.data = data
        self.smoothing_k = smoothing_k

        if self.N == 1:
            self.counts = Counter(flatten(data))
            self.prob = self.get_prob(self.counts)
        else:
            self.vocab.add('<s>')
            counts = self.count_ngram()
            
            self.prob = {}
            for context, counter in counts.items():
                self.prob[context] = self.get_prob(counter)
    
    def count_ngram(self):
        counts = defaultdict(Counter)
        for sentence in self.data:
            sentence = (self.N - 1) * ['<s>'] + sentence 
            for i in range(len(sentence)-self.N+1):
                context = sentence[i:i+self.N-1]
                context = " ".join(context)
                word = sentence[i+self.N-1]
                counts[context][word] += 1

        self.counts = counts
        return counts
        
    # normalize counts into probability(considering smoothing)
    def get_prob(self, counter):
        total = float(sum(counter.values()))
        k = self.smoothing_k
        
        prob = {}
        for word, count in counter.items():
            prob[word] = (count + k) / (total + len(self.vocab) * k)
        return prob
        
    def get_ngram_logprob(self, word, seq_len, context=""):
        if self.N == 1 and word in self.prob.keys():
            return math.log(self.prob[word]) / seq_len
        elif self.N > 1 and not self._is_unseen_ngram(context, word):
            return math.log(self.prob[context][word]) / seq_len
        else:
            # assign a small probability to the unseen ngram
            # to avoid log by zero and penalise unseen word or context
            return math.log(1/len(self.vocab)) / seq_len
        
    def get_ngram_prob(self, word, context=""):
        if self.N == 1 and word in self.prob.keys():
            return self.prob[word]
        elif self.N > 1 and not self._is_unseen_ngram(context, word):
            return self.prob[context][word]
        elif word in self.vocab and self.smoothing_k > 0:
            # probability assigned by smoothing
            return self.smoothing_k / (sum(self.counts[context].values()) + self.smoothing_k*len(self.vocab))
        else:
            # unseen word or context
            return 0
            
    # In this method, the perplexity is measured on sentence-level, averaging over all sentences.
    # Actually, it is also possible to calculate perplexity by merging all sentences into a long one.
    def perplexity(self, test_data):
        log_ppl = 0
        if self.N == 1:
            for sentence in test_data:
                for word in sentence:
                    log_ppl += self.get_ngram_logprob(word=word, seq_len=len(sentence))
        else:
            for sentence in test_data:
                for i in range(len(sentence)-self.N+1):
                    context = sentence[i:i+self.N-1]
                    context = " ".join(context)
                    word = sentence[i+self.N-1]
                    log_ppl += self.get_ngram_logprob(context=context, word=word, seq_len=len(sentence))
                        
        log_ppl /= len(test_data)
        ppl = math.exp(-log_ppl)
        return ppl
    
    def _is_unseen_ngram(self, context, word):
        if context not in self.prob.keys() or word not in self.prob[context].keys():
            return True
        else:
            return False
    
    # generate the most probable k words
    def generate_next(self, context, k):
        contex = (self.N-1) * '<s>' + context
        context = context.split()
        ngram_context_list = context[-self.N+1:]
        ngram_context = " ".join(ngram_context_list)
        
        if ngram_context in self.prob.keys():
            candidates = self.prob[ngram_context]
            most_probable_words = sorted(candidates.items(), key=lambda kv: kv[1], reverse=True)
            for i in range(min(k, len(most_probable_words))):
                print(" ".join(context)+" "+most_probable_words[i][0]+"\t P={}".format(most_probable_words[i][1]))
        else:
            print("Unseen context!")
            
    # generate the next n words with greedy search
    def generate_next_n(self, context, n):
        contex = (self.N-1) * '<s>' + context
        context = context.split()
        ngram_context_list = context[-self.N+1:]
        ngram_context = " ".join(ngram_context_list)
        
        for i in range(n):
            try:
                candidates = self.prob[ngram_context]
                most_likely_next = max(candidates.items(), key=operator.itemgetter(1))[0]
                context += [most_likely_next]
                ngram_context_list = ngram_context_list[1:] + [most_likely_next]
                ngram_context = " ".join(ngram_context_list)
            except:
                break
        print(" ".join(context))
    

# Train with toy dataset

At this step, let's train a Bigram language model on the toy dataset. The ngram probabilities are shown to give you an intuition of how the ngrams are counted. 

In [4]:
corpus = ["I like ice cream",
         "I like chocolate",
         "I hate beans"]
data = [l.strip().split() + ['</s>'] for l in corpus if l.strip()]
vocab = set(flatten(data))
print(data)
print(vocab)

[['I', 'like', 'ice', 'cream', '</s>'], ['I', 'like', 'chocolate', '</s>'], ['I', 'hate', 'beans', '</s>']]
{'</s>', 'beans', 'chocolate', 'like', 'cream', 'I', 'hate', 'ice'}


In [5]:
def print_probability(lm):
    for context in lm.vocab:
        for word in lm.vocab:
            prob = lm.get_ngram_prob(word, context)
            print("P({}\t|{}) = {}".format(word, context, prob))
        print("--------------------------")

# smoothing
Smoothing method is used to deal with the sparsity problem in NGram LM.
The probability of frequent words are stolen by the less frequent words.

For example, with an add-1 smoothing, the probability is calculated as:

$$P(w_t | context) = \frac{C(w_t, context)+1}{C(context)+|V|}$$

Q1: What is the disadvantage of smoothing?

A:

In [6]:
lm = NGramLM(2)
lm.train(vocab, data, smoothing_k=0)

######################################################
# Q2: try with add-1 smoothing and see the probability
######################################################
# lm.train(vocab, data, smoothing_k=1)

print_probability(lm)

P(</s>	|</s>) = 0
P(beans	|</s>) = 0
P(<s>	|</s>) = 0
P(chocolate	|</s>) = 0
P(like	|</s>) = 0
P(cream	|</s>) = 0
P(I	|</s>) = 0
P(hate	|</s>) = 0
P(ice	|</s>) = 0
--------------------------
P(</s>	|beans) = 1.0
P(beans	|beans) = 0
P(<s>	|beans) = 0
P(chocolate	|beans) = 0
P(like	|beans) = 0
P(cream	|beans) = 0
P(I	|beans) = 0
P(hate	|beans) = 0
P(ice	|beans) = 0
--------------------------
P(</s>	|<s>) = 0
P(beans	|<s>) = 0
P(<s>	|<s>) = 0
P(chocolate	|<s>) = 0
P(like	|<s>) = 0
P(cream	|<s>) = 0
P(I	|<s>) = 1.0
P(hate	|<s>) = 0
P(ice	|<s>) = 0
--------------------------
P(</s>	|chocolate) = 1.0
P(beans	|chocolate) = 0
P(<s>	|chocolate) = 0
P(chocolate	|chocolate) = 0
P(like	|chocolate) = 0
P(cream	|chocolate) = 0
P(I	|chocolate) = 0
P(hate	|chocolate) = 0
P(ice	|chocolate) = 0
--------------------------
P(</s>	|like) = 0
P(beans	|like) = 0
P(<s>	|like) = 0
P(chocolate	|like) = 0.5
P(like	|like) = 0
P(cream	|like) = 0
P(I	|like) = 0
P(hate	|like) = 0
P(ice	|like) = 0.5
-----------------

# Train on WikiText-2 dataset and evaluate perplexity
### Evaluating perplexity

Q3: why do we need to calculate log perplexity?

A:

$ PPL(W) = P(w_1, w_2, ... , w_n)^{-\frac{1}{n}}$

$ log(PPL(W)) = -\frac{1}{n}\sum^n_{k=1}log(P(w_k|w_1, w_2, ... , w_{k-1}))$

In [7]:
vocab, train_data = prepare_data('train.txt')
_, valid_data = prepare_data('valid.txt')
_, test_data = prepare_data('test.txt')
print(len(vocab))

33278


In [8]:
lm = NGramLM(3)
lm.train(vocab, train_data)

In [9]:
print(lm.perplexity(valid_data))
print(lm.perplexity(test_data))

449.27961912112
355.09754856529463


# Generating sentence
With the pre-trained NGram language model, we can predict possible next words with given context.

In [10]:
# generate the most probable following words given the context
print(" ".join(valid_data[12]))

# actually the only useful contexts in the Trigram LM is ["where", "they"]
context = "The eggs hatch at night , and the larvae swim to the water surface where they"  

lm.generate_next(context, 3)

The eggs hatch at night , and the larvae swim to the water surface where they drift with the ocean currents , preying on <unk> . This stage involves three <unk> and lasts for 15 – 35 days . After the third moult , the juvenile takes on a form closer to the adult , and adopts a <unk> lifestyle . The juveniles are rarely seen in the wild , and are poorly known , although they are known to be capable of digging extensive burrows . It is estimated that only 1 larva in every 20 @,@ 000 survives to the <unk> phase . When they reach a carapace length of 15 mm ( 0 @.@ 59 in ) , the juveniles leave their burrows and start their adult lives . </s>
The eggs hatch at night , and the larvae swim to the water surface where they were	 P=0.12408759124087591
The eggs hatch at night , and the larvae swim to the water surface where they are	 P=0.08029197080291971
The eggs hatch at night , and the larvae swim to the water surface where they would	 P=0.051094890510948905


In [11]:
# generate the next n words with greedy search
lm.generate_next_n(context, 10)

# This is not a good method in practice,
# because wrong predictions in the early steps would hinder the following predictions
lm.generate_next_n(context, 20)

The eggs hatch at night , and the larvae swim to the water surface where they were not able to get the part of the <unk>
The eggs hatch at night , and the larvae swim to the water surface where they were not able to get the part of the <unk> of the <unk> of the <unk> of the <unk> of


# Effect of N

Q4: Why does the perplexity increase when N is large?

A:

In [12]:
for n in range(1,6):
    lm = NGramLM(n)
    lm.train(vocab, train_data)
    print("************************")
    print("{}-gram LM perplexity on valid set: {}".format(n, lm.perplexity(valid_data)))
    print("{}-gram LM perplexity on test  set: {}".format(n, lm.perplexity(test_data)))

************************
1-gram LM perplexity on valid set: 727.3800286163042
1-gram LM perplexity on test  set: 661.8109371098135
************************
2-gram LM perplexity on valid set: 141.41008166795802
2-gram LM perplexity on test  set: 121.75625147213104
************************
3-gram LM perplexity on valid set: 449.27961912112
3-gram LM perplexity on test  set: 355.09754856529463
************************
4-gram LM perplexity on valid set: 1251.3004060217763
4-gram LM perplexity on test  set: 968.4259408279687
************************
5-gram LM perplexity on valid set: 1754.5532742330947
5-gram LM perplexity on test  set: 1344.5022350288755


# Interpolation
Adding interpolation would utilize different ngram context, which would alleviate the problem when N grows large.

In [13]:
class InterpolateNGramLM(NGramLM):
    
    def __init__(self, N):
        super(InterpolateNGramLM, self).__init__(N)
        self.ngram_lms = []
        self.lambdas = []
        
    def train(self, vocab, data, smoothing_k=0, lambdas=[]):
        assert len(lambdas) == self.N
        assert sum(lambdas) - 1 < 1e-9
        self.vocab = vocab
        self.lambdas = lambdas
        
        for i in range(self.N, 0, -1):
            lm = NGramLM(i)
            print("Training {}-gram language model".format(i))
            lm.train(vocab, data, smoothing_k)
            self.ngram_lms.append(lm)
    
    def get_ngram_logprob(self, word, seq_len, context):
        prob = 0
        for i, (coef, lm) in enumerate(zip(self.lambdas, self.ngram_lms)):
            context_words = context.split()
            cutted_context = " ".join(context_words[-self.N + i + 1:])
            prob += coef * lm.get_ngram_prob(context=cutted_context, word=word)
        return math.log(prob) / seq_len
        

In [14]:
##################################################
# Q5: tune your coefficients to decrease perplexity
##################################################
ilm = InterpolateNGramLM(3)
ilm.train(vocab, train_data, lambdas=["coefficients here"])

Training 3-gram language model
Training 2-gram language model
Training 1-gram language model


In [15]:
print(ilm.perplexity(valid_data))
print(ilm.perplexity(test_data))

140.07040126740415
111.56423076848951
