In [8]:
from keras.models import model_from_json
import nltk
import gensim
import numpy as np
from ipywidgets import widgets 

Using Theano backend.


### Parameters

In [9]:
MODEL_NAME = 'nn_3_in_1_lstm_1_out.json'

#CORPORA_NAME = 'test.txt.corpora.dat'
#MODEL_WEIGHT = 'test.txt.nn_3_in_1_lstm_1_out.weights.hdf5'

CORPORA_NAME = 'en_US.blogs.txt.corpora.dat'
MODEL_WEIGHT_1M = 'en_US.blogs.txt.nn_3_in_1_lstm_1_out.3.1000000.weights.hdf5' # 1M training samples
MODEL_WEIGHT_5M = 'en_US.blogs.txt.nn_3_in_1_lstm_1_out.4.5000000.weights.hdf5'  # 5M training samples
MODEL_WEIGHT_10M = 'en_US.blogs.txt.nn_3_in_1_lstm_1_out.weights-epoch-03-loss-4.70-acc-0.20.hdf5' # 10M training samples
MODEL_WEIGHT_20M = 'en_US.blogs.txt.nn_3_in_1_lstm_1_out.start-10000000.weights-epoch-03-loss-4.60-acc-0.22.hdf5' # 20M training samples

IN_SEQ_LENGTH = 3
OUT_SEQ_LENGTH = 1

### Load corpora

In [10]:
META_UNKNOWN = '<<<!UNK!>>>'
META_EMPTY = '<<<!EMP!>>>'
META_NUMBER = '<<<!NUM!>>>'

In [11]:
corpora = gensim.corpora.Dictionary.load('./data/'+CORPORA_NAME)
vocab_size = len(corpora)
print('Number of words in corpora: %d'%(vocab_size))
tmp = list(corpora.items())
del(tmp)

Number of words in corpora: 20002


### Load model and weights

In [20]:
model = None

def load_model(weights):
    global model
    with open('./data/' + MODEL_NAME, mode='r') as f:
        model = model_from_json(f.read())
    model.load_weights('./data/'  + weights)

### Test model

In [13]:
def get_last_token_ids(inp, seq_len):
    
    res = np.full(seq_len, corpora.token2id[META_EMPTY], dtype=int)
      
    raw_sents = nltk.tokenize.sent_tokenize(inp.lower())

    id_tokens = []
    
    for s in raw_sents:
        raw_tokens = nltk.tokenize.wordpunct_tokenize(s)
       
        for t in raw_tokens:
            try:
                tid = corpora.token2id[t]
            except:
                tid = corpora.token2id[META_UNKNOWN]
            id_tokens.append(tid)
            #print('%s -> %d'%(t, tid))

    l = min(seq_len, len(id_tokens))
    s = seq_len - l
    #print(id_tokens)
    #print(id_tokens[-l:])
    
    res[s:] = id_tokens[-l:]
       
    return res

In [14]:
def get_next_word(inp):
    tids = get_last_token_ids(inp, IN_SEQ_LENGTH)
    arrs = np.array(tids)[np.newaxis,:]
    p = model.predict(arrs)[0]
    #print(p)
    
    m0 = np.argmax(p); p0 = p[m0]; p[m0] = 0
    m1 = np.argmax(p); p1 = p[m1]; p[m1] = 0
    m2 = np.argmax(p); p2 = p[m2]; p[m2] = 0
    m3 = np.argmax(p); p3 = p[m3]; p[m3] = 0
    m4 = np.argmax(p); p4 = p[m4]; p[m4] = 0
    m5 = np.argmax(p); p5 = p[m5]; p[m5] = 0
    m6 = np.argmax(p); p6 = p[m6]; p[m6] = 0
    
    return [(p0, corpora.id2token[m0]),
            (p1, corpora.id2token[m1]),
            (p2, corpora.id2token[m2]),
            (p3, corpora.id2token[m3]),
            (p4, corpora.id2token[m4]),
            (p5, corpora.id2token[m5]),
            (p6, corpora.id2token[m6])]

In [15]:
next_words_beam = []

def get_next_words_requesive(inp, seq_len, prob=1.0):
        
    if seq_len == 0 :
        return
    
    pred = get_next_word(inp)
    
    for p in pred:
        if len(inp) > 0: 
            s = inp + ' ' + p[1]
        else:
            s = p[1]
        #print(s)
        if seq_len > 1 :
            get_next_words_requesive(s, seq_len-1, prob*p[0])
        else:
            #print( '%s %f'%(s, prob*p[1]) )
            next_words_beam.append( (prob*p[0], s) )

In [16]:
def get_next_words(inp, words_to_predict, top=10):
    
    next_words_beam.clear()
    
    get_next_words_requesive(inp, words_to_predict)
    
    next_words_beam.sort(key=lambda x: x[0], reverse=True)
    
    return next_words_beam[:top]

### Test cases

In [29]:
def print_next_words(inp, words_to_predict, top=10):
    var = get_next_words(inp, words_to_predict, top)
    
    print(inp + ':')
    for v in var:
        print(v)

In [34]:
def run_test_cases(weights):
    load_model(weights)
    print_next_words('', 1)
    print_next_words('', 2)
    print_next_words('Hi', 1)
    print_next_words('What do', 2)
    print_next_words('this week I', 2)

### Run test cases

In [36]:
run_test_cases(MODEL_WEIGHT_1M)

