In [1]:
import torch
from transformers import AutoTokenizer
import math
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from transformers import AutoModel, AutoTokenizer
import torch
from IPython.display import display, HTML

  from .autonotebook import tqdm as notebook_tqdm


# Tools

In [2]:
tokenizer = AutoTokenizer.from_pretrained("tokenizer")

In [3]:
def to_tensor(sentence):
    tokens = tokenizer(sentence)
    return torch.tensor([tokens["input_ids"]])

In [4]:
def get_attention_map(model, x):
    attn_maps = []

    # Embed
    B, T = x.size()
    mask = (x != 0)
    x = model.emb_static(x) + model.emb_pos(torch.arange(0, T))
    
    for block in model.blocks:
        # Get the attention map
        B, T, C = x.size()
        qkv = block.attn.attn_matrix(x)
        q, k, v = qkv.split(block.attn.n_embd, dim=2)
        k = k.view(B, T, block.attn.n_head, C // block.attn.n_head).transpose(1, 2)
        q = q.view(B, T, block.attn.n_head, C // block.attn.n_head).transpose(1, 2)
        v = v.view(B, T, block.attn.n_head, C // block.attn.n_head).transpose(1, 2)
        scale_factor = 1 / math.sqrt(q.size(-1))
        attn_weight = q @ k.transpose(-2, -1) * scale_factor
        attn_weight = torch.softmax(attn_weight, dim=-1)
        attn_maps.append(attn_weight)
        
        # Pass through block
        x = block(x, mask, "cpu")
        
    return torch.cat(attn_maps)

In [7]:
def colorize(words, color_array):
    # words is a list of words
    # color_array is an array of numbers between 0 and 1 of length equal to words
    cmap = matplotlib.cm.get_cmap('Greens')
    template = '<span class="barcode" style="color: black; background-color: {};">{}</span>'
    colored_string = ''
    for word, color in zip(words, color_array):
        color = matplotlib.colors.rgb2hex(cmap(color)[:3])
        colored_string += template.format(color, '&nbsp;' + word + '&nbsp;')
    return colored_string

In [8]:
def compare_models(attn_maps1, attn_maps2, sentence, selected):
    # Average attention scores over the sequence
    avg_attention_scores1 = attn_maps1[:, :, 0, :].mean(dim=[0, 1]).detach().numpy()
    avg_attention_scores2 = attn_maps2[:, :, 0, :].mean(dim=[0, 1]).detach().numpy()

    # Normalize the attention scores
    normalized_attention_scores1 = (
        (avg_attention_scores1 - avg_attention_scores1.min())
        / (avg_attention_scores1.max() - avg_attention_scores1.min())
    )
    normalized_attention_scores2 = (
        (avg_attention_scores2 - avg_attention_scores2.min())
        / (avg_attention_scores2.max() - avg_attention_scores2.min())
    )

    # Tokenize the text for display
    tokens = tokenizer.convert_ids_to_tokens(tokenizer(sentence)["input_ids"])
    for word in selected:
        if word not in tokens:
            return False, None, None, None, None
    indices = [idx for idx in range(len(tokens)) if tokens[idx] in selected]

    # Remove special tokens like [CLS] and [SEP]
    tokens = tokens[1:-1]
    normalized_attention_scores1 = normalized_attention_scores1[1:-1]
    normalized_attention_scores2 = normalized_attention_scores2[1:-1]

    colored_string1 = colorize(tokens, normalized_attention_scores1)
    colored_string2 = colorize(tokens, normalized_attention_scores2)
    
    return (
        True,
        colored_string1,
        colored_string2,
        [normalized_attention_scores1[idx-1] for idx in indices],
        [normalized_attention_scores2[idx-1] for idx in indices]
    )

# Data

In [9]:
sentences = [
    "The movie was fantastic",
    "I hated the movie",
    "The plot was boring",
    "I love this movie",
    "The plot was terrible",
    "This movie is great",
    "The scenes were dirty",
    "I'm satisfied with movie",
    "The DVD arrived late",
    "The subtitles work perfectly",
    "The movie was disappointing",
    "I enjoyed the movie",
    "The pacing is unreliable",
    "The cast were friendly",
    "The script is slow",
    "The movie was great",
    "The DVD was poor",
    "The plot was fascinating",
    "The set was sturdy",
    "The cinematography was ruined",
    "The documentary was engaging",
    "The DVD crashes often",
    "The scenes were delicious",
    "The DVD broke down",
    "The scenery was breathtaking",
    "The service was prompt",
    "The plot was predictable",
    "The tickets overpriced",
    "The service was excellent",
    "The projector overheats",
    "The theater is scenic",
    "The projector stopped",
    "The festival was vibrant",
    "The popcorn runs out",
    "The movie was fun",
    "The screening was delayed",
    "The impact was pleasant",
    "The streaming is unstable",
    "The snacks are fresh",
    "The DVD cracked",
    "The theater has selection",
    "The interface is difficult",
    "The cinema is spacious",
    "The equipment broke",
    "The staff are friendly",
    "The seats are uncomfortable",
    "The movie was heavenly",
    "The equipment is outdated",
    "The theater is well-kept",
    "The plot was confusing"
]

In [10]:
good_tokens = [
    set({"fantastic"}),
    set({"hated"}),
    set({"boring"}),
    set({"love"}),
    set({"terrible"}),
    set({"great"}),
    set({"dirty"}),
    set({"satisfied"}),
    set({"late"}),
    set({"perfectly"}),
    set({"disappointing"}),
    set({"excellent"}),
    set({"unreliable"}),
    set({"friendly"}),
    set({"slow"}),
    set({"great"}),
    set({"poor"}),
    set({"fascinating"}),
    set({"sturdy"}),
    set({"ruined"}),
    set({"engaging"}),
    set({"crashes"}),
    set({"delicious"}),
    set({"broke"}),
    set({"breathtaking"}),
    set({"prompt"}),
    set({"predictable"}),
    set({"overpriced"}),
    set({"excellent"}),
    set({"overheats"}),
    set({"scenic"}),
    set({"stopped"}),
    set({"vibrant"}),
    set({"quickly"}),
    set({"fun"}),
    set({"delayed"}),
    set({"pleasant"}),
    set({"unstable"}),
    set({"fresh"}),
    set({"cracked"}),
    set({"selection"}),
    set({"difficult"}),
    set({"spacious"}),
    set({"broke"}),
    set({"friendly"}),
    set({"uncomfortable"}),
    set({"heavenly"}),
    set({"outdated"}),
    set({"happy"}),
    set({"confusing"})
]

In [11]:
positive_review = [
    True,
    False,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    True,
    False,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False,
    True,
    False
]

# Get Results

In [12]:
result = torch.load("results/1.1/masked_first/1.pt", map_location=torch.device('cpu'))
model_1 = result["model"]
result = torch.load("results/2/masked_first/1.pt", map_location=torch.device('cpu'))
model_2 = result["model"]
colored1 = []
colored2 = []
score_list1 = []
score_list2 = []
good_sentences = []
important = []
review_rating = []

for idx in range(len(sentences)):
    sentence = sentences[idx]
    x = to_tensor(sentence)
    attn_maps_1 = get_attention_map(model_1, x)
    attn_maps_2 = get_attention_map(model_2, x)
    (
        good_parse,
        highlighted1,
        highlighted2,
        scores1,
        scores2,
    ) = compare_models(attn_maps_1, attn_maps_2, sentence, good_tokens[idx])
    if not good_parse:
        continue
    colored1.append(highlighted1)
    colored2.append(highlighted2)
    score_list1.append(scores1[0])
    score_list2.append(scores2[0])
    good_sentences.append(sentence)
    important.append(list(good_tokens[idx])[0])
    review_rating.append(positive_review[idx])

In [19]:
string = """<!DOCTYPE html>
<html>
<head>
<style>
table {
  font-family: arial, sans-serif;
  border-collapse: collapse;
  width: 100%;
}

td, th {
  border: 1px solid #dddddd;
  text-align: left;
  padding: 8px;
}

tr:nth-child(even) {
  background-color: #dddddd;
}
</style>
</head>
<body>

<table>
  <tr>
    <th>Label</th>
    <th>Optimal Token</th>
    <th>l_{1.1}-MD Token Selection</th>
    <th>l_{2}-MD Token Selection</th>
    <th>Better Selector</th>
  </tr>
"""

In [20]:
for idx in range(40):
    string += f"""
  <tr>
    <td>{"+" if review_rating[idx] else "-"}</td>
    <td>{important[idx]}</td>
    <td>{colored1[idx]}</td>
    <td>{colored2[idx]}</td>
    <td>{"=" if score_list1[idx] == score_list2[idx] else "1.1" if score_list1[idx] > score_list2[idx] else "2"}</td>
  </tr>
"""

In [21]:
string += """</table>

</body>
</html>"""

In [None]:
f = open("results/img/attention_map.html", "w")
f.write(string)
f.close()

In [123]:
print(f"Score for 1.1-MD = {torch.tensor(score_list1).mean().item()}")

tensor(0.7594)

In [124]:
print(f"Score for GD = {torch.tensor(score_list2).mean().item()}")

tensor(0.5533)