In [1]:
from bertviz import head_view
from transformers import BertModel, BertTokenizer

model_version = 'bert-base-uncased'
model = BertModel.from_pretrained(model_version, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_version)
sentence_a = "The cat sat on the mat"
sentence_b = "The cat lay on the rug"
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt')
input_ids = inputs['input_ids']
token_type_ids = inputs['token_type_ids']
attention = model(input_ids, token_type_ids=token_type_ids)[-1]
sentence_b_start = token_type_ids[0].tolist().index(1)
input_id_list = input_ids[0].tolist()  # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list)
head_view(attention, tokens, sentence_b_start)

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]

<IPython.core.display.Javascript object>

In [None]:
import os
os.environ["MILVUS2_ENABLED"] = "false"
from bertviz import model_view
import torch
import random
from src.utils.io_utils import load_checkpoint
from src.utils.tokenization import Tokenization

random.seed(1234)
train_examples_file = "../data/resources/train/small_training_queries.json"
train_passages_file = "../data/resources/train/Dev_training_passages.json"
model_file = "../data/checkpoints/17032022_125925/model-0"

weces_retriever, query_tokenizer, passage_tokenizer = load_checkpoint(model_file)
weces_retriever.eval()
tokenization = Tokenization(query_tokenizer=query_tokenizer, passage_tokenizer=passage_tokenizer)
print(f"Experiment using model={model_file}, query_file={train_examples_file}, passage_file={train_passages_file}")

train_search_feats = tokenization.generate_train_search_feats(
    train_examples_file,
    train_passages_file, max_query_length=50,
    max_passage_length=150, add_qbound=False)

sample = train_search_feats[10]
with torch.no_grad():
    query_input = sample.query
    pos_passage = sample.positive_passage
    neg_passage = sample.negative_passages[1]
    query_encode = weces_retriever.query_encoder.segment_encode(
        torch.tensor(query_input.input_ids, device=weces_retriever.device).view(1, -1),
        torch.tensor(query_input.segment_ids, device=weces_retriever.device).view(1, -1),
        torch.tensor(query_input.input_mask, device=weces_retriever.device).view(1, -1))

    pos_pass_encode = weces_retriever.passage_encoder.segment_encode(
        torch.tensor(pos_passage.input_ids, device=weces_retriever.device).view(1, -1),
        torch.tensor(pos_passage.segment_ids, device=weces_retriever.device).view(1, -1),
        torch.tensor(pos_passage.input_mask, device=weces_retriever.device).view(1, -1))
    query_tokenized = tokenization.query_tokenizer.convert_ids_to_tokens(query_input.input_ids)
    query_attentions = query_encode.attentions
    model_view(query_attentions, query_tokenized)