In [None]:
from transformers import BertTokenizer, BertModel
from bertviz import model_view, head_view
model_name_or_path = r'batterydata/batteryscibert-cased'
tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
model = BertModel.from_pretrained(model_name_or_path, output_attentions=True)
inputs = tokenizer.encode("Performance characteristics of intercalation materials such as lithium cobalt oxide (LCO) are contrasted with that of conversion materials.", return_tensors='pt')
outputs = model(inputs)
attention = outputs[-1]  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0]) 

In [None]:
head_view(attention, tokens)

In [None]:
def show_model_view(model, tokenizer, sentence_a, sentence_b=None, hide_delimiter_attn=False, display_mode="dark"):
    inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    if sentence_b:
        token_type_ids = inputs['token_type_ids']
        attention = model(input_ids, token_type_ids=token_type_ids)[-1]
        sentence_b_start = token_type_ids[0].tolist().index(1)
    else:
        attention = model(input_ids)[-1]
        sentence_b_start = None
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)  
    if hide_delimiter_attn:
        for i, t in enumerate(tokens):
            if t in ("[SEP]", "[CLS]"):
                for layer_attn in attention:
                    layer_attn[0, :, i, :] = 0
                    layer_attn[0, :, :, i] = 0
    model_view(attention, tokens, sentence_b_start, display_mode=display_mode)

In [None]:
model_name_or_path = r'batterydata/batteryonlybert-uncased-squad-v1'
tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
model = BertModel.from_pretrained(model_name_or_path, output_attentions=True)
sentence_a = "The cathode of this Li-ion battery is LiFePO4."
show_model_view(model, tokenizer, sentence_a, sentence_b=None, hide_delimiter_attn=False, display_mode="dark")