# Visualizing results from BERT-tiny

In [1]:
import torch
from transformers import *
import pandas as pd
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def visualize(topk, sentence_num, batch_list, mode='avg'):
    
    for k in batch_list:
        
        batch = str(k)
        sent = str(sentence_num)
        batch_0_x_sent_0 = torch.load('tensor/batch_{}_x_sent_{}.pt'.format(batch, sent)).squeeze(0)
        batch_0_hess_0 = torch.load('tensor/batch_{}_hess_{}.pt'.format(batch, sent)).squeeze(0)
        cause_mask = torch.load('tensor/batch_{}_cause_mask_{}.pt'.format(batch, sent))
        # print(batch_0_x_sent_0)

        ori_sentence = tokenizer.convert_ids_to_tokens(batch_0_x_sent_0)
        print(tokenizer.convert_ids_to_tokens(batch_0_x_sent_0))
        print('\n cause mask: ', [ori_sentence[i] for i in range(len(ori_sentence)) if cause_mask[0][i] != 0], '\n')
        
        temp = dict([i, ori_sentence[i]] for i in range(len(ori_sentence)))

        # print(torch.abs(batch_0_hess_0))
        abs_hess = torch.abs(batch_0_hess_0)
        # no_diag_hess = abs_hess - abs_hess.max() * torch.eye(abs_hess.shape[0], device='cuda:0')

        # abs_hess.split(128)
        # conv = torch.nn.Conv2d(1, 1, (128, 128), padding=0, stride=128).to('cuda:0')(abs_hess.unsqueeze(0).unsqueeze(0))
        avg_hess = torch.nn.AvgPool2d(128, stride=128).to('cuda:0')(abs_hess.unsqueeze(0))[0] - torch.eye(40, device='cuda:0') * 10
        max_hess = torch.nn.MaxPool2d(128, stride=128).to('cuda:0')(abs_hess.unsqueeze(0))[0] - torch.eye(40, device='cuda:0') * 10
        if mode == 'avg':
            values, indices = avg_hess.max(1)
        else:
            values, indices = max_hess.max(1)
        pairs = [[i, int(indices[i]), float(values[i])] for i in range(indices.shape[0]) if float(values[i]) != 0]

        df = pd.DataFrame(pairs)
        df.columns = ['word', 'most relavent', 'score']
        df = df.replace(temp)
        df.sort_values('score', ascending=False, inplace=True)
        print(df.head(topk))

        print("-" * 30 + '\n')

# Params for visualization

In [7]:
topk = 10
sentence_num = 5 # select one from 0 to 15
batch_list = [0, 62, 124, 186, 248, 310, 375, 437, 499, 561, 623, 688, 750, 812, 874, 936, 1001, 1063, 1125, 1187, 1249]

## Average

In [8]:
visualize(topk, sentence_num, batch_list, mode='avg')

['[CLS]', 'two', 'young', 'children', 'in', 'blue', 'jerseys', ',', 'one', 'with', 'the', 'number', '9', 'and', 'one', 'with', 'the', 'number', '2', 'are', 'standing', 'on', 'wooden', 'steps', 'in', 'a', 'bathroom', 'and', 'washing', 'their', '[SEP]', 'two', 'kids', 'in', 'jackets', 'walk', 'to', 'school', '.', '[SEP]']

 cause mask:  ['blue', 'in', 'a', 'jackets', 'walk', 'to', 'school'] 

        word most relavent     score
39     [SEP]         [SEP]  0.025032
30     [SEP]         [SEP]  0.025031
0      [CLS]         [SEP]  0.017122
38         .         [SEP]  0.008138
32      kids         [SEP]  0.008067
35      walk         [SEP]  0.007766
37    school         [SEP]  0.007385
29     their         [SEP]  0.007075
19       are         [SEP]  0.006904
26  bathroom         [SEP]  0.006482
------------------------------

['[CLS]', 'two', 'young', 'children', 'in', 'blue', 'jerseys', ',', 'one', 'with', 'the', 'number', '9', 'and', 'one', 'with', 'the', 'number', '2', 'are', 'standing',

['[CLS]', 'two', 'young', 'children', 'in', 'blue', 'jerseys', ',', 'one', 'with', 'the', 'number', '9', 'and', 'one', 'with', 'the', 'number', '2', 'are', 'standing', 'on', 'wooden', 'steps', 'in', 'a', 'bathroom', 'and', 'washing', 'their', '[SEP]', 'two', 'kids', 'in', 'jackets', 'walk', 'to', 'school', '.', '[SEP]']

 cause mask:  ['blue', 'in', 'a', 'jackets', 'walk', 'to', 'school'] 

        word most relavent     score
36        to       washing  0.124716
28   washing            to  0.124716
34   jackets            to  0.082201
32      kids       washing  0.063362
26  bathroom       washing  0.061207
35      walk       washing  0.050184
29     their            to  0.042939
37    school            to  0.042271
25         a       washing  0.027991
31       two       washing  0.025272
------------------------------

['[CLS]', 'two', 'young', 'children', 'in', 'blue', 'jerseys', ',', 'one', 'with', 'the', 'number', '9', 'and', 'one', 'with', 'the', 'number', '2', 'are', 'standing',

## Max

In [6]:
visualize(topk, sentence_num, batch_list, mode='max')

['[CLS]', 'a', 'man', 'selling', 'don', '##uts', 'to', 'a', 'customer', 'during', 'a', 'world', 'exhibition', 'event', 'held', 'in', 'the', 'city', 'of', 'angeles', '[SEP]', 'a', 'man', 'selling', 'don', '##uts', 'to', 'a', 'customer', 'during', 'a', 'world', 'exhibition', 'event', 'while', 'people', 'wait', 'in', 'line', '[SEP]']

 cause mask:  ['people', 'wait', 'in', 'line'] 

          word most relavent     score
20       [SEP]         [SEP]  0.522676
39       [SEP]         [SEP]  0.522676
13       event         event  0.458439
33       event         event  0.458439
0        [CLS]         event  0.412748
28    customer         [CLS]  0.303048
23     selling         [CLS]  0.299778
25       ##uts         ##uts  0.250136
5        ##uts         ##uts  0.250136
32  exhibition         event  0.247122
------------------------------

['[CLS]', 'a', 'man', 'selling', 'don', '##uts', 'to', 'a', 'customer', 'during', 'a', 'world', 'exhibition', 'event', 'held', 'in', 'the', 'city', 'of', 'a

['[CLS]', 'a', 'man', 'selling', 'don', '##uts', 'to', 'a', 'customer', 'during', 'a', 'world', 'exhibition', 'event', 'held', 'in', 'the', 'city', 'of', 'angeles', '[SEP]', 'a', 'man', 'selling', 'don', '##uts', 'to', 'a', 'customer', 'during', 'a', 'world', 'exhibition', 'event', 'while', 'people', 'wait', 'in', 'line', '[SEP]']

 cause mask:  ['people', 'wait', 'in', 'line'] 

          word most relavent     score
35      people          wait  2.892927
36        wait        people  2.892927
12  exhibition    exhibition  2.713221
32  exhibition    exhibition  2.713221
38        line        people  1.384827
8     customer          wait  1.331163
2          man        people  1.041624
31       world          wait  0.857795
26          to          wait  0.576123
33       event        people  0.537918
------------------------------

['[CLS]', 'a', 'man', 'selling', 'don', '##uts', 'to', 'a', 'customer', 'during', 'a', 'world', 'exhibition', 'event', 'held', 'in', 'the', 'city', 'of', 'a