# attentionの可視化
- ついでにエラー解析
- パラメータは学習済のものを使用

## Settings

In [1]:
import os

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import HTML

from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_auc_score

import spacy
from bs4 import BeautifulSoup

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, TensorDataset, DataLoader

In [2]:
print('Pytorch version: ', torch.__version__)
print('Currently selected device: ', torch.cuda.current_device())
print('# GPUs available: ', torch.cuda.device_count())
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
# device = torch.device('cpu') # デバッグ用
print(device)

Pytorch version:  1.1.0
Currently selected device:  0
# GPUs available:  3
cuda:2


## Def Network

In [3]:
class ATT(nn.Module):
    def __init__(self, hidden_dim):
        super(ATT, self).__init__()
        self.hidden_dim = hidden_dim
        self.fc = nn.Linear(hidden_dim, 1)
    def forward(self, inputs):
        b_size = inputs.size(0)
        inputs = inputs.contiguous().view(-1, self.hidden_dim)
        att = self.fc(torch.tanh(inputs))
        return F.softmax(att.view(b_size, -1), dim=1).unsqueeze(2)
    
class LSTM(nn.Module):
    def __init__(self, batch_size, vocab_size, emb_dim, hidden_dim, dropout_rate=0.0, activate='tanh', bidirectional=False, device='cpu'):
        super(LSTM, self).__init__()
        
        self.vocab_size = vocab_size
        self.emb_dim    = emb_dim
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.bidirectional = bidirectional
        self.activate   = activate
        
        self.emb  = nn.Embedding(self.vocab_size, self.emb_dim, padding_idx=0)
        self.lstm = nn.LSTM(self.emb_dim, self.hidden_dim, batch_first=True, bidirectional=self.bidirectional)
        self.att = ATT(hidden_dim * 2)
        
        self.fc0 = nn.Linear(hidden_dim * 2, 100)
        self.fc1 = nn.Linear(100, 2)
        self.do  = nn.Dropout(dropout_rate)
        self.device = device
        self.hidden = self.init_hidden()

    def forward(self, x):

        x = self.emb(x)
        lstm_out, self.hidden = self.lstm(x, self.hidden)
        
        att = self.att(lstm_out)
        feats = (lstm_out * att).sum(dim=1) # (b, s, h) -> (b, h)
        
        y = self.fc0(feats)
        y = self.do(y)
        if self.activate == 'tanh':
            y = self.fc1(torch.tanh(y))
        elif self.activate == 'relu':
            y = self.fc1(F.relu(y))
        tag_scores = F.log_softmax(y)
        return tag_scores, att

    def init_hidden(self):
        # The axes semantics are (num_layers, minibatch_size, hidden_dim)
        num = 2 if self.bidirectional else 1    # bidirectionalのとき2
        h0 = torch.zeros(num, self.batch_size, self.hidden_dim).to(self.device)
        c0 = torch.zeros(num, self.batch_size, self.hidden_dim).to(self.device)
        return (h0, c0)

In [4]:
def confusion2score(confusion):
    tp, fn, fp, tn = confusion.ravel()
    if tp == None:
        tp = 0
    if fn == None:
        fn = 0
    if fp == None:
        fp = 0
    if tn == None:
        tn = 0
    acc = (tp + tn) / (tp + fn + fp + tn)
    if (tp+fp) == 0:
        pre=0
    else:
        pre = tp / (tp + fp)
    if (tp+fn) == 0:
        rec=0
    else:
        rec = tp / (tp + fn)
    if (2*tp+fp+fn) == 0:
        f1=0
    else:
        f1  = (2 * tp) / (2*tp + fp + fn)
    return (acc, pre, rec, f1)

def training(net, train_loader, valid_loader, epoch_num):
    
    for epoch in range(epoch_num):

        train_loss = 0.0
        train_acc  = 0.0

        # train====================
        net.train()
        for xx, yy in train_loader:
            net.batch_size = len(yy)
            net.hidden = net.init_hidden()

            optimizer.zero_grad()    # 勾配の初期化

            output, att = net(xx)
            loss   = criterion(output, yy)

            train_loss += loss.item()
            train_acc += (output.max(1)[1] == yy).sum().item()

            loss.backward(retain_graph=True)     # 逆伝播の計算
            optimizer.step()    # 勾配の更新