:
(0.13940274715423584, 'i')
(0.07107437402009964, 'the')
(0.046676371246576309, 'it')
(0.031658142805099487, 'and')
(0.024488970637321472, 'this')
(0.022508643567562103, 'but')
(0.021897658705711365, 'we')
:
(0.012725758562348588, "i '")
(0.0093280714576638446, "it '")
(0.0092455425771342803, 'i have')
(0.0081878243198056921, 'this is')
(0.0081202295937043978, 'i ’')
(0.007741433969887801, 'i am')
(0.006880082826056122, 'it ’')
(0.0068554046504845645, 'it is')
(0.0064491832820475214, 'it was')
(0.0056063334160483258, 'i don')
Hi:
(0.12810064852237701, 'Hi everyone')
(0.12010793387889862, 'Hi !')
(0.094802573323249817, 'Hi ,')
(0.054005932062864304, 'Hi it')
(0.052542150020599365, 'Hi the')
(0.036244608461856842, 'Hi you')
(0.023914987221360207, 'Hi for')
What do:
(0.054811323741992268, 'What do you have')
(0.032900612315368782, 'What do you know')
(0.025432437904105765, 'What do you want')
(0.021123467560435927, 'What do you see')
(0.014198034420835781, 'What do i have')
(0.0140031641

In [37]:
run_test_cases(MODEL_WEIGHT_5M)

:
(0.13507360219955444, 'i')
(0.07171204686164856, 'the')
(0.044144440442323685, 'it')
(0.028704248368740082, 'and')
(0.023586263880133629, 'this')
(0.023142430931329727, 'but')
(0.023014085367321968, 'we')
:
(0.012147098318613381, "i '")
(0.0099463143717977154, 'i ’')
(0.0088157689885602353, 'it was')
(0.0079136054076203166, 'i was')
(0.0078452155595398199, "it '")
(0.007791965758301822, 'this is')
(0.0075279102646594875, 'it is')
(0.0069070183151729303, 'i have')
(0.0065189287421396624, 'i am')
(0.006335929678655261, 'it ’')
Hi:
(0.2609286904335022, 'Hi everyone')
(0.13318325579166412, 'Hi there')
(0.080077745020389557, 'Hi ,')
(0.046894118189811707, 'Hi everybody')
(0.025618864223361015, 'Hi all')
(0.025262799113988876, 'Hi friends')
(0.019602682441473007, 'Hi girls')
What do:
(0.062955060069654678, 'What do you have')
(0.046869745883043379, 'What do you think')
(0.04005599855683073, 'What do you really')
(0.035279480308652467, 'What do you know')
(0.034403582593434123, 'What do you

In [38]:
run_test_cases(MODEL_WEIGHT_10M)

:
(0.11459830403327942, 'i')
(0.061969183385372162, 'the')
(0.043100703507661819, 'it')
(0.030616007745265961, 'and')
(0.030332930386066437, 'but')
(0.022078005596995354, 'we')
(0.020217198878526688, 'this')
:
(0.0070295590680283793, 'i have')
(0.0069738208677774249, 'it is')
(0.0068474870980202995, "i '")
(0.006551913649257779, 'it was')
(0.00609389225094914, 'this is')
(0.0060626188695754868, 'i am')
(0.0056759218183881077, 'it ’')
(0.0052342204385173452, "it '")
(0.0051041708036366762, 'i ’')
(0.0049886636686967689, 'but i')
Hi:
(0.15030717849731445, 'Hi everyone')
(0.12912140786647797, 'Hi readers')
(0.082765072584152222, 'Hi there')
(0.062626980245113373, 'Hi ,')
(0.038216222077608109, 'Hi -')
(0.037745315581560135, 'Hi !')
(0.033834852278232574, 'Hi all')
What do:
(0.22502352050447705, 'What do you think')
(0.087862486056573985, 'What do you do')
(0.076162893950597166, 'What do you mean')
(0.035002434447635844, 'What do you want')
(0.028762933582406225, 'What do you guys')
(0.027

In [39]:
run_test_cases(MODEL_WEIGHT_20M)

:
(0.018361009657382965, 'the')
(0.017092021182179451, 'i')
(0.0087438393384218216, 'it')
(0.0054858271032571793, 'and')
(0.0050167012959718704, 'a')
(0.0042165094055235386, 'this')
(0.0041746222414076328, 'in')
:
(0.0012691810995876796, "it '")
(0.0011691764382460168, "i '")
(0.0010671753182412771, 'i was')
(0.0010643443033528377, 'it was')
(0.0010042405745425073, 'in the')
(0.00097995062343296524, 'it is')
(0.00095021885145014917, 'it ’')
(0.0009489200756838817, 'i am')
(0.00089666646798108179, 'i have')
(0.00089408345590946681, 'i ’')
Hi:
(0.29048067331314087, 'Hi everyone')
(0.079488232731819153, 'Hi guys')
(0.064141124486923218, 'Hi readers')
(0.059681512415409088, 'Hi -')
(0.05562620609998703, 'Hi ,')
(0.016783943399786949, 'Hi !')
(0.01605776883661747, 'Hi reverend')
What do:
(0.31630035661740052, 'What do you think')
(0.040568585313280225, 'What do i know')
(0.036355796617380642, 'What do you know')
(0.031426119270950403, 'What do you do')
(0.031321494042487741, 'What do i do')