In [1]:
from collections import defaultdict
import numpy as np

from bertviz.attention_details import AttentionDetailsData, show
from bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
bert_version = '/Users/bayartsogtyadamsuren/Downloads/bert-japanese-files/bert-wiki-ja'
model = BertModel.from_pretrained(bert_version)
tokenizer = BertTokenizer.from_pretrained(bert_version)

model_file /Users/bayartsogtyadamsuren/Downloads/bert-japanese-files/bert-wiki-ja/wiki-ja.model
!!!the path to the vocab file is in SENTENCE PIECE!!! /Users/bayartsogtyadamsuren/Downloads/bert-japanese-files/bert-wiki-ja/wiki-ja.vocab
Loaded a trained SentencePiece model.


In [3]:
def _get_attention_details(tokens_a, tokens_b, query_vectors, key_vectors, atts):
    key_vectors_dict = defaultdict(list)
    query_vectors_dict = defaultdict(list)
    atts_dict = defaultdict(list)

    slice_a = slice(0, len(tokens_a))  # Positions corresponding to sentence A in input
    slice_b = slice(len(tokens_a), len(tokens_a) + len(tokens_b))  # Position corresponding to sentence B in input
    
    num_layers = len(query_vectors)
    for layer in range(num_layers):
        # Process queries and keys
        query_vector = query_vectors[layer][0] # assume batch_size=1; shape = [num_heads, seq_len, vector_size]
        key_vector = key_vectors[layer][0] # assume batch_size=1; shape = [num_heads, seq_len, vector_size]
        query_vectors_dict['all'].append(query_vector.tolist())
        key_vectors_dict['all'].append(key_vector.tolist())
        query_vectors_dict['a'].append(query_vector[:, slice_a, :].tolist())
        key_vectors_dict['a'].append(key_vector[:, slice_a, :].tolist())
        query_vectors_dict['b'].append(query_vector[:, slice_b, :].tolist())
        key_vectors_dict['b'].append(key_vector[:, slice_b, :].tolist())
        # Process attention
        att = atts[layer][0] # assume batch_size=1; shape = [num_heads, source_seq_len, target_seq_len]
        atts_dict['all'].append(att.tolist())
        atts_dict['aa'].append(att[:, slice_a, slice_a].tolist()) # Append A->A attention for layer, across all heads
        atts_dict['bb'].append(att[:, slice_b, slice_b].tolist()) # Append B->B attention for layer, across all heads
        atts_dict['ab'].append(att[:, slice_a, slice_b].tolist()) # Append A->B attention for layer, across all heads
        atts_dict['ba'].append(att[:, slice_b, slice_a].tolist()) # Append B->A attention for layer, across all heads

    attentions =  {
        'all': {
            'queries': query_vectors_dict['all'],
            'keys': key_vectors_dict['all'],
            'att': atts_dict['all'],
            'left_text': tokens_a + tokens_b,
            'right_text': tokens_a + tokens_b
        },
        'aa': {
            'queries': query_vectors_dict['a'],
            'keys': key_vectors_dict['a'],
            'att': atts_dict['aa'],
            'left_text': tokens_a,
            'right_text': tokens_a
        },
        'bb': {
            'queries': query_vectors_dict['b'],
            'keys': key_vectors_dict['b'],
            'att': atts_dict['bb'],
            'left_text': tokens_b,
            'right_text': tokens_b
        },
        'ab': {
            'queries': query_vectors_dict['a'],
            'keys': key_vectors_dict['b'],
            'att': atts_dict['ab'],
            'left_text': tokens_a,
            'right_text': tokens_b
        },
        'ba': {
            'queries': query_vectors_dict['b'],
            'keys': key_vectors_dict['a'],
            'att': atts_dict['ba'],
            'left_text': tokens_b,
            'right_text': tokens_a
        }
    }
    
    return attentions

In [4]:
def showComputation(config):
#     print("attention",config["attention"])
    att_dets = config["attention"][config["att_type"]]
    query_vector = att_dets["queries"][config["layer"]][config["att_head"]][config["query_index"]]
    keys = att_dets["keys"][config["layer"]][config["att_head"]]
    att = att_dets["att"][config["layer"]][config["att_head"]][config["query_index"]]
    
    seq_len = len(keys)
    dotProducts = []
    
    for i in range(seq_len):
        key_vector = keys[i]
        dotProduct = 0
        
        for j in range(config["vector_size"]):
            product = query_vector[j] * key_vector[j]
            dotProduct += product
        dotProducts.append(dotProduct)
    
    return dotProducts

