In [1]:
import io
from collections import defaultdict
import numpy as np

In [2]:
def load_data(filename):
    fin = io.open(filename, 'r', encoding='utf-8')
    data = []
    vocab = defaultdict(lambda:0)
    for line in fin:
        sentence = line.split()
        data.append(sentence)
        for word in sentence:
            vocab[word] += 1
    return data, vocab

In [3]:
def remove_rare_words(data, vocab, mincount):
    data_with_unk = data[:]
    for i in range(len(data_with_unk)):
        for j in range(len(data_with_unk[i])):
            if vocab[data_with_unk[i][j]] < mincount:
                data_with_unk[i][j] = '<unk>'
    return data_with_unk

In [4]:
print("load training set")
train_data, vocab = load_data("reddit.txt")
train_data = remove_rare_words(train_data, vocab, 5)

print("load validation set")
test_data, _ = load_data("test.ted.txt")
test_data = remove_rare_words(test_data, vocab, 5)

load training set
load validation set


In [5]:
def build_ngram(data, n):
    total_number_words = 0
    counts = defaultdict(lambda: defaultdict(lambda: 0.0))
    assert n >= 1, 'n should be greater than 1'
    for sentence in data:
        sentence = tuple(sentence)
        for gram_size in range(n):
            for idx in range(len(sentence)):
                total_number_words += 1.
                if gram_size+idx < len(sentence):
                    counts[sentence[idx:gram_size+idx]][sentence[idx+gram_size]] += 1.
    total_number_words /= n
    freq  = defaultdict(lambda: defaultdict(lambda: 0.0))
    for context in counts:
        for word in counts[context]:
            freq[context][word] = counts[context][word]/sum(counts[context].values())
    return freq

In [6]:
n = 2
print("build ngram model with n = ", n)
model = build_ngram(train_data, n)

build ngram model with n =  2


In [7]:
model

defaultdict(<function __main__.build_ngram.<locals>.<lambda>()>,
            {(): defaultdict(<function __main__.build_ngram.<locals>.<lambda>.<locals>.<lambda>()>,
                         {'<': 0.0002449208786088494,
                          '--': 0.00021206563879546718,
                          '-': 0.0033074274745471454,
                          '<unk>': 0.05138459945480214,
                          '(': 0.00561426355477765,
                          'ツ': 1.1947359932138996e-05,
                          ')': 0.006067267618871253,
                          '_/¯': 8.960519949104246e-06,
                          '/': 0.00022998667869367567,
                          '..': 0.0003594164112918481,
                          ';': 0.0005475873302230373,
                          ':': 0.004481255587879801,
                          '’': 0.015985567589201975,
                          '!': 0.004939237718611797,
                          '”': 0.0020479766150341594,
                      

In [8]:
def get_prob(model, context, w):
    return model[context][w] if model[context][w] != 0.0 else 0.4*get_prob(model, context[1:],w)

def perplexity(model, data, n):
    perp = 0.0
    for sentence in data:
        sentence = tuple(sentence)
        probs = 0.0
        for idx in range(1,len(sentence)):
            probs += (-1.0/len(sentence))*np.log(get_prob(model, sentence[max(0,idx-n+1):idx], sentence[idx]))
        perp += probs/len(data)
    return np.exp(perp)

In [9]:
print("The perplexity is", perplexity(model, test_data, n))

The perplexity is 114.17708228241057
