<a href="https://colab.research.google.com/github/amashi/BankBox/blob/master/Visualizing_Self_Attention_with_BertViz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://www.comet.com/images/logo_comet_light.png" width="300px"/>

-----
# 🤖 Visualizing Self-Attention with BertViz

Data science and machine learning teams use [Comet](https://www.comet.com?utm_medium=referral&utm_source=Colab&utm_content=VisualizingAttention_blog)'s ML platform to track, compare, explain, and optimize their models across the complete ML lifecycle – from managing experiments to monitoring models in production.

Comet improves productivity, reproducibility, and collaboration, regardless of the tools used for training and deploying models, whether they are managed, open-source, or in-house. The platform can be used on the cloud, virtual private cloud (VPC), or on-premises.

To find out more about Comet, visit our [Documentation Page](https://www.comet.com/docs/v2/?utm_medium=referral&utm_source=Colab&utm_content=VisualizingAttention_blog)

The following Colab is heavily inspired by work from Dhruv Nair, Data Scientist at Comet.

---
**Note:** We suggest you follow along with [this blog here](https://www.comet.com/site/blog/explainable-ai-for-transformers/).


**Note:** [If you can't wait to see the finished project, check out the results here](https://www.comet.com/examples/demo-visualizing-attention-bertviz/view/vyr6Nk6Y1cQIklggSZ4A3zSrf/panels?utm_medium=referral&utm_source=Colab&utm_content=VisualizingAttention_blog).
**Let's get started!** 🚀

## 🚧 Setup and Installation

In [1]:
!pip install comet_ml transformers --quiet

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m684.6/684.6 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m979.1/979.1 kB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.3/3.3 MB[0m [31m27.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m54.5/54.5 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m301.8/301.8 kB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m137.9/137.9 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[?25h

For this next step, you'll need to grab your API key from your account settings. If you don't already have a Comet account, [make one here for free](https://www.comet.com/signup/?utm_medium=referral&utm_source=Colab&utm_content=VisualizingAttention_blog)!

In [2]:
import comet_ml

comet_ml.init(api_key="FtGq0EiG08gSiQpR7iBwktLSk", project_name="BertViz")

[1;38;5;39mCOMET INFO:[0m Valid Comet API Key saved in /root/.comet.config (set COMET_CONFIG to change where it is saved).


In [3]:
import json
import os
import uuid

import torch

from tqdm import tqdm
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,  # for sentiment analysis
    AutoModelForCausalLM,  # for text generation
    AutoModelForQuestionAnswering,  # for question answering
    pipeline,
)

## 🛠 Helper Functions

In [8]:
def name_exp(experiment, task):
    ex_key = experiment.get_key()
    exp_name = f"{task}: {ex_key[:9]}"
    return exp_name


def format_attention_output(attention):
    squeezed = []
    for layer_attention in attention:
        # 1 x num_heads x seq_len x seq_len
        if len(layer_attention.shape) != 4:
            raise ValueError(
                "The attention tensor does not have the correct number of"
                "dimensions. Make sure you set "
                "output_attentions=True when initializing your model."
            )
        squeezed.append(layer_attention.squeeze(0))
    # num_layers x num_heads x seq_len x seq_len
    return torch.stack(squeezed)


def format_special_chars(tokens):
    return [
        t.replace("Ġ", " ")
        .replace("▁", " ")
        .replace("</w>", "")
        .replace("ĊĊ", " ")
        .replace("Ċ", " ")
        for t in tokens
    ]


def get_attn_data(
    model,
    tokenizer,
    text_a,
    text_b=None,
    return_token_type_ids=False,
    prettify_tokens=True,
    layer=None,
    heads=None,
):
    if return_token_type_ids:
        inputs = tokenizer.encode_plus(
            text_a,
            text_b,
            add_special_tokens=True,
            return_tensors="pt",
            return_token_type_ids=True,
        )
        token_type_ids = inputs["token_type_ids"]

    else:
        inputs = tokenizer.encode_plus(
            text_a, return_tensors="pt", add_special_tokens=True
        )

    input_ids = inputs["input_ids"]
    attention = model(input_ids)[-1]
    input_id_list = input_ids[0].tolist()  # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)

    if prettify_tokens:
        tokens = format_special_chars(tokens)

    attn = format_attention_output(attention)
    attn_data = {
        "all": {"attn": attn.tolist(), "left_text": tokens, "right_text": tokens}
    }

    if text_b is not None:
        text_b_start = token_type_ids[0].tolist().index(1)

        slice_a = slice(
            0, text_b_start
        )  # Positions corresponding to sentence A in input
        slice_b = slice(
            text_b_start, len(tokens)
        )  # Position corresponding to sentence B in input
        attn_data["aa"] = {
            "attn": attn[:, :, slice_a, slice_a].tolist(),
            "left_text": tokens[slice_a],
            "right_text": tokens[slice_a],
        }
        attn_data["bb"] = {
            "attn": attn[:, :, slice_b, slice_b].tolist(),
            "left_text": tokens[slice_b],
            "right_text": tokens[slice_b],
        }
        attn_data["ab"] = {
            "attn": attn[:, :, slice_a, slice_b].tolist(),
            "left_text": tokens[slice_a],
            "right_text": tokens[slice_b],
        }
        attn_data["ba"] = {
            "attn": attn[:, :, slice_b, slice_a].tolist(),
            "left_text": tokens[slice_b],
            "right_text": tokens[slice_a],
        }

    attn_seq_len = len(attn_data["all"]["attn"][0][0])
    if attn_seq_len != len(tokens):
        raise ValueError(
            f"Attention has {attn_seq_len} positions, while number of tokens is {len(tokens)}"
        )

    return attn_data


def text_generation_viz(prompts, model_version):
    task = "text-gen"
    experiment = comet_ml.Experiment()
    experiment.set_name(name_exp(experiment, task))
    experiment.add_tag(task)
    experiment.log_parameter("model_version", model_version)

    do_lower_case = True
    tokenizer = AutoTokenizer.from_pretrained(
        model_version, do_lower_case=do_lower_case
    )
    model = AutoModelForCausalLM.from_pretrained(
        model_version, output_attentions=True, pad_token_id=tokenizer.eos_token_id
    )

    text_generation = pipeline("text-generation", model=model, tokenizer=tokenizer)
    for prompt in tqdm(prompts):
        generated_text = text_generation(prompt, max_length=20, do_sample=False)[0][
            "generated_text"
        ]
        attn_data = get_attn_data(model, tokenizer, generated_text)

        viz_params = {
            "attention": attn_data,
            "default_filter": "all",
            "bidirectional": False,
            "display_mode": "light",
            "layer": None,
            "head": None,
        }
        experiment.log_asset_data(viz_params, f"attn-view-tg-{prompt}.json")

    experiment.end()


def sentiment_viz(prompts, model_version):
    task = "sentiment"
    experiment = comet_ml.Experiment()
    experiment.set_name(name_exp(experiment, task))
    experiment.add_tag(task)
    experiment.log_parameter("model_version", model_version)

    do_lower_case = True
    tokenizer = AutoTokenizer.from_pretrained(
        model_version, do_lower_case=do_lower_case
    )
    tokenizer.pad_token = "[PAD]"
    model = AutoModelForSequenceClassification.from_pretrained(
        model_version, output_attentions=True, pad_token_id=tokenizer.eos_token_id
    )

    sentiment_analysis = pipeline(
        "sentiment-analysis", model=model, tokenizer=tokenizer
    )
    for prompt in tqdm(prompts):
        prediction = sentiment_analysis(prompt)[0]
        attn_data = get_attn_data(model, tokenizer, prompt)
        viz_params = {
            "attention": attn_data,
            "default_filter": "all",
            "bidirectional": False,
            "display_mode": "light",
            "layer": None,
            "head": None,
        }
        experiment.log_asset_data(
            viz_params,
            f"attn-view-sentiment-{prompt}-{prediction['label']}-{round(prediction['score'], 2)}.json",
        )

    experiment.end()


def qa_viz(context, prompts, model_version):
    task = "question-answering"
    experiment = comet_ml.Experiment()
    experiment.set_name(name_exp(experiment, task))
    experiment.add_tag(task)
    experiment.log_parameter("model_version", model_version)
    experiment.log_text(str(context))

    do_lower_case = True
    tokenizer = AutoTokenizer.from_pretrained(
        model_version, do_lower_case=do_lower_case
    )
    tokenizer.pad_token = "[PAD]"
    model = AutoModelForQuestionAnswering.from_pretrained(
        model_version, output_attentions=True, pad_token_id=tokenizer.eos_token_id
    )

    qa = pipeline("question-answering", model=model, tokenizer=tokenizer)
    for prompt in tqdm(prompts):
        prediction = qa(question=prompt, context=context)
        answer = prediction["answer"]

        start = prediction["start"]
        end = prediction["end"]

        attn_data = get_attn_data(
            model, tokenizer, prompt, text_b=answer, return_token_type_ids=True
        )
        viz_params = {
            "attention": attn_data,
            "default_filter": "all",
            "bidirectional": False,
            "display_mode": "light",
            "layer": None,
            "head": None,
        }
        experiment.log_asset_data(
            viz_params,
            f"attn-view-qa-{prompt}-score-{round(prediction['score'], 3)}-start-{start}-end-{end}.json",
        )

    experiment.end()

## 🧞 Text generation

In [9]:
textgen_model_version = "gpt2"
text_gen_prompts = [
    "The animal didn't cross the street because it was too",
    "The dog didn't play at the park becase it was too",
    "I went to the store. At the store I bought fresh",
    "At the store he bought flowers, candy, jewelry, and",
    "The dog ran up the street and barked too",
    "In 2016, the Young Mens' Christian Association (YMCA) was very",
    "The Doctor asked the Nurse a question. She",
    "The Doctor asked the Nurse a question. He",
]
text_generation_viz(
    text_gen_prompts,
    textgen_model_version,
)

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/content' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/amashi/bertviz/ddbc54f4d4f84051a51fc0daa9fef0e5

  0%|          | 0/8 [00:00<?, ?it/s]Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 12%|█▎        | 1/8 [00:01<00:07,  1.05s/it]Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
 25%|██▌       | 2/8 [00:01<00:05,  1.16it/s]Setting `pad_token_id` to `eos_token_id`:50256 for open-end 

## 🎭 Sentiment analysis

In [6]:
sa_prompts = [
    "Many people dislike Steve Jobs, while acknowledging his genius.",
    "The quick, brown fox jumps over the lazy dog.",
    "It was a beautiful day.",
    "It was a horrible day.",
    "I am confused.",
    "That movie was so sick but I wish it was longer.",
    "That movie was so awesome but I wish it was longer.",
    "That movie was so gross but I wish it was longer.",
    "That movie was so available but I wish it was longer.",
]
sa_model_version = "distilbert-base-uncased-finetuned-sst-2-english"

sentiment_viz(sa_prompts, sa_model_version)

[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : text-gen: 993890c8f
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/amashi/bertviz/993890c8f7794a5d931b8bd25e413d24
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Name : text-gen: 993890c8f
[1;38;5;39mCOMET INFO:[0m   Parameters:
[1;38;5;39mCOMET INFO:[0m     model_version : gpt2
[1;38;5;39mCOMET INFO:[0m   Uploads:
[1;38;5;39mCOMET INFO:[0m     environment details : 1
[1;38;5;39mCOMET INFO:[0m     filename            : 1
[1;38;5;39mCOMET INFO:[0m     installed packages  : 1
[1;38;5;39mCOMET

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/629 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

100%|██████████| 9/9 [00:01<00:00,  4.96it/s]
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : sentiment: a5020fb4f
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/amashi/bertviz/a5020fb4f15947378716e649b28c5e34
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Name : sentiment: a5020fb4f
[1;38;5;39mCOMET INFO:[0m   Parameters:
[1;38;5;39mCOMET INFO:[0m     model_version : distilbert-base-uncased-finetuned-sst-2-english
[1;38;5;39mCOMET INFO:[0m   Uploads:
[1;38;5;39mCOMET INFO:[0m     asset               : 9
[1;38;5;39mCOMET INFO:[0m     enviro

## 💬 Question-answering

In [7]:
context = r"""A robot may not injure a human being or, through inaction, allow a human being to come to harm.
A robot must obey the orders given it by human beings except where such orders would conflict with the First Law.
A robot must protect its own existence as long as such protection does not conflict with the First or Second Laws.
"""
questions = [
    "Can a robot hurt a human?",
    "Can a robot injure a human?",
    "Should a robot obey orders from humans?",
    "Can a robot protect itself from a human?",
    "Can a robot love a human?"
]

qa_viz(context, questions, "distilbert-base-uncased-distilled-squad")

[1;38;5;39mCOMET INFO:[0m Couldn't find a Git repository in '/content' nor in any parent directory. Set `COMET_GIT_DIRECTORY` if your Git Repository is elsewhere.
[1;38;5;39mCOMET INFO:[0m Experiment is live on comet.com https://www.comet.com/amashi/bertviz/07129b778fc74059a4d374d8633c2aca



tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/451 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/265M [00:00<?, ?B/s]

100%|██████████| 5/5 [00:02<00:00,  2.09it/s]
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m Comet.ml Experiment Summary
[1;38;5;39mCOMET INFO:[0m ---------------------------------------------------------------------------------------
[1;38;5;39mCOMET INFO:[0m   Data:
[1;38;5;39mCOMET INFO:[0m     display_summary_level : 1
[1;38;5;39mCOMET INFO:[0m     name                  : question-answering: 07129b778
[1;38;5;39mCOMET INFO:[0m     url                   : https://www.comet.com/amashi/bertviz/07129b778fc74059a4d374d8633c2aca
[1;38;5;39mCOMET INFO:[0m   Others:
[1;38;5;39mCOMET INFO:[0m     Name : question-answering: 07129b778
[1;38;5;39mCOMET INFO:[0m   Parameters:
[1;38;5;39mCOMET INFO:[0m     model_version : distilbert-base-uncased-distilled-squad
[1;38;5;39mCOMET INFO:[0m   Uploads:
[1;38;5;39mCOMET INFO:[0m     asset               : 5
[1;38;5;39mCOMET INFO:[0m 

## 🕵 Exploring our results in Comet


Visualizing attention in Comet will help us interpret our models’ decisions by showing how they attend to different parts of the input.

To add BertViz to your dashboard, navigate to Comet’s public panels and select either ‘Transformers Model Viewer’ or ‘Transformers Attention Head Viewer.’

[![adding-bertviz-to-comet.gif](https://s12.gifyu.com/images/SWRLV.gif)](https://www.comet.com/examples/demo-visualizing-attention-bertviz/view/vyr6Nk6Y1cQIklggSZ4A3zSrf/panels)

### 👤 Head View

The attention-head view shows how attention flows between tokens within the same transformer layer by uncovering patterns between attention heads. In this view, the tokens on the left are attending to the tokens on the right and attention is represented as a line connecting each token pair. Colors correspond to attention heads and line thickness represents the attention weight.

In the drop-down menu, we can select the experiment we’d like to visualize, and if we logged more than one asset to our experiment, we can also select our asset. We can then choose which attention layer we’d like to visualize and, optionally, we can choose any combination of attention heads we’d like to see. Note that color intensity of the lines connecting tokens corresponds to the attention weights between tokens.

[![SWRLY.gif](https://s12.gifyu.com/images/SWRLY.gif)](https://www.comet.com/examples/demo-visualizing-attention-bertviz/view/vyr6Nk6Y1cQIklggSZ4A3zSrf/panels)


### 🚂 Model View
The model view is a bird’s-eye perspective of attention across all layers and heads. Here we may notice attention patterns across layers, illustrating the evolution of attention patterns from input to output. Each row of figures represents an attention layer and each column represents individual attention heads. To enlarge the figure for any particular head, we can simply click on it. Note that you can find the same line pattern in the model view as in the head view.


[![SWRLt.gif](https://s12.gifyu.com/images/SWRLt.gif)](https://www.comet.com/examples/demo-visualizing-attention-bertviz/view/vyr6Nk6Y1cQIklggSZ4A3zSrf/panels)

Thanks for making it all the way to the end, and we hope you enjoyed this article. Feel free to connect with us on our Community Slack channel with any questions, comments, or suggestions!

## 📚 Additional Resources

- [The Illustrated Transformer](http://jalammar.github.io/illustrated-transformer/) by Jay Alammar
- [The Illustrated BERT](http://jalammar.github.io/illustrated-bert/) by Jay Alammar

**Questions, comments, or suggestions?** [Join our community Slack and connect with a team member today](https://cometml.slack.com/)!