<a href="https://colab.research.google.com/github/Arsuh/Seq2Seq-Chatbot/blob/master/Seq2SeqV2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [0]:
!git clone https://github.com/Arsuh/Seq2Seq-Chatbot

In [0]:
from __future__ import absolute_import, division, print_function
%tensorflow_version 2.x
import tensorflow as tf

from google.oauth2 import service_account
from google.cloud import bigquery

import matplotlib.pyplot as plt
import numpy as np
import json
import re
import os
import time
import random
import shutil

In [0]:
%cd Seq2Seq-Chatbot/
from Vocabulary import Vocabulary
from MainModel import loss_fnc
from helper import *
from evaluate import evaluate

drive_main_path = '/content/drive/My Drive/Colab Files/Chatbot/'
main_path = '/content/Seq2Seq-Chatbot/'
hparams_path = main_path + 'hyper_parameters_std.json'
#hparams_path = main_path + 'hyper_parameters_test.json'
ckpt_path = drive_main_path + 'ckeckpoints/'
ckpt_prefix = os.path.join(ckpt_path, 'ckpt')

In [0]:
test_sentences = ['Hello!',
                  'How are you?',
                  'Tomorow is my birthday, but I keep feeling sad...',
                  'What is your name sir?',
                  'Artificial intelligence will take over the world some day!',
                  'Can you please bring me some water?',
                  'Come on! This is the easiest thing you are supposed to do!',
                  'My name is Thomas!']

def train_step(hparams, inp, tar, enc_h1, enc_h2):
    global enc, dec, opt
    loss = 0
    with tf.GradientTape() as tape:
        enc_out, enc_h1, enc_h2 = enc(inp, enc_h1, enc_h2)
        dec_h1, dec_h2 = enc_h1, enc_h2
        dec_inp = tf.expand_dims([1]*hparams['BATCH_SIZE'], 1)

        for t in range(1, tar.shape[1]):
            pred, dec_h1, dec_h2, _ = dec(dec_inp, enc_out, dec_h1, dec_h2)

            loss += loss_fnc(tar[:, t], pred)
            dec_inp = tf.expand_dims(tar[:, t], 1)

    batch_loss = (loss/int(tar.shape[1]))
    variables = enc.trainable_variables + dec.trainable_variables
    gradients = tape.gradient(loss, variables)
    opt.apply_gradients(zip(gradients, variables))

    return batch_loss

def train(hparams, credentials, print_step, offset=0, initial_epoch=1, saving=True, checkpoint_prefix=ckpt_prefix, verbose=True):
    global enc, dec, opt
    start = time.time()
    v = Vocabulary(max_len=hparams['MAX_LEN'])
    v.load_bigquery_vocab_from_indexed(credentials, hparams['VOCAB_DB'], hparams['VOCAB'], verbose)
    v.create_inputs_from_indexed(credentials,
                                 offset=offset,
                                 limit_main=hparams['NUM_EXAMPLES'],
                                 verbose=True)  # <--- False

    if verbose: print('Vocabulary created!')
    dataset = create_dataset(v, hparams['BATCH_SIZE'], hparams['NUM_EXAMPLES'])
    if verbose: print('Time to initialize model {:.2f} min | {:.2f} hrs\n'.format((time.time()-start)/60, (time.time()-start)/3600))
    del start

    if hparams['NUM_EXAMPLES'] == None:
        N_BATCH = hparams['MAX_EXAMPLES'] // hparams['BATCH_SIZE']
    else:
        N_BATCH = hparams['NUM_EXAMPLES'] // hparams['BATCH_SIZE']

    if saving: checkpoint = tf.train.Checkpoint(optimizer=opt, encoder=enc, decoder=dec)

    plt_loss = []
    for epoch in range(initial_epoch, hparams['EPOCHS']+1):
        epoch_time = time.time()
        h1, h2 = enc.initialize_hidden()

        total_loss = 0
        for (batch, (inp, tar)) in enumerate(dataset.take(N_BATCH)):
            batch_time = time.time()
            batch_loss = train_step(hparams, inp, tar, h1, h2)
            total_loss += batch_loss

            if batch % print_step == 0 or batch == 0:
              print('  >>> Epoch: {} | Batch: {}\\{} | Loss: {:.4f} | Time: {:.2f} sec'
                  .format(epoch, batch+1, N_BATCH, batch_loss, time.time() - batch_time))

        sentences = random.choices(test_sentences, k=2)
        result1, text1, _ = evaluate(sentences[0], v, enc, dec, hparams['MAX_LEN'])
        result2, text2, _ = evaluate(sentences[1], v, enc, dec, hparams['MAX_LEN'])
        print(50*'+')
        print(text1)
        print(result1)
        print(text2)
        print(result2)
        print(50*'+')

        if saving:
            print('Saving model...')
            checkpoint.save(file_prefix=checkpoint_prefix)

        plt_loss.append(total_loss/N_BATCH)
        print('Epoch: {} | Loss: {:.4f} | Time: {:.2f} min'.format(epoch, total_loss/N_BATCH, (time.time()-epoch_time)/60))
    return plt_loss

