In [1]:
import re
import random
import time

print('Library versions:')

import keras
print(f'keras:{keras.__version__}')
import pandas as pd
print(f'pandas:{pd.__version__}')
import sklearn
print(f'sklearn:{sklearn.__version__}')
import nltk
print(f'nltk:{nltk.__version__}')
import numpy as np
print(f'numpy:{np.__version__}')

from sklearn.feature_extraction.text import CountVectorizer
from nltk.tokenize import casual_tokenize

Library versions:


Using TensorFlow backend.


keras:2.3.1
pandas:1.1.5
sklearn:0.22.1
nltk:3.6.7
numpy:1.19.5


In [2]:
import tensorflow as tf
print(f'tensofrlow:{tf.__version__}')

tensofrlow:1.15.0


In [3]:
# from tqdm import tqdm.notebook as tqdm # Special jupyter notebook progress bar 💫
from tqdm.notebook import tqdm as tqdm
from tqdm.notebook import tqdm
tqdm.pandas()

In [4]:
import os
os.getcwd()

'C:\\Users\\JiatingChen\\Documents\\nlp-coe\\twitter-conversational-chatbot\\EDA-notebooks'

## Model Parameters

In [10]:
# 8192 - large enough for demonstration, larger values make network training slower
MAX_VOCAB_SIZE = 2**13
# seq2seq generally relies on fixed length message vectors - longer messages provide more info
# but result in slower training and larger networks
MAX_MESSAGE_LEN = 30  
# Embedding size for words - gives a trade off between expressivity of words and network size
EMBEDDING_SIZE = 100
# Embedding size for whole messages, same trade off as word embeddings
CONTEXT_SIZE = 100
# Larger batch sizes generally reach the average response faster, but small batch sizes are
# required for the model to learn nuanced responses.  Also, GPU memory limits max batch size.
BATCH_SIZE = 4
# Helps regularize network and prevent overfitting.
DROPOUT = 0.2
# High learning rate helps model reach average response faster, but can make it hard to 
# converge on nuanced responses
LEARNING_RATE=0.005

# Tokens needed for seq2seq
UNK = 0  # words that aren't found in the vocab
PAD = 1  # after message has finished, this fills all remaining vector positions
START = 2  # provided to the model at position 0 for every response predicted

# Implementaiton detail for allowing this to be run in Kaggle's notebook hardware
SUB_BATCH_SIZE = 1000


ERROR! Session/line number was not unique in database. History logging moved to new session 115


## Data Prep
Here, we'll prepare the data for training our seq2seq model, including:

- Replace screen names with `@__sn__` token to show model the commonality between them
- Build a vocab to turn tokens into integers suitable for our seq2seq model
- Tokenize input and target text into fixed size vectors
- Partition our dataset into train and test sets

