In [1]:
import re
import random
import time

print('Library versions:')

import keras
print(f'keras:{keras.__version__}')
import pandas as pd
print(f'pandas:{pd.__version__}')
import sklearn
print(f'sklearn:{sklearn.__version__}')
import nltk
print(f'nltk:{nltk.__version__}')
import numpy as np
print(f'numpy:{np.__version__}')

from sklearn.feature_extraction.text import CountVectorizer
from nltk.tokenize import casual_tokenize

Library versions:


Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


keras:2.3.1
pandas:1.3.4
sklearn:1.0.2
nltk:3.6.7
numpy:1.21.5


In [2]:
# from tqdm import tqdm.notebook as tqdm # Special jupyter notebook progress bar 💫
from tqdm.notebook import tqdm as tqdm

In [3]:
import os
os.getcwd()

'C:\\Users\\JiatingChen\\Documents\\nlp-coe\\twitter-conversational-chatbot\\EDA Notebooks'

## Model Parameters

In [4]:
# 8192 - large enough for demonstration, larger values make network training slower
MAX_VOCAB_SIZE = 2**13
# seq2seq generally relies on fixed length message vectors - longer messages provide more info
# but result in slower training and larger networks
MAX_MESSAGE_LEN = 30  
# Embedding size for words - gives a trade off between expressivity of words and network size
EMBEDDING_SIZE = 100
# Embedding size for whole messages, same trade off as word embeddings
CONTEXT_SIZE = 100
# Larger batch sizes generally reach the average response faster, but small batch sizes are
# required for the model to learn nuanced responses.  Also, GPU memory limits max batch size.
BATCH_SIZE = 4
# Helps regularize network and prevent overfitting.
DROPOUT = 0.2
# High learning rate helps model reach average response faster, but can make it hard to 
# converge on nuanced responses
LEARNING_RATE=0.005

# Tokens needed for seq2seq
UNK = 0  # words that aren't found in the vocab
PAD = 1  # after message has finished, this fills all remaining vector positions
START = 2  # provided to the model at position 0 for every response predicted

# Implementaiton detail for allowing this to be run in Kaggle's notebook hardware
SUB_BATCH_SIZE = 1000


## Data Prep
Here, we'll prepare the data for training our seq2seq model, including:

- Replace screen names with `@__sn__` token to show model the commonality between them
- Build a vocab to turn tokens into integers suitable for our seq2seq model
- Tokenize input and target text into fixed size vectors
- Partition our dataset into train and test sets

