<a href="https://colab.research.google.com/github/EffiSciencesResearch/ML4G-2.0/blob/master/workshops/transformer_interp/logit_lens_probing_atlas_old.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# GPT-2 interpretability

## Logit lens

### First step: understanding Logit lens

Read: https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens

### Second step: Reproducing the results

Reimplement the Logit lens in a minimal way by reproducing the figure at the end of this section.
This exercice is quite unguided because being able to use the transformer library autonomously is very important.

Resources if you are stuck:
- Read about hooks here https://pytorch.org/tutorials/beginner/former_torchies/nnft_tutorial.html#forward-and-backward-function-hooks
- pip install transformer-utils and use the function _plot_logit_lens https://github.dev/nostalgebraist/transformer-utils/tree/main/src/transformer_utils/logit_lens


You should optain this figure:

<!-- ![results from logit lens](./results.png) -->

![results from the logit lens](https://github.com/EffiSciencesResearch/ML4G-2.0/blob/master/workshops/transformer_interp/results.png?raw=true)

In [None]:
# Setup: Don't read, just run

try:
    import google.colab

    IN_COLAB = True
except:
    IN_COLAB = False

if IN_COLAB:
    # Install packages
    %pip install transformers jaxtyping einops typeguard==2.13.3 -q

    !wget -q https://github.com/EffiSciencesResearch/ML4G-2.0/archive/refs/heads/master.zip
    !unzip -o /content/master.zip 'ML4G-2.0-master/workshops/transformer_interp/*'
    !mv --no-clobber ML4G-2.0-master/workshops/transformer_interp/* .
    !rm -r ML4G-2.0-master

    print("Imports & installations complete!")

else:
    from IPython import get_ipython

    ipython = get_ipython()
    ipython.run_line_magic("load_ext", "autoreload")
    ipython.run_line_magic("autoreload", "2")

In [None]:
from functools import partial

import einops
import torch
import transformers
from jaxtyping import Float, Int, jaxtyped
from torch import Tensor
from typeguard import typechecked

from utils import plot_logit_lens_low_level

Hints:
- GPT-2 has tied embeddings, so the embedding and unembedding matrices are the same. You can access them with `model.base_model.wte.weight`
- Do not forget to apply the final LayerNorm to normalise the residual stream before applying the softmax. You can access it with `model.base_model.ln_f`

In [None]:
# Loading GPT2 and its tokenizer
gpt2 = transformers.AutoModelForCausalLM.from_pretrained("gpt2")
gpt2_tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
gpt2.eval()

# Looking at the model to see the name of the different layers
print(gpt2)

In [None]:
@jaxtyped(typechecker=typechecked)
def plot_logit_lens(
    per_layer_logits: Float[Tensor, "layer nb_tokens vocab=50257"],
    per_layer_token_to_show: Int[Tensor, "layer nb_tokens"],
    input_ids: Int[Tensor, "batch=1 nb_tokens"],
    tokenizer=gpt2_tokenizer,
):
    plot_logit_lens_low_level(
        per_layer_logits.detach(),
        per_layer_token_to_show.detach(),
        per_layer_logits.softmax(dim=-1).detach(),
        tokenizer,
        # Hack: add the end-of-text token to avoid crash in _plot_logit_lens
        input_ids=torch.cat([input_ids, torch.tensor([[50256]])], dim=1),
        # input_ids=input_ids,
        start_ix=0,
        layer_names=None,
        probs=True,
    )

In [None]:
prompt = "Happy birthday to you, happy birthday to"

...  # Implement logit lens

plot_logit_lens(
    per_layer_logits,
    per_layer_token_to_show,
    input_ids,
)

<details>
  <summary>Hint: steps</summary>

```python
# 1. Define a hook that stores the output of the layer
# 2. Add the hook to each layer, use partial to pass the layer index
# 3. Run the model on the input, then remove the hooks
# 4. For each layer
# 4.1. Normalize the output using the final layer norm
# 4.2. Compute the word distribution using the word embeddings
# 4.3. Find the most likely token
```
</details>

<details>
  <summary>Click to see the solution</summary>

```python
n_layers_gpt = len(gpt2.base_model.h)
outputs = [None] * n_layers_gpt

input_ids = gpt2_tokenizer.encode(prompt, return_tensors="pt")

# 1. Define a hook that stores the output of the layer
def memorize_output_layer_hook(self, input, output, layer):
    # Remark: the `global` keyword is not necessary, because we are modifying 
    # the content of a list. `global` would have been necessary if we were
    # overwriting the list (i.e. outputs = ...)
    outputs[layer] = output[0].detach()


# 2. Add the hook to each layer, use partial to pass the layer index
handles = [
    gpt_block.register_forward_hook(partial(memorize_output_layer_hook, layer=layer))
    for layer, gpt_block in enumerate(gpt2.base_model.h)
]

# 3. Run the model on the input, then remove the hooks
try:
    with torch.no_grad():
        gpt2(input_ids)
finally:
    for handle in handles:
        handle.remove()


last_layer_norm = gpt2.base_model.ln_f
word_embeddings = gpt2.base_model.wte.weight


per_layer_token_to_show = []
per_layer_logits = []

# 4. For each layer
for layer, output in enumerate(outputs):
    # 4.1. Normalize the output using the final layer norm
    normalized_output = last_layer_norm(output)

    # 4.2. Compute the word distribution using the word embeddings
    word_distribution = einops.einsum(
        normalized_output, word_embeddings, 
        "batch token d_model, vocab d_model -> token vocab"
    )
    # 4.3. Find the most likely token
    best_token = torch.argmax(word_distribution, dim=-1)
    output_text = gpt2_tokenizer.decode(best_token)
    print(output_text)
    
    per_layer_token_to_show.append(best_token) 
    per_layer_logits.append(word_distribution) 
    

per_layer_logits = torch.stack(per_layer_logits)
per_layer_token_to_show = torch.stack(per_layer_token_to_show)
```
</details>

## Probing

Look at this youtube video to introduce you to probing: https://www.youtube.com/watch?v=HJn-OTNLnoE

We will use the fetch_20newsgroups dataset, and we will classify the news according to which journal they come from.
We will try to implement a small probe and analyse each layer of GPT-2. Which layer contains most of the information we are insterested with?

Questions: 
- What is your strategy to use the internal states of gpt-2 as features for classification?
- Propose 2 other strategies that won't work.
- Try to predict the score of each layer at classifying the fetch_20newsgroups dataset.
- Implement and check your prediction.

Bonu read this: https://arxiv.org/pdf/1704.01444.pdf



In [None]:
from sklearn.datasets import fetch_20newsgroups

categories = ["alt.atheism", "soc.religion.christian", "comp.graphics", "sci.med"]

twenty_train = fetch_20newsgroups(
    subset="train", categories=categories, shuffle=True, random_state=42
)

In [None]:
print(twenty_train.data[0])

In [None]:
twenty_train.target[0]

In [None]:
twenty_train.target_names[twenty_train.target[0]]

In [None]:
len(twenty_train.target)

https://huggingface.co/docs/transformers/main_classes/output#transformers.modeling_outputs.CausalLMOutputWithCrossAttentions

In [None]:
import numpy as np
import torch
from tqdm.auto import tqdm

torch.cuda.empty_cache()
embed_dim = 768
n_layers = 12
N = len(twenty_train.data)

# We only take the last token
# If you do not put everything in a single array, the memory explodes
hidden_states = np.zeros((N, n_layers, embed_dim))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

activations = np.zeros((N, n_layers, embed_dim))

gpt2 = gpt2.to(device)
gpt2.eval()


# Fill hidden_states
...

In [None]:
np.save("hidden_states", hidden_states)
hidden_states.shape

In [None]:
# If you do not have a gpu, use this line.
# hidden_states = np.load("hidden_states.npy")

In [None]:
# Check the documentation of sklearn and use those imports to score each layer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

scores = []
for layer in range(n_layers):
    ...

In [None]:
import matplotlib.pyplot as plt

plt.plot(list(range(len(scores))), scores)

# Activation Atlas

With out dataset we can also try to minimally reproduce the activation atlas paper:
https://openai.com/blog/introducing-activation-atlases/


Questions:
- How to implement dimensionality reduction?
- Install umap-learn.
- Visualize the umap of the best previously selected layer. Comment.

Bonus: Plot the sentences alongside the point in the UMAP plot. Chack that everything makes sense.

In [None]:
!pip install umap-learn -q

In [None]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

# Dimension reduction and clustering library
import umap as umap

In [None]:
# Choose the best_layer
best_layer = ...
X = hidden_states[:, best_layer, :]
y = twenty_train.target

standard_embedding = umap.UMAP(random_state=42).fit_transform(X_train)
plt.scatter(
    standard_embedding[:, 0], standard_embedding[:, 1], c=y.astype(int), s=0.1, cmap="Spectral"
);