# Visualizing results from BERT-tiny

In [50]:
import torch
from transformers import *
import pandas as pd
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def visualize(topk, sentence_num, batch_list, mode='avg'):
    
    for k in batch_list:
        
        batch = str(k)
        sent = str(sentence_num)
        batch_0_x_sent_0 = torch.load('tensor/batch_{}_x_sent_{}.pt'.format(batch, sent)).squeeze(0)
        batch_0_hess_0 = torch.load('tensor/batch_{}_hess_{}.pt'.format(batch, sent)).squeeze(0)
        cause_mask = torch.load('tensor/batch_{}_cause_mask_{}.pt'.format(batch, sent))
        # print(batch_0_x_sent_0)

        ori_sentence = tokenizer.convert_ids_to_tokens(batch_0_x_sent_0)
        print(tokenizer.convert_ids_to_tokens(batch_0_x_sent_0))
        print('\n cause mask: ', [ori_sentence[i] for i in range(len(ori_sentence)) if cause_mask[0][i] != 0], '\n')
        
        temp = dict([i, ori_sentence[i]] for i in range(len(ori_sentence)))

        # print(torch.abs(batch_0_hess_0))
        abs_hess = torch.abs(batch_0_hess_0)
        # no_diag_hess = abs_hess - abs_hess.max() * torch.eye(abs_hess.shape[0], device='cuda:0')

        # abs_hess.split(128)
        # conv = torch.nn.Conv2d(1, 1, (128, 128), padding=0, stride=128).to('cuda:0')(abs_hess.unsqueeze(0).unsqueeze(0))
        avg_hess = torch.nn.AvgPool2d(128, stride=128).to('cuda:0')(abs_hess.unsqueeze(0))[0] - torch.eye(40, device='cuda:0') * 10
        max_hess = torch.nn.MaxPool2d(128, stride=128).to('cuda:0')(abs_hess.unsqueeze(0))[0] - torch.eye(40, device='cuda:0') * 10
        if mode == 'avg':
            values, indices = torch.topk(avg_hess.view(-1), 2*topk)
        else:
            values, indices = torch.topk(max_hess.view(-1), 2*topk)

        row = indices // 40
        col = indices % 40
        # pairs = [[i, int(indices[i]), float(values[i])] for i in range(indices.shape[0]) if float(values[i]) != 0]
        pairs = [[int(row[i]), int(col[i]), float(values[i])] for i in range(indices.shape[0]) if float(values[i]) != 0]
        df = pd.DataFrame(pairs)
        df.columns = ['word', 'most relavent', 'score']
        df = df.replace(temp)
        df.sort_values('score', ascending=False, inplace=True)
        print(df[df.index%2==0].head(topk))

        print("-" * 30 + '\n')

# Params for visualization

In [51]:
topk = 10
sentence_num = 9 # select one from 0 to 15
# batch_list = [0, 62, 124, 186, 248, 310, 375, 437, 499, 561, 623, 688, 750, 812, 874, 936, 1001, 1063, 1125, 1187, 1249]
batch_list = [0, 310, 623, 936, 1249]

## Average

In [52]:
visualize(topk, sentence_num, batch_list, mode='avg')

['[CLS]', 'two', 'young', 'boys', 'of', 'opposing', 'teams', 'play', 'football', ',', 'while', 'wearing', 'full', 'protection', 'uniforms', 'and', 'helmets', '.', '[SEP]', 'boys', 'play', 'football', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']

 cause mask:  ['boys', 'opposing', 'teams', 'play', 'boys', 'play', 'football'] 

        word most relavent     score
0      [SEP]         [SEP]  0.044089
2      [CLS]         [SEP]  0.025706
4      [CLS]         [SEP]  0.024833
6      [SEP]             .  0.014351
8   football         [SEP]  0.014056
10         .         [SEP]  0.013824
12     [SEP]      football  0.011118
14     [SEP]      football  0.010987
16  football         [SEP]  0.010924
18     [SEP]       helmets  0.010910
------------------------------

['[CLS]', 'two', 'young', 'boys', 'of', 'opposing', 'teams', 'play', 'football', ',', 'while', 'wearing', 'full', 'p

## Max

In [4]:
visualize(topk, sentence_num, batch_list, mode='max')

['[CLS]', 'two', 'young', 'boys', 'of', 'opposing', 'teams', 'play', 'football', ',', 'while', 'wearing', 'full', 'protection', 'uniforms', 'and', 'helmets', '.', '[SEP]', 'dog', 'eats', 'out', 'of', 'bowl', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']

 cause mask:  ['boys', 'dog'] 

        word most relavent     score
24     [SEP]         [SEP]  0.604129
18     [SEP]         [SEP]  0.604129
0      [CLS]         [SEP]  0.594729
17         .         [CLS]  0.275844
8   football         [SEP]  0.233400
19       dog         [SEP]  0.167147
16   helmets         [SEP]  0.165270
20      eats         [CLS]  0.146375
23      bowl         [SEP]  0.145649
6      teams         [SEP]  0.135518
------------------------------

['[CLS]', 'two', 'young', 'boys', 'of', 'opposing', 'teams', 'play', 'football', ',', 'while', 'wearing', 'full', 'protection', 'uniforms', 'and', 'helmets', '.', '[SEP]', 'dog

['[CLS]', 'two', 'young', 'boys', 'of', 'opposing', 'teams', 'play', 'football', ',', 'while', 'wearing', 'full', 'protection', 'uniforms', 'and', 'helmets', '.', '[SEP]', 'dog', 'eats', 'out', 'of', 'bowl', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]']

 cause mask:  ['boys', 'dog'] 

          word most relavent     score
20        eats           dog  4.188957
19         dog          eats  4.188957
14    uniforms          eats  2.664770
12        full          eats  1.026166
23        bowl          eats  1.023168
16     helmets          eats  0.935852
13  protection          eats  0.837115
21         out          eats  0.782858
3         boys          eats  0.638935
0        [CLS]          eats  0.618322
------------------------------

['[CLS]', 'two', 'young', 'boys', 'of', 'opposing', 'teams', 'play', 'football', ',', 'while', 'wearing', 'full', 'protection', 'uniforms', 'and', 'helmet