# Logit Lens with nnsight and NDIF

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week1/logit_lens.ipynb)

This notebook demonstrates the **logit lens** technique using [nnsight](https://nnsight.net/) and the [NDIF](https://ndif.us/) remote inference API. The logit lens lets us peek inside a transformer to see what the model is "thinking" at each layer.

**Key Idea:** At each layer, we project the hidden states into vocabulary space using the model's unembedding matrix. This reveals how the model's predictions evolve as information flows through the network.

We'll use **Llama 3.1 70B** via NDIF to explore how large language models process **puns**—a fascinating case where the model must hold multiple meanings in mind simultaneously.

## References
- [nostalgebraist's Logit Lens post](https://www.lesswrong.com/posts/AcKRB8wDpdaN6v6ru/interpreting-gpt-the-logit-lens)
- [nnsight documentation](https://nnsight.net/)
- [NDIF - National Deep Inference Fabric](https://ndif.us/)

## Setup

First, install nnsight if needed:

In [None]:
!pip install -q nnsight

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from nnsight import LanguageModel, CONFIG

# Configure NDIF API key from Colab secrets
try:
    from google.colab import userdata
    CONFIG.set_default_api_key(userdata.get('NDIF_API'))
except:
    pass  # Not in Colab or secret not set

# We use remote=True to run on NDIF's shared GPU resources
# This lets us use Llama 3 70B without needing massive local compute!
REMOTE = True

## Load Llama 3.1 70B

Thanks to NDIF, we can run a 70 billion parameter model from a Colab notebook!

In [None]:
# Load Llama 3.1 70B via NDIF
model = LanguageModel("meta-llama/Llama-3.1-70B", device_map="auto")

print(f"Model: {model.config._name_or_path}")
print(f"Layers: {model.config.num_hidden_layers}")
print(f"Hidden size: {model.config.hidden_size}")
print(f"Vocabulary size: {model.config.vocab_size}")

## Puns: A Window into Dual Meanings

Puns are interesting for interpretability because they require the model to process words with multiple meanings. When does the model "get" the joke? At which layer does the pun's alternative meaning emerge?

Let's start with a classic pun setup:

In [None]:
# A pun that plays on "current" (electrical vs water)
pun_prompt = "Why do electricians make good swimmers? Because they know the"
# Expected punchline involves "current"

print(f"Prompt: {pun_prompt}")

## Basic Logit Lens: Watching the Pun Develop

Let's see how the model's predictions evolve across layers.

In [None]:
def get_value(saved):
    """Helper to get value from saved tensor (handles local vs remote)."""
    try:
        return saved.value
    except AttributeError:
        return saved

def get_top_predictions(probs, indices, tokenizer):
    """Convert saved top-k probs and indices to token strings."""
    return [(tokenizer.decode([idx]), prob.item()) for idx, prob in zip(indices, probs)]

def logit_lens_layers(prompt, model, layers_to_check=None, remote=True, top_k=10):
    """
    Run logit lens on specified layers.
    Returns top-k predictions per layer (much less data than full vocab).
    """
    n_layers = model.config.num_hidden_layers
    if layers_to_check is None:
        # Check every 10th layer plus first and last
        layers_to_check = [0, 10, 20, 30, 40, 50, 60, 70, n_layers-1]
        layers_to_check = [l for l in layers_to_check if l < n_layers]
    
    layer_results = {}
    
    # Single trace call - compute top-k on server to reduce data transfer
    with model.trace(prompt, remote=remote):
        for layer_idx in layers_to_check:
            hidden = model.model.layers[layer_idx].output[0]
            logits = model.lm_head(model.model.norm(hidden))
            # Compute softmax and top-k on server, only save small tensors
            probs = torch.softmax(logits[0, -1], dim=-1)
            top_probs, top_indices = probs.topk(top_k)
            layer_results[layer_idx] = (top_probs.save(), top_indices.save())
    
    # Extract values after trace completes
    return {k: (get_value(v[0]), get_value(v[1])) for k, v in layer_results.items()}

In [None]:
# Run logit lens on the pun
layer_results = logit_lens_layers(pun_prompt, model, remote=REMOTE)

print(f"Prompt: '{pun_prompt}'\n")
print("Layer-by-layer predictions:")
print("=" * 60)

for layer_idx, (probs, indices) in sorted(layer_results.items()):
    preds = get_top_predictions(probs, indices, model.tokenizer)
    print(f"\nLayer {layer_idx:2d}:")
    for token, prob in preds[:5]:  # Show top 5
        print(f"  {repr(token):15} {prob:.3f}")

## Tracking the Pun Word

Let's specifically track how the probability of "current" (the pun word) develops across layers.

In [None]:
def track_token_probability(prompt, target_token, model, remote=True):
    """
    Track the probability of a specific token across all layers.
    Only saves single probability value per layer (minimal data transfer).
    """
    n_layers = model.config.num_hidden_layers
    
    # Get target token ID
    target_ids = model.tokenizer.encode(target_token, add_special_tokens=False)
    if len(target_ids) != 1:
        print(f"Warning: '{target_token}' tokenizes to {len(target_ids)} tokens")
    target_id = target_ids[0]
    
    layer_probs = []
    
    # Single trace call - only save the single probability we need
    with model.trace(prompt, remote=remote):
        for layer_idx in range(n_layers):
            hidden = model.model.layers[layer_idx].output[0]
            logits = model.lm_head(model.model.norm(hidden))
            probs = torch.softmax(logits[0, -1], dim=-1)
            # Only save the single probability for target token
            target_prob = probs[target_id].save()
            layer_probs.append(target_prob)
    
    # Extract probabilities after trace completes
    return [get_value(p).item() for p in layer_probs]

# Track "current" probability
current_probs = track_token_probability(pun_prompt, " current", model, remote=REMOTE)

In [None]:
plt.figure(figsize=(12, 5))
plt.plot(range(len(current_probs)), current_probs, 'b-o', markersize=3)
plt.xlabel('Layer')
plt.ylabel('P(" current")')
plt.title(f'Logit Lens: When does the model "get" the pun?\n"{pun_prompt}"')
plt.grid(True, alpha=0.3)
plt.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

# Find the layer where probability first exceeds 0.1
threshold = 0.1
for i, p in enumerate(current_probs):
    if p > threshold:
        print(f"'current' first exceeds {threshold} probability at layer {i}")
        break

## Exercise 1: Compare Multiple Puns

Do different types of puns show similar patterns? Let's compare!

In [None]:
puns = [
    ("Why do electricians make good swimmers? Because they know the", " current"),
    ("Why did the banker break up with his girlfriend? He lost", " interest"),
    ("Why do cows wear bells? Because their horns don't", " work"),
    ("What do you call a fish without eyes? A", " f"),  # "fsh" pun
]

plt.figure(figsize=(14, 6))

for prompt, target in puns:
    probs = track_token_probability(prompt, target, model, remote=REMOTE)
    label = f'"{target.strip()}" ({prompt[:30]}...)'
    plt.plot(range(len(probs)), probs, '-o', markersize=2, label=label)

plt.xlabel('Layer')
plt.ylabel('Probability')
plt.title('Logit Lens: Comparing Pun Recognition Across Layers')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Exercise 2: Pun vs Literal Completion

How does the model decide between a pun completion and a literal one? Let's compare prompts that could go either way.

In [None]:
# Compare pun setup vs literal setup
pun_setup = "Why do electricians make good swimmers? Because they know the"
literal_setup = "Electricians work with wires that carry electrical"

target = " current"

pun_probs = track_token_probability(pun_setup, target, model, remote=REMOTE)
literal_probs = track_token_probability(literal_setup, target, model, remote=REMOTE)

plt.figure(figsize=(12, 5))
plt.plot(range(len(pun_probs)), pun_probs, 'b-o', markersize=3, label='Pun context')
plt.plot(range(len(literal_probs)), literal_probs, 'r-o', markersize=3, label='Literal context')
plt.xlabel('Layer')
plt.ylabel(f'P("{target}")')
plt.title('Logit Lens: Same Word, Different Contexts')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Exercise 3: Visualize Full Token Heatmap

Create a heatmap showing how the pun word's probability develops at each token position.

In [None]:
def logit_lens_heatmap(prompt, target_token, model, remote=True):
    """
    Create a logit lens heatmap showing the probability of a target token
    at each layer and position.
    Only saves target token probabilities (minimal data transfer).
    """
    n_layers = model.config.num_hidden_layers
    
    # Get target token ID
    target_ids = model.tokenizer.encode(target_token, add_special_tokens=False)
    target_id = target_ids[0]
    
    # Get token strings
    tokens = model.tokenizer.encode(prompt)
    token_strs = [model.tokenizer.decode([t]) for t in tokens]
    
    layer_probs = []
    
    # Single trace call - only save target token probability at each position
    with model.trace(prompt, remote=remote):
        for layer_idx in range(n_layers):
            hidden = model.model.layers[layer_idx].output[0]
            logits = model.lm_head(model.model.norm(hidden))
            probs = torch.softmax(logits[0], dim=-1)
            # Only save probabilities for target token at all positions
            target_probs = probs[:, target_id].save()
            layer_probs.append(target_probs)
    
    # Stack after trace completes
    all_probs = torch.stack([get_value(p) for p in layer_probs])
    target_probs = all_probs.cpu().numpy()
    
    return target_probs, token_strs

# Visualize
prompt = "Why do electricians make good swimmers? Because they know the"
target = " current"

probs, tokens = logit_lens_heatmap(prompt, target, model, remote=REMOTE)

plt.figure(figsize=(16, 10))
plt.imshow(probs, aspect='auto', cmap='Blues', vmin=0)
plt.colorbar(label=f'P("{target}")')
plt.xlabel('Token Position')
plt.ylabel('Layer')
plt.title(f'Logit Lens Heatmap: "{target}"\nPrompt: "{prompt}"')
plt.xticks(range(len(tokens)), tokens, rotation=45, ha='right', fontsize=8)
plt.tight_layout()
plt.show()

## Exercise 4: Alternative Scaling Methods

The standard logit lens uses the final RMSNorm. Try different scaling approaches.

In [None]:
# TODO: Implement and compare:
# 1. Standard logit lens (with RMSNorm)
# 2. Raw projection (no normalization)
# 3. Tuned lens (learned affine transform per layer)
#
# See: https://arxiv.org/abs/2303.08112 (Tuned Lens paper)

## Summary

In this notebook, we learned:

1. **The Logit Lens** projects intermediate hidden states to vocabulary space to see what the model "thinks" at each layer

2. **nnsight + NDIF** lets us run Llama 3 70B from a notebook without local GPU resources

3. **Puns are interesting** because they require dual meanings—we can watch when the pun "clicks" in the model

4. **Layer patterns vary** by context, suggesting different processing for literal vs humorous completions

### Questions to Consider

- At which layer does the pun word first become the top prediction?
- Do puns "develop" differently than literal factual knowledge?
- How does the pattern change for puns that require more cultural knowledge?