In [3]:
import torch
from transformers import AutoTokenizer, AutoModel, utils
utils.logging.set_verbosity_error()  # Suppress standard warnings
from bertviz import model_view, head_view
from scipy.linalg import toeplitz


model_name = "microsoft/xtremedistil-l12-h384-uncased"  # Find popular HuggingFace models here: https://huggingface.co/models
model = AutoModel.from_pretrained(model_name, output_attentions=True)  # Configure model to return attention values
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm


## Simple, full attention
"The cat sat on the mat"

In [2]:
input_text = "The cat sat on the mat"  
inputs = tokenizer.encode(input_text, return_tensors='pt')  # Tokenize input text
outputs = model(inputs)  # Run model
attention = outputs[-1]  # Retrieve attention from model outputs
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to token strings
model_view(attention, tokens)  # Display model view

<IPython.core.display.Javascript object>

In [3]:
t = torch.stack([attention[i].mean(dim=[0,1]) for i in range(len(attention))]).mean(dim=0)
t_a = tuple([torch.unsqueeze(torch.unsqueeze(t, 0).repeat(len(attention), 1, 1),0) for _ in range(len(attention))])
head_view(t_a, tokens)  # Display model view

<IPython.core.display.Javascript object>

## Anaphora resolution type 1
"The cat sat on the mat, because it was tired"\
*(it = cat)*

Removing [CLS] and [SEP] tokens and renormalizing, because we want the between words attention

In [4]:
input_text = "The cat sat on the mat, because it was tired"  
inputs = tokenizer.encode(input_text, return_tensors='pt')  # Tokenize input text
outputs = model(inputs)  # Run model
attention = outputs[-1]  # Retrieve attention from model outputs
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to token strings
model_view(attention, tokens)  # Display model view

<IPython.core.display.Javascript object>

In [5]:
t = torch.stack([attention[i].mean(dim=[0,1]) for i in range(len(attention))]).mean(dim=0)
t = t[1:-1, 1:-1]
t = (t*(t.sum(dim=1)**-1)[:,None])
t_a = tuple([torch.unsqueeze(torch.unsqueeze(t, 0).repeat(len(attention), 1, 1),0) for _ in range(len(attention))])
head_view(t_a, tokens[1:-1])

<IPython.core.display.Javascript object>

## Anaphora resolution type 2
"The cat sat on the mat, because it was soft"\
*(it = mat)*

These anaphorae resolution don't seem to work too much for this model

In [6]:
input_text = "The cat sat on the mat, because it was soft"
inputs = tokenizer.encode(input_text, return_tensors='pt')  # Tokenize input text
outputs = model(inputs)  # Run model
attention = outputs[-1]  # Retrieve attention from model outputs
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to token strings
model_view(attention, tokens)  # Display model view

<IPython.core.display.Javascript object>

In [7]:
t = torch.stack([attention[i].mean(dim=[0,1]) for i in range(len(attention))]).mean(dim=0)
t = t[1:-1, 1:-1]
t = (t*(t.sum(dim=1)**-1)[:,None])
t_a = tuple([torch.unsqueeze(torch.unsqueeze(t, 0).repeat(len(attention), 1, 1),0) for _ in range(len(attention))])
head_view(t_a, tokens[1:-1])

<IPython.core.display.Javascript object>

## Anaphora vs no anaphora

In [4]:
def compute_attentionCost(t, strategy='mean'):
    """Computes the min cost of the flow of attention across the linear graph of tokens.

    Args:
        t (tensor): dim=[num_layers, num_heads, num_tokens, num_tokens]

    Returns:
        float: cost of the flow to match attentional and linear graph
    """
    if strategy == 'mean':
        t = t.mean(dim=[0,1])
    elif strategy == 'max':
        t = t.max(dim=[0,1])
    else:
        raise ValueError(f"Invalid strategy: {strategy}")
    t_steps = torch.tensor(toeplitz(torch.arange(len(t)), torch.arange(len(t))))
    t_costs = t*t_steps
    return t_costs.sum().item()

In [27]:
input_text = "Ludwig van Beethoven was a famous German composer."
# input_text = "Alice is eating pizza. I like her."

