In [None]:
import torch
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load pre-trained model tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased', output_attentions=True)

# Example sentence
sentence = "The quick brown fox jumps over the lazy dog"

# Tokenize the sentence and get input IDs and attention mask
inputs = tokenizer(sentence, return_tensors='pt')
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

# Get the attention weights from the model
with torch.no_grad():
    outputs = model(input_ids, attention_mask=attention_mask)
    attentions = outputs.attentions

# Get the attention matrix for the first layer (for simplicity)
attention_matrix = attentions.numpy()

# Plot the attention matrix for each head in the first layer
fig, axes = plt.subplots(1, 12, figsize=(30, 5))
for i, ax in enumerate(axes):
    sns.heatmap(attention_matrix[i], xticklabels=tokenizer.convert_ids_to_tokens(input_ids), 
                yticklabels=tokenizer.convert_ids_to_tokens(input_ids), cmap='viridis', ax=ax)
    ax.set_title(f'Head {i+1}')
plt.tight_layout()
plt.show()
