In [14]:
import bertviz
from bertviz import model_view, head_view
from transformers import BertTokenizer, BertModel

In [15]:
def show_model_view(model, tokenizer, sentence_a, sentence_b=None, hide_delimiter_attn=False,display_mode="dark"):
    inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    if sentence_b:
        token_type_ids = inputs['token_type_ids']
        attention = model(input_ids, token_type_ids=token_type_ids)[-1]
        sentence_b_start = token_type_ids[0].tolist().index(1)
    else:
        attention = model(input_ids)[-1]
        sentence_b_start = None
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)  
    if hide_delimiter_attn:
        for i, t in enumerate(tokens):
            if t in ("[SEP]", "[CLS]"):
                for layer_attn in attention:
                    layer_attn[0, :, i, :] = 0
                    layer_attn[0, :, :, i] = 0
    model_view(attention, tokens, sentence_b_start, display_mode=display_mode)

In [24]:

def show_head_view(model, tokenizer, sentence_a, sentence_b=None, layer=None, heads=None):
    inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    if sentence_b:
        token_type_ids = inputs['token_type_ids']
        attention = model(input_ids, token_type_ids=token_type_ids)[-1]
        sentence_b_start = token_type_ids[0].tolist().index(1)
    else:
        attention = model(input_ids)[-1]
        sentence_b_start = None
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)    
    print('attention shape {} {}'.format(len(attention), len(attention[0])))
    head_view(attention, tokens, sentence_b_start, layer=layer, heads=heads)

In [22]:
# for the task finetuned model, first transfer the model to the transformers format
from transformers import BertConfig
import torch
import json
pt_checkpoint_path = '/data7/emobert/exp/finetune/onlytext/meld_bert_base_lr2e-5_bs32/ckpt/model_step_0.pt'
save_checkpoint_dir = '/data7/emobert/exp/finetune/onlytext/meld_bert_base_lr2e-5_bs32/ckpt/transformer_format'
config_path = '/data7/MEmoBert/code/uniter/config/uniter-base.json'
config = json.load(open(config_path, 'r'))
config = BertConfig().from_dict(config)
config.output_attentions == True
state_dict = torch.load(pt_checkpoint_path)
model = BertModel.from_pretrained(pretrained_model_name_or_path=None, config=config, state_dict=state_dict)
model.save_pretrained(save_checkpoint_dir)

In [None]:
# for the mlm pretrained model 
checkpoint_dir = '/data7/MEmoBert/emobert/exp/mlm_pretrain/results/opensub/bert_base_uncased_1000w_linear_lr1e4_warm4k_bs256_acc2_4gpu/checkpoint-93980'
do_lower_case = True
model = BertModel.from_pretrained(checkpoint_dir, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(checkpoint_dir, do_lower_case=do_lower_case)

In [17]:
# for the only text task finetuned model 
save_checkpoint_dir = '/data7/emobert/exp/finetune/onlytext/meld_bert_base_lr2e-5_bs32/ckpt/transformer_format'
do_lower_case = True
model = BertModel.from_pretrained(checkpoint_dir, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(checkpoint_dir, do_lower_case=do_lower_case)

Some weights of BertModel were not initialized from the model checkpoint at /data7/MEmoBert/emobert/exp/mlm_pretrain/results/opensub/bert_base_uncased_1000w_linear_lr1e4_warm4k_bs256_acc2_4gpu/checkpoint-93980 and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
sentence_a = "No don't I beg of you!"
sentence_b = None
show_model_view(model, tokenizer, sentence_a, sentence_b, hide_delimiter_attn=False, display_mode="dark")

<IPython.core.display.Javascript object>

In [25]:
sentence_a = "No don't I beg of you!"
sentence_b = None
show_head_view(model, tokenizer, sentence_a, sentence_b)

attention shape 12


<IPython.core.display.Javascript object>