inputs = tokenizer.encode(input_text, return_tensors='pt')  # Tokenize input text
outputs = model(inputs)  # Run model
attention = outputs[-1]  # Retrieve attention from model outputs

t = torch.stack(attention)[:,0,:,1:-1,1:-1]
t = t*t.sum(dim=3, keepdim=True)**-1
assert t.sum(dim=3).allclose(torch.tensor(1.0))
print(t.shape)
cost = compute_attentionCost(t)
n_token = len(inputs[0])-2
print(f"Number of tokens: {n_token}\nt_costs.sum(): {cost:.2f}, t_costs.sum()/len(t_costs): {cost/n_token:.2f}")

torch.Size([12, 12, 9, 9])
Number of tokens: 9
t_costs.sum(): 22.18, t_costs.sum()/len(t_costs): 2.46


In [28]:
tokens = tokenizer.convert_ids_to_tokens(inputs[0])  # Convert input ids to token strings
model_view(attention, tokens)  # Display model view

<IPython.core.display.Javascript object>

In [29]:
t = torch.stack([attention[i].mean(dim=[0,1]) for i in range(len(attention))]).mean(dim=0)
t = t[1:-1, 1:-1]
t = (t*(t.sum(dim=1)**-1)[:,None])
t_a = tuple([torch.unsqueeze(torch.unsqueeze(t, 0).repeat(len(attention), 1, 1),0) for _ in range(len(attention))])
head_view(t_a, tokens[1:-1])

<IPython.core.display.Javascript object>

# <>

In [35]:
t = torch.stack(attention)[:,0,:,1:-1,1:-1]
t = t*t.sum(dim=3, keepdim=True)**-1
compute_attentionCost(t)

tensor(97.0174, grad_fn=<SumBackward0>)

In [13]:
attention[0][0,0].sum(axis=1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

In [14]:
model.encoder.layer[1]

BertLayer(
  (attention): BertAttention(
    (self): BertSdpaSelfAttention(
      (query): Linear(in_features=384, out_features=384, bias=True)
      (key): Linear(in_features=384, out_features=384, bias=True)
      (value): Linear(in_features=384, out_features=384, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=384, out_features=384, bias=True)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=384, out_features=1536, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=1536, out_features=384, bias=True)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [15]:
model.encoder.layer[0]

BertLayer(
  (attention): BertAttention(
    (self): BertSdpaSelfAttention(
      (query): Linear(in_features=384, out_features=384, bias=True)
      (key): Linear(in_features=384, out_features=384, bias=True)
      (value): Linear(in_features=384, out_features=384, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (output): BertSelfOutput(
      (dense): Linear(in_features=384, out_features=384, bias=True)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
  )
  (intermediate): BertIntermediate(
    (dense): Linear(in_features=384, out_features=1536, bias=True)
    (intermediate_act_fn): GELUActivation()
  )
  (output): BertOutput(
    (dense): Linear(in_features=1536, out_features=384, bias=True)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [16]:
model.encoder.layer[0].attention

BertAttention(
  (self): BertSdpaSelfAttention(
    (query): Linear(in_features=384, out_features=384, bias=True)
    (key): Linear(in_features=384, out_features=384, bias=True)
    (value): Linear(in_features=384, out_features=384, bias=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (output): BertSelfOutput(
    (dense): Linear(in_features=384, out_features=384, bias=True)
    (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [17]:
list(model.modules())

[BertModel(
   (embeddings): BertEmbeddings(
     (word_embeddings): Embedding(30522, 384, padding_idx=0)
     (position_embeddings): Embedding(512, 384)
     (token_type_embeddings): Embedding(2, 384)
     (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
     (dropout): Dropout(p=0.1, inplace=False)
   )
   (encoder): BertEncoder(
     (layer): ModuleList(
       (0-11): 12 x BertLayer(
         (attention): BertAttention(
           (self): BertSdpaSelfAttention(
             (query): Linear(in_features=384, out_features=384, bias=True)
             (key): Linear(in_features=384, out_features=384, bias=True)
             (value): Linear(in_features=384, out_features=384, bias=True)
             (dropout): Dropout(p=0.1, inplace=False)
           )
           (output): BertSelfOutput(
             (dense): Linear(in_features=384, out_features=384, bias=True)
             (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
             (dropout): Dropou