# Getting Started with TransformerLens (Easy-Transformer)

Welcome! This notebook will guide you through the basics of using `TransformerLens`, the successor to `Easy-Transformer`. We will cover:
1.  **Installation**
2.  **Loading a pre-trained model** (GPT-2 Small)
3.  **Running the model** and accessing internal activations.
4.  Performing a simple **ablation experiment** to understand the model's behavior.

Let's get started!

In [None]:
### 1. Installation
# First, let's install the library. TransformerLens is the new name for Easy-Transformer.
# The '!' allows us to run shell commands directly from the notebook.

!pip install transformer-lens

### 2. Loading a Model

Now that the library is installed, we can import it and load a pre-trained model. We'll use `gpt2-small`, a great model for introductory interpretability work.

In [None]:
import torch
import transformer_lens

# Load the GPT-2 Small model from TransformerLens
# This automatically downloads the weights and sets up the model with all the hooks we need.
model = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")

### 3. Running the Model & Accessing Activations

Let's give the model some text and see what it predicts. The key feature of `TransformerLens` is the ability to easily cache and view all the intermediate activations inside the model. We can do this with the `run_with_cache` method.

# Sample text
prompt = "The quick brown fox jumps over the lazy"
print(f"Original prompt: '{prompt}'")

# Run the model and get both the final output (logits) and a cache of all internal activations.
original_logits, cache = model.run_with_cache(prompt)

# The logits tensor has shape [batch_size, sequence_position, vocabulary_size]
print("Logits tensor shape:", original_logits.shape)

# Let's find the model's prediction for the *next* token
last_token_logits = original_logits[0, -1, :]
predicted_token_index = torch.argmax(last_token_logits).item()
predicted_token = model.to_string([predicted_token_index])

print(f"Model's top prediction for the next word: '{predicted_token}'")

# The 'cache' is a dictionary mapping activation names to their values
# Let's see what the activation of the query vector of head 7 in layer 4 looks like.
# The format is utils.get_act_name(activation_name, layer_index)
key_to_check = transformer_lens.utils.get_act_name("q", 4)
activation_shape = cache[key_to_check][0, :, 7, :].shape # Batch, Seq Pos, Head Index, Head Dim

print(f"\nShape of query vector for layer 4, head 7: {activation_shape}")

### 4. Running an Ablation Experiment

**Ablation** is the process of removing or disabling a part of the model to see how it affects the output. This helps us understand what different components (like an attention head) are responsible for.

Here, we will "zero-ablate" a specific attention head. We'll write a **hook function** that intercepts the head's activation during the forward pass and replaces it with zeros. We can then see how this changes the model's prediction.

# We will target layer 8, attention head 10.
# This head is known to be involved in identifying patterns of repeated tokens.
LAYER_TO_ABLATE = 8
HEAD_TO_ABLATE = 10

def zero_ablate_head_hook(
    activation_value: torch.Tensor,
    hook: transformer_lens.hook_points.HookPoint
) -> torch.Tensor:
    """
    This function is our hook. It takes an activation and a hook object.
    It modifies the activation in-place to zero out a specific head.
    """
    print(f"Ablating Layer {LAYER_TO_ABLATE}, Head {HEAD_TO_ABLATE}...")
    # Shape of activation_value: [batch, seq_pos, head_index, d_head]
    activation_value[:, :, HEAD_TO_ABLATE, :] = 0.
    return activation_value

# The 'add_hook' method returns a new set of logits.
# It runs the model with our custom hook function applied at the specified activation point.
ablated_logits = model.run_with_hooks(
    prompt,
    fwd_hooks=[(transformer_lens.utils.get_act_name("z", LAYER_TO_ABLATE), zero_ablate_head_hook)]
)

# Find the ablated model's prediction
ablated_last_token_logits = ablated_logits[0, -1, :]
ablated_predicted_token_index = torch.argmax(ablated_last_token_logits).item()
ablated_predicted_token = model.to_string([ablated_predicted_token_index])

print(f"\nOriginal prediction: '{predicted_token}'")
print(f"Prediction after ablating L{LAYER_TO_ABLATE}H{HEAD_TO_ABLATE}: '{ablated_predicted_token}'")

### 5. Quantifying the Effect

We can see the prediction changed! To measure this change more formally, we can calculate the **KL Divergence** between the original probability distribution (from the original logits) and the ablated one. A higher KL divergence means the ablation had a bigger impact.

import torch.nn.functional as F

def get_log_probs(logits):
    # Use the last token's logits for prediction
    last_token_logits = logits[0, -1, :]
    return F.log_softmax(last_token_logits, dim=-1)

original_log_probs = get_log_probs(original_logits)
ablated_log_probs = get_log_probs(ablated_logits)

# KL Divergence D_KL(original || ablated)
# Measures how much information is lost when using the ablated distribution to approximate the original.
kl_div = F.kl_div(
    input=ablated_log_probs, # The approximation Q
    target=original_log_probs, # The true distribution P
    log_target=True,
    reduction="sum"
)

print(f"KL Divergence between original and ablated distributions: {kl_div.item():.4f}")
print("\nA higher KL divergence indicates the ablated component was more important for the original prediction.")

### Conclusion

You've successfully used TransformerLens to:
- Load a model.
- Inspect its internal activations.
- Run an ablation experiment by zeroing out a specific attention head.
- Quantify the impact of this change.

This is a core workflow in mechanistic interpretability. From here, you can explore ablating different components, patching activations from one run to another, and much more!