In [1]:
from transformers import BertModel, BertTokenizer
from bertviz import model_view, head_view
import torch


# Define your local model path
model_path = "chexbert.pth"
input_text = "Cardiomegaly and small bilateral pleural effusions but no\n evidence of CHF."
# input_text = "no pneumonia"

# Load the tokenizer from the same model used to train chexbert.pth
tokenizer_name = "bert-base-uncased"  # Replace with the correct tokenizer if different
tokenizer = BertTokenizer.from_pretrained(tokenizer_name)

# Load the BERT model and load the weights from the local file
model = BertModel.from_pretrained(tokenizer_name, output_attentions=True)  # Initialize the model with the correct architecture

# Load the state dictionary
state_dict = torch.load(model_path)

# Extract the actual model state dictionary if necessary
if 'model_state_dict' in state_dict:
    state_dict = state_dict['model_state_dict']

# Remove the `module.` and `bert.` prefixes if they exist
new_state_dict = {}
for key, value in state_dict.items():
    if key.startswith('module.bert.'):
        new_key = key[12:]  # Remove `module.bert.` prefix
    elif key.startswith('bert.'):
        new_key = key[5:]  # Remove `bert.` prefix
    elif key.startswith('module.'):
        new_key = key[7:]  # Remove `module.` prefix
    else:
        new_key = key
    
    # Add the key to new_state_dict only if it belongs to the standard BERT model
    if new_key in model.state_dict():
        new_state_dict[new_key] = value

# Load the model weights from the modified state dictionary
model.load_state_dict(new_state_dict)

# Tokenize input text
inputs = tokenizer.encode(input_text, return_tensors='pt')

# Run model
outputs = model(inputs)

# Retrieve attention from model outputs
attention = outputs[-1]

# Convert input ids to token strings
tokens = tokenizer.convert_ids_to_tokens(inputs[0])

# Display model view
# model_view(attention, tokens)
head_view(attention, tokens)


<IPython.core.display.Javascript object>

In [2]:
from transformers import BertTokenizer, BertModel
from bertviz import model_view, head_view

model_version = 'bert-base-uncased'
# sentence_a = "no pneumonia"
sentence_a = "Cardiomegaly and small bilateral pleural effusions but no\n evidence of CHF."


model = BertModel.from_pretrained(model_version, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_version)
inputs = tokenizer.encode(sentence_a, return_tensors='pt')
outputs = model(inputs)
attention = outputs[-1]  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to tokens

head_view(attention, tokens)

<IPython.core.display.Javascript object>