In [1]:
#! SETUP 1
import sys, os
_snlp_book_dir = "../../../../"
sys.path.append(_snlp_book_dir) 
import statnlpbook.lm as lm
import statnlpbook.ohhla as ohhla
import math
import numpy as np
import matplotlib.pyplot as plt
import collections

In [2]:
#! SETUP 2
_snlp_train_dir = _snlp_book_dir + "/data/ohhla/train"
_snlp_dev_dir = _snlp_book_dir + "/data/ohhla/dev"
_snlp_train_song_words = ohhla.words(ohhla.load_all_songs(_snlp_train_dir))
_snlp_dev_song_words = ohhla.words(ohhla.load_all_songs(_snlp_dev_dir))
assert(len(_snlp_train_song_words)==1041496)

Could not load ../../../..//data/ohhla/train/www.ohhla.com/anonymous/nas/distant/tribal.nas.txt.html


In [3]:
class JelinekMercerLM(lm.LanguageModel):
    def __init__(self, unigram, alpha1, bigram, alpha2, trigram, alpha3, fourgram):
        super().__init__(unigram.vocab, fourgram.order)
        self.unigram = unigram
        self.bigram = bigram
        self.trigram  = trigram
        self.fourgram = fourgram
        self.alpha1 = alpha1
        self.alpha2 = alpha2
        self.alpha3 = alpha3
        self.alpha4 = 1 - (alpha1 + alpha2 + alpha3)

    def probability(self, word, *history):

        return self.alpha1 * self.unigram.probability(word, *history) + \
               self.alpha2 * self.bigram.probability(word, * history) + \
               self.alpha3 * self.trigram.probability(word, *history) + \
               self.alpha4 * self.fourgram.probability(word, *history) 

In [20]:
oov_train = lm.inject_OOVs(_snlp_train_song_words)
bigram = lm.NGramLM(oov_train, 2)
unigram = lm.NGramLM(oov_train,1)
trigram = lm.NGramLM(oov_train,3)
fourgram = lm.NGramLM(oov_train,4)
oov_vocab = set(oov_train)
vocab = set(_snlp_train_song_words) | set(_snlp_dev_song_words)
my_LM = JelinekMercerLM(unigram, 0.24, bigram, 0.56, trigram, 0.18, fourgram)


In [23]:
lm.perplexity(lm.OOVAwareLM(my_LM, vocab - oov_vocab), _snlp_dev_song_words)

159.74151561586498

In [49]:
alpha1 = np.linspace(0.2,0.26,5)
alpha2 = np.linspace(0.43,0.53,5)
alpha3 = np.linspace(0.1,0.2,5)

perplexity_dic = collections.defaultdict(float)

In [50]:
print(alpha1)

[ 0.2    0.215  0.23   0.245  0.26 ]


In [51]:
for i in alpha1:
    for j in alpha2:
        for k in alpha3:
            print(i,j,k)
            my_LM = JelinekMercerLM(unigram, i, bigram, j, trigram, k, fourgram)
            ppl = lm.perplexity(lm.OOVAwareLM(my_LM,vocab - oov_vocab), _snlp_dev_song_words)
            if ppl is not "inf":
                perplexity_dic[ppl] = [i,j,k]
            print(ppl)

0.2 0.43 0.1
161.8854736049674
0.2 0.43 0.125
160.64619110468198
0.2 0.43 0.15
159.68641222595102
0.2 0.43 0.175
158.97675506774817
0.2 0.43 0.2
158.50572938867484
0.2 0.455 0.1
160.2589846303951
0.2 0.455 0.125
159.1763346464395
0.2 0.455 0.15
158.37555028238447
0.2 0.455 0.175
157.83433759671888
0.2 0.455 0.2
157.55003962432104
0.2 0.48 0.1
158.91851965583962
0.2 0.48 0.125
157.99908827252867
0.2 0.48 0.15
157.3699789581885
0.2 0.48 0.175
157.0182650625239
0.2 0.48 0.2
156.95453335175307
0.2 0.505 0.1
157.86590021474728
0.2 0.505 0.125
157.1224617868283
0.2 0.505 0.15
156.68641242253528
0.2 0.505 0.175
156.55862422407012
0.2 0.505 0.2
156.7719824401184
0.2 0.53 0.1
157.1124723925157
0.2 0.53 0.125
156.56670987428322
0.2 0.53 0.15
156.35877586514573
0.2 0.53 0.175
156.51251016573661
0.2 0.53 0.2
157.10407057588554
0.215 0.43 0.1
160.89832628567672
0.215 0.43 0.125
159.71885570504773
0.215 0.43 0.15
158.82305167792913
0.215 0.43 0.175
158.1845676078505
0.215 0.43 0.2
157.79600913247825

In [28]:
import heapq

In [52]:
max_keys = heapq.nsmallest(5, perplexity_dic)
max_keys

[155.90276234926196,
 155.9185515790749,
 155.93262939615101,
 155.94298534071186,
 155.9702814571645]

In [53]:
for key in max_keys:
    print(perplexity_dic[key])

[0.245, 0.505, 0.15000000000000002]
[0.23000000000000001, 0.53000000000000003, 0.125]
[0.245, 0.53000000000000003, 0.125]
[0.23000000000000001, 0.505, 0.15000000000000002]
[0.245, 0.505, 0.125]
