[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/googlecolab/colabtools/blob/master/notebooks/colab-github-demo.ipynb)

# Description

This notebook examines the attention weights of a T5 model fine-tuned for [sentiment span extraction](https://huggingface.co/mrm8488/t5-base-finetuned-span-sentiment-extraction). This T5 model takes a span of text containing positive or negative (or neutral) sentiment and will extract the subsequence containing the sentiment. For example, given the input text `question: negative context: You're a nice person, but your feet stink.`, the model should return the span `your feet stink.`. If the input text replaced "negative" with "positive", then the model returns the span `nice person,`.

We want to see if the attention mechanism highlights the negative sentiment span when the input text asks for negative context and the same for positive sentiment. We'll use the [BertViz](https://github.com/jessevig/bertviz) library to view the attention weights.


# Environment setup

This notebook uses an older version of Huggingface Transformers because the T5 model being used doesn't work with the most recent version.

In [None]:
!pip install transformers==4.11.3 sentencepiece==0.1.96 bertviz==1.0.0

In [None]:
from transformers import (T5ForConditionalGeneration, 
                          AutoTokenizer)

In [None]:
!nvidia-smi

In [None]:
%matplotlib inline
import torch
import transformers
import numpy as np

print(torch.__version__)
print(transformers.__version__)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model

## Loading the model and tokenizer

In [None]:
from transformers import T5ForConditionalGeneration, AutoTokenizer

model_name = "mrm8488/t5-base-finetuned-span-sentiment-extraction"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model = model.to(device)

## Inference

In [None]:
def get_sentiment_span(text, sentiment):
    """
    Given a string of text and sentiment type,
    return the substring of input text that contains 
    the specified sentiment.
    """
    query = f"question: {sentiment} context: {text}"
    input_ids = tokenizer.encode(
        query, 
        return_tensors="pt", 
        add_special_tokens=True).to(device)
    generated_ids = model.generate(
        input_ids=input_ids, 
        num_beams=1, 
        max_length=80).squeeze()
    predicted_span = tokenizer.decode(
        generated_ids, 
        skip_special_tokens=True, 
        clean_up_tokenization_spaces=True)
    return predicted_span

In [None]:
text = "You're a nice person, but your feet stink."
get_sentiment_span(text, "positive")


In [None]:
get_sentiment_span(text, "negative")

# Visualizing attention weights with BertVis

The T5 model has 12 layers, each of which has three attention mechanisms: 

1. encoder self-attention
2. decoder self-attention
3. cross-attention

Each attention mechanism has 12 heads, and thus has 144 sets of attention weights, one for each choice of layer and attention head. As mentioned above, we'll use the [BertViz](https://github.com/jessevig/bertviz) library to view the weights of the attention heads.

In [None]:
from bertviz import head_view


In [None]:
def view_cross_attn_heads(text, sentiment, layer=None, heads=None):   
    query = f"question: {sentiment} context: {text}"
    input_ids = tokenizer.encode(
        query, 
        return_tensors="pt", 
        add_special_tokens=True).to(device)
    
    with torch.no_grad():
        output = model.forward(
            input_ids=input_ids, 
            decoder_input_ids=input_ids, 
            output_attentions=True, 
            return_dict=True)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    head_view(output.cross_attentions, tokens, layer=layer, heads=heads)


def view_decoder_attn_heads(text, sentiment, layer=None, heads=None):    
    query = f"question: {sentiment} context: {text}"
    input_ids = tokenizer.encode(
        query, 
        return_tensors="pt", 
        add_special_tokens=True).to(device)
    
    with torch.no_grad():
        output = model.forward(
            input_ids=input_ids, 
            decoder_input_ids=input_ids, 
            output_attentions=True, 
            return_dict=True)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    head_view(output.decoder_attentions, tokens, layer=layer, heads=heads)


def view_encoder_attn_heads(text, sentiment, layer=None, heads=None):    
    query = f"question: {sentiment} context: {text}"
    input_ids = tokenizer.encode(
        query, 
        return_tensors="pt", 
        add_special_tokens=True).to(device)
    
    with torch.no_grad():
        output = model.forward(
            input_ids=input_ids, 
            decoder_input_ids=input_ids, 
            output_attentions=True, 
            return_dict=True)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    head_view(output.encoder_attentions, tokens, layer=layer, heads=heads)

For starters, we want to see if the weights of any of the attention heads show the word "positive" in the input text attending to any of the tokens in the extracted subsequence.

In [None]:
text = "You're a nice person, but your feet stink."

In [None]:
view_encoder_attn_heads(text, "positive")

In [None]:
view_decoder_attn_heads(text, "positive")

In [None]:
view_cross_attn_heads(text, "positive")

In [None]:
view_encoder_attn_heads(text, "positive", layer=6, heads=[11])

In [None]:
view_encoder_attn_heads(text, "negative")

In [None]:
view_cross_attn_heads(text, "negative")

Now we'll look at an example with positive *and* negative sentiment.

In [None]:
text = "It was the best of times, it was the worst of times."

In [None]:
get_sentiment_span(text, "positive")

In [None]:
get_sentiment_span(text, "negative")

In [None]:
view_encoder_attn_heads(text, "positive")

In [None]:
view_encoder_attn_heads(text, "negative")

In [None]:
view_decoder_attn_heads(text, "positive")

In [None]:
view_decoder_attn_heads(text, "negative")

In [None]:
view_cross_attn_heads(text, "positive")

In [None]:
view_cross_attn_heads(text, "negative")