# Logit Lens with nnsight and NDIF

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/davidbau/neural-mechanics-web/blob/main/labs/week1/logit_lens.ipynb)

This notebook demonstrates the **logit lens** technique using [nnsight](https://nnsight.net/) and the [NDIF](https://ndif.us/) remote inference API. The logit lens lets us peek inside a transformer to see what the model is "thinking" at each layer.

**Key Idea:** At each layer, we project the hidden states into vocabulary space using the model's unembedding matrix. This reveals how the model's predictions evolve as information flows through the network.

We'll use **Llama 3.1 70B** via NDIF to explore:
1. **Puns** - where the model must hold multiple meanings
2. **Multilingual concepts** - where we can see English emerge as an internal "concept language"

## References
- [nostalgebraist's Logit Lens post](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)
- [Do Llamas Work in English? (Wendler et al., ACL 2024)](https://aclanthology.org/2024.acl-long.820/) - Key paper on multilingual concept representations
- [nnsight documentation](https://nnsight.net/)
- [NDIF - National Deep Inference Fabric](https://ndif.us/)

## Setup

First, install the required packages:

In [None]:
# Install nnsight for model access
!pip install -q nnsight

# Install nnterp for standardized model access (also needed on NDIF server)
!pip install -q nnterp

# Install logitlenskit for visualization
!pip install -q git+https://github.com/davidbau/logitlenskit.git#subdirectory=python

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from nnsight import LanguageModel, CONFIG

# Configure NDIF API key from Colab secrets
try:
    from google.colab import userdata
    CONFIG.set_default_api_key(userdata.get('NDIF_API'))
except:
    pass  # Not in Colab or secret not set

# We use remote=True to run on NDIF's shared GPU resources
# This lets us use Llama 3 70B without needing massive local compute!
REMOTE = True

## Load Llama 3.1 70B

Thanks to NDIF, we can run a 70 billion parameter model from a Colab notebook!

In [None]:
# Load Llama 3.1 70B via NDIF
model = LanguageModel("meta-llama/Llama-3.1-70B", device_map="auto")

print(f"Model: {model.config._name_or_path}")
print(f"Layers: {model.config.num_hidden_layers}")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Vocabulary size: {model.config.vocab_size}")

---

# Part 1: The Quick Way - LogitLensKit

Before diving into the details, let's see logit lens in action with just **two lines of code**!

The `logitlenskit` library provides a high-level API that handles all the complexity:

In [None]:
from nnterp import StandardizedTransformer
from logitlenskit import collect_logit_lens, show_logit_lens

# Use nnterp's StandardizedTransformer for consistent access across model architectures
st_model = StandardizedTransformer("meta-llama/Llama-3.1-70B")

# Two lines to visualize logit lens!
data = collect_logit_lens("The capital of France is", st_model, remote=REMOTE)
show_logit_lens(data, title="Logit Lens: Capital of France")

### Understanding the Widget

The interactive widget shows:
- **Rows**: Input token positions (top to bottom)
- **Columns**: Layers (left to right, from layer 0 to final layer)
- **Cell text**: Top-1 predicted next token at that layer
- **Cell color**: Probability of the top prediction (darker = higher)

**Interactions:**
- **Hover** over cells to see the probability trajectory in the chart below
- **Click** cells to see top-k predictions with probabilities
- **Shift+click** to pin trajectories for comparison
- **Drag** the column borders to resize and see more/fewer layers

---

# Part 2: Understanding the Details - Building Logit Lens by Hand

Now let's understand what's happening under the hood. The logit lens works by:

1. **Intercepting hidden states** at each layer
2. **Applying the final layer normalization** (RMSNorm for Llama)
3. **Projecting to vocabulary space** using the language model head (unembedding matrix)
4. **Converting to probabilities** via softmax

Let's implement this step by step using nnsight:

In [None]:
def get_value(saved):
    """Helper to get value from saved tensor (handles local vs remote)."""
    try:
        return saved.value
    except AttributeError:
        return saved


def logit_lens_manual(prompt, model, layers_to_check=None, remote=True, top_k=10):
    """
    Implement logit lens from scratch using nnsight.
    
    This shows exactly what happens at each step:
    1. Get hidden state from layer output
    2. Apply final layer norm (model.model.norm)
    3. Project to vocabulary (model.lm_head)
    4. Softmax to get probabilities
    """
    n_layers = model.config.num_hidden_layers
    if layers_to_check is None:
        # Sample every 10 layers plus first and last
        layers_to_check = list(range(0, n_layers, 10)) + [n_layers - 1]
        layers_to_check = sorted(set(layers_to_check))
    
    # Use nnsight's trace context to intercept model internals
    saved_logits = None
    with model.trace(prompt, remote=remote):
        logits_list = []
        for layer_idx in layers_to_check:
            # Step 1: Get hidden state from this layer's output
            # model.model.layers[i].output is a tuple; [0] is the hidden state
            hidden = model.model.layers[layer_idx].output[0]
            
            # Step 2: Apply final layer normalization
            # For Llama, this is RMSNorm stored at model.model.norm
            normed = model.model.norm(hidden)
            
            # Step 3: Project to vocabulary space
            # The lm_head maps hidden_size -> vocab_size
            logits = model.lm_head(normed)
            
            # Get last position only (the "next token" prediction)
            last_logits = logits[0, -1] if len(logits.shape) == 3 else logits[-1]
            logits_list.append(last_logits)
        
        # Save all logits to retrieve after trace
        saved_logits = logits_list.save()
    
    # Process results after trace completes
    layer_results = {}
    for i, layer_idx in enumerate(layers_to_check):
        logits = get_value(saved_logits[i])
        # Step 4: Convert to probabilities
        probs = torch.softmax(logits.float(), dim=-1)
        top_probs, top_indices = probs.topk(top_k)
        layer_results[layer_idx] = (top_probs, top_indices)
    
    return layer_results

In [None]:
# Test our manual implementation
prompt = "The capital of France is"
layer_results = logit_lens_manual(prompt, model, remote=REMOTE)

print(f"Prompt: '{prompt}'\n")
print("Layer-by-layer predictions:")
print("=" * 60)

for layer_idx, (probs, indices) in sorted(layer_results.items()):
    top_tokens = [(model.tokenizer.decode([idx]), prob.item()) 
                  for idx, prob in zip(indices, probs)]
    print(f"\nLayer {layer_idx:2d}:")
    for token, prob in top_tokens[:5]:  # Show top 5
        print(f"  {repr(token):15} {prob:.3f}")

Notice how " Paris" emerges as the top prediction around the middle layers and becomes increasingly confident toward the final layer!

---

# Part 3: Multilingual Concepts - "Espanol: amor, Francais: amour"

One of the most fascinating findings about multilingual LLMs comes from [Wendler et al. (2024)](https://aclanthology.org/2024.acl-long.820/): **"Do Llamas Work in English?"**

Their key insight: When processing non-English text, the model's internal representations pass through three phases:
1. **Input space**: Early layers encode the input language
2. **Concept space**: Middle layers represent meaning in a language-neutral (but English-biased) space
3. **Output space**: Final layers translate back to the target language

Let's test this! We'll prompt the model with a French pattern and see if English "love" appears in the middle layers:

In [None]:
# Multilingual concept prompt
multilingual_prompt = "Espanol: amor, Francais:"

# Track both English "love" and French "amour" across layers
data = collect_logit_lens(multilingual_prompt, st_model, remote=REMOTE)
show_logit_lens(data, title='Multilingual Concepts: "amor" → "amour"')

### Tracking Specific Tokens

Let's explicitly track how the probabilities of " love" (English), " amour" (French), and " amor" (Spanish) evolve:

In [None]:
def track_token_probability(prompt, target_token, model, remote=True):
    """
    Track the probability of a specific token across all layers.
    """
    n_layers = model.config.num_hidden_layers
    
    # Get target token ID
    target_ids = model.tokenizer.encode(target_token, add_special_tokens=False)
    if len(target_ids) != 1:
        print(f"Warning: '{target_token}' tokenizes to {len(target_ids)} tokens: {target_ids}")
    target_id = target_ids[0]
    
    # Collect logits at all layers
    saved_logits = None
    with model.trace(prompt, remote=remote):
        logits_list = []
        for layer_idx in range(n_layers):
            hidden = model.model.layers[layer_idx].output[0]
            logits = model.lm_head(model.model.norm(hidden))
            last_logits = logits[0, -1] if len(logits.shape) == 3 else logits[-1]
            logits_list.append(last_logits)
        saved_logits = logits_list.save()
    
    # Extract probabilities
    layer_probs = []
    for i in range(n_layers):
        logits = get_value(saved_logits[i])
        probs = torch.softmax(logits.float(), dim=-1)
        layer_probs.append(probs[target_id].item())
    
    return layer_probs

In [None]:
# Track English, French, and Spanish words for "love"
multilingual_prompt = "Espanol: amor, Francais:"

love_en = track_token_probability(multilingual_prompt, " love", model, remote=REMOTE)
love_fr = track_token_probability(multilingual_prompt, " amour", model, remote=REMOTE)  
love_es = track_token_probability(multilingual_prompt, " amor", model, remote=REMOTE)

# Plot the trajectories
plt.figure(figsize=(14, 6))

layers = range(len(love_en))
plt.plot(layers, love_en, 'b-o', markersize=3, label='" love" (English)', linewidth=2)
plt.plot(layers, love_fr, 'r-o', markersize=3, label='" amour" (French)', linewidth=2)
plt.plot(layers, love_es, 'g-o', markersize=3, label='" amor" (Spanish)', linewidth=2)

plt.xlabel('Layer', fontsize=12)
plt.ylabel('Probability', fontsize=12)
plt.title(f'Multilingual Concept Representation\nPrompt: "{multilingual_prompt}"', fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)

# Add annotations for the three phases
plt.axvspan(0, 20, alpha=0.1, color='gray', label='Input Space')
plt.axvspan(20, 60, alpha=0.1, color='blue', label='Concept Space')
plt.axvspan(60, 80, alpha=0.1, color='green', label='Output Space')

plt.tight_layout()
plt.show()

# Find peak probabilities
print(f"\nPeak probability layers:")
print(f"  'love' (English): layer {np.argmax(love_en)} ({max(love_en):.3f})")
print(f"  'amour' (French): layer {np.argmax(love_fr)} ({max(love_fr):.3f})")
print(f"  'amor' (Spanish): layer {np.argmax(love_es)} ({max(love_es):.3f})")

### Interpretation

If the Wendler et al. hypothesis is correct, you should see:
- **Early layers**: Low probability for all translations
- **Middle layers**: English "love" peaks higher than French "amour" (English as concept space)
- **Final layers**: French "amour" overtakes as the model prepares the output

This suggests the model internally "thinks" in English before translating to the output language!

---

# Part 4: Puns - A Window into Dual Meanings

Puns are interesting for interpretability because they require the model to process words with multiple meanings. When does the model "get" the joke? At which layer does the pun's alternative meaning emerge?

In [None]:
# A pun that plays on "current" (electrical vs water)
pun_prompt = "Why do electricians make good swimmers? Because they know the"

data = collect_logit_lens(pun_prompt, st_model, remote=REMOTE)
show_logit_lens(data, title="Pun: Electricians & Swimmers")

In [None]:
# Track "current" probability across layers
current_probs = track_token_probability(pun_prompt, " current", model, remote=REMOTE)

plt.figure(figsize=(12, 5))
plt.plot(range(len(current_probs)), current_probs, 'b-o', markersize=3)
plt.xlabel('Layer')
plt.ylabel('P(" current")')
plt.title(f'When does the model "get" the pun?\n"{pun_prompt}"')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Find when probability first exceeds 0.1
threshold = 0.1
for i, p in enumerate(current_probs):
    if p > threshold:
        print(f"'current' first exceeds {threshold} probability at layer {i}")
        break

## Exercise: Compare Multiple Puns

Do different types of puns show similar patterns? Let's compare!

In [None]:
puns = [
    ("Why do electricians make good swimmers? Because they know the", " current"),
    ("Why did the banker break up with his girlfriend? He lost", " interest"),
    ("Why do cows wear bells? Because their horns don't", " work"),
]

plt.figure(figsize=(14, 6))

for prompt, target in puns:
    probs = track_token_probability(prompt, target, model, remote=REMOTE)
    label = f'"{target.strip()}" ({prompt[:25]}...)'
    plt.plot(range(len(probs)), probs, '-o', markersize=2, label=label)

plt.xlabel('Layer')
plt.ylabel('Probability')
plt.title('Logit Lens: Comparing Pun Recognition Across Layers')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

---

# Part 5: Context Changes Interpretation

The same sentence can be interpreted literally or as a pun depending on context. Does preceding context prime the model toward pun interpretations?

In [None]:
# The same sentence in different contexts
neutral = "I used to be a banker, but I lost my"
after_pun = "I used to be a tailor, but the job didn't suit me. I used to be a banker, but I lost my"

# Track both "job" (literal) and "interest" (pun)
targets = [" job", " interest"]

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for ax, context_name, prompt in zip(axes, ["Neutral context", "After another pun"], [neutral, after_pun]):
    for target in targets:
        probs = track_token_probability(prompt, target, model, remote=REMOTE)
        ax.plot(range(len(probs)), probs, '-o', markersize=2, label=f'P("{target.strip()}")')
    
    ax.set_xlabel('Layer')
    ax.set_ylabel('Probability')
    ax.set_title(f'{context_name}\n"...banker, but I lost my ___"')
    ax.legend()
    ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---

# Part 6: Advanced Visualization with LogitLensKit

The LogitLensKit widget provides rich interactivity for exploring logit lens data. Here are some advanced features:

In [None]:
# Try the multilingual example with the full widget
data = collect_logit_lens(
    "German: Liebe, Italian: amore, English:",
    st_model,
    remote=REMOTE
)
show_logit_lens(data, title="Multilingual Love Across Languages")

In [None]:
# More examples to explore
examples = [
    "The Eiffel Tower is located in",
    "To be or not to be, that is the",
    "In 1969, Neil Armstrong became the first person to walk on the",
    "The quick brown fox jumps over the lazy",
]

for ex in examples:
    data = collect_logit_lens(ex, st_model, remote=REMOTE)
    show_logit_lens(data, title=f'"{ex[:40]}..."')

---

## Summary

In this notebook, we learned:

1. **The Logit Lens** projects intermediate hidden states to vocabulary space to see what the model "thinks" at each layer

2. **Two ways to use it:**
   - **Quick way**: `collect_logit_lens()` + `show_logit_lens()` from logitlenskit
   - **Manual way**: Intercept with nnsight, apply norm and lm_head, softmax

3. **Multilingual concepts**: Models may use English as an internal "concept language" (Wendler et al., 2024)

4. **Puns** are interesting because they require dual meanings—we can watch when the pun "clicks"

5. **nnsight + NDIF** lets us run Llama 3 70B from a notebook without local GPU resources

### Questions to Consider

- At which layer does the correct answer first become the top prediction?
- Do factual vs. creative completions show different layer patterns?
- How does the pattern change for different languages?
- Can you find examples where the middle-layer prediction is "more correct" than the final prediction?

### Further Reading

- [Wendler et al. (2024): Do Llamas Work in English?](https://aclanthology.org/2024.acl-long.820/)
- [Tuned Lens (Belrose et al., 2023)](https://arxiv.org/abs/2303.08112)
- [LogitLensKit Documentation](https://davidbau.github.io/logitlenskit/)