In [1]:
from collections import defaultdict
import numpy as np

In [2]:
train_path = 'wiki-en-train.word'
test_path = 'wiki-en-test.word'

train_sents = []
with open(train_path) as f:
    for row in f.readlines():
        train_sents.append(row.strip())

test_sents = []
with open(test_path) as f:
    for row in f.readlines():
        test_sents.append(row.strip())

In [3]:
class NGramLM:
    def __init__(self, n):
        assert isinstance(n, int), 'n must be integer'
        assert n > 1, 'n must be larger than 1'
        self.n = n
        self.count_ngram = defaultdict(int)
        self.count_next = defaultdict(lambda: defaultdict(int))
        
        self.num_ngram = 0
        
    def train(self, sents):
        ## count frequency
        for sent in sents:
            
            # simple tokenization
            tokens = ['<s>'] * (self.n-1) + sent.split(' ') + ['</s>'] * (self.n-1)

            for i in range(len(tokens)-self.n+1):
                ngram = tuple(tokens[i:i+self.n])
                h = ngram[:-1]  # history
                w = ngram[-1]  # next word
                
                # count n-gram frequency
                self.count_ngram[ngram] += 1

                # count history and the next word frequency
                self.count_next[h][w] += 1
                
        self.num_ngram = sum([len(self.count_next[key]) for key in self.count_next.keys()])
        
    def calc_prob_ngram(self, h, w, d=0.75):
        h_next = len(self.count_next[h])
        if h_next == 0:
            return 0
        else:
            # kneser-ney smoothing
            h_total = sum(self.count_next[h].values())
            lmb = d * h_next / h_total
            
            p_cont = len([key for key in self.count_next.keys() if w in self.count_next[key]]) / self.num_ngram
            p_kn = max(self.count_ngram[(h + (w, ))]-d, 0) / h_next + lmb * p_cont
            
            return p_kn
        
    def calc_mle(self, sent):
        # simple tokenization
        tokens = tuple(['<s>'] * (self.n-1) + sent.split(' ') + ['</s>'] * (self.n-1))
        
        prob = 0.
        for i in range(len(tokens)-self.n+1):            
            p = self.calc_prob_ngram(tokens[i:i+self.n-1], tokens[(i+self.n-1)%(len(tokens))])
            if p != 0:
                prob += -np.log(p) / len(sent.split())

        return prob
    
    def calc_mle_corpus(self, sents):
        prob = 0.
        for sent in sents:
            prob += self.calc_mle(sent)

        return prob / len(sents)

In [4]:
for n in [2, 3, 4, 5]:
    ng = NGramLM(n)
    ng.train(train_sents)
    print('n : {}, entropy : {}'.format(n, ng.calc_mle_corpus(test_sents)))

n : 2, entropy : 3.7080974631812897
n : 3, entropy : 1.8182439265814083
n : 4, entropy : 0.428697616506698
n : 5, entropy : 0.1205694109177372