In [5]:
import pandas as pd
import numpy as np
from tqdm import tqdm

In [6]:
f = pd.read_csv("/Users/bayartsogtyadamsuren/Downloads/bert-japanese-files/bertviz_samples/bertviz_input_chosen_jp.csv")
print(f.head())
print(len(f))

                    article_id     title  \
0  schIBJP010822173500_art0001  かわいいNEWS   
1  schIBJP010822173500_art0001  かわいいNEWS   
2  schIBJP010822173500_art0001  かわいいNEWS   
3  schIBJP010822173500_art0001  かわいいNEWS   
4  schIBJP010822173500_art0001  かわいいNEWS   

                                                text  
0  ごく普通のスポーツバッグかと思ったら保冷バッグと知り、デザイン性の高さに驚きました。ジムバッ...  
1  秋に向けて大活躍しそうなDEAN &amp; DELUCAのスープポットが9月1日に新発売！...  
2  マスキングテープといえば、我が家のラインナップは「用途を選ばず便利」という理由で無地のものば...  
3  週末は、冷えたビールにアツアツの餃子が定番の我が家。そんな餃子大好きな私が見つけてしまったの...  
4  かさもアップ。常にテーブルにお花がある生活って素敵ですよね。でもズボラな私はなかなか毎日飾る...  
154


In [9]:
q_x_k_scores = []
para_tokens = []
too_long = 0
errors = 0

ff = open("/Users/bayartsogtyadamsuren/Downloads/bert-japanese-files/bertviz_samples/bertviz_input_chosen_jp_token2token.tsv", "w")
ff.write("id\ttitle token\ttext token\tscore\n")

for i, x in tqdm(f.iterrows()):
    
    sentence_a = str(x["text"]).replace("\n","。").replace("〝","").replace("〞","").replace("「","").replace("」","").strip()
    sentence_b = x["title"].replace("\n","").replace("〝","").replace("〞","").replace("「","").replace("」","").strip()
    
    if len (sentence_a) > 512 or len (sentence_a) > 512:
        too_long += 1
        sentence_a = sentence_a[:512]
        sentence_b = sentence_b[:512]
#         raise Exception("too long")
    
    details_data = AttentionDetailsData(model, tokenizer)
    tokens_a, tokens_b, queries, keys, atts = details_data.get_data(sentence_a, sentence_b)
    attentions = _get_attention_details(tokens_a, tokens_b, queries, keys, atts)
    q_x_k_score = np.zeros((len(tokens_a),))

    for j, k in enumerate(tokens_b):

        config = {
            "attention": attentions,
            "att_type": "ba",
            "vector_size": 64,
            "layer": 9,
            "att_head": 6,
            "query_index": j
        }
        q_x_k_score += np.array(showComputation(config))
        
        ## token2token code change
        for j in range(len(tokens_a)):
            ff.write(f"{x['article_id']}\t{k}\t{tokens_a[j]}\t{showComputation(config)[j]}\n")
                     
    assert len(q_x_k_score) == len(tokens_a)
        
    q_x_k_scores.append(q_x_k_score)
    para_tokens.append(tokens_a)
#     break
ff.close()
print("Total Too Longs: ", too_long)


