In [1]:
import json
import string
import numpy as np

import tensorflow_addons as tfa
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences

In [2]:
#upload dictionaries

with open('dict_w2i_chatbot.json') as json_file: 
    dict_w2i = json.load(json_file) 

with open('dict_i2w_chatbot.json') as json_file: 
    dict_i2w = json.load(json_file) 

with open('contractions_dict.json') as json_file: 
    contractions_dict = json.load(json_file) 

In [3]:
#define variables
lat_dim=512
vocab_len=len(dict_w2i)-3
batch_size=64
maximum_iterations=10
buck_t1=6

In [4]:
#create models
import chatbot_models_class as cb


encoder_=cb.encoder_model(lat_dim=lat_dim,vocab_len=vocab_len)
decoder_=cb.decoder_model(lat_dim=lat_dim,batch_size=batch_size,vocab_len=vocab_len,buck_t1=buck_t1)



In [5]:
#load weights

encoder_.load_weights("models_Att/encoder_att_weights_inv_input")
decoder_.load_weights("models_Att/decoder_att_weights_inv_input")


<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f85efdc8e80>

In [6]:
#preprocessing input string
def exp_remPunt(sent):
    '''
    clean the input sentence
    - set to lower
    - remove contract form
    - remove punctuation
    '''
    clean_sent=[]
    table = str.maketrans(dict.fromkeys(string.punctuation))
    remove_digits = str.maketrans('', '', string.digits)
    sent=sent.lower()
    
    for word in sent.split():
        if word in contractions_dict:
            sent = sent.replace(word, contractions_dict[word])  #expand

    sent = sent.translate(remove_digits)
    clean_sent.append(sent.translate(table).lower()) #remove punt and set to lower
    
    return clean_sent

In [7]:
def string_to_seq(l_sent):
    '''
    input list string 
    '''
    emb_str=[]
    for l in l_sent:
        for word in l.split():
            if (word in dict_w2i):
                emb_str.append(dict_w2i[word])

   
    emb_str = pad_sequences([emb_str],buck_t1, padding='post')
    #return np.flip(emb_str)  #reverse and return
    return emb_str

In [8]:
def chat_loop(sentence): 
    
    sentence=exp_remPunt(sentence)
    sentence=string_to_seq(sentence)
    sent_tens= tf.convert_to_tensor(sentence) 
    encoder_embedding=encoder_.encoder_embedding_layer(sent_tens)
    encoder_output,enc_state_h,enc_state_c =encoder_.encoder(encoder_embedding,
                                                            initial_state=[tf.zeros((1, lat_dim*2)), tf.zeros((1, lat_dim*2))])

    dec_infere_inp = tf.expand_dims([ dict_w2i['<sos>']],1)
    decoder_embedding=decoder_.decoder_embedding_layer(dec_infere_inp)#!!




    inference_sampler=tfa.seq2seq.sampler.GreedyEmbeddingSampler()#(self.decoder_embedding_layer)

    #inference decoder
    inference_decoder=tfa.seq2seq.BasicDecoder(cell=decoder_.attention_decoder,
                                               sampler=inference_sampler,
                                               output_layer=decoder_.dense_layer,
                                               maximum_iterations=maximum_iterations
                                              )


    #setup memory 
    decoder_.attention_mechanism.setup_memory(encoder_output)


    start_tokens=tf.fill([1], dict_w2i['<sos>'])

    decoder_embedding_matrix = decoder_.decoder_embedding_layer.variables[0] 
    _,inputs_t0,states_t0=inference_decoder.initialize(decoder_embedding_matrix,
                                       initial_state=decoder_.get_initial_states(batch_size=1,
                                                                                 enc_state_h=enc_state_h,
                                                                                 enc_state_c=enc_state_c),
                                       start_tokens=start_tokens,
                                       end_token=dict_w2i['<eos>'],
                                       )

    #init "inference iterators"
    input_iterator=inputs_t0
    state_iterator=states_t0  
    predictions=np.empty((1,0),dtype=np.int32) 

    #prediction loop
    stop_condition=False
    iteration_counter=0
    while not stop_condition:

        outputs,states_tn, inputs_tn, finished = inference_decoder.step(iteration_counter,input_iterator,state_iterator)
        iteration_counter+=1

        input_iterator = inputs_tn
        state_iterator = states_tn
        predictions = np.append(predictions, outputs.sample_id)

        if iteration_counter>=maximum_iterations:
            stop_condition=True

        if outputs.sample_id==dict_w2i['<eos>']:
            stop_condition=True
            predictions=predictions[:-1]
            
    bot_s="BOT:"
    for i in predictions:
        bot_s=bot_s+" " +dict_i2w[str(i)]
    print(bot_s)

In [9]:
print("Type \"quitbot\" to exit. Let\'s start! \n \n")
stop_condition=False

while not stop_condition:
    sentence = input("YOU: ")
    if sentence=='quitbot':
        print("BOT: good bye")
        stop_condition=True
        continue
    chat_loop(sentence)

Type "quitbot" to exit. Let's start! 
 

YOU: hello there
BOT: hello
YOU: how are you?
BOT: i am fine
YOU: tell me something
BOT: i love you
YOU: good night
BOT: good night
YOU: quitbot
BOT: good bye
