In [1]:
import pickle
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [2]:
infile = open('/mnt/d/gru_without_lda/word_index','rb')
word_index = pickle.load(infile)
infile.close()
index_to_word = { index:word for word,index in word_index.items()}

model = load_model('/mnt/d/gru_without_lda/model.h5')

In [3]:
def decode_sequence_beam_search(input_seq,beam_width=3,MAX_SEQ_LENGTH=20,vocab_size=20000):
  index_to_word[0] = ""
  target_seq = np.zeros((beam_width,MAX_SEQ_LENGTH))#to store target sequences at different time steps
  is_complete = []#if the given sequence has been ended by "aakhir" token
  fed_seq = []#sequence to be fed to predict output
  scores = np.zeros((beam_width,))#beam scores
  for i in range(beam_width):
    target_seq[i,0] = word_index["awwal"]
    fed_seq.append([input_seq,target_seq[i:i+1]])
    is_complete.append(False)

  for i in range(MAX_SEQ_LENGTH-1):
    all_candidates = []#all possible beam_width*beam_width candidates at a given time step
    all_scores = []#their scores
    all_bool = []#if the given seq is complete
    for j in range(beam_width):
      if is_complete[j]: #if the seq is complete 
        all_candidates.append(np.array(target_seq[j]))
        all_scores+=[scores[j]]
        all_bool+=[is_complete[j]]
        continue
      output = model.predict(fed_seq[j])[i]

      tmp_scores = (scores[j]*((i+1)**0.7)+np.log(output).reshape(vocab_size,))/((i+2)**0.7)  #length_normalization**0.7 scores[j]+np.log(output).reshape(vocab_size,)
      indexes = np.argsort(-1*tmp_scores)[:beam_width]#top k candidates
      scores_of_top = np.array([tmp_scores[i] for i in indexes])
      tmp_bool = [index_to_word[i]=="aakhir" for i in indexes]

      for index in indexes:
        new_candidate = np.array(target_seq[j])
        new_candidate[i+1] = index

        all_candidates.append(new_candidate)
      
      all_scores+=list(scores_of_top)
      all_bool+=tmp_bool

    if i==0:
      sorted_scores_indexes = np.arange(0,beam_width) #first k candidates bcoz symmetry imposed by awwal
    else:
      sorted_scores_indexes = np.argsort(-1*np.array(all_scores))[:beam_width] #top k candidates from all beams

    target_seq = [all_candidates[i] for i in sorted_scores_indexes]
    scores = [all_scores[i] for i in sorted_scores_indexes]
    is_complete = [all_bool[i] for i in sorted_scores_indexes]

    fed_seq = [[input_seq,np.array(t).reshape(1,MAX_SEQ_LENGTH)] for t in target_seq]

  return target_seq,scores

In [4]:

tokenizer = Tokenizer(num_words=20000,oov_token="<oov>")
tokenizer.word_index = word_index

while True:
  query = str(input("HUMAN: "))
  if query=='exit':
        break
  else:
        query = tokenizer.texts_to_sequences([query])[0]
        query = pad_sequences([query],maxlen=20,padding='post')
        responses,scores = decode_sequence_beam_search(query,beam_width=5)
        print("BOT:\n")
        for i,r in enumerate(responses):
            s = ' '.join([index_to_word[t] for t in responses[i]])
            print(s,scores[i])



HUMAN: hi!
BOT:

awwal you got a lot more than than than than aakhir          -1.9483557
awwal you got a lot more than than than than than aakhir         -1.983195
awwal you got a lot more than than than than than than than aakhir       -2.0171983
awwal you got a lot more than than than than than than than than aakhir      -2.0424464
awwal you got a lot more than than than than than than aakhir        -2.056576
HUMAN: hello
BOT:

awwal you know that you shouldn't be a real aakhir           -2.3113902
awwal you know that is you you shouldn't be a real aakhir         -2.4650555
awwal you know that shouldn't be a real but you aakhir          -2.5453212
awwal you know that is you you shouldn't be a real but i can do aakhir     -2.821666
awwal you know that is you you shouldn't be a real but all the way you could be your way -3.0350852
HUMAN: hey,how are you
BOT:

awwal all of the situation you know this is you know aakhir         -2.0356996
awwal if you take this back to your account aakhi