In [177]:
from keras.models import model_from_json
import nltk
import gensim
import numpy as np
from ipywidgets import widgets 

### Parameters

In [159]:
MODEL_NAME = 'nn_3_in_1_lstm_1_out.json'

#CORPORA_NAME = 'test.txt.corpora.dat'
#MODEL_WEIGHT = 'test.txt.nn_3_in_1_lstm_1_out.weights.hdf5'

CORPORA_NAME = 'en_US.blogs.txt.corpora.dat'
MODEL_WEIGHT = 'en_US.blogs.txt.nn_3_in_1_lstm_1_out.3.1000000.weights.hdf5'

IN_SEQ_LENGTH = 3
OUT_SEQ_LENGTH = 1

### Load corpora

In [160]:
META_UNKNOWN = '<<<!UNK!>>>'
META_EMPTY = '<<<!EMP!>>>'
META_NUMBER = '<<<!NUM!>>>'

In [161]:
corpora = gensim.corpora.Dictionary.load('./data/'+CORPORA_NAME)
vocab_size = len(corpora)
print('Number of words in corpora: %d'%(vocab_size))
tmp = list(corpora.items())
del(tmp)

Number of words in corpora: 20002


### Load model and weights

In [162]:
with open('./data/' + MODEL_NAME, mode='r') as f:
    model = model_from_json(f.read())

In [163]:
model.load_weights('./data/'  + MODEL_WEIGHT)

In [164]:
model.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
embedding_3 (Embedding)          (None, 3, 64)         1280128     embedding_input_7[0][0]          
____________________________________________________________________________________________________
lstm_4 (LSTM)                    (None, 1024)          4460544     embedding_3[0][0]                
____________________________________________________________________________________________________
dense_3 (Dense)                  (None, 14)            20502050    lstm_4[0][0]                     
Total params: 26242722
____________________________________________________________________________________________________


### Test model

In [165]:
def get_last_token_ids(inp, seq_len):
    
    res = np.full(seq_len, corpora.token2id[META_EMPTY], dtype=int)
      
    raw_sents = nltk.tokenize.sent_tokenize(inp.lower())

    id_tokens = []
    
    for s in raw_sents:
        raw_tokens = nltk.tokenize.wordpunct_tokenize(s)
       
        for t in raw_tokens:
            try:
                tid = corpora.token2id[t]
            except:
                tid = corpora.token2id[META_UNKNOWN]
            id_tokens.append(tid)
            #print('%s -> %d'%(t, tid))

    l = min(seq_len, len(id_tokens))
    s = seq_len - l
    #print(id_tokens)
    #print(id_tokens[-l:])
    
    res[s:] = id_tokens[-l:]
       
    return res

In [166]:
def get_next_word(inp):
    tids = get_last_token_ids(inp, IN_SEQ_LENGTH)
    arrs = np.array(tids)[np.newaxis,:]
    p = model.predict(arrs)[0]
    #print(p)
    
    m0 = np.argmax(p); p0 = p[m0]; p[m0] = 0
    m1 = np.argmax(p); p1 = p[m1]; p[m1] = 0
    m2 = np.argmax(p); p2 = p[m2]; p[m2] = 0
    m3 = np.argmax(p); p3 = p[m3]; p[m3] = 0
    m4 = np.argmax(p); p4 = p[m4]; p[m4] = 0
    m5 = np.argmax(p); p5 = p[m5]; p[m5] = 0
    m6 = np.argmax(p); p6 = p[m6]; p[m6] = 0
    
    return [(p0, corpora.id2token[m0]),
            (p1, corpora.id2token[m1]),
            (p2, corpora.id2token[m2]),
            (p3, corpora.id2token[m3]),
            (p4, corpora.id2token[m4]),
            (p5, corpora.id2token[m5]),
            (p6, corpora.id2token[m6])]

In [167]:
next_words_beam = []

def get_next_words_requesive(inp, seq_len, prob=1.0):
        
    if seq_len == 0 :
        return
    
    pred = get_next_word(inp)
    
    for p in pred:
        if len(inp) > 0: 
            s = inp + ' ' + p[1]
        else:
            s = p[1]
        #print(s)
        if seq_len > 1 :
            get_next_words_requesive(s, seq_len-1, prob*p[0])
        else:
            #print( '%s %f'%(s, prob*p[1]) )
            next_words_beam.append( (prob*p[0], s) )

In [168]:
def get_next_words(inp, words_to_predict, top=10):
    
    next_words_beam.clear()
    
    get_next_words_requesive(inp, words_to_predict)
    
    next_words_beam.sort(key=lambda x: x[0], reverse=True)
    
    return next_words_beam[:top]

### Test cases

In [169]:
get_next_word('')

[(0.13940275, 'i'),
 (0.071074374, 'the'),
 (0.046676371, 'it'),
 (0.031658143, 'and'),
 (0.024488971, 'this'),
 (0.022508644, 'but'),
 (0.021897659, 'we')]

In [181]:
get_next_words('Hi', 1)

[(0.12810064852237701, 'Hi everyone'),
 (0.12010793387889862, 'Hi !'),
 (0.094802573323249817, 'Hi ,'),
 (0.054005932062864304, 'Hi it'),
 (0.052542150020599365, 'Hi the'),
 (0.036244608461856842, 'Hi you'),
 (0.023914987221360207, 'Hi for')]

In [171]:
get_next_words('', 2)

[(0.012725758562348588, "i '"),
 (0.0093280714576638446, "it '"),
 (0.0092455425771342803, 'i have'),
 (0.0081878243198056921, 'this is'),
 (0.0081202295937043978, 'i ’'),
 (0.007741433969887801, 'i am'),
 (0.006880082826056122, 'it ’'),
 (0.0068554046504845645, 'it is'),
 (0.0064491832820475214, 'it was'),
 (0.0056063334160483258, 'i don')]

In [175]:
get_next_words('What do', 5)

[(0.00043809224046883355, 'What do you know that ’ s'),
 (0.00032913239100016757, "What do you know that ' s"),
 (0.00019197904706720826, 'What do you think that ’ s'),
 (0.00018286607957294815, 'What do you want to know that'),
 (0.00014402114180568702, 'What do you know of course ,'),
 (0.00013085715655169867, 'What do you think it ’ s'),
 (0.00011256596185976002, 'What do i know that ’ s'),
 (0.00010323181508792545, "What do you think that ' s"),
 (8.1762451390438693e-05, "What do you think it ' s"),
 (7.6961083522705296e-05, 'What do you have to seen the')]