# Logit Lens Debug Version (Local GPT-2)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week1/logit_lens_debug.ipynb)

**This is a debug version using GPT-2 locally.** Use this to test nnsight patterns before running on large models via NDIF.

**Colab Setup:** Go to Runtime > Change runtime type > Select **T4 GPU**

This notebook demonstrates the **logit lens** technique using [nnsight](https://nnsight.net/). The logit lens lets us peek inside a transformer to see what the model is "thinking" at each layer.

**Key Idea:** At each layer, we project the hidden states into vocabulary space using the model's unembedding matrix. This reveals how the model's predictions evolve as information flows through the network.

## References
- [nostalgebraist's Logit Lens post](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)
- [nnsight Logit Lens tutorial](https://nnsight.net/notebooks/tutorials/probing/logit_lens/)
- [nnsight documentation](https://nnsight.net/)

## Setup

Install nnsight and check GPU availability:

In [None]:
!pip install -q nnsight

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from nnsight import LanguageModel

# Check GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Local execution - no NDIF needed
REMOTE = False

## Load GPT-2

GPT-2 is small enough to run locally on Colab. This lets us debug nnsight patterns.

In [None]:
# Load GPT-2 locally
model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=True)

print(f"Model: {model.config._name_or_path}")
print(f"Layers: {model.config.n_layer}")
print(f"Hidden size: {model.config.n_embd}")
print(f"Vocabulary size: {model.config.vocab_size}")

## Test Prompt

Let's start with a simple prompt to test the logit lens:

In [None]:
# Classic logit lens example
prompt = "The Eiffel Tower is in the city of"
print(f"Prompt: {prompt}")

## Basic Logit Lens (Following nnsight Tutorial)

This follows the exact pattern from the [nnsight logit lens tutorial](https://nnsight.net/notebooks/tutorials/probing/logit_lens/):

In [None]:
# Get all layers (GPT-2 uses transformer.h)
layers = model.transformer.h
probs_layers = []

with model.trace() as tracer:
    with tracer.invoke(prompt) as invoker:
        input_tokens = invoker.inputs.save()
        for layer_idx, layer in enumerate(layers):
            # GPT-2: apply ln_f (final layer norm) then lm_head
            layer_output = model.lm_head(model.transformer.ln_f(layer.output[0]))
            probs = torch.nn.functional.softmax(layer_output, dim=-1).save()
            probs_layers.append(probs)

print(f"Collected probabilities from {len(probs_layers)} layers")
print(f"Input tokens shape: {input_tokens[1]['input_ids'].shape}")

In [None]:
# Stack probabilities and get top predictions
probs = torch.cat([p.value for p in probs_layers])
max_probs, tokens = probs.max(dim=-1)

# Decode tokens
words = [[model.tokenizer.decode(t.cpu()) for t in layer_tokens] for layer_tokens in tokens]
input_words = [model.tokenizer.decode(t) for t in input_tokens[1]['input_ids'][0]]

print(f"Input words: {input_words}")
print(f"\nTop prediction at each layer for last token:")
for i, w in enumerate(words):
    print(f"  Layer {i:2d}: {repr(w[-1]):15} (prob: {max_probs[i, -1]:.3f})")

## Visualize with Plotly (Like nnsight Tutorial)

In [None]:
!pip install -q plotly

In [None]:
import plotly.express as px
import plotly.io as pio

# Set renderer for Colab
try:
    import google.colab
    pio.renderers.default = "colab"
except ImportError:
    pio.renderers.default = "notebook_connected"

fig = px.imshow(
    max_probs.detach().cpu().numpy(),
    x=input_words,
    y=list(range(len(words))),
    color_continuous_scale=px.colors.diverging.RdYlBu_r,
    color_continuous_midpoint=0.50,
    text_auto=True,
    labels=dict(x="Input Tokens", y="Layers", color="Probability")
)

fig.update_layout(title='Logit Lens: Top Token Probability by Layer', xaxis_tickangle=0)
fig.update_traces(text=words, texttemplate="%{text}")
fig.show()

## Track Specific Token Probability

Now let's track a specific target token across layers:

In [None]:
def track_token_probability_gpt2(prompt, target_token, model):
    """
    Track the probability of a specific token across all layers.
    GPT-2 version (uses transformer.h, transformer.ln_f).
    """
    n_layers = model.config.n_layer
    
    # Get target token ID
    target_ids = model.tokenizer.encode(target_token, add_special_tokens=False)
    if len(target_ids) != 1:
        print(f"Warning: '{target_token}' tokenizes to {len(target_ids)} tokens")
    target_id = target_ids[0]
    
    layers = model.transformer.h
    layer_probs = []
    
    with model.trace() as tracer:
        with tracer.invoke(prompt):
            for layer in layers:
                logits = model.lm_head(model.transformer.ln_f(layer.output[0]))
                probs = torch.softmax(logits[0, -1], dim=-1)
                target_prob = probs[target_id].save()
                layer_probs.append(target_prob)
    
    return [p.value.item() for p in layer_probs]

# Test with "Paris"
target = " Paris"
probs = track_token_probability_gpt2(prompt, target, model)

plt.figure(figsize=(10, 5))
plt.plot(range(len(probs)), probs, 'b-o', markersize=5)
plt.xlabel('Layer')
plt.ylabel(f'P("{target}")')
plt.title(f'Logit Lens: Tracking "{target}" probability\nPrompt: "{prompt}"')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Test with Pun (GPT-2 may not get it)

GPT-2 is smaller so it might not understand puns as well, but we can still test the pattern:

In [None]:
pun_prompt = "Why do electricians make good swimmers? Because they know the"
target = " current"

probs = track_token_probability_gpt2(pun_prompt, target, model)

plt.figure(figsize=(10, 5))
plt.plot(range(len(probs)), probs, 'b-o', markersize=5)
plt.xlabel('Layer')
plt.ylabel(f'P("{target}")')
plt.title(f'Logit Lens: Tracking "{target}" probability\nPrompt: "{pun_prompt}"')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nFinal layer probability: {probs[-1]:.4f}")

## Logit Lens with Top-K (Optimized Pattern)

This is the pattern we want to use for NDIF - compute top-k on server:

In [None]:
def logit_lens_topk_gpt2(prompt, model, layers_to_check=None, top_k=10):
    """
    Run logit lens returning top-k predictions per layer.
    GPT-2 version.
    """
    n_layers = model.config.n_layer
    if layers_to_check is None:
        layers_to_check = list(range(n_layers))
    
    all_layers = model.transformer.h
    layer_results = {}
    
    with model.trace() as tracer:
        with tracer.invoke(prompt):
            for layer_idx in layers_to_check:
                layer = all_layers[layer_idx]
                logits = model.lm_head(model.transformer.ln_f(layer.output[0]))
                probs = torch.softmax(logits[0, -1], dim=-1)
                top_probs, top_indices = probs.topk(top_k)
                layer_results[layer_idx] = (top_probs.save(), top_indices.save())
    
    return {k: (v[0].value, v[1].value) for k, v in layer_results.items()}

# Test
results = logit_lens_topk_gpt2(prompt, model)

print(f"Prompt: '{prompt}'\n")
print("Top-5 predictions per layer:")
print("=" * 50)

for layer_idx, (probs, indices) in sorted(results.items()):
    print(f"\nLayer {layer_idx:2d}:")
    for p, idx in zip(probs[:5], indices[:5]):
        token = model.tokenizer.decode([idx])
        print(f"  {repr(token):15} {p.item():.3f}")

## Heatmap for Target Token

In [None]:
def logit_lens_heatmap_gpt2(prompt, target_token, model):
    """
    Create heatmap of target token probability at each layer and position.
    GPT-2 version.
    """
    target_ids = model.tokenizer.encode(target_token, add_special_tokens=False)
    target_id = target_ids[0]
    
    tokens = model.tokenizer.encode(prompt)
    token_strs = [model.tokenizer.decode([t]) for t in tokens]
    
    layers = model.transformer.h
    layer_probs = []
    
    with model.trace() as tracer:
        with tracer.invoke(prompt):
            for layer in layers:
                logits = model.lm_head(model.transformer.ln_f(layer.output[0]))
                probs = torch.softmax(logits[0], dim=-1)
                target_probs = probs[:, target_id].save()
                layer_probs.append(target_probs)
    
    all_probs = torch.stack([p.value for p in layer_probs])
    return all_probs.cpu().numpy(), token_strs

# Visualize
probs, tokens = logit_lens_heatmap_gpt2(prompt, " Paris", model)

plt.figure(figsize=(12, 6))
plt.imshow(probs, aspect='auto', cmap='Blues', vmin=0)
plt.colorbar(label='P(" Paris")')
plt.xlabel('Token Position')
plt.ylabel('Layer')
plt.title(f'Logit Lens Heatmap\nPrompt: "{prompt}"')
plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right')
plt.tight_layout()
plt.show()

## Summary

This debug notebook demonstrates the nnsight patterns that work locally:

1. **Basic pattern**: `model.trace()` + `tracer.invoke(prompt)`
2. **Iterate over layers**: `for layer in model.transformer.h`
3. **GPT-2 logit lens**: `model.lm_head(model.transformer.ln_f(layer.output[0]))`
4. **Save results**: `.save()` on tensors you want to keep
5. **Access after trace**: `saved_tensor.value`

### Differences for Llama on NDIF:
- Use `model.model.layers` instead of `model.transformer.h`
- Use `model.model.norm` instead of `model.transformer.ln_f`
- Add `remote=True` to `model.trace()`