# Attention Visualization with BERTViz

This notebook demonstrates how to visualize attention mechanisms in transformer models using the BERTViz library. We will cover both encoder-only models (like BERT) and encoder-decoder models (like translation models).

In both cases, a pretrained transformer model from the [Hugging Face Transformers library](https://huggingface.co/docs/transformers/index) is used, and attention weights are extracted during a forward pass. The BERTViz library provides interactive visualizations to explore these attention weights.

In [None]:
!pip install torchvision==0.21.0
!pip install torch==2.6.0
!pip install transformers==4.57.3
!pip install bertviz==1.4.1

In [None]:
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view
from bertviz import head_view
import matplotlib.pyplot as plt
utils.logging.set_verbosity_error()  # Remove line to see warnings

#### Encoder Model Visualization

In [None]:
utils.logging.set_verbosity_error()  # Suppress standard warnings
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = AutoModel.from_pretrained("bert-base-uncased", output_attentions=True)

In [None]:
inputs = tokenizer.encode("The kid likes to go to school because it likes to learn new things.", return_tensors='pt')
outputs = model(inputs)
attention = outputs[-1]  # Output includes attention weights when output_attentions=True
tokens = tokenizer.convert_ids_to_tokens(inputs[0])

In [None]:
head_view(attention, tokens)

#### Encoder-Decoder Model Visualization

In [None]:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-de")
model = AutoModel.from_pretrained("Helsinki-NLP/opus-mt-en-de", output_attentions=True)

In [None]:
encoder_input_ids = tokenizer("She sees the small elephant.", return_tensors="pt", add_special_tokens=True).input_ids
with tokenizer.as_target_tokenizer():
    decoder_input_ids = tokenizer("Sie sieht den kleinen Elefanten.", return_tensors="pt", add_special_tokens=True).input_ids

outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens = decoder_text
)