In [11]:
!pip install transformers torch matplotlib
!pip install ipywidgets



In [12]:
import torch
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
import numpy as np
from ipywidgets import interact, widgets

In [13]:
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
model = BertModel.from_pretrained("bert-base-multilingual-cased", output_attentions=True)

In [14]:
def show_attention(sentence, layer, head):
    tokens = tokenizer(sentence, return_tensors="pt")
    input_ids = tokens["input_ids"]

    with torch.no_grad():
        outputs = model(**tokens)

    attention = outputs.attentions[layer][0][head]
    token_list = tokenizer.convert_ids_to_tokens(input_ids[0])

    plt.figure(figsize=(8, 6))
    plt.imshow(attention, cmap="viridis")
    plt.colorbar()

    plt.xticks(range(len(token_list)), token_list, rotation=45)
    plt.yticks(range(len(token_list)), token_list)

    plt.title(f"Attention Heatmap — Layer {layer}, Head {head}")
    plt.tight_layout()
    plt.show()

In [15]:
interact(
    show_attention,
    sentence=widgets.Text(value="I love transformers", description="Sentence:"),
    layer=widgets.IntSlider(min=0, max=11, step=1, value=0, description="Layer"),
    head=widgets.IntSlider(min=0, max=11, step=1, value=0, description="Head")
)

interactive(children=(Text(value='I love transformers', description='Sentence:'), IntSlider(value=0, descripti…