0it [00:00, ?it/s][A
1it [00:01,  1.96s/it][A
2it [00:06,  2.59s/it][A
3it [00:07,  2.39s/it][A
4it [00:10,  2.56s/it][A
5it [00:13,  2.67s/it][A
6it [00:16,  2.76s/it][A
7it [00:18,  2.51s/it][A
8it [00:20,  2.34s/it][A
9it [00:21,  1.86s/it][A
10it [00:22,  1.57s/it][A
11it [00:24,  1.80s/it][A
12it [00:25,  1.44s/it][A
13it [00:28,  1.90s/it][A
14it [00:28,  1.47s/it][A
15it [00:29,  1.26s/it][A
16it [00:30,  1.31s/it][A
17it [00:31,  1.03s/it][A
18it [00:31,  1.15it/s][A
19it [00:33,  1.20s/it][A
20it [00:35,  1.23s/it][A
21it [00:36,  1.17s/it][A
22it [00:39,  1.97s/it][A
23it [00:40,  1.54s/it][A
24it [00:42,  1.59s/it][A
25it [00:43,  1.49s/it][A
26it [00:44,  1.30s/it][A
27it [00:44,  1.04s/it][A
28it [00:47,  1.48s/it][A
29it [00:47,  1.26s/it][A
30it [00:48,  1.06s/it][A
31it [00:49,  1.08s/it][A
32it [00:50,  1.02it/s][A
33it [00:52,  1.23s/it][A
34it [00:53,  1.18s/it][A
35it [00:54,  1.27s/it][A
36it [00:56,  1.39s/it][A
37it [00:57,  

Total Too Longs:  3


In [48]:
len(para_tokens[1])

114

In [60]:
len(para_tokens)

134

In [50]:
f["title"].head()

0    かわいいNEWS
1    かわいいNEWS
2    かわいいNEWS
3    かわいいNEWS
4    かわいいNEWS
Name: title, dtype: object

In [54]:
para_tokens_ = [",".join(x) for x in para_tokens]
q_x_k_scores_ = [",".join([str(l) for l in list(x)]) for x in q_x_k_scores]
f["paragraph_tokens"] = para_tokens_
f["q*k"] = q_x_k_scores_

In [55]:
print(f.head())

                    article_id     title  \
0  schIBJP010822173500_art0001  かわいいNEWS   
1  schIBJP010822173500_art0001  かわいいNEWS   
2  schIBJP010822173500_art0001  かわいいNEWS   
3  schIBJP010822173500_art0001  かわいいNEWS   
4  schIBJP010822173500_art0001  かわいいNEWS   

                                           paragraph  \
0  ごく普通のスポーツバッグかと思ったら保冷バッグと知り、デザイン性の高さに驚きました。ジムバッ...   
1  秋に向けて大活躍しそうなDEAN &amp; DELUCAのスープポットが9月1日に新発売！...   
2  マスキングテープといえば、我が家のラインナップは「用途を選ばず便利」という理由で無地のものば...   
3  週末は、冷えたビールにアツアツの餃子が定番の我が家。そんな餃子大好きな私が見つけてしまったの...   
4  かさもアップ。常にテーブルにお花がある生活って素敵ですよね。でもズボラな私はなかなか毎日飾る...   

                                    paragraph_tokens  \
0  [CLS],▁,ごく,普通の,スポーツ,バッグ,か,と思った,ら,保,冷,バッグ,と,知り,...   
1  [CLS],▁,秋,に向けて,大,活躍し,そうな,de,an,▁&,amp,;,▁de,lu...   
2  [CLS],▁,マス,キング,テープ,といえば,、,我が,家の,ラインナップ,は,用途,を,...   
3  [CLS],▁,週末,は,、,冷,えた,ビール,に,ア,ツ,ア,ツ,の,餃,子が,定番,の,...   
4  [CLS],▁,かさ,も,アップ,。,常に,テーブル,に,お,花,がある,生活,って,素,敵...   

                                                 q*k  
0  110

In [58]:
f_ = f[["title","paragraph_tokens","q*k"]]
f_.head()

Unnamed: 0,title,paragraph_tokens,q*k
0,かわいいNEWS,"[CLS],▁,ごく,普通の,スポーツ,バッグ,か,と思った,ら,保,冷,バッグ,と,知り,...","110.55391529070833,115.06578941142755,-8.67406..."
1,かわいいNEWS,"[CLS],▁,秋,に向けて,大,活躍し,そうな,de,an,▁&,amp,;,▁de,lu...","113.93482727274728,114.10239268376226,129.7944..."
2,かわいいNEWS,"[CLS],▁,マス,キング,テープ,といえば,、,我が,家の,ラインナップ,は,用途,を,...","56.445613481652494,126.40534373321229,82.30859..."
3,かわいいNEWS,"[CLS],▁,週末,は,、,冷,えた,ビール,に,ア,ツ,ア,ツ,の,餃,子が,定番,の,...","53.238954647136055,112.03032352120688,111.8142..."
4,かわいいNEWS,"[CLS],▁,かさ,も,アップ,。,常に,テーブル,に,お,花,がある,生活,って,素,敵...","94.78399916033817,89.06140743303354,102.265414..."


In [59]:
f_.to_csv("/Users/bayartsogtyadamsuren/Downloads/bert_viz_samples.tsv", sep="\t", index=None)