def multi_initializer_train(hparams, credentials, print_step, initial_epoch=1, saving=True, checkpoint_prefix=ckpt_prefix, verbose=True):
    global enc, dec, opt
    v = Vocabulary(max_len=hparams['MAX_LEN'])
    v.load_bigquery_vocab_from_indexed(credentials, hparams['VOCAB_DB'], hparams['VOCAB'], verbose)
    if verbose: print('Vocabulary created!')

    if hparams['NUM_EXAMPLES'] == None:
        N_BATCH = hparams['MAX_EXAMPLES'] // hparams['BATCH_SIZE']
    else:
        N_BATCH = hparams['NUM_EXAMPLES'] // hparams['BATCH_SIZE']

    if saving: checkpoint = tf.train.Checkpoint(optimizer=opt, encoder=enc, decoder=dec)

    plt_loss = []
    for epoch in range(initial_epoch, hparams['EPOCHS']+1):
        epoch_time = time.time()
        ep_losses = []
        h1, h2 = enc.initialize_hidden()

        reps = int(hparams['MAX_EXAMPLES']//hparams['NUM_EXAMPLES']) if hparams['OFFSET_REP']=='max' else int(hparams['OFFSET_REP'])
        offset = 0
        total_loss = 0
        for rep in range(reps):
            v, dataset = reinitialize_vocab(v, hparams, credentials, offset, verbose=True)

            for (batch, (inp, tar)) in enumerate(dataset.take(N_BATCH)):
                batch_time = time.time()
                batch_loss = train_step(hparams, inp, tar, h1, h2)
                ep_losses.append(batch_loss)
                total_loss += batch_loss

                if batch % print_step == 0 or batch == 0:
                  print('  >>> Epoch: {} | Batch: {}\\{} | Loss: {:.4f} | Time: {:.2f} sec'
                      .format(epoch, batch+1, N_BATCH, batch_loss, time.time() - batch_time))
            
            offset += hparams['NUM_EXAMPLES']
            tf.keras.backend.clear_session()
            print(' -> Rep: {} done!'.format(rep + 1))

        sentences = random.choices(test_sentences, k=2)
        result1, text1, _ = evaluate(sentences[0], v, enc, dec, hparams['MAX_LEN'])
        result2, text2, _ = evaluate(sentences[1], v, enc, dec, hparams['MAX_LEN'])
        print(50*'+')
        print(text1)
        print(result1)
        print(text2)
        print(result2)
        print(50*'+')

        if saving:
            print('Saving model...')
            checkpoint.save(file_prefix=checkpoint_prefix)
            save_plot(ckpt_path, ep_losses)

        plt_loss.append(total_loss/N_BATCH)
        print('Epoch: {} | Loss: {:.4f} | Time: {:.2f} min'.format(epoch, total_loss/(N_BATCH*reps), (time.time()-epoch_time)/60))
    return plt_loss

In [0]:
hparams = load_hyper_params(hparams_path)
credentials = service_account.Credentials.from_service_account_file(hparams['CREDENTIALS_PATH'])
enc, dec, opt = create_model(hparams)

In [0]:
#'''
v = Vocabulary(max_len=hparams['MAX_LEN'])
v.load_bigquery_vocab_from_indexed(credentials, hparams['VOCAB_DB'], hparams['VOCAB'], True)

checkpoint = tf.train.Checkpoint(optimizer=opt, encoder=enc, decoder=dec)
checkpoint.restore(tf.train.latest_checkpoint(ckpt_path))
#checkpoint.restore(drive_main_path + 'checkpoints-final-1/' + 'ckpt-2')

result, _, _ = evaluate(u'The world is changing once again!', v, enc, dec, hparams['MAX_LEN'])
print(result)
#'''

In [0]:
if os.path.isdir('./__pycache__/'): shutil.rmtree(path='./__pycache__/', ignore_errors=True, onerror=None)

if hparams['TRAINING_MODE'] == 'single': plt_loss = train(hparams, credentials,500, initial_epoch=4, saving=True)
elif hparams['TRAINING_MODE'] == 'multi': plt_loss = multi_initializer_train(hparams, credentials, 500, initial_epoch=4, saving=True)
else: raise Exception('Please enter a valid TRAINING_MODE: \'single\' or \'multi\' '
                      '(for \'multi\' please use OFFSET_REP >= 2 or \'max\')')

plt.plot(plt_loss)
plt.ylabel('loss')
plt.xlabel('epoch')
plt.show()