In [7]:
from keras.models import model_from_json
import nltk
import gensim
import numpy as np
from ipywidgets import widgets 

### Parameters

In [17]:
MODEL_NAME = 'nn_3_in_1_lstm_1_out.json'

#CORPORA_NAME = 'test.txt.corpora.dat'
#MODEL_WEIGHT = 'test.txt.nn_3_in_1_lstm_1_out.weights.hdf5'

CORPORA_NAME = 'en_US.blogs.txt.corpora.dat'
#MODEL_WEIGHT = 'en_US.blogs.txt.nn_3_in_1_lstm_1_out.3.1000000.weights.hdf5'
MODEL_WEIGHT = 'en_US.blogs.txt.nn_3_in_1_lstm_1_out.4.5000000.weights.hdf5'

IN_SEQ_LENGTH = 3
OUT_SEQ_LENGTH = 1

### Load corpora

In [18]:
META_UNKNOWN = '<<<!UNK!>>>'
META_EMPTY = '<<<!EMP!>>>'
META_NUMBER = '<<<!NUM!>>>'

In [19]:
corpora = gensim.corpora.Dictionary.load('./data/'+CORPORA_NAME)
vocab_size = len(corpora)
print('Number of words in corpora: %d'%(vocab_size))
tmp = list(corpora.items())
del(tmp)

Number of words in corpora: 20002


### Load model and weights

In [20]:
with open('./data/' + MODEL_NAME, mode='r') as f:
    model = model_from_json(f.read())

In [21]:
model.load_weights('./data/'  + MODEL_WEIGHT)

In [22]:
model.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
embedding_2 (Embedding)          (None, 3, 64)         1280128     embedding_input_4[0][0]          
____________________________________________________________________________________________________
lstm_2 (LSTM)                    (None, 1024)          4460544     embedding_2[0][0]                
____________________________________________________________________________________________________
dense_2 (Dense)                  (None, 14)            20502050    lstm_2[0][0]                     
Total params: 26242722
____________________________________________________________________________________________________


### Test model

In [23]:
def get_last_token_ids(inp, seq_len):
    
    res = np.full(seq_len, corpora.token2id[META_EMPTY], dtype=int)
      
    raw_sents = nltk.tokenize.sent_tokenize(inp.lower())

    id_tokens = []
    
    for s in raw_sents:
        raw_tokens = nltk.tokenize.wordpunct_tokenize(s)
       
        for t in raw_tokens:
            try:
                tid = corpora.token2id[t]
            except:
                tid = corpora.token2id[META_UNKNOWN]
            id_tokens.append(tid)
            #print('%s -> %d'%(t, tid))

    l = min(seq_len, len(id_tokens))
    s = seq_len - l
    #print(id_tokens)
    #print(id_tokens[-l:])
    
    res[s:] = id_tokens[-l:]
       
    return res

In [24]:
def get_next_word(inp):
    tids = get_last_token_ids(inp, IN_SEQ_LENGTH)
    arrs = np.array(tids)[np.newaxis,:]
    p = model.predict(arrs)[0]
    #print(p)
    
    m0 = np.argmax(p); p0 = p[m0]; p[m0] = 0
    m1 = np.argmax(p); p1 = p[m1]; p[m1] = 0
    m2 = np.argmax(p); p2 = p[m2]; p[m2] = 0
    m3 = np.argmax(p); p3 = p[m3]; p[m3] = 0
    m4 = np.argmax(p); p4 = p[m4]; p[m4] = 0
    m5 = np.argmax(p); p5 = p[m5]; p[m5] = 0
    m6 = np.argmax(p); p6 = p[m6]; p[m6] = 0
    
    return [(p0, corpora.id2token[m0]),
            (p1, corpora.id2token[m1]),
            (p2, corpora.id2token[m2]),
            (p3, corpora.id2token[m3]),
            (p4, corpora.id2token[m4]),
            (p5, corpora.id2token[m5]),
            (p6, corpora.id2token[m6])]

In [25]:
next_words_beam = []

def get_next_words_requesive(inp, seq_len, prob=1.0):
        
    if seq_len == 0 :
        return
    
    pred = get_next_word(inp)
    
    for p in pred:
        if len(inp) > 0: 
            s = inp + ' ' + p[1]
        else:
            s = p[1]
        #print(s)
        if seq_len > 1 :
            get_next_words_requesive(s, seq_len-1, prob*p[0])
        else:
            #print( '%s %f'%(s, prob*p[1]) )
            next_words_beam.append( (prob*p[0], s) )

In [26]:
def get_next_words(inp, words_to_predict, top=10):
    
    next_words_beam.clear()
    
    get_next_words_requesive(inp, words_to_predict)
    
    next_words_beam.sort(key=lambda x: x[0], reverse=True)
    
    return next_words_beam[:top]

### Test cases

In [27]:
get_next_word('')

[(0.1350736, 'i'),
 (0.071712047, 'the'),
 (0.04414444, 'it'),
 (0.028704248, 'and'),
 (0.023586264, 'this'),
 (0.023142431, 'but'),
 (0.023014085, 'we')]

In [28]:
get_next_words('Hi', 1)

[(0.2609286904335022, 'Hi everyone'),
 (0.13318325579166412, 'Hi there'),
 (0.080077745020389557, 'Hi ,'),
 (0.046894118189811707, 'Hi everybody'),
 (0.025618864223361015, 'Hi all'),
 (0.025262799113988876, 'Hi friends'),
 (0.019602682441473007, 'Hi girls')]

In [29]:
get_next_words('', 2)

[(0.012147098318613381, "i '"),
 (0.0099463143717977154, 'i ’'),
 (0.0088157689885602353, 'it was'),
 (0.0079136054076203166, 'i was'),
 (0.0078452155595398199, "it '"),
 (0.007791965758301822, 'this is'),
 (0.0075279102646594875, 'it is'),
 (0.0069070183151729303, 'i have'),
 (0.0065189287421396624, 'i am'),
 (0.006335929678655261, 'it ’')]

In [30]:
get_next_words('What do', 2)

[(0.062955060069654678, 'What do you have'),
 (0.046869745883043379, 'What do you think'),
 (0.04005599855683073, 'What do you really'),
 (0.035279480308652467, 'What do you know'),
 (0.034403582593434123, 'What do you want'),
 (0.033588090571583962, 'What do you see'),
 (0.026183887447594145, 'What do you can'),
 (0.010686122208101445, 'What do i think'),
 (0.0095483829820464106, 'What do i really'),
 (0.0079956908612359157, 'What do they really')]