# Logit Lens Debug Version (Local GPT-2)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week1/logit_lens_debug.ipynb)

**This is a debug version using GPT-2 locally.** Use this to test nnsight patterns before running on large models via NDIF.

**Colab Setup:** Go to Runtime > Change runtime type > Select **T4 GPU**

This notebook demonstrates the **logit lens** technique using [nnsight](https://nnsight.net/). The logit lens lets us peek inside a transformer to see what the model is "thinking" at each layer.

**Key Idea:** At each layer, we project the hidden states into vocabulary space using the model's unembedding matrix. This reveals how the model's predictions evolve as information flows through the network.

## References
- [nostalgebraist's Logit Lens post](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)
- [nnsight Logit Lens tutorial](https://nnsight.net/notebooks/tutorials/probing/logit_lens/)
- [nnsight documentation](https://nnsight.net/)

## Setup

Install nnsight and check GPU availability:

In [None]:
!pip install -q nnsight

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from nnsight import LanguageModel

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Local execution - no NDIF needed
REMOTE = False

## Load GPT-2

GPT-2 is small enough to run locally on Colab. This lets us debug nnsight patterns.

In [None]:
# Load GPT-2 locally
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

print(f"Model: {model.config._name_or_path}")
print(f"Layers: {model.config.n_layer}")
print(f"Hidden size: {model.config.n_embd}")
print(f"Vocabulary size: {model.config.vocab_size}")

## Test Prompt

Let's start with a simple prompt to test the logit lens:

In [None]:
# Classic logit lens example
prompt = "The Eiffel Tower is in the city of"
print(f"Prompt: {prompt}")

## Basic Logit Lens: Watching Predictions Develop

Let's see how the model's predictions evolve across layers. This uses the same pattern as the NDIF notebook.

In [None]:
def get_value(saved):
    """Helper to get value from saved tensor (handles local vs remote)."""
    try:
        return saved.value
    except AttributeError:
        return saved

def get_top_predictions(probs, indices, tokenizer):
    """Convert saved top-k probs and indices to token strings."""
    return [(tokenizer.decode([idx]), prob.item()) for idx, prob in zip(indices, probs)]

def logit_lens_layers(prompt, model, layers_to_check=None, remote=False, top_k=10):
    """
    Run logit lens on specified layers.
    Returns top-k predictions per layer (much less data than full vocab).
    GPT-2 version (uses transformer.h, transformer.ln_f).
    """
    n_layers = model.config.n_layer
    if layers_to_check is None:
        # Check every 2nd layer plus first and last for GPT-2 (12 layers)
        layers_to_check = [0, 2, 4, 6, 8, 10, n_layers-1]
        layers_to_check = [l for l in layers_to_check if l < n_layers]
    
    layer_results = {}
    
    # Single trace call - compute top-k on server to reduce data transfer
    with model.trace(prompt, remote=remote):
        for layer_idx in layers_to_check:
            hidden = model.transformer.h[layer_idx].output[0]
            logits = model.lm_head(model.transformer.ln_f(hidden))
            # Compute softmax and top-k, only save small tensors
            probs = torch.softmax(logits[0, -1], dim=-1)
            top_probs, top_indices = probs.topk(top_k)
            layer_results[layer_idx] = (top_probs.save(), top_indices.save())
    
    # Extract values after trace completes
    return {k: (get_value(v[0]), get_value(v[1])) for k, v in layer_results.items()}

In [None]:
# Run logit lens on the classic example
layer_results = logit_lens_layers(prompt, model, remote=REMOTE)

print(f"Prompt: '{prompt}'\n")
print("Layer-by-layer predictions:")
print("=" * 60)

for layer_idx, (probs, indices) in sorted(layer_results.items()):
    preds = get_top_predictions(probs, indices, model.tokenizer)
    print(f"\nLayer {layer_idx:2d}:")
    for token, prob in preds[:5]:  # Show top 5
        print(f"  {repr(token):15} {prob:.3f}")

## Track Specific Token Probability

Now let's track a specific target token across layers:

In [None]:
def track_token_probability(prompt, target_token, model, remote=False):
    """
    Track the probability of a specific token across all layers.
    Only saves single probability value per layer (minimal data transfer).
    GPT-2 version (uses transformer.h, transformer.ln_f).
    """
    n_layers = model.config.n_layer
    
    # Get target token ID
    target_ids = model.tokenizer.encode(target_token, add_special_tokens=False)
    if len(target_ids) != 1:
        print(f"Warning: '{target_token}' tokenizes to {len(target_ids)} tokens")
    target_id = target_ids[0]
    
    layer_probs = []
    
    # Single trace call - only save the single probability we need
    with model.trace(prompt, remote=remote):
        for layer_idx in range(n_layers):
            hidden = model.transformer.h[layer_idx].output[0]
            logits = model.lm_head(model.transformer.ln_f(hidden))
            probs = torch.softmax(logits[0, -1], dim=-1)
            # Only save the single probability for target token
            target_prob = probs[target_id].save()
            layer_probs.append(target_prob)
    
    # Extract probabilities after trace completes
    return [get_value(p).item() for p in layer_probs]

# Track " Paris" probability
target = " Paris"
probs = track_token_probability(prompt, target, model, remote=REMOTE)

plt.figure(figsize=(10, 5))
plt.plot(range(len(probs)), probs, 'b-o', markersize=5)
plt.xlabel('Layer')
plt.ylabel(f'P("{target}")')
plt.title(f'Logit Lens: Tracking "{target}" probability\nPrompt: "{prompt}"')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Test with Pun (GPT-2 may not get it)

GPT-2 is smaller so it might not understand puns as well, but we can still test the pattern:

In [None]:
pun_prompt = "Why do electricians make good swimmers? Because they know the"
target = " current"

probs = track_token_probability(pun_prompt, target, model, remote=REMOTE)

plt.figure(figsize=(10, 5))
plt.plot(range(len(probs)), probs, 'b-o', markersize=5)
plt.xlabel('Layer')
plt.ylabel(f'P("{target}")')
plt.title(f'Logit Lens: Tracking "{target}" probability\nPrompt: "{pun_prompt}"')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nFinal layer probability: {probs[-1]:.4f}")
print("(Note: GPT-2 may not understand puns as well as larger models)")

## Logit Lens with Top-K (Optimized Pattern)

This is the pattern we want to use for NDIF - compute top-k on server:

In [None]:
# Test with all layers
results = logit_lens_layers(prompt, model, layers_to_check=list(range(model.config.n_layer)), remote=REMOTE)

print(f"Prompt: '{prompt}'\n")
print("Top-5 predictions per layer:")
print("=" * 50)

for layer_idx, (probs, indices) in sorted(results.items()):
    preds = get_top_predictions(probs, indices, model.tokenizer)
    print(f"\nLayer {layer_idx:2d}:")
    for token, prob in preds[:5]:
        print(f"  {repr(token):15} {prob:.3f}")

## Heatmap for Target Token

In [None]:
def logit_lens_heatmap(prompt, target_token, model, remote=False):
    """
    Create a logit lens heatmap showing the probability of a target token
    at each layer and position.
    Only saves target token probabilities (minimal data transfer).
    GPT-2 version (uses transformer.h, transformer.ln_f).
    """
    n_layers = model.config.n_layer
    
    # Get target token ID
    target_ids = model.tokenizer.encode(target_token, add_special_tokens=False)
    target_id = target_ids[0]
    
    # Get token strings
    tokens = model.tokenizer.encode(prompt)
    token_strs = [model.tokenizer.decode([t]) for t in tokens]
    
    layer_probs = []
    
    # Single trace call - only save target token probability at each position
    with model.trace(prompt, remote=remote):
        for layer_idx in range(n_layers):
            hidden = model.transformer.h[layer_idx].output[0]
            logits = model.lm_head(model.transformer.ln_f(hidden))
            probs = torch.softmax(logits[0], dim=-1)
            # Only save probabilities for target token at all positions
            target_probs = probs[:, target_id].save()
            layer_probs.append(target_probs)
    
    # Stack after trace completes
    all_probs = torch.stack([get_value(p) for p in layer_probs])
    target_probs = all_probs.cpu().numpy()
    
    return target_probs, token_strs

# Visualize
probs, tokens = logit_lens_heatmap(prompt, " Paris", model, remote=REMOTE)

plt.figure(figsize=(12, 6))
plt.imshow(probs, aspect='auto', cmap='Blues', vmin=0)
plt.colorbar(label='P(" Paris")')
plt.xlabel('Token Position')
plt.ylabel('Layer')
plt.title(f'Logit Lens Heatmap\nPrompt: "{prompt}"')
plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
plt.tight_layout()
plt.show()

## Summary

This debug notebook uses the **same nnsight pattern** as the NDIF notebook:

1. **Single trace call**: `with model.trace(prompt, remote=remote):`
2. **Loop over layers inside trace**: `for layer_idx in range(n_layers):`
3. **Access layer output**: `model.transformer.h[layer_idx].output[0]`
4. **Apply logit lens**: `model.lm_head(model.transformer.ln_f(hidden))`
5. **Compute top-k on "server"**: `probs.topk(top_k)` - reduces data transfer
6. **Save results**: `.save()` on tensors you want to keep
7. **Access after trace**: Use `get_value()` helper (handles local vs remote)

### Differences for Llama on NDIF:
- Use `model.model.layers[layer_idx]` instead of `model.transformer.h[layer_idx]`
- Use `model.model.norm` instead of `model.transformer.ln_f`
- Add `remote=True` to `model.trace()`
- Use `model.config.num_hidden_layers` instead of `model.config.n_layer`