### Data Loading and Reshaping
Pulled from [this kernel](https://www.kaggle.com/soaxelbrooke/first-inbound-and-response-tweets).

In [5]:
""" A kernel posted on Kaggle that shows how to pull just the first consumer request and
    company response from the dataset.
"""

import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

tweets = pd.read_csv('../data/twcs/twcs.csv')


# Pick only inbound tweets that aren't in reply to anything...
first_inbound = tweets[pd.isnull(tweets.in_response_to_tweet_id) & tweets.inbound]
print('Found {} first inbound messages.'.format(len(first_inbound)))

# Merge in all tweets in response
inbounds_and_outbounds = pd.merge(first_inbound, tweets, left_on='tweet_id', 
                                  right_on='in_response_to_tweet_id')
print("Found {} responses.".format(len(inbounds_and_outbounds)))

# Filter out cases where reply tweet isn't from company
inbounds_and_outbounds = inbounds_and_outbounds[inbounds_and_outbounds.inbound_y ^ True]

# Et voila!
print("Found {} responses from companies.".format(len(inbounds_and_outbounds)))
print("Tweets Preview:")
print(inbounds_and_outbounds)

Found 787346 first inbound messages.
Found 875292 responses.
Found 794299 responses from companies.
Tweets Preview:
        tweet_id_x author_id_x  inbound_x                    created_at_x  \
0                8      115712       True  Tue Oct 31 21:45:10 +0000 2017   
1                8      115712       True  Tue Oct 31 21:45:10 +0000 2017   
2                8      115712       True  Tue Oct 31 21:45:10 +0000 2017   
3               18      115713       True  Tue Oct 31 19:56:01 +0000 2017   
4               20      115715       True  Tue Oct 31 22:03:34 +0000 2017   
...            ...         ...        ...                             ...   
875287     2987942      823867       True  Wed Nov 22 07:30:39 +0000 2017   
875288     2987944      823868       True  Wed Nov 22 07:43:36 +0000 2017   
875289     2987946      524544       True  Wed Nov 22 08:25:48 +0000 2017   
875290     2987948      823869       True  Wed Nov 22 08:35:16 +0000 2017   
875291     2987950      823870       

In [6]:
tweets.head() 
#inbound: message from others (response_tweet_id); outbound: message sent by the tweet_id (in_response_to_tweet_id)

Unnamed: 0,tweet_id,author_id,inbound,created_at,text,response_tweet_id,in_response_to_tweet_id
0,1,sprintcare,False,Tue Oct 31 22:10:47 +0000 2017,@115712 I understand. I would like to assist y...,2.0,3.0
1,2,115712,True,Tue Oct 31 22:11:45 +0000 2017,@sprintcare and how do you propose we do that,,1.0
2,3,115712,True,Tue Oct 31 22:08:27 +0000 2017,@sprintcare I have sent several private messag...,1.0,4.0
3,4,sprintcare,False,Tue Oct 31 21:54:49 +0000 2017,@115712 Please send us a Private Message so th...,3.0,5.0
4,5,115712,True,Tue Oct 31 21:49:35 +0000 2017,@sprintcare I did.,4.0,6.0


In [7]:
%%time
tweets = pd.read_csv('../data/twcs/twcs.csv')

first_inbound = tweets[pd.isnull(tweets.in_response_to_tweet_id) & tweets.inbound]

inbounds_and_outbounds = pd.merge(first_inbound, tweets, left_on='tweet_id', 
                                  right_on='in_response_to_tweet_id').sample(frac=1)

# Filter to only outbound replies (from companies)
inbounds_and_outbounds = inbounds_and_outbounds[inbounds_and_outbounds.inbound_y ^ True]

#tqdm().pandas()  # Enable tracking of progress in dataframe `apply` calls

Wall time: 11.3 s


In [8]:
tqdm().pandas()

0it [00:00, ?it/s]

In [9]:
print(f'Data shape: {inbounds_and_outbounds.shape}')

Data shape: (794299, 14)


### Tokenizing and Vocab Build

We'll use NLTK's `casual_tokenize`, which handles a lot of corner cases found in social media data ("casual" text data) along with scitkit learn's `CountVectorizer`.  We won't use the actual `CountVectorizer`, just use it as a convenient vocabulary builder, which we'll apply with functions that turn text into "word indexes" - integers that represent each word - and back.

In [10]:
inbounds_and_outbounds.head()

Unnamed: 0,tweet_id_x,author_id_x,inbound_x,created_at_x,text_x,response_tweet_id_x,in_response_to_tweet_id_x,tweet_id_y,author_id_y,inbound_y,created_at_y,text_y,response_tweet_id_y,in_response_to_tweet_id_y
514330,1834275,550194,True,Tue Oct 17 19:32:31 +0000 2017,"@british_airways - flight to Seattle, flight w...",18342741834276,,1834274,British_Airways,False,Wed Oct 18 05:46:31 +0000 2017,"@550194 If you want us to take a look, please ...",,1834275.0
86976,332623,195209,True,Sat Oct 07 14:11:32 +0000 2017,Wheres my achievement @115790 @XboxSupport htt...,332621,,332621,XboxSupport,False,Sat Oct 07 19:04:26 +0000 2017,@195209 Hi there! Could you follow the steps h...,332622.0,332623.0
160341,598547,261930,True,Sun Dec 03 13:30:52 +0000 2017,@Tesco this drink is falsely advertised. This ...,598545,,598545,Tesco,False,Sun Dec 03 20:22:29 +0000 2017,"@261930 Hi Laszlo, I'd be happy to look in to ...",598546.0,598547.0
224425,830319,317722,True,Wed Oct 11 00:33:35 +0000 2017,Still pissed about the #huluupdate @hulu_suppo...,830318,,830318,hulu_support,False,Fri Oct 13 03:41:39 +0000 2017,@317722 Rest assured we're making updates base...,,830319.0
860406,2941537,812757,True,Wed Nov 29 12:11:36 +0000 2017,@140046 Any idea where I can buy your #glutenf...,29415362941538,,2941538,sainsburys,False,Wed Nov 29 14:41:45 +0000 2017,"@812757 Hi Michelle, you can place a product r...",,2941537.0


In [11]:
inbounds_and_outbounds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 794299 entries, 514330 to 392195
Data columns (total 14 columns):
 #   Column                     Non-Null Count   Dtype  
---  ------                     --------------   -----  
 0   tweet_id_x                 794299 non-null  int64  
 1   author_id_x                794299 non-null  object 
 2   inbound_x                  794299 non-null  bool   
 3   created_at_x               794299 non-null  object 
 4   text_x                     794299 non-null  object 
 5   response_tweet_id_x        794299 non-null  object 
 6   in_response_to_tweet_id_x  0 non-null       float64
 7   tweet_id_y                 794299 non-null  int64  
 8   author_id_y                794299 non-null  object 
 9   inbound_y                  794299 non-null  bool   
 10  created_at_y               794299 non-null  object 
 11  text_y                     794299 non-null  object 
 12  response_tweet_id_y        263771 non-null  object 
 13  in_response_to_tweet_id_

In [12]:
inbounds_and_outbounds.author_id_y.value_counts(normalize = True).head(20)

AmazonHelp         0.106556
AppleSupport       0.093960
Uber_Support       0.050395
Delta              0.035862
SpotifyCares       0.033917
Tesco              0.031332
AmericanAir        0.030852
comcastcares       0.030015
SouthwestAir       0.026421
TMobileHelp        0.025261
British_Airways    0.024690
Ask_Spectrum       0.021742
VirginTrains       0.018075
UPSHelp            0.017994
hulu_support       0.017870
ChipotleTweets     0.017441
sprintcare         0.015925
XboxSupport        0.015748
AskPlayStation     0.014349
sainsburys         0.013550
Name: author_id_y, dtype: float64

In [13]:
inbounds_and_outbounds.author_id_y.value_counts(normalize = True).tail(20) #108 companies

askvisa            0.000721
ask_progressive    0.000658
GooglePlayMusic    0.000643
YahooCare          0.000640
USCellularCares    0.000628
asksalesforce      0.000578
MTNC_Care          0.000574
MOO                0.000524
KeyBank_Help       0.000482
AskSeagate         0.000473
AskVirginMoney     0.000466
OPPOCareIN         0.000428
AskRobinhood       0.000392
AskTigogh          0.000349
JackBox            0.000253
mediatemplehelp    0.000239
AskDSC             0.000238
CarlsJr            0.000174
HotelTonightCX     0.000165
OfficeSupport      0.000073
Name: author_id_y, dtype: float64

In [14]:
from tqdm.notebook import tqdm
tqdm.pandas()

In [15]:
# Replace anonymized screen names with common token @__sn__
def sn_replace(match):
    _sn = match.group(2).lower()
    if not _sn.isnumeric():
        # This is a company screen name
        return match.group(1) + match.group(2)
    return ' @__sn__'

sn_re = re.compile('(\W@|^@)([a-zA-Z0-9_]+)')
print("Replacing anonymized screen names in X...")
x_text = inbounds_and_outbounds.text_x.progress_apply(lambda txt: sn_re.sub(sn_replace, txt))
print("Replacing anonymized screen names in Y...")
y_text = inbounds_and_outbounds.text_y.progress_apply(lambda txt: sn_re.sub(sn_replace, txt))

Replacing anonymized screen names in X...


  0%|          | 0/794299 [00:00<?, ?it/s]

Replacing anonymized screen names in Y...


  0%|          | 0/794299 [00:00<?, ?it/s]

In [16]:
def sn_replace(match):
    _sn = match.group(2).lower()
    if not _sn.isnumeric():
        # This is a company screen name
        return match.group(1) + match.group(2)
    return ' @__sn__'
inbounds_and_outbounds.text_x[0:10]

514330    @british_airways - flight to Seattle, flight w...
86976     Wheres my achievement @115790 @XboxSupport htt...
160341    @Tesco this drink is falsely advertised. This ...
224425    Still pissed about the #huluupdate @hulu_suppo...
860406    @140046 Any idea where I can buy your #glutenf...
858257    @AmazonHelp Hi, is Amazon Household available ...
518640    @115858 please fix the issue with messages not...
508857                    Chase bank can suck my ass foreal
742272    And suddenly @115911 cuts out and I lose servi...
494175    Hey @Ask_Spectrum are you in contract disputes...
Name: text_x, dtype: object

In [17]:
inbounds_and_outbounds.text_x[0:10].apply(lambda txt: sn_re.sub(sn_replace, txt))

514330    @british_airways - flight to Seattle, flight w...
86976     Wheres my achievement @__sn__ @XboxSupport htt...
160341    @Tesco this drink is falsely advertised. This ...
224425    Still pissed about the #huluupdate @hulu_suppo...
860406     @__sn__ Any idea where I can buy your #gluten...
858257    @AmazonHelp Hi, is Amazon Household available ...
518640     @__sn__ please fix the issue with messages no...
508857                    Chase bank can suck my ass foreal
742272    And suddenly @__sn__ cuts out and I lose servi...
494175    Hey @Ask_Spectrum are you in contract disputes...
Name: text_x, dtype: object

In [18]:
count_vec = CountVectorizer(tokenizer=casual_tokenize, max_features=MAX_VOCAB_SIZE - 3)
print("Fitting CountVectorizer on X and Y text data...")
count_vec.fit(tqdm(x_text + y_text))
analyzer = count_vec.build_analyzer()
vocab = {k: v + 3 for k, v in count_vec.vocabulary_.items()}
vocab['__unk__'] = UNK
vocab['__pad__'] = PAD
vocab['__start__'] = START

# Used to turn seq2seq predictions into human readable strings
reverse_vocab = {v: k for k, v in vocab.items()}
print(f"Learned vocab of {len(vocab)} items.")

Fitting CountVectorizer on X and Y text data...


  0%|          | 0/794299 [00:00<?, ?it/s]

  "The parameter 'token_pattern' will not be used"


Learned vocab of 8192 items.


In [19]:
x_text + y_text

514330    @british_airways - flight to Seattle, flight w...
86976     Wheres my achievement @__sn__ @XboxSupport htt...
160341    @Tesco this drink is falsely advertised. This ...
224425    Still pissed about the #huluupdate @hulu_suppo...
860406     @__sn__ Any idea where I can buy your #gluten...
                                ...                        
495943    @TMobileHelp Switched from Sprint less than 6 ...
85035      @__sn__ @idea_cares https://t.co/p9bYulAVYR @...
573101     @__sn__  please do something about this I️ no...
542974    @Ask_Spectrum A man can't even play xbox. We a...
392195    @Uber_Support You guys couldn't help me find m...
Length: 794299, dtype: object

In [20]:
vocab['__start__']

2

### Vocab Helper Functions
These helper functions take strings and turn them into word indexes used by the actual seq2seq models.  This turns something like "This is how we do it." into a padded array of integers, like [153, 4, 643, 48, 94, 54, 8, 0, 0, 0].  We'll apply the `to_word_idx` function to our text data to get our `N x MESSAGE_LEN` training/test data.

In [21]:
def to_word_idx(sentence):
    full_length = [vocab.get(tok, UNK) for tok in analyzer(sentence)] + [PAD] * MAX_MESSAGE_LEN
    return full_length[:MAX_MESSAGE_LEN]

def from_word_idx(word_idxs):
    return ' '.join(reverse_vocab[idx] for idx in word_idxs if idx != PAD).strip()


In [22]:
x_text.head().apply(to_word_idx)

514330    [484, 155, 2992, 7207, 6330, 154, 2992, 7656, ...
86976     [7738, 4876, 602, 449, 543, 0, 1, 1, 1, 1, 1, ...
160341    [530, 7116, 2489, 3976, 0, 661, 157, 7116, 397...
224425    [6783, 5424, 568, 7088, 0, 501, 4354, 7496, 42...
860406    [449, 875, 3777, 7736, 3763, 1490, 1445, 7916,...
Name: text_x, dtype: object

In [23]:
# Make sure our helpers work as expected...
x_text.head().apply(to_word_idx).apply(from_word_idx)

514330    @british_airways - flight to seattle , flight ...
86976     wheres my achievement @__sn__ @xboxsupport __u...
160341    @tesco this drink is __unk__ advertised . this...
224425    still pissed about the __unk__ @hulu_support l...
860406    @__sn__ any idea where i can buy your #glutenf...
Name: text_x, dtype: object

In [24]:
vocab['anyone']

878

In [25]:
reverse_vocab[878]

'anyone'

In [26]:
print("Calculating word indexes for X...")
x = pd.np.vstack(x_text.progress_apply(to_word_idx).values)
print("Calculating word indexes for Y...")
y = pd.np.vstack(y_text.progress_apply(to_word_idx).values)

Calculating word indexes for X...


  


  0%|          | 0/794299 [00:00<?, ?it/s]

Calculating word indexes for Y...


  after removing the cwd from sys.path.


  0%|          | 0/794299 [00:00<?, ?it/s]

### Train / Test Split
Here, we split our data into training and test sets.  For simplicity, we use a random split, which may result in different distributions between the training and test set, but we won't worry about that for this case.

In [27]:
all_idx = list(range(len(x)))
random.seed(1234)
train_idx = set(random.sample(all_idx, int(0.8 * len(all_idx))))
test_idx = {idx for idx in all_idx if idx not in train_idx}

train_x = x[list(train_idx)]
test_x = x[list(test_idx)]
train_y = y[list(train_idx)]
test_y = y[list(test_idx)]

assert train_x.shape == train_y.shape
assert test_x.shape == test_y.shape

print(f'Training data of shape {train_x.shape} and test data of shape {test_x.shape}.')

Training data of shape (635439, 30) and test data of shape (158860, 30).


In [28]:
train_x

array([[ 484,  155, 2992, ...,    1,    1,    1],
       [7738, 4876,  602, ...,    1,    1,    1],
       [ 530, 7116, 2489, ..., 2489,    4,  157],
       ...,
       [ 532, 6937, 3089, ...,    1,    1,    1],
       [ 461,  552, 4583, ...,    1,    1,    1],
       [ 534, 7909, 3310, ..., 2494, 2221,  157]])

## Model Creation
We'll create and compile the model here.  It will consist of the following components:

- Shared word embeddings
  - A shared embedding layer that turns word indexes (a sparse representation) into a dense/compressed representation.  This embeds both the request from the customer, and also the last words uttered by the model that are fed back into the model.
- Encoder RNN
  - In this case, a single LSTM layer.  This encodes the whole input sentence into a context vector (or thought vector) that represents completely what the customer is saying, and produces a single output.
- Decoder RNN
  - This RNN (also an LSTM in this case) decodes the context vector into a string of tokens/utterances.  For each time step, it takes the context vector and the embedded last utterance and produces the next utterance, which is fed back into the model.  More complex and effective models copy the encoder state into the decoder, add more layers of LSTMs, and apply attention mechanisms - but these are out of the scope of this simple example.
- Next Word Dense+Softmax
  - These two layers take the decoder output and turn it into the next word to be uttered.  The dense layer allows the decoder to not map directly to words uttered, and the softmax turns the dense layer output into a probability distribution, from which we pick the most likely next word.

![seq2seq model structure](https://i.imgur.com/JmuryKu.png)

In [29]:
# keras imports, because there are like... A million of them.
from keras.models import Model
# from keras.optimizers import Adam
from tensorflow.keras.optimizers import Adam
from keras.layers import Dense, Input, LSTM, Dropout, Embedding, RepeatVector, concatenate, \
    TimeDistributed
from keras.utils import np_utils

In [30]:
import tensorflow as tf
tf.__version__

'1.14.0'

In [31]:
import numpy as np
np.__version__

'1.21.5'

In [32]:
def create_model():
    shared_embedding = Embedding(
        output_dim=EMBEDDING_SIZE,
        input_dim=MAX_VOCAB_SIZE,
        input_length=MAX_MESSAGE_LEN,
        name='embedding',
    )
    
    # ENCODER
    
    encoder_input = Input(
        shape=(MAX_MESSAGE_LEN,),
        dtype='int32',
        name='encoder_input',
    )
    
    embedded_input = shared_embedding(encoder_input)
    
    # No return_sequences - since the encoder here only produces a single value for the
    # input sequence provided.
    encoder_rnn = LSTM(
        CONTEXT_SIZE,
        name='encoder',
        dropout=DROPOUT
    )
    
    context = RepeatVector(MAX_MESSAGE_LEN)(encoder_rnn(embedded_input))
    
    # DECODER
    
    last_word_input = Input(
        shape=(MAX_MESSAGE_LEN, ),
        dtype='int32',
        name='last_word_input',
    )
    
    embedded_last_word = shared_embedding(last_word_input)
    # Combines the context produced by the encoder and the last word uttered as inputs
    # to the decoder.
    decoder_input = concatenate([embedded_last_word, context], axis=2)
    
    # return_sequences causes LSTM to produce one output per timestep instead of one at the
    # end of the intput, which is important for sequence producing models.
    decoder_rnn = LSTM(
        CONTEXT_SIZE,
        name='decoder',
        return_sequences=True,
        dropout=DROPOUT
    )
    
    decoder_output = decoder_rnn(decoder_input)
    
    # TimeDistributed allows the dense layer to be applied to each decoder output per timestep
    next_word_dense = TimeDistributed(
        Dense(int(MAX_VOCAB_SIZE / 2), activation='relu'),
        name='next_word_dense',
    )(decoder_output)
    
    next_word = TimeDistributed(
        Dense(MAX_VOCAB_SIZE, activation='softmax'),
        name='next_word_softmax'
    )(next_word_dense)
    
    return Model(inputs=[encoder_input, last_word_input], outputs=[next_word])

s2s_model = create_model()
optimizer = Adam(lr=LEARNING_RATE, clipvalue=5.0)
s2s_model.compile(optimizer='adam', loss='categorical_crossentropy')

## Model Training
We'll train the model here.  After each sub-batch of the dataset, we'll test with static input strings to see how the model is progressing in human readable terms.  Its important to have these tests along with traditional model evaluation to provide a better understanding of how well the model is training.

It's important to pull test strings from the real distribution of the data, also.  It can be hard to really put yourself in customers' shoes when writing test messages, and you will get non-representative results when you provide test examples that don't fit the true distribution of the input data (when your input text doesn't sound like real customer requests).

In [33]:
def add_start_token(y_array):
    """ Adds the start token to vectors.  Used for training data. """
    return np.hstack([
        START * np.ones((len(y_array), 1)),
        y_array[:, :-1],
    ])

def binarize_labels(labels):
    """ Helper function that turns integer word indexes into sparse binary matrices for 
        the expected model output.
    """
    return np.array([np_utils.to_categorical(row, num_classes=MAX_VOCAB_SIZE)
                     for row in labels])

In [34]:
def respond_to(model, text):
    """ Helper function that takes a text input and provides a text output. """
    input_y = add_start_token(PAD * np.ones((1, MAX_MESSAGE_LEN)))
    idxs = np.array(to_word_idx(text)).reshape((1, MAX_MESSAGE_LEN))
    for position in range(MAX_MESSAGE_LEN - 1):
        prediction = model.predict([idxs, input_y]).argmax(axis=2)[0]
        input_y[:,position + 1] = prediction[position]
    return from_word_idx(model.predict([idxs, input_y]).argmax(axis=2)[0])

In [35]:
def train_mini_epoch(model, start_idx, end_idx):
    """ Batching seems necessary in Kaggle Jupyter Notebook environments, since
        `model.fit` seems to freeze on larger batches (somewhere 1k-10k).
    """
    b_train_y = binarize_labels(train_y[start_idx:end_idx])
    input_train_y = add_start_token(train_y[start_idx:end_idx])
    
    model.fit(
        [train_x[start_idx:end_idx], input_train_y], 
        b_train_y,
        epochs=1,
        batch_size=BATCH_SIZE,
    )
    
    rand_idx = random.sample(list(range(len(test_x))), SUB_BATCH_SIZE)
    print('Test results:', model.evaluate(
        [test_x[rand_idx], add_start_token(test_y[rand_idx])],
        binarize_labels(test_y[rand_idx])
    ))
    
    input_strings = [
        "@AppleSupport I fix I this I stupid I problem I",
        "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service.",
    ]
    
    for input_string in input_strings:
        output_string = respond_to(model, input_string)
        print(f'> "{input_string}"\n< "{output_string}"')


### Train the model!

You can stop training by pressing the stop button - the training code is configured to watch for the `KeyboardInterrupt` exception triggered that way.  Also, it will run until the configured stopping point below.


Let's start the training! 🚀

In [36]:
training_time_limit = 360 * 60  # seconds (notebooks terminate after 1 hour)
start_time = time.time()
stop_after = start_time + training_time_limit

class TimesUpInterrupt(Exception):
    pass

try:
    for epoch in range(100):
        print(f'Training in epoch {epoch}...')
        for start_idx in range(0, len(train_x), SUB_BATCH_SIZE):
            train_mini_epoch(s2s_model, start_idx, start_idx + SUB_BATCH_SIZE)
            if time.time() > stop_after:
                raise TimesUpInterrupt
except KeyboardInterrupt:
    print("Halting training from keyboard interrupt.")
except TimesUpInterrupt:
    print(f"Halting after {time.time() - start_time} seconds spent training.")

Training in epoch 0...
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Epoch 1/1
Test results: 4.258560516357422
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ to , . to . ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ to , . to . ."
Epoch 1/1
Test results: 3.863120491027832
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we can help ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ we can help ."
Epoch 1/1
Test results: 3.6045391426086426
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi , , we can help . please dm us a dm with your account number . ^ __unk__"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , , we can help . please dm us a dm with your account number . ^ __un

Test results: 2.6538987770080564
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there ! please dm us your account number , and address , and we will be able to help . ^ __unk__"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there ! please dm us your account number , and address , and we will be able to help . ^ __unk__"
Epoch 1/1
Test results: 2.7146755867004395
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , we are sorry to hear this . please dm us your account number so we can look into this for you ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your account number so we can look into this for you ."
Epoch 1/1
Test results: 2.6892468223571777
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , please send us a dm with 

Epoch 1/1
Test results: 2.4935453567504884
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , we are sorry to hear this . please dm us your email address and phone number and we'll take a look backstage / __unk__ https://t.co/ldfdzrinat"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your email address and phone number and we'll take a look backstage / __unk__ https://t.co/ldfdzrinat"
Epoch 1/1
Test results: 2.4478374652862547
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , we are sorry to hear this . please dm us your email address and phone number so we can look into this for you . ^"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your email address and phone number so we can look into this f

Epoch 1/1
Test results: 2.3128198337554933
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , sorry to hear this . please dm us your email address and we'll be happy to help ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , sorry to hear this . please dm us your email address and we'll be happy to help ."
Epoch 1/1
Test results: 2.3077019901275633
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi , we are sorry to hear this . please dm us your tracking number and we will be happy to help . ^ kr"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , we are sorry to hear this . please dm us your tracking number and we will be happy to help . ^ kr"
Epoch 1/1
Test results: 2.422158935546875
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi , __unk__ . we are sorry to hear this . please

Epoch 1/1
Test results: 2.172768119812012
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , we are sorry to hear this . please dm us your full name , address and email address so we can take a look . https://t.co/sxpdictw1a"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your full name , address and email address so we can take a look . https://t.co/sxpdictw1a"
Epoch 1/1
Test results: 2.2013617458343506
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , we are sorry to hear this . please dm us your full name , address , and phone number so we can assist . ^ mm"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your full name , address , and phone number so we can assist . ^ kd"
Epoch 1/1
Test result

> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there ! we are sorry to hear this . please dm us your contact number and phone number . ^ ez"
Epoch 1/1
Test results: 2.2233642292022706
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi , __unk__ . please dm us your tracking number , contact number , and contact number via dm . ^ ds https://t.co/wkjhdxwgrq"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . we have a great flight . ^ pa"
Epoch 1/1
Test results: 2.2516570568084715
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , we are sorry to hear this . please dm us your idea number and alternate number . ^ ap https://t.co/wkjhdxwgrq"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , i'm sorry to hear about this 

Epoch 1/1
Test results: 2.1426185579299926
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . send us a dm and we'll continue from there . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we're sorry to hear about this . please dm us your gamertag and some more info about your issue . ^ kr"
Epoch 1/1
Test results: 2.119567663192749
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . send us a dm and we'll go from there . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , sorry to hear this . please dm us your email address so we can take a look into this for you ."
Epoch 1/1
Test results: 2.136117377281189
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we can help you out . please dm us the phone number on

Epoch 1/1
Test results: 2.1586026649475096
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . send us a dm with your email address so we can follow up ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , i am sorry to hear this . please dm us your full name , address , and phone number . ^ mm https://t.co/wkjhdxwgrq"
Epoch 1/1
Test results: 2.1650350036621093
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . send us a dm and we can start there . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , i'm sorry to hear this . please dm us your contact info and store location ."
Epoch 1/1
Test results: 2.1752389488220216
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . send us a dm and we'll go from there . https://t.co/gdrqu2

Epoch 1/1
Test results: 2.0866097450256347
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . what device are you using ?"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we don't have any info on this right now , but we'll let the right team know it's something you'd like to see / gk"
Epoch 1/1
Test results: 2.1243874139785768
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . can you tell us more about what you're experiencing ?"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , sorry to hear this . please dm us your contact info and store address . thanks !"
Epoch 1/1
Test results: 2.163889051437378
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . let's take this to dm so we can better assist you . https://t.co/gdrqu22ypt"
> "@Amazon

Epoch 1/1
Test results: 2.045133424758911
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . dm us and we'll continue from there . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , i am sorry to hear this . please dm your confirmation number . ^ ml https://t.co/wkjhdxwgrq"
Epoch 1/1
Test results: 2.0437242336273194
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . dm us and we'll go from there . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear about this . please dm us your contact number and store address . thanks !"
Epoch 1/1
Test results: 2.0609326639175416
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . dm us which ios version you're using and we'll go from th

Epoch 1/1
Test results: 2.1113169898986817
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we can help . dm us which iphone and ios version you are using . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , i can help with your internet service . can you dm me your account number ? - aaron"
Epoch 1/1
Test results: 2.1085328102111816
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . dm us and we'll go from there . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , we are sorry to hear about this . please dm us your contact info and store address . thanks !"
Epoch 1/1
Test results: 2.0361676921844483
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . dm us the ios version you are using . https://t.co/gdrqu22ypt"


Test results: 2.034298530578613
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . let's take this to dm so we can better assist you . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . please dm us your email address so we can connect ."
Epoch 1/1
Test results: 2.013015291213989
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . which ios version are you currently running ?"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear about this . please dm us your email address so we can connect ."
Epoch 1/1
Test results: 2.0325551013946535
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . which device are you using ?"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have suc

Test results: 2.0544078874588014
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . what happens when you try to restart ?"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . please dm us your email address and we'll take a look ."
Epoch 1/1
Test results: 2.0181142330169677
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . let's take this to dm so we can better assist you . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . please dm us your contact info and store address . thanks !"
Epoch 1/1
Test results: 2.0550790529251097
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . let's take this to dm so we can better assist you . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand li

Epoch 1/1
Test results: 2.0052889385223387
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . send us a dm with your current ios version and we'll go from there . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , i am sorry to hear this . please dm us your full name , address , and phone number . ^ ck"
Epoch 1/1
Test results: 2.0061479434967042
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . dm us which iphone you're using and we'll go from there . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , i am sorry to hear this . please dm us with your paypal email and any applicable screenshots . ^ ez"
Epoch 1/1
Test results: 2.0213373851776124
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . dm us a

In [37]:
respond_to(s2s_model, '''@AppleSupport iPhone 8 touchID doesnt unlock while charging on 
    110v w/ 61w laptop charger to usbc lightning cable just uh.. so you guys know''')

"@__sn__ we'd like to help . send us a dm with your current ios version and we'll go from there . https://t.co/gdrqu22ypt"

In [38]:
respond_to(s2s_model, '''@sprintcare I can't make calls... wtf''')

'@__sn__ hi , __unk__ . please send us a dm with your email address so we can look into this .'

In [39]:
s2s_model

<keras.engine.training.Model at 0x1773e620cc8>

In [None]:
tf.keras.models.save_model(
    model, filepath, overwrite=True, include_optimizer=True, save_format=None,
    signatures=None, options=None, save_traces=True
)

In [40]:
os.getcwd()

'C:\\Users\\JiatingChen\\Documents\\nlp-coe\\twitter-conversational-chatbot\\EDA Notebooks'

In [43]:
import h5py

ImportError: DLL load failed: The specified procedure could not be found.

In [41]:
tf.keras.models.save_model(s2s_model, filepath = '../model')

ImportError: `save_model` requires h5py.