### Data Loading and Reshaping
Pulled from [this kernel](https://www.kaggle.com/soaxelbrooke/first-inbound-and-response-tweets).

In [6]:
""" A kernel posted on Kaggle that shows how to pull just the first consumer request and
    company response from the dataset.
"""

tweets = pd.read_csv('../data/twcs/twcs.csv')


# Pick only inbound tweets that aren't in reply to anything...
first_inbound = tweets[pd.isnull(tweets.in_response_to_tweet_id) & tweets.inbound]
print('Found {} first inbound messages.'.format(len(first_inbound)))

# Merge in all tweets in response
inbounds_and_outbounds = pd.merge(first_inbound, tweets, left_on='tweet_id', 
                                  right_on='in_response_to_tweet_id')
print("Found {} responses.".format(len(inbounds_and_outbounds)))

# Filter out cases where reply tweet isn't from company
inbounds_and_outbounds = inbounds_and_outbounds[inbounds_and_outbounds.inbound_y ^ True]

# Et voila!
print("Found {} responses from companies.".format(len(inbounds_and_outbounds)))
print("Tweets Preview:")
print(inbounds_and_outbounds)

Found 787346 first inbound messages.
Found 875292 responses.
Found 794299 responses from companies.
Tweets Preview:
        tweet_id_x author_id_x  inbound_x                    created_at_x  \
0                8      115712       True  Tue Oct 31 21:45:10 +0000 2017   
1                8      115712       True  Tue Oct 31 21:45:10 +0000 2017   
2                8      115712       True  Tue Oct 31 21:45:10 +0000 2017   
3               18      115713       True  Tue Oct 31 19:56:01 +0000 2017   
4               20      115715       True  Tue Oct 31 22:03:34 +0000 2017   
...            ...         ...        ...                             ...   
875287     2987942      823867       True  Wed Nov 22 07:30:39 +0000 2017   
875288     2987944      823868       True  Wed Nov 22 07:43:36 +0000 2017   
875289     2987946      524544       True  Wed Nov 22 08:25:48 +0000 2017   
875290     2987948      823869       True  Wed Nov 22 08:35:16 +0000 2017   
875291     2987950      823870       

In [7]:
tweets.head() 
#inbound: message from others (response_tweet_id); outbound: message sent by the tweet_id (in_response_to_tweet_id)

Unnamed: 0,tweet_id,author_id,inbound,created_at,text,response_tweet_id,in_response_to_tweet_id
0,1,sprintcare,False,Tue Oct 31 22:10:47 +0000 2017,@115712 I understand. I would like to assist y...,2.0,3.0
1,2,115712,True,Tue Oct 31 22:11:45 +0000 2017,@sprintcare and how do you propose we do that,,1.0
2,3,115712,True,Tue Oct 31 22:08:27 +0000 2017,@sprintcare I have sent several private messag...,1.0,4.0
3,4,sprintcare,False,Tue Oct 31 21:54:49 +0000 2017,@115712 Please send us a Private Message so th...,3.0,5.0
4,5,115712,True,Tue Oct 31 21:49:35 +0000 2017,@sprintcare I did.,4.0,6.0


In [8]:
%%time
tweets = pd.read_csv('../data/twcs/twcs.csv')

first_inbound = tweets[pd.isnull(tweets.in_response_to_tweet_id) & tweets.inbound]

inbounds_and_outbounds = pd.merge(first_inbound, tweets, left_on='tweet_id', 
                                  right_on='in_response_to_tweet_id').sample(frac=1)

# Filter to only outbound replies (from companies)
inbounds_and_outbounds = inbounds_and_outbounds[inbounds_and_outbounds.inbound_y ^ True]

#tqdm().pandas()  # Enable tracking of progress in dataframe `apply` calls

Wall time: 9.94 s


In [9]:
tqdm().pandas()

0it [00:00, ?it/s]

In [10]:
print(f'Data shape: {inbounds_and_outbounds.shape}')

Data shape: (794299, 14)


### Tokenizing and Vocab Build

We'll use NLTK's `casual_tokenize`, which handles a lot of corner cases found in social media data ("casual" text data) along with scitkit learn's `CountVectorizer`.  We won't use the actual `CountVectorizer`, just use it as a convenient vocabulary builder, which we'll apply with functions that turn text into "word indexes" - integers that represent each word - and back.

In [11]:
inbounds_and_outbounds.head()

Unnamed: 0,tweet_id_x,author_id_x,inbound_x,created_at_x,text_x,response_tweet_id_x,in_response_to_tweet_id_x,tweet_id_y,author_id_y,inbound_y,created_at_y,text_y,response_tweet_id_y,in_response_to_tweet_id_y
610043,2122295,625251,True,Wed Nov 08 15:29:12 +0000 2017,Hey @120533 on peut payer en 4x maintenant ou ...,2122294,,2122294,AmazonHelp,False,Wed Nov 08 15:40:34 +0000 2017,"@625251 Exactement, le paiement en 4 fois est ...",2122293.0,2122295.0
605412,2106941,621789,True,Wed Nov 08 01:06:33 +0000 2017,@AirAsiaSupport is the new Manila-Bali route i...,21069402106942,,2106942,AirAsiaSupport,False,Wed Nov 08 07:23:32 +0000 2017,"@621789 Hi Pia, for Manila-Bali, you have to b...",2106943.0,2106941.0
272269,1000516,356865,True,Sun Oct 22 16:46:47 +0000 2017,So we got RedZone or Bears vs Panthers on cbs....,1000515,,1000515,comcastcares,False,Sun Oct 22 17:19:58 +0000 2017,@356865 Happy to hear you are pleased with the...,,1000516.0
309113,1129407,380895,True,Tue Oct 24 11:11:00 +0000 2017,"When @115821 is great, it's great. When @11582...",1129406,,1129406,AmazonHelp,False,Tue Oct 24 11:15:00 +0000 2017,@380895 Hi Billy- Without providing personal i...,1129405.0,1129407.0
107767,412613,213467,True,Tue Oct 10 00:39:08 +0000 2017,@azuresupport #azcommunity,412612,,412612,AzureSupport,False,Tue Oct 10 00:40:51 +0000 2017,@213467 Hello! How can we help you today? ^MH,,412613.0


In [12]:
inbounds_and_outbounds.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 794299 entries, 610043 to 589229
Data columns (total 14 columns):
 #   Column                     Non-Null Count   Dtype  
---  ------                     --------------   -----  
 0   tweet_id_x                 794299 non-null  int64  
 1   author_id_x                794299 non-null  object 
 2   inbound_x                  794299 non-null  bool   
 3   created_at_x               794299 non-null  object 
 4   text_x                     794299 non-null  object 
 5   response_tweet_id_x        794299 non-null  object 
 6   in_response_to_tweet_id_x  0 non-null       float64
 7   tweet_id_y                 794299 non-null  int64  
 8   author_id_y                794299 non-null  object 
 9   inbound_y                  794299 non-null  bool   
 10  created_at_y               794299 non-null  object 
 11  text_y                     794299 non-null  object 
 12  response_tweet_id_y        263771 non-null  object 
 13  in_response_to_tweet_id_

In [13]:
inbounds_and_outbounds.author_id_y.value_counts(normalize = True).head(20)

AmazonHelp         0.106556
AppleSupport       0.093960
Uber_Support       0.050395
Delta              0.035862
SpotifyCares       0.033917
Tesco              0.031332
AmericanAir        0.030852
comcastcares       0.030015
SouthwestAir       0.026421
TMobileHelp        0.025261
British_Airways    0.024690
Ask_Spectrum       0.021742
VirginTrains       0.018075
UPSHelp            0.017994
hulu_support       0.017870
ChipotleTweets     0.017441
sprintcare         0.015925
XboxSupport        0.015748
AskPlayStation     0.014349
sainsburys         0.013550
Name: author_id_y, dtype: float64

In [14]:
inbounds_and_outbounds.author_id_y.value_counts(normalize = True).tail(20) #108 companies

askvisa            0.000721
ask_progressive    0.000658
GooglePlayMusic    0.000643
YahooCare          0.000640
USCellularCares    0.000628
asksalesforce      0.000578
MTNC_Care          0.000574
MOO                0.000524
KeyBank_Help       0.000482
AskSeagate         0.000473
AskVirginMoney     0.000466
OPPOCareIN         0.000428
AskRobinhood       0.000392
AskTigogh          0.000349
JackBox            0.000253
mediatemplehelp    0.000239
AskDSC             0.000238
CarlsJr            0.000174
HotelTonightCX     0.000165
OfficeSupport      0.000073
Name: author_id_y, dtype: float64

In [16]:
# Replace anonymized screen names with common token @__sn__
def sn_replace(match):
    _sn = match.group(2).lower()
    if not _sn.isnumeric():
        # This is a company screen name
        return match.group(1) + match.group(2)
    return ' @__sn__'

sn_re = re.compile('(\W@|^@)([a-zA-Z0-9_]+)')
print("Replacing anonymized screen names in X...")
x_text = inbounds_and_outbounds.text_x.progress_apply(lambda txt: sn_re.sub(sn_replace, txt))
print("Replacing anonymized screen names in Y...")
y_text = inbounds_and_outbounds.text_y.progress_apply(lambda txt: sn_re.sub(sn_replace, txt))

Replacing anonymized screen names in X...


  0%|          | 0/794299 [00:00<?, ?it/s]

Replacing anonymized screen names in Y...


  0%|          | 0/794299 [00:00<?, ?it/s]

In [17]:
def sn_replace(match):
    _sn = match.group(2).lower()
    if not _sn.isnumeric():
        # This is a company screen name
        return match.group(1) + match.group(2)
    return ' @__sn__'
inbounds_and_outbounds.text_x[0:10]

610043    Hey @120533 on peut payer en 4x maintenant ou ...
605412    @AirAsiaSupport is the new Manila-Bali route i...
272269    So we got RedZone or Bears vs Panthers on cbs....
309113    When @115821 is great, it's great. When @11582...
107767                           @azuresupport #azcommunity
613898    @VirginTrains do you have a policy on recruiti...
457395    Is it possible to stream @116935 from your And...
820438    @TMobileHelp hi I am a customer on TMobile One...
160448    @116016 is thee worst bank. They could careles...
630258    @116297 Very disappointed in your service to m...
Name: text_x, dtype: object

In [58]:
inbounds_and_outbounds.text_x[630258]

'@116297 Very disappointed in your service to me as a customer of many years.  Again issues with you your customer service department. They promise one thing and do another. I guess Fool me once, shame on you. Fool me twice, shame on me.'

In [59]:
inbounds_and_outbounds.text_y[630258]

'@640349 We have nothing but love for you. We have received your DM and will be glad to take a deeper look into this for you. ^AW'

In [18]:
inbounds_and_outbounds.text_x[0:10].apply(lambda txt: sn_re.sub(sn_replace, txt))

610043    Hey @__sn__ on peut payer en 4x maintenant ou ...
605412    @AirAsiaSupport is the new Manila-Bali route i...
272269    So we got RedZone or Bears vs Panthers on cbs....
309113    When @__sn__ is great, it's great. When @__sn_...
107767                           @azuresupport #azcommunity
613898    @VirginTrains do you have a policy on recruiti...
457395    Is it possible to stream @__sn__ from your And...
820438    @TMobileHelp hi I am a customer on TMobile One...
160448     @__sn__ is thee worst bank. They could carele...
630258     @__sn__ Very disappointed in your service to ...
Name: text_x, dtype: object

In [19]:
count_vec = CountVectorizer(tokenizer=casual_tokenize, max_features=MAX_VOCAB_SIZE - 3)
print("Fitting CountVectorizer on X and Y text data...")
count_vec.fit(tqdm(x_text + y_text))
analyzer = count_vec.build_analyzer()
vocab = {k: v + 3 for k, v in count_vec.vocabulary_.items()}
vocab['__unk__'] = UNK
vocab['__pad__'] = PAD
vocab['__start__'] = START

# Used to turn seq2seq predictions into human readable strings
reverse_vocab = {v: k for k, v in vocab.items()}
print(f"Learned vocab of {len(vocab)} items.")

Fitting CountVectorizer on X and Y text data...


  0%|          | 0/794299 [00:00<?, ?it/s]



Learned vocab of 8192 items.


In [20]:
x_text + y_text

610043    Hey @__sn__ on peut payer en 4x maintenant ou ...
605412    @AirAsiaSupport is the new Manila-Bali route i...
272269    So we got RedZone or Bears vs Panthers on cbs....
309113    When @__sn__ is great, it's great. When @__sn_...
107767    @azuresupport #azcommunity @__sn__ Hello! How ...
                                ...                        
9343      About to board @VirginTrains to Wembley for @_...
673749    @AirAsiaSupport \nSaya mahu booking tiket KL-K...
13800     @TMobileHelp just sent you guys a dm with an o...
248940    Can We Get Fox Sports On Basic Cable Instead O...
589229    Because @__sn__ will tell you you’re signing u...
Length: 794299, dtype: object

In [21]:
vocab['__start__']

2

### Vocab Helper Functions
These helper functions take strings and turn them into word indexes used by the actual seq2seq models.  This turns something like "This is how we do it." into a padded array of integers, like [153, 4, 643, 48, 94, 54, 8, 0, 0, 0].  We'll apply the `to_word_idx` function to our text data to get our `N x MESSAGE_LEN` training/test data.

In [22]:
def to_word_idx(sentence):
    full_length = [vocab.get(tok, UNK) for tok in analyzer(sentence)] + [PAD] * MAX_MESSAGE_LEN
    return full_length[:MAX_MESSAGE_LEN]

def from_word_idx(word_idxs):
    return ' '.join(reverse_vocab[idx] for idx in word_idxs if idx != PAD).strip()


In [23]:
x_text.head().apply(to_word_idx)

610043    [3450, 449, 5111, 0, 0, 2630, 353, 0, 5179, 40...
605412    [451, 3976, 7088, 4943, 0, 6173, 3847, 3831, 7...
272269    [6609, 7681, 3250, 0, 5158, 0, 7621, 0, 5111, ...
309113    [7734, 449, 3976, 3268, 154, 3994, 3268, 157, ...
107767    [481, 20, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...
Name: text_x, dtype: object

In [24]:
# Make sure our helpers work as expected...
x_text.head().apply(to_word_idx).apply(from_word_idx)

610043    hey @__sn__ on __unk__ __unk__ en 4x __unk__ o...
605412    @airasiasupport is the new __unk__ route inclu...
272269    so we got __unk__ or __unk__ vs __unk__ on cbs...
309113    when @__sn__ is great , it's great . when @__s...
107767                           @azuresupport #azcommunity
Name: text_x, dtype: object

In [25]:
vocab['anyone']

878

In [26]:
reverse_vocab[878]

'anyone'

In [27]:
print("Calculating word indexes for X...")
x = pd.np.vstack(x_text.progress_apply(to_word_idx).values)
print("Calculating word indexes for Y...")
y = pd.np.vstack(y_text.progress_apply(to_word_idx).values)

Calculating word indexes for X...


  


  0%|          | 0/794299 [00:00<?, ?it/s]

Calculating word indexes for Y...


  after removing the cwd from sys.path.


  0%|          | 0/794299 [00:00<?, ?it/s]

### Train / Test Split
Here, we split our data into training and test sets.  For simplicity, we use a random split, which may result in different distributions between the training and test set, but we won't worry about that for this case.

In [28]:
all_idx = list(range(len(x)))
random.seed(1234)
train_idx = set(random.sample(all_idx, int(0.8 * len(all_idx))))
test_idx = {idx for idx in all_idx if idx not in train_idx}

train_x = x[list(train_idx)]
test_x = x[list(test_idx)]
train_y = y[list(train_idx)]
test_y = y[list(test_idx)]

assert train_x.shape == train_y.shape
assert test_x.shape == test_y.shape

print(f'Training data of shape {train_x.shape} and test data of shape {test_x.shape}.')

Training data of shape (635439, 30) and test data of shape (158860, 30).


In [29]:
train_x

array([[3450,  449, 5111, ...,    1,    1,    1],
       [ 451, 3976, 7088, ...,    1,    1,    1],
       [6609, 7681, 3250, ...,    1,    1,    1],
       ...,
       [ 568, 7207, 1288, ...,    1,    1,    1],
       [1490, 7681, 3182, ...,    1,    1,    1],
       [1178,  449, 7765, ..., 6497,  552,  269]])

## Model Creation
We'll create and compile the model here.  It will consist of the following components:

- Shared word embeddings
  - A shared embedding layer that turns word indexes (a sparse representation) into a dense/compressed representation.  This embeds both the request from the customer, and also the last words uttered by the model that are fed back into the model.
- Encoder RNN
  - In this case, a single LSTM layer.  This encodes the whole input sentence into a context vector (or thought vector) that represents completely what the customer is saying, and produces a single output.
- Decoder RNN
  - This RNN (also an LSTM in this case) decodes the context vector into a string of tokens/utterances.  For each time step, it takes the context vector and the embedded last utterance and produces the next utterance, which is fed back into the model.  More complex and effective models copy the encoder state into the decoder, add more layers of LSTMs, and apply attention mechanisms - but these are out of the scope of this simple example.
- Next Word Dense+Softmax
  - These two layers take the decoder output and turn it into the next word to be uttered.  The dense layer allows the decoder to not map directly to words uttered, and the softmax turns the dense layer output into a probability distribution, from which we pick the most likely next word.

![seq2seq model structure](https://i.imgur.com/JmuryKu.png)

In [6]:
# keras imports, because there are like... A million of them.
from keras.models import Model
# from keras.optimizers import Adam
from tensorflow.keras.optimizers import Adam
from keras.layers import Dense, Input, LSTM, Dropout, Embedding, RepeatVector, concatenate, \
     TimeDistributed
from keras.utils import np_utils

# from tensorflow.keras import Model, Sequential
# from tensorflow.keras.layers import Dense, Input, LSTM, Dropout, Embedding, RepeatVector, concatenate, \
#      TimeDistributed

In [7]:
import tensorflow as tf
tf.__version__

'1.15.0'

In [8]:
import numpy as np
np.__version__

'1.19.5'

In [9]:
import keras
keras.__version__

'2.3.1'

In [12]:
def create_model():
    shared_embedding = Embedding(
        output_dim=EMBEDDING_SIZE,
        input_dim=MAX_VOCAB_SIZE,
        input_length=MAX_MESSAGE_LEN,
        name='embedding',
    )
    
    # ENCODER
    
    encoder_input = Input(
        shape=(MAX_MESSAGE_LEN,),
        dtype='int32',
        name='encoder_input',
    )
    
    embedded_input = shared_embedding(encoder_input)
    
    # No return_sequences - since the encoder here only produces a single value for the
    # input sequence provided.
    encoder_rnn = LSTM(
        CONTEXT_SIZE,
        name='encoder',
        dropout=DROPOUT
    )
    
    context = RepeatVector(MAX_MESSAGE_LEN)(encoder_rnn(embedded_input))
    
    # DECODER
    
    last_word_input = Input(
        shape=(MAX_MESSAGE_LEN, ),
        dtype='int32',
        name='last_word_input',
    )
    
    embedded_last_word = shared_embedding(last_word_input)
    # Combines the context produced by the encoder and the last word uttered as inputs
    # to the decoder.
    decoder_input = concatenate([embedded_last_word, context], axis=2)
    
    # return_sequences causes LSTM to produce one output per timestep instead of one at the
    # end of the intput, which is important for sequence producing models.
    decoder_rnn = LSTM(
        CONTEXT_SIZE,
        name='decoder',
        return_sequences=True,
        dropout=DROPOUT
    )
    
    decoder_output = decoder_rnn(decoder_input)
    
    # TimeDistributed allows the dense layer to be applied to each decoder output per timestep
    next_word_dense = TimeDistributed(
        Dense(int(MAX_VOCAB_SIZE / 2), activation='relu'),
        name='next_word_dense',
    )(decoder_output)
    
    next_word = TimeDistributed(
        Dense(MAX_VOCAB_SIZE, activation='softmax'),
        name='next_word_softmax'
    )(next_word_dense)
    
    return Model(inputs=[encoder_input, last_word_input], outputs=[next_word])

s2s_model = create_model()
optimizer = Adam(lr=LEARNING_RATE, clipvalue=5.0)
s2s_model.compile(optimizer='adam', loss='categorical_crossentropy')

## Model Training
We'll train the model here.  After each sub-batch of the dataset, we'll test with static input strings to see how the model is progressing in human readable terms.  Its important to have these tests along with traditional model evaluation to provide a better understanding of how well the model is training.

It's important to pull test strings from the real distribution of the data, also.  It can be hard to really put yourself in customers' shoes when writing test messages, and you will get non-representative results when you provide test examples that don't fit the true distribution of the input data (when your input text doesn't sound like real customer requests).

In [35]:
def add_start_token(y_array):
    """ Adds the start token to vectors.  Used for training data. """
    return np.hstack([
        START * np.ones((len(y_array), 1)),
        y_array[:, :-1],
    ])

def binarize_labels(labels):
    """ Helper function that turns integer word indexes into sparse binary matrices for 
        the expected model output.
    """
    return np.array([np_utils.to_categorical(row, num_classes=MAX_VOCAB_SIZE)
                     for row in labels])

In [36]:
def respond_to(model, text):
    """ Helper function that takes a text input and provides a text output. """
    input_y = add_start_token(PAD * np.ones((1, MAX_MESSAGE_LEN)))
    idxs = np.array(to_word_idx(text)).reshape((1, MAX_MESSAGE_LEN))
    for position in range(MAX_MESSAGE_LEN - 1):
        prediction = model.predict([idxs, input_y]).argmax(axis=2)[0]
        input_y[:,position + 1] = prediction[position]
    return from_word_idx(model.predict([idxs, input_y]).argmax(axis=2)[0])

In [37]:
def train_mini_epoch(model, start_idx, end_idx):
    """ Batching seems necessary in Kaggle Jupyter Notebook environments, since
        `model.fit` seems to freeze on larger batches (somewhere 1k-10k).
    """
    b_train_y = binarize_labels(train_y[start_idx:end_idx])
    input_train_y = add_start_token(train_y[start_idx:end_idx])
    
    model.fit(
        [train_x[start_idx:end_idx], input_train_y], 
        b_train_y,
        epochs=1,
        batch_size=BATCH_SIZE,
    )
    
    rand_idx = random.sample(list(range(len(test_x))), SUB_BATCH_SIZE)
    print('Test results:', model.evaluate(
        [test_x[rand_idx], add_start_token(test_y[rand_idx])],
        binarize_labels(test_y[rand_idx])
    ))
    
    input_strings = [
        "@AppleSupport I fix I this I stupid I problem I",
        "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service.",
    ]
    
    for input_string in input_strings:
        output_string = respond_to(model, input_string)
        print(f'> "{input_string}"\n< "{output_string}"')


### Train the model!

You can stop training by pressing the stop button - the training code is configured to watch for the `KeyboardInterrupt` exception triggered that way.  Also, it will run until the configured stopping point below.


Let's start the training! 🚀

In [38]:
training_time_limit = 360 * 60  # seconds (notebooks terminate after 1 hour)
start_time = time.time()
stop_after = start_time + training_time_limit

class TimesUpInterrupt(Exception):
    pass

try:
    for epoch in range(100):
        print(f'Training in epoch {epoch}...')
        for start_idx in range(0, len(train_x), SUB_BATCH_SIZE):
            train_mini_epoch(s2s_model, start_idx, start_idx + SUB_BATCH_SIZE)
            if time.time() > stop_after:
                raise TimesUpInterrupt
except KeyboardInterrupt:
    print("Halting training from keyboard interrupt.")
except TimesUpInterrupt:
    print(f"Halting after {time.time() - start_time} seconds spent training.")

Training in epoch 0...
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where

Epoch 1/1
Test results: 4.176258781433106
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi , , , , , . . ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , , , , . . ."
Epoch 1/1
Test results: 3.8346783142089844
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ we __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __unk__ __un

Epoch 1/1
Test results: 2.765815565109253
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there ! we want to help . please dm us your email address and we'll take a look backstage / __unk__ https://t.co/ldfdzrinat"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there ! we want to help . please dm us your email address and we'll take a look backstage / __unk__ https://t.co/ldfdzrinat"
Epoch 1/1
Test results: 2.7641054344177247
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there ! we ’ re sorry to hear this . please dm us your full name , address , address , and address . - __unk__"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there ! we ’ re sorry to hear this . please dm us your full name , address , address , and address . - __unk__"
Epoch 1/1
Test results: 2.740233102798462
> "@AppleSupport I fix 

Epoch 1/1
Test results: 2.4758928365707398
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi , __unk__ . we have a great day . ^ hp"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . we have a great day . ^ hp"
Epoch 1/1
Test results: 2.510856426239014
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , we are here to help . please dm us your email address and more details . - __unk__"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are here to help . please dm us your email address and more details . - __unk__"
Epoch 1/1
Test results: 2.4436936779022216
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi , __unk__ . we can help you with your account . please dm us your account number and i will be happy to help . ^ __unk__"
> "@AmazonHelp I hadnt expected that such a 

Epoch 1/1
Test results: 2.410180543899536
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , we are sorry to hear this . please dm us your email address so we can connect ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your email address so we can connect ."
Epoch 1/1
Test results: 2.3673566608428955
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there , we are sorry to hear this . please dm us your account email address and we'll look into this for you ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , sorry to hear this . please send us a dm with your email address and more details so we can assist ."
Epoch 1/1
Test results: 2.2846765422821047
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi , i am sorry to hear this . pl

Epoch 1/1
Test results: 2.2513327293395995
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . please dm us the phone number on your account so we can follow up ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we can help with the service issues . can you please dm your full name , address , and phone number , and address on the"
Epoch 1/1
Test results: 2.2951872901916506
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ hi there ! we can help . dm us your email address and we'll get started . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . we are sorry to hear this . please dm us your account email address and we'll take a look ."
Epoch 1/1
Test results: 2.2925282154083253
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help .

> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . please dm us your email address and we will be able to help you ."
Epoch 1/1
Test results: 2.1661465606689454
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . send us a dm with your email address so we can connect ."
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your confirmation number and we'll take a look ."
Epoch 1/1
Test results: 2.1782838344573974
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . send us a dm and we'll get started . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , we are sorry to hear this . please dm us your contact info and store address . thanks 

Epoch 1/1
Test results: 2.154663122177124
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . dm us the country you're located in . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear that . please dm us your tracking number and phone number . ^ ac https://t.co/wkjhdxwgrq"
Epoch 1/1
Test results: 2.1458360786437987
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . please dm us the ios version you are using . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . please dm us your contact number and store details . ^ rr"
Epoch 1/1
Test results: 2.1449705142974853
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . send us a dm with more details and we'll go from there . h

Epoch 1/1
Test results: 2.1726361446380613
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . dm us which ios version you're using . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear about this . please dm us your email address and we'll be happy to look into this for you ."
Epoch 1/1
Test results: 2.1808878707885744
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . send us a dm and we'll get started . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there ! we are sorry to hear this . can you dm us your account's email address ? we'll take a look backstage / cg https://t.co/ldfdzrinat"
Epoch 1/1
Test results: 2.0714572525024413
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to he

Epoch 1/1
Test results: 2.128258949279785
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . dm us and we'll go from there . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . please dm us your email address and more details about your concern ."
Epoch 1/1
Test results: 2.0959259586334227
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . dm us which iphone and software version you are using . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , i would be happy to help with your internet services . when you can plz dm your phone number and address to chat ? ^ dt"
Epoch 1/1
Test results: 2.0503558750152586
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . let's take this to dm so we can g

Epoch 1/1
Test results: 2.069374320983887
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . dm us the exact ios version you are using . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , we are sorry to hear this . please dm us with your order number and bakery-cafe location and we will look into this for you ."
Epoch 1/1
Test results: 2.0570416564941407
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . dm us and we'll get started . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear about this . please dm us your email address so we can look into this for you ."
Epoch 1/1
Test results: 1.9996230545043945
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . let's take this to dm 

Epoch 1/1
Test results: 2.001162254333496
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . which device are you using ?"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your email address and we'll reach out right away ."
Epoch 1/1
Test results: 2.0393237600326537
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . which version of ios 11 are you using ?"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear about this . please dm us your email address and more details so we can look into this for you ."
Epoch 1/1
Test results: 2.035539478302002
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we'd like to help . which device are you using ?"
> "@AmazonHelp I hadnt expected that such a big bran

Epoch 1/1
Test results: 1.9993338747024536
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . let's take a look at this together . dm us the ios version you're using . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your email address and we'll look into this for you ."
Epoch 1/1
Test results: 2.0264577436447144
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . let's take a look at this together . dm us and we'll get started . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , we are sorry to hear this . please dm us your contact info and store address . thanks !"
Epoch 1/1
Test results: 2.0612673454284667
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we can hel

Epoch 1/1
Test results: 1.9559007415771483
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . let's take a closer look at this . dm us and we'll take a closer look at this . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi , __unk__ . please dm us your tracking number and contact number via dm . ^ ds https://t.co/wkjhdxwgrq"
Epoch 1/1
Test results: 2.022499376296997
> "@AppleSupport I fix I this I stupid I problem I"
< "@__sn__ we want to help . let's take this to dm and we'll explore ways to provide you assistance . https://t.co/gdrqu22ypt"
> "@AmazonHelp I hadnt expected that such a big brand like amazon would have such a poor customer service."
< "@__sn__ hi there , i am sorry to hear this . please dm us your tracking number and contact number via the link so we can look into this for"
Epoch 1/1
Test results: 1.9363671741485595
> "@AppleSupport I fix I this

In [60]:
respond_to(s2s_model, '''@AppleSupport iPhone 8 touchID doesnt unlock while charging on 
    110v w/ 61w laptop charger to usbc lightning cable just uh.. so you guys know''')

"@__sn__ we want to help . please send us a dm with your current ios version and we'll go from there . https://t.co/gdrqu22ypt"

In [61]:
respond_to(s2s_model, '''@sprintcare I can't make calls... wtf''')

"@__sn__ hi there ! we don't have any info on this right now , but we'll let our devs know it's something you'd like to see / nq"

In [62]:
respond_to(s2s_model, '''''@sprintcare is the worst customer service''')

"@__sn__ hi there ! we don't have any info on this right now , but we'll let our devs know it's something you'd like to see / nq"

In [63]:
respond_to(s2s_model, '''''@VerizonSupport My friend is without internet we need to play videogames together please our skills diminish every moment without internet''')

"@__sn__ hi there ! we are sorry to hear about this . can you dm us your account's email address or username ? we'll take a look backstage / nq"

In [64]:
respond_to(s2s_model, '''@XboxSupport can I change me sons Xbox live account to his Hotmail account, currently linked to my Hotmail account''')

"@__sn__ hey there ! can you dm us your account's username or email address ? we'll take a look backstage / nq https://t.co/ldfdzrinat"

In [65]:
respond_to(s2s_model, """@116297 Very disappointed in your service to me as a customer of many years.  Again issues with you your customer service department. They promise one thing and do another. I guess Fool me once, shame on you. Fool me twice, shame on me.""")

"@__sn__ hi there ! we don't have any info on this right now , but we'll let our team know it's something you'd like to see / nq"

In [44]:
#s2s_model

In [45]:
#os.getcwd()

In [13]:
import h5py
from keras.utils.vis_utils import plot_model

In [48]:
s2s_model.save("../model/s2s_model.h5")

In [52]:
tf.keras.utils.plot_model(s2s_model)

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.
ERROR! Session/line number was not unique in database. History logging moved to new session 100


In [50]:
s2s_model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
last_word_input (InputLayer)    (None, 30)           0                                            
__________________________________________________________________________________________________
encoder_input (InputLayer)      (None, 30)           0                                            
__________________________________________________________________________________________________
embedding (Embedding)           (None, 30, 100)      819200      encoder_input[0][0]              
                                                                 last_word_input[0][0]            
__________________________________________________________________________________________________
encoder (LSTM)                  (None, 100)          80400       embedding[0][0]            

In [6]:
# keras.models.load_model(
#     '../model/s2s_model.h5'
# )