# Interpretability
In this notebook we will explore how our models use the attention mechanism during prediction.

We will only use the twitter dataset because it's the best public available corpus for a customer support domain.

In [1]:
import torch.nn as nn
from tqdm import tqdm
import pickle

In [2]:
word2ix =  pickle.load(open('../data/twitter/tmp/word2ix.pkl', 'rb'))
trn_q_vecs, trn_a_vecs, trn_y = pickle.load(open('../data/twitter/tmp/de_train.pkl', 'rb'))
ix2word = {v:k  for k,v in word2ix.items()}
trn_q_vecs = trn_q_vecs[trn_y == 1]
trn_a_vecs = trn_a_vecs[trn_y == 1]

## Dual Encoders
Lets start by loading our dual attention dual encoder

In [3]:
import sys
sys.path.append('../retrieval')

In [4]:
from ElmoDE import ElmoDE
from utils import pad_sequences

In [5]:
model_path = "../data/twitter/models/transfer.ElmoDE.torch"

In [6]:
modelclass_name = model_path.split('.')[-2]
model = getattr(sys.modules[modelclass_name], modelclass_name).load(model_path, map_location={'cuda:2':'cuda:0'})
(model.cuda(), model.eval())

(ElmoDE(
   (encoder): SentenceEncoder(
     (embedding): EmbeddingLayer(
       (elmo): ElmoEmbeddings(
         (elmo): Elmo(
           (_elmo_lstm): _ElmoBiLm(
             (_token_embedder): _ElmoCharacterEncoder(
               (char_conv_0): Conv1d(16, 32, kernel_size=(1,), stride=(1,))
               (char_conv_1): Conv1d(16, 32, kernel_size=(2,), stride=(1,))
               (char_conv_2): Conv1d(16, 64, kernel_size=(3,), stride=(1,))
               (char_conv_3): Conv1d(16, 128, kernel_size=(4,), stride=(1,))
               (char_conv_4): Conv1d(16, 256, kernel_size=(5,), stride=(1,))
               (char_conv_5): Conv1d(16, 512, kernel_size=(6,), stride=(1,))
               (char_conv_6): Conv1d(16, 1024, kernel_size=(7,), stride=(1,))
               (_highways): Highway(
                 (_layers): ModuleList(
                   (0): Linear(in_features=2048, out_features=4096, bias=True)
                   (1): Linear(in_features=2048, out_features=4096, bias=True)
         

### Selecting the example:
Now that we have imported our model we need to select a good example to plot the attentions

In [7]:
def get_pair_text(q_vec, a_vec, ix):
    pair = {}
    pair["question"] = [ix2word[i] for i in trn_q_vecs[ix]]
    pair["answer"] = [ix2word[i] for i in trn_a_vecs[ix]]
    return pair

In [63]:
example_ix = 800 # 20!! 300!! 800!! 1994!! 400!!

In [64]:
pair = get_pair_text(trn_q_vecs, trn_a_vecs, example_ix)

In [65]:
print (pair["question"])

['_APPLE_', 'I', '’', 've', 'downloaded', 'the', 'new', 'iPhone', 'update', 'and', 'my', 'battery', 'lasts', 'half', 'as', 'long', '.', 'What', '’', 's', 'going', 'on', '?', '!', '?', '_EOT_']


In [66]:
len(pair["question"])

26

In [67]:
print (pair["answer"])

['_USERID_', 'We', 'completely', 'understand', 'wanting', 'to', 'have', 'good', 'battery', 'life', '.', 'How', 'long', 'is', 'the', 'battery', 'lasting', 'on', 'a', 'full', 'charge', 'after', 'the', 'update', '?']


In [53]:
e1_inputs, e1_lengths, e1_idxs = pad_sequences([trn_q_vecs[example_ix]])
e2_inputs, e2_lengths, e2_idxs = pad_sequences([trn_a_vecs[example_ix]])

In [54]:
q_attn, prob = model.get_attn_weigths(e1_inputs.cuda(), e1_lengths, e1_idxs, e2_inputs.cuda(), e2_lengths, e2_idxs)

In [55]:
prob

tensor([[0.9415]], device='cuda:0')

In [56]:
def align_attention(doc, doc_attn):
    print (len(doc),len(doc_attn))
    attn_align = []
    for i in range(len(doc)):
        attn_align.append((doc[i], doc_attn[i].item()*100))
    return attn_align, sorted(attn_align, key=lambda x: x[1], reverse=True)

In [59]:
attn_align, sorted_align = align_attention(pair["question"], q_attn[0][0])
sorted_align

75 75


[('👌', 49.17473495006561),
 ('Worked', 29.57480251789093),
 ('_APPLE_', 17.98672527074814),
 ('🏽', 3.0652744695544243),
 ('_EOT_', 0.11834370670840144),
 ('_USERID_', 0.01909219427034259),
 ('_EOT_', 0.017554494843352586),
 ('.', 0.012524795602075756),
 ('?', 0.00774120053392835),
 ('_APPLE_', 0.004826409349334426),
 ('?', 0.004819091554963961),
 ('Please', 0.0037497171433642507),
 ('.', 0.0017171305444207974),
 ('_EOT_', 0.0014407723938347772),
 ('afterwards', 0.0012076463463017717),
 ('keep', 0.000926772099774098),
 ('_EOT_', 0.000451875575890881),
 ('.', 0.00036163633012620267),
 ('_APPLE_', 0.00025873762297123903),
 ('ima', 0.00025843701223493554),
 ('posted', 0.00024589144231867976),
 ('iOS', 0.00016307300256812596),
 ('_EOT_', 0.00015981811429810477),
 ('so', 0.00015751962791910046),
 ('update', 0.0001569897790432151),
 ('back', 0.0001437989340047352),
 ('model', 0.00013501961575457244),
 ('update', 0.00013086736316836323),
 ('Is', 0.0001254759240509884),
 ('us', 0.00012099006880

In [None]:
attn_align, sorted_align = align_attention(pair["answer"], a_attn[0][0])
attn_align

# Sequence-to-sequence
Lets start by loading the seq2seq model.

**NOTE**:
We need to reset the notebook kernel for loading the seq2seq utils

In [1]:
import sys
sys.path.append('../generative')
from seq2seq import Encoder, Decoder, Attention
from utils import load_seq2seq, prepare_data
import pickle

In [2]:
word2ix =  pickle.load(open('../data/twitter/tmp/word2ix.pkl', 'rb'))
in_seqs, out_seqs = pickle.load(open('../data/twitter/tmp/seq2seq_test.pkl', 'rb'))
ix2word = {v:k  for k,v in word2ix.items()}

In [3]:
model_path = "../data/twitter/models/trained.seq2seq.torch"

In [4]:
encoder, decoder = load_seq2seq(model_path)
(encoder.cuda(), decoder.cuda(), encoder.eval(), decoder.eval())

(Encoder(
   (embedding): Embedding(26470, 300)
   (rnn): LSTM(300, 300, num_layers=2, dropout=0.2, bidirectional=True)
 ), Decoder(
   (embedding): Embedding(26470, 300)
   (embedding_dropout): Dropout(p=0.2)
   (rnn): LSTM(300, 300, num_layers=2, dropout=0.2)
   (concat): Linear(in_features=600, out_features=300, bias=True)
   (out): Linear(in_features=300, out_features=26470, bias=True)
   (attn): Attention()
 ), Encoder(
   (embedding): Embedding(26470, 300)
   (rnn): LSTM(300, 300, num_layers=2, dropout=0.2, bidirectional=True)
 ), Decoder(
   (embedding): Embedding(26470, 300)
   (embedding_dropout): Dropout(p=0.2)
   (rnn): LSTM(300, 300, num_layers=2, dropout=0.2)
   (concat): Linear(in_features=600, out_features=300, bias=True)
   (out): Linear(in_features=300, out_features=26470, bias=True)
   (attn): Attention()
 ))

In [5]:
import torch
def decoding(encoder, decoder, bos_token, input_seq, input_length, targets, mask):
    attention_weights = []
    with torch.no_grad():
        # Forward input through encoder model
        encoder_outputs, encoder_hidden, encoder_cell = encoder(input_seq, input_length)
        # Create initial decoder input (start with BOS tokens for each sentence)
        decoder_input = torch.LongTensor([[bos_token for _ in range(input_seq.shape[1])]]).cuda()
        pred_tokens = decoder_input.clone()
        # Set initial decoder hidden state to the encoder's final hidden state
        decoder_hidden = encoder_hidden[:decoder.n_layers]
        decoder_cell = encoder_cell[:decoder.n_layers]
        for t in range(1, targets.shape[0]):
            # Forward pass through decoder
            decoder_output, decoder_hidden, decoder_cell, attn_weights = decoder.attn_forward(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
            decoder_input = targets[t].view(1, -1)
            pred_tokens = torch.cat((pred_tokens, decoder_input.clone()), dim=0)
            attention_weights.append(attn_weights[0][0].cpu().numpy().tolist())
        return pred_tokens, attention_weights
            

In [6]:
def get_pair_text(q_vec, a_vec, ix):
    pair = {}
    pair["question"] = [ix2word[i] for i in q_vec[ix]]
    pair["answer"] = [ix2word[i] for i in a_vec[ix]]
    return pair

In [10]:
example_ix = 20 #20 300 800 -- 1994!! 400!!

In [11]:
pair = get_pair_text(in_seqs, out_seqs, example_ix)

In [12]:
print (pair["question"])

['_BOS_', '_APPLE_', 'help', '!', 'My', 'phone', 'keeps', 'rebooting', 'every', 'couple', 'of', 'minutes', '.', 'This', 'screen', '_URL_', 'can', 'I', 'stop', 'this', '?', '!', '_URL_', '_EOT_', '_EOS_']


In [13]:
print (pair["answer"])

['_BOS_', '_USERID_', 'We', 'know', 'the', 'importance', 'of', 'getting', 'that', 'fixed', 'quickly', '!', "You'll", 'want', 'to', 'get', 'your', 'device', 'updated', 'to', 'correct', 'that', 'issue', '.', "Here's", 'the', 'steps', ':', '_URL_', '_EOS_']


In [17]:
q = ['_BOS_', '_APPLE_', 'I', '’', 've', 'downloaded', 'the', 'new', 'iPhone', 'update', 'and', 'my', 'battery', 'lasts', 'half', 'as', 'long', '.', 'What', '’', 's', 'going', 'on', '?', '!', '?', '_EOT_', '_EOS_']
q_in = [word2ix[word] for word in q]

In [20]:
a = ['_BOS_', '_USERID_', 'We', 'completely', 'understand', 'wanting', 'to', 'have', 'good', 'battery', 'life', '.', 'How', 'long', 'is', 'the', 'battery', 'lasting', 'on', 'a', 'full', 'charge', 'after', 'the', 'update', '?', '_EOS_']
a_in = [word2ix[word] for word in a]

Prepare input:

In [21]:
# Extract fields from batch
inputs, lengths, targets, mask, max_target_len = prepare_data([q_in], [a_in])

In [22]:
pred_tokens, attn_weigths = decoding(encoder, decoder, word2ix["_BOS_"], inputs, lengths, targets, mask)

In [23]:
def restore_words(pred_tokens):
    return " ".join(ix2word[ix.item()] for ix in pred_tokens)

restore_words(pred_tokens.view(1, -1)[0])

'_BOS_ _USERID_ We completely understand wanting to have good battery life . How long is the battery lasting on a full charge after the update ? _EOS_'

In [24]:
restore_words(inputs.view(1, -1)[0])

'_BOS_ _APPLE_ I ’ ve downloaded the new iPhone update and my battery lasts half as long . What ’ s going on ? ! ? _EOT_ _EOS_'

In [25]:
import numpy as np
alignments = np.array(attn_weigths)

In [26]:
alignments.shape

(26, 28)

In [27]:
import matplotlib.pyplot as plt
import matplotlib
source_words = restore_words(inputs.view(1, -1)[0]).split()          #28
target_words = restore_words(pred_tokens.view(1, -1)[0]).split()[1:] #26

matrix = alignments
fig, ax = plt.subplots(figsize=(25, 15))
im = ax.imshow(matrix, cmap="Blues")

# We want to show all ticks...
ax.set_xticks(np.arange(len(source_words)))
ax.set_yticks(np.arange(len(target_words)))

# ... and label them with the respective list entries
ax.set_xticklabels(source_words)
ax.set_yticklabels(target_words)

# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor")
"""
# Loop over data dimensions and create text annotations.
for i in range(len(target_words)):
    for j in range(len(source_words)):
        text = ax.text(j, i, int(matrix[i, j]*100), ha="center", va="center", color="black")
"""
fig.tight_layout()
plt.savefig("plots/attention_align.png")