def test(net, test_loader, y_test):
    net.eval()
    y_pred = y_pred = np.zeros((1,2))
    with torch.no_grad():
        for xx, yy in test_loader:
            net.batch_size = len(yy)
            net.hidden = net.init_hidden()

            output, att = net(xx)
            
            y_pred = np.concatenate([y_pred, output.to('cpu').numpy()], axis=0)

    confusion = confusion_matrix(y_test, np.argmax(y_pred[1:,], axis=1).tolist())
    scores = confusion2score(confusion)
    auc = roc_auc_score(y_test, np.argmax(y_pred[1:,], axis=1).tolist())
    return confusion, [scores[0], scores[1], scores[2], scores[3], auc]

## Prepare Data

In [5]:
spacy_en = spacy.load('en')
def tokenizer(text):
    soup = BeautifulSoup(text)
    clean_txt = soup.get_text()
    words = []
    for tok in spacy_en.tokenizer(clean_txt):
        if tok.text not in "[],.();:<>{}|*-~":
            words.append(tok.lemma_)
    return words

def df2indexseq(df, vocab_idx):
    data = []
    for text in df.values:
        words = tokenizer(text)
        data.append([vocab_idx[word]+1 for word in words if word in vocab_idx.keys()])
    return data

def padding(data):
    # npに変換し、0埋めを行う
    max_length = max([len(d) for d in data])
    padded_data = np.zeros((len(data), max_length))  # 0で埋める
    for i, d1 in enumerate(data):
        for j, d2 in enumerate(d1):
            padded_data[i][j] = d2
    return padded_data

class MyDataset(torch.utils.data.Dataset):
    def __init__(self, data, tags):
        super(MyDataset, self).__init__()
        assert len(data) == len(tags)
        self.data = data
        self.tags = tags
        
    def __len__(self):
        return len(self.tags)
    
    def __getitem__(self, index):
        return self.data[index], self.tags[index]

In [6]:
train_df = pd.read_csv('/home/b2018yniki/data/semeval2010task8/train_original.tsv', sep='\t')
train_df = train_df.assign(causal_flag = [0 if 'Cause-Effect' in relation else 1 for relation in train_df.relation.values]).drop(['relation', 'comment'], axis=1)
train_df.body = [text.replace('"', '') for text in train_df.body.values]
test_df = pd.read_csv('/home/b2018yniki/data/semeval2010task8/test_original.tsv', sep='\t')
test_df = test_df.assign(causal_flag = [0 if 'Cause-Effect' in relation else 1 for relation in test_df.relation.values]).drop(['relation', 'comment'], axis=1)
test_df.body = [text.replace('"', '') for text in test_df.body.values]
    
vocab = []
for text in train_df.body.values:
    vocab.extend(tokenizer(text))
vocab = list(set(vocab))
print('vocabulaly size: {}'.format(len(vocab)))
vocab_idx = dict(zip(vocab, range(len(vocab))))
del vocab
   
X_train = torch.LongTensor(padding(df2indexseq(train_df.body, vocab_idx))).to(device)
y_train = torch.LongTensor(train_df.causal_flag.values).to(device)
X_test  = torch.LongTensor(padding(df2indexseq(test_df.body, vocab_idx))).to(device)
y_test  = torch.LongTensor(test_df.causal_flag.values).to(device)

train_ds = MyDataset(X_train, y_train)
test_ds  = MyDataset(X_test, y_test)
del X_train, X_test, y_train

vocabulaly size: 16619


### 精度確認

In [7]:
np.random.seed(2019)
np.random.RandomState(2019)
torch.manual_seed(2019)

# hyperparameter
epoch      = 300
batch_size = 64
vocab_size = len(vocab_idx)+1
emb_dim    = 200
hidden_dim = 100
activate   = 'relu'
drop_rate  = 0.0
lr = 0.1
l2 = 0.001

train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_ds, batch_size=batch_size, shuffle=False)

np.random.seed(2019)
np.random.RandomState(2019)
torch.manual_seed(2019)

net = LSTM(batch_size, vocab_size, emb_dim, hidden_dim, drop_rate, activate, bidirectional=True, device=device).to(device)

criterion = nn.NLLLoss()
optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=l2)

training(net, train_loader, test_loader, epoch)
conf, result = test(net, test_loader, y_test.to('cpu').numpy().tolist())
print(conf)
print('Accuracy: {}, Precision: {}, Recall: {}, F1: {}, AUC: {}'.format(result[0], result[1], result[2], result[3], result[4]))



[[ 270   58]
 [  47 2342]]
Accuracy: 0.9613544350386456, Precision: 0.8517350157728707, Recall: 0.823170731707317, F1: 0.8372093023255814, AUC: 0.9017486140746717


## func. to visualize

In [8]:
def highlight(word, attn):
    html_color = '#%02X%02X%02X' % (255, int(255*(1 - attn)), int(255*(1 - attn)))
    return '<span style="background-color: {}">{}</span>'.format(html_color, word)

def mk_html(sequence, attns):
    html = ""
    sentence = ""
    for w_id, attn in zip(sequence, attns):
        if w_id == 0:
            break
        word = get_keys_from_value(vocab_idx, w_id-1)
        sentence += ' ' + word
        html += ' ' + highlight(
            word,
            attn
        )
    return sentence, (html + "<br><br>\n")

In [9]:
def get_keys_from_value(d, val):
    return [k for k, v in d.items() if v == val][0]

In [10]:
output_df = pd.DataFrame()

net.eval()
with torch.no_grad():
    for xx, yy in test_loader:
        net.batch_size = len(yy)
        net.hidden = net.init_hidden()
        output, att = net(xx)

        for i in range(len(yy)):
            sentence, html = mk_html(xx[i].to('cpu').numpy().tolist(), att[i,:,0].to('cpu').numpy().tolist())
            output_df = pd.concat([
                output_df, 
                pd.DataFrame([yy[i].to('cpu').numpy(), np.argmax(output[i].to('cpu').numpy()), sentence, html])
            ], axis=1)



In [12]:
output_df.index = ['y_true', 'y_pred', 'sentence', 'html']
output_df = output_df.T.reset_index(drop=True)
output_df.head()

Unnamed: 0,y_true,y_pred,sentence,html
0,1,1,The much common audit be about waste and recycle,"<span style=""background-color: #FFBCBC"">The</..."
1,1,1,The company fabricate plastic chair,"<span style=""background-color: #FFECEC"">The</..."
2,1,1,The school master teach the lesson with a stick,"<span style=""background-color: #FFDFDF"">The</..."
3,1,1,The suspect dump the dead body into a local r...,"<span style=""background-color: #FFE2E2"">The</..."
4,0,0,influenza be a infectious disease of bird cau...,"<span style=""background-color: #FFDBDB"">influ..."


In [13]:
# 念のため確認
confusion_matrix(output_df.y_true.values.tolist(), output_df.y_pred.values.tolist())

array([[ 270,   58],
       [  47, 2342]])

In [14]:
HTML(output_df.head(1)["html"].values[0])

### 因果ありを正解

In [44]:
tp_df = output_df[output_df.y_true == '0']
tp_df = tp_df[tp_df.y_pred == '0']
tp_df.reset_index(drop=True, inplace=True)
print(tp_df.shape)
tp_df.head()

(270, 4)


Unnamed: 0,y_true,y_pred,sentence,html
0,0,0,influenza be a infectious disease of bird cau...,"<span style=""background-color: #FFDBDB"">influ..."
1,0,0,Of the hundred of strain of avian influenza A...,"<span style=""background-color: #FFFBFB"">Of</s..."
2,0,0,In South Africa which have one of the well po...,"<span style=""background-color: #FFFAFA"">In</s..."
3,0,0,Traffic vibration on the street outside have ...,"<span style=""background-color: #FFDFDF"">Traff..."
4,0,0,The slide which be trigger by a avalanche con...,"<span style=""background-color: #FF9797"">The</..."


In [45]:
print(tp_df[tp_df.index == 0].sentence.values[0])
HTML(tp_df[tp_df.index == 0].html.values[0])

 influenza be a infectious disease of bird cause by type A strain of the influenza virus


In [46]:
print(tp_df[tp_df.index == 1].sentence.values[0])
HTML(tp_df[tp_df.index == 1].html.values[0])

 Of the hundred of strain of avian influenza A virus only four have cause human infection H5N1 and


In [47]:
print(tp_df[tp_df.index == 2].sentence.values[0])
HTML(tp_df[tp_df.index == 2].html.values[0])

 In South Africa which have one of the well police to public ratio on the continent the share of murder that result in a be about 18 % compare to % in the US and % in the UK


### 因果なしを因果ありと予測した例を可視化

In [19]:
output_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2717 entries, 0 to 2716
Data columns (total 4 columns):
y_true      2717 non-null object
y_pred      2717 non-null object
sentence    2717 non-null object
html        2717 non-null object
dtypes: object(4)
memory usage: 85.0+ KB


In [22]:
fp_df = output_df[output_df.y_true == '1']
fp_df = fp_df[fp_df.y_pred == '0']
fp_df.reset_index(drop=True, inplace=True)
fp_df.head()

Unnamed: 0,y_true,y_pred,sentence,html
0,1,0,oil have a anti bacterial and anti effect,"<span style=""background-color: #FFADAD"">oil</..."
1,1,0,the blister that appear in the mouth be cause...,"<span style=""background-color: #FFF2F2"">the</..."
2,1,0,or be cause by a intestinal parasite call,"<span style=""background-color: #FF5353"">or</s..."
3,1,0,The monitor station receive the signal throug...,"<span style=""background-color: #FFEFEF"">The</..."
4,1,0,-PRON- have the lead and have a sound that ca...,"<span style=""background-color: #FF9696"">-PRON..."


In [23]:
print(fp_df[fp_df.index == 0].sentence.values[0])
HTML(fp_df[fp_df.index == 0].html.values[0])

 oil have a anti bacterial and anti effect


In [24]:
print(fp_df[fp_df.index == 1].sentence.values[0])
HTML(fp_df[fp_df.index == 1].html.values[0])

 the blister that appear in the mouth be cause by the herpes simplex virus type 1 for short


In [25]:
print(fp_df[fp_df.index == 2].sentence.values[0])
HTML(fp_df[fp_df.index == 2].html.values[0])

 or be cause by a intestinal parasite call


In [26]:
print(fp_df[fp_df.index == 3].sentence.values[0])
HTML(fp_df[fp_df.index == 3].html.values[0])

 The monitor station receive the signal through a communication device and the combine signal be process to retrieve GPS datum


In [27]:
print(fp_df[fp_df.index == 4].sentence.values[0])
HTML(fp_df[fp_df.index == 4].html.values[0])

 -PRON- have the lead and have a sound that capture the with a combination of a British / Scandinavian accent that have not lose in the mix


In [28]:
print(fp_df[fp_df.index == 5].sentence.values[0])
HTML(fp_df[fp_df.index == 5].html.values[0])

 on the cause behind the below capacity production of ethanol by the sugar factory in the state figure in the Council on Tuesday


In [29]:
print(fp_df[fp_df.index == 6].sentence.values[0])
HTML(fp_df[fp_df.index == 6].html.values[0])

 The time machine itself be a plain gray box with a distinctive electronic hum create by the sound of a mechanical and a car engine rather than a process digital effect


In [30]:
print(fp_df[fp_df.index == 7].sentence.values[0])
HTML(fp_df[fp_df.index == 7].html.values[0])

 Although he initially return to his wife and child when they return from evacuation after the war now see his wife a and


In [32]:
print(fp_df[fp_df.index == 8].sentence.values[0])
HTML(fp_df[fp_df.index == 8].html.values[0])

 a adjective the word be derive from a noun mean force power or cause


In [33]:
print(fp_df[fp_df.index == 9].sentence.values[0])
HTML(fp_df[fp_df.index == 9].html.values[0])

 The produce signal depart from a ideal a the input angle approach the extreme of the range


### 因果ありを因果なしと予測した例

In [31]:
fn_df = output_df[output_df.y_true == '0']
fn_df = fn_df[fn_df.y_pred == '1']
fn_df.reset_index(drop=True, inplace=True)
print(fn_df.shape)
fn_df.head()

(58, 4)


Unnamed: 0,y_true,y_pred,sentence,html
0,0,1,The same effect be achieve the traditional wa...,"<span style=""background-color: #FFBBBB"">The</..."
1,0,1,The subject of be the source of a implication...,"<span style=""background-color: #FFDEDE"">The</..."
2,0,1,The treaty establish a double majority rule f...,"<span style=""background-color: #FFBFBF"">The</..."
3,0,1,The receiver be output the same tone to my 5,"<span style=""background-color: #FFCBCB"">The</..."
4,0,1,These type of script help visually impair ind...,"<span style=""background-color: #FFE2E2"">These..."


In [34]:
print(fn_df[fn_df.index == 0].sentence.values[0])
HTML(fn_df[fn_df.index == 0].html.values[0])

 The same effect be achieve the traditional way with a team of worker like


In [35]:
print(fn_df[fn_df.index == 1].sentence.values[0])
HTML(fn_df[fn_df.index == 1].html.values[0])

 The subject of be the source of a implication while the subject of be the recipient of a implication


In [36]:
print(fn_df[fn_df.index == 2].sentence.values[0])
HTML(fn_df[fn_df.index == 2].html.values[0])

 The treaty establish a double majority rule for Council decision


In [37]:
print(fn_df[fn_df.index == 3].sentence.values[0])
HTML(fn_df[fn_df.index == 3].html.values[0])

 The receiver be output the same tone to my 5


In [38]:
print(fn_df[fn_df.index == 4].sentence.values[0])
HTML(fn_df[fn_df.index == 4].html.values[0])

 These type of script help visually impair individual to get much enjoyment from programme because the action on screen be describe for them


In [39]:
print(fn_df[fn_df.index == 5].sentence.values[0])
HTML(fn_df[fn_df.index == 5].html.values[0])

 I will not tell you how much music mean now in this time when you be battle depression from unemployment


In [40]:
print(fn_df[fn_df.index == 6].sentence.values[0])
HTML(fn_df[fn_df.index == 6].html.values[0])

 The dust from the set off the alarm


In [41]:
print(fn_df[fn_df.index == 7].sentence.values[0])
HTML(fn_df[fn_df.index == 7].html.values[0])

 The drug have a low half life that make it ideal for kill of the susceptible bacterium responsible for sleep sickness


In [42]:
print(fn_df[fn_df.index == 8].sentence.values[0])
HTML(fn_df[fn_df.index == 8].html.values[0])

 I still shiver a I remember try to page through economics text by the flicker from candle while clothe in and little knit glove with the fingertip cut off in the 4 p.m.


In [43]:
print(fn_df[fn_df.index == 9].sentence.values[0])
HTML(fn_df[fn_df.index == 9].html.values[0])

 In the last couple of year I have be work with my fear from darkness


### 因果無しを正解

In [49]:
tn_df = output_df[output_df.y_true == '1']
tn_df = tn_df[tn_df.y_pred == '1']
tn_df.reset_index(drop=True, inplace=True)
print(tn_df.shape)
tn_df.head()

(2342, 4)


Unnamed: 0,y_true,y_pred,sentence,html
0,1,1,The much common audit be about waste and recycle,"<span style=""background-color: #FFBCBC"">The</..."
1,1,1,The company fabricate plastic chair,"<span style=""background-color: #FFECEC"">The</..."
2,1,1,The school master teach the lesson with a stick,"<span style=""background-color: #FFDFDF"">The</..."
3,1,1,The suspect dump the dead body into a local r...,"<span style=""background-color: #FFE2E2"">The</..."
4,1,1,The ear of the African elephant be significan...,"<span style=""background-color: #FFD6D6"">The</..."


In [50]:
print(tn_df[tn_df.index == 0].sentence.values[0])
HTML(tn_df[tn_df.index == 0].html.values[0])

 The much common audit be about waste and recycle
