# Extracting "cognitive" signals from language models

This notebook will cover:

1. [Extracting surprisal](#1-extracting-surprisal)
2. [Extracting attention scores](#2-extracting-attention-scores)
3. [Extracting word embeddings](#3-extracting-word-embeddings)

We will use Hugging Face's [transformers](https://huggingface.co/docs/transformers/en/index) library.

The techniques shown in this notebook only do a single forward pass through the model, and the models are quite small, so you should be able to run everything on your laptop.

In [None]:
%pip install transformers torch

## 1. Extracting surprisal

Surprisal is the negative log-probability of a word given its left-hand context:

$surprisal(w_i) = -\log(P(w_i | w_1, w_2 \dots, w_{i-1}))$

In GPT-style transformer language models, the attention on the right-hand context is masked for every token, therefore the output probabilities for each token are only based on the left-hand context. This means that we can do a single forward pass with the entire text and get the surprisal values for each token at once, instead of having to do a forward pass for every single token.

Let's test this with GPT-2:

In [None]:
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = AutoModelForCausalLM.from_pretrained("gpt2")

# Encode the input text
input_text = "Words are flowing out like endless rain into a paper cup"
inputs = tokenizer(input_text, return_tensors="pt")

# Generate output logits and convert to log-probabilities
with torch.no_grad():
    outputs = model(**inputs)
logits = outputs.logits[0]
logprobs = F.log_softmax(logits, dim=-1)

# Get the log-probabilities of the true next tokens and convert to surprisals
next_token_input_ids = inputs.input_ids[0, 1:]
next_token_logprobs = logprobs.gather(1, next_token_input_ids.unsqueeze(-1)).squeeze(-1)
surprisal = -next_token_logprobs

# Print the surprisal values
for token_id, surprisal_value in zip(next_token_input_ids, surprisal):
    token = tokenizer.decode(token_id)
    print(f"{token!r}\t{surprisal_value.item()}")

Note that we don't get a surprisal value for the first word. This is because GPT-2 was trained without a BOS token. The first word in the sequence is *Words*, therefore the first predicted next token is the second word in the sequence (*are*).

Can you explain why certain words have high/low surprisal?

## 2. Extracting attention scores

Transformer language models usually have multiple layers and multiple attention heads within each layer. So for each token in the input sequence, we get a whole bunch of attention scores, making it more difficult to interpret them.

GPT-2 has 12 layers and 12 attention heads per layer:

In [None]:
with torch.no_grad():
    outputs = model(**inputs, output_attentions=True)

for layer, attentions in enumerate(outputs.attentions, 1):
    print(f"Layer {layer}: {attentions.size()}")

The first dimension in the tensors is the batch dimension. Can you explain what the other dimensions mean?

To simplify things, let's average the attention scores across all layers and heads and visualize them:

In [None]:
mean_attentions = torch.concatenate(outputs.attentions, dim=0).mean(dim=(0, 1))
mean_attentions.size()

In [None]:
import matplotlib.pyplot as plt

plt.imshow(mean_attentions.numpy(), cmap='gray')
plt.colorbar()
tokens = [tokenizer.decode(token_id) for token_id in inputs.input_ids[0]]
plt.xticks(ticks=range(len(tokens)), labels=tokens, rotation=90)
plt.yticks(ticks=range(len(tokens)), labels=tokens)
plt.show()

Try to answer the following questions:

- Which token(s) receive(s) the most attention?
- Why is the upper right half completely black?
- Do the attention patterns look similar in all layers and attention heads? Can you find heads or layers with more interesting patterns?

## 3. Extracting word embeddings

Like most modern NLP models that process text input, the very first layer in a transformer model is a regular embedding layer:

In [None]:
model

Can you figure out what this model's vocabulary size is? And what is its embedding size?

While we *could* use the embeddings produced by the first layer as word representations, this is not be a very good representation of that word's meaning (can you explain why?). What people usually do is use hidden states from intermediate layers. Here's how you can access these:

In [None]:
with torch.no_grad():
    outputs = model(**inputs, output_hidden_states=True)

for layer, hidden_states in enumerate(outputs.hidden_states, 1):
    print(f"Layer {layer}: {hidden_states.size()}")

Again, the first dimension is the batch dimension.

Now, which layer(s) you use for your representation depends on your use case. The early layers tend to contain less information about context, while the very last layer mainly contains information about the next word to be predicted. A common approach is to average the last couple of layers (see, for example, Table 7 in the original BERT paper: [Devlin et al., 2019](https://doi.org/10.18653/v1/N19-1423)):

In [None]:
embeddings = outputs.hidden_states[-4:]  # Use the last 4 layers
mean_embedding = torch.mean(torch.concatenate(embeddings), dim=0)
mean_embedding

We have now extracted word embeddings from GPT-2. Since GPT-2 is a unidirectional language model, each embedding will only contain information about the word's left-hand context. To get bidirectionally contextualized embeddings, we would have to use a bidirectional model like BERT.

In [None]:
# TODO: Use a BERT model to get bidirectionally contextualized embeddings