# The Geometry of Puns: Visualizing Humor in Representation Space

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Nix07/neural-mechanics-web/blob/main/labs/week4/pun_geometry.ipynb)

This notebook explores the **geometric structure of pun representations** in large language models, inspired by Marks & Tegmark's "Geometry of Truth" methodology.

**Key Questions:**
- Do puns and non-puns separate in activation space?
- Is there a linear "pun direction" analogous to the "truth direction"?
- At which layer does pun understanding emerge?

We'll use **Llama 3 70B** via NDIF to visualize high-dimensional representations using PCA and find concept directions.

## References
- [The Geometry of Truth](https://arxiv.org/abs/2310.06824) - Marks & Tegmark
- [Linear Representations of Sentiment](https://arxiv.org/abs/2310.15154) - Tigges et al.
- [nnsight documentation](https://nnsight.net/)
- [NDIF - National Deep Inference Fabric](https://ndif.us/)

## Setup

In [None]:
!pip install -q nnsight scikit-learn

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from nnsight import LanguageModel

# Use remote=True to run on NDIF's shared GPU resources
REMOTE = True

# For reproducibility
np.random.seed(42)
torch.manual_seed(42)

## Load Llama 3 70B

In [None]:
model = LanguageModel("meta-llama/Meta-Llama-3-70B", device_map="auto")

print(f"Model: {model.config._name_or_path}")
print(f"Layers: {model.config.num_hidden_layers}")
print(f"Hidden size: {model.config.hidden_size}")

## Part 1: Prepare Pun and Non-Pun Datasets

Following the Geometry of Truth methodology, we need matched pairs of pun and non-pun sentences. The key insight is that comparing similar sentences that differ only in "pun-ness" helps isolate the pun representation.

In [None]:
# Pun examples with their punchlines
puns = [
    "Why do electricians make good swimmers? Because they know the current.",
    "Why did the banker break up with his girlfriend? He lost interest.",
    "What do you call a fish without eyes? A fsh.",
    "Why don't scientists trust atoms? Because they make up everything.",
    "What did the ocean say to the beach? Nothing, it just waved.",
    "Why do cows wear bells? Because their horns don't work.",
    "I used to hate facial hair, but then it grew on me.",
    "Why did the scarecrow win an award? He was outstanding in his field.",
    "What do you call a bear with no teeth? A gummy bear.",
    "Why can't a bicycle stand on its own? It's two tired.",
    "What do you call a fake noodle? An impasta.",
    "Why did the math book look so sad? It had too many problems.",
    "What do you call a sleeping dinosaur? A dino-snore.",
    "Why did the golfer bring two pairs of pants? In case he got a hole in one.",
    "What did the grape say when it got stepped on? Nothing, it just let out a little wine.",
    "Why do seagulls fly over the sea? Because if they flew over the bay, they'd be bagels.",
    "What do you call a can opener that doesn't work? A can't opener.",
    "Why did the coffee file a police report? It got mugged.",
    "What do you call a pig that does karate? A pork chop.",
    "Why don't eggs tell jokes? They'd crack each other up.",
]

# Non-pun sentences (literal, similar structure but no wordplay)
non_puns = [
    "Why do electricians wear rubber gloves? To protect themselves from shocks.",
    "Why did the banker open a new account? To manage his investments.",
    "What do you call a fish that lives in freshwater? A trout.",
    "Why don't scientists make assumptions? Because they need evidence.",
    "What did the ocean look like at sunset? Calm and peaceful.",
    "Why do cows produce milk? To feed their calves.",
    "I used to avoid exercise, but then I started running.",
    "Why did the scarecrow need repairs? The birds had damaged it.",
    "What do you call a bear in winter? A hibernating animal.",
    "Why can't a bicycle go uphill easily? The gradient is steep.",
    "What do you call a fresh noodle? An al dente pasta.",
    "Why did the math book look so old? It had been used for years.",
    "What do you call a prehistoric reptile? A dinosaur.",
    "Why did the golfer check the weather? To plan his game.",
    "What did the grape taste like? Sweet and juicy.",
    "Why do seagulls live by the coast? Because that's their habitat.",
    "What do you call a manual can opener? A hand-operated tool.",
    "Why did the coffee taste bitter? It was over-extracted.",
    "What do you call a pig on a farm? Livestock.",
    "Why don't eggs last forever? They spoil over time.",
]

print(f"Puns: {len(puns)}")
print(f"Non-puns: {len(non_puns)}")

## Part 2: Extract Activations

We'll extract hidden states at multiple layers to see where the pun/non-pun distinction emerges.

In [None]:
def get_activation(prompt, model, layer_idx, position=-1, remote=True):
    """
    Get the hidden state activation at a specific layer and position.
    
    Args:
        prompt: Input text
        model: nnsight LanguageModel
        layer_idx: Which layer to extract from
        position: Token position (-1 = last token)
        remote: Use NDIF remote inference
    
    Returns:
        numpy array of shape (hidden_size,)
    """
    with model.trace(prompt, remote=remote) as tracer:
        hidden = model.model.layers[layer_idx].output[0].save()
    
    # Get activation at specified position
    activation = hidden.value[0, position, :].cpu().numpy()
    return activation

def collect_activations(sentences, model, layer_idx, position=-1, remote=True):
    """
    Collect activations for multiple sentences.
    
    Returns:
        numpy array of shape (n_sentences, hidden_size)
    """
    activations = []
    for i, sentence in enumerate(sentences):
        act = get_activation(sentence, model, layer_idx, position, remote)
        activations.append(act)
        if (i + 1) % 5 == 0:
            print(f"  Processed {i + 1}/{len(sentences)}")
    
    return np.array(activations)

In [None]:
# Collect activations at multiple layers
# Following Geometry of Truth, we examine early, middle, and late layers
n_layers = model.config.num_hidden_layers
layers_to_examine = {
    'early': n_layers // 4,      # ~Layer 20 for 80-layer model
    'middle': n_layers // 2,      # ~Layer 40
    'late': 3 * n_layers // 4,    # ~Layer 60
    'final': n_layers - 1         # Last layer
}

print(f"Examining layers: {layers_to_examine}")

# Dictionary to store activations
pun_activations = {}
nonpun_activations = {}

for name, layer_idx in layers_to_examine.items():
    print(f"\nCollecting activations for layer {layer_idx} ({name})...")
    print("  Puns:")
    pun_activations[name] = collect_activations(puns, model, layer_idx, remote=REMOTE)
    print("  Non-puns:")
    nonpun_activations[name] = collect_activations(non_puns, model, layer_idx, remote=REMOTE)

print("\nDone collecting activations!")
print(f"Activation shape: {pun_activations['middle'].shape}")

## Part 3: PCA Visualization

Following the Geometry of Truth paper, we use PCA to visualize the high-dimensional activations in 2D. The key question: **do puns and non-puns cluster separately?**

In [None]:
def plot_pca_scatter(pun_acts, nonpun_acts, layer_name, ax=None):
    """
    Create a PCA scatter plot of pun vs non-pun activations.
    """
    # Combine for joint PCA
    all_acts = np.vstack([pun_acts, nonpun_acts])
    labels = ['pun'] * len(pun_acts) + ['non-pun'] * len(nonpun_acts)
    
    # Fit PCA
    pca = PCA(n_components=2)
    projected = pca.fit_transform(all_acts)
    
    # Split back
    pun_proj = projected[:len(pun_acts)]
    nonpun_proj = projected[len(pun_acts):]
    
    # Plot
    if ax is None:
        fig, ax = plt.subplots(figsize=(8, 6))
    
    ax.scatter(pun_proj[:, 0], pun_proj[:, 1], c='coral', label='Puns', 
               alpha=0.7, s=100, edgecolors='darkred', linewidth=0.5)
    ax.scatter(nonpun_proj[:, 0], nonpun_proj[:, 1], c='steelblue', label='Non-puns',
               alpha=0.7, s=100, edgecolors='darkblue', linewidth=0.5)
    
    ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
    ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
    ax.set_title(f'{layer_name.capitalize()} Layer')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    return pca

In [None]:
# Create PCA plots for all examined layers
fig, axes = plt.subplots(2, 2, figsize=(14, 12))
axes = axes.flatten()

for i, (name, layer_idx) in enumerate(layers_to_examine.items()):
    plot_pca_scatter(pun_activations[name], nonpun_activations[name], 
                     f'{name} (L{layer_idx})', ax=axes[i])

plt.suptitle('PCA Visualization: Puns vs Non-Puns Across Layers', fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

## Part 4: The "Pun Direction"

Following Marks & Tegmark, we compute the **mass mean-difference vector** between puns and non-puns. This direction should capture the essence of "pun-ness".

In [None]:
def compute_pun_direction(pun_acts, nonpun_acts):
    """
    Compute the pun direction as mass mean-difference.
    
    pun_direction = mean(pun_activations) - mean(nonpun_activations)
    """
    mean_pun = np.mean(pun_acts, axis=0)
    mean_nonpun = np.mean(nonpun_acts, axis=0)
    
    direction = mean_pun - mean_nonpun
    
    # Normalize to unit vector
    direction_normalized = direction / np.linalg.norm(direction)
    
    return direction_normalized, mean_pun, mean_nonpun

# Compute pun direction for each layer
pun_directions = {}
for name in layers_to_examine.keys():
    direction, mean_pun, mean_nonpun = compute_pun_direction(
        pun_activations[name], nonpun_activations[name]
    )
    pun_directions[name] = {
        'direction': direction,
        'mean_pun': mean_pun,
        'mean_nonpun': mean_nonpun
    }
    
    # Compute separation (distance between means)
    separation = np.linalg.norm(mean_pun - mean_nonpun)
    print(f"{name}: mean separation = {separation:.4f}")

### Testing the Pun Direction

We can use the pun direction as a simple classifier: project each sentence onto the direction and see if puns have higher scores.

In [None]:
def score_on_direction(activations, direction):
    """
    Project activations onto a direction.
    Returns scalar scores for each example.
    """
    return activations @ direction

def evaluate_direction(pun_acts, nonpun_acts, direction):
    """
    Evaluate how well a direction separates puns from non-puns.
    """
    pun_scores = score_on_direction(pun_acts, direction)
    nonpun_scores = score_on_direction(nonpun_acts, direction)
    
    # Simple threshold classifier: predict pun if score > threshold
    all_scores = np.concatenate([pun_scores, nonpun_scores])
    all_labels = np.array([1] * len(pun_scores) + [0] * len(nonpun_scores))
    
    # Find optimal threshold
    threshold = np.median(all_scores)
    predictions = (all_scores > threshold).astype(int)
    accuracy = accuracy_score(all_labels, predictions)
    
    return {
        'pun_scores': pun_scores,
        'nonpun_scores': nonpun_scores,
        'accuracy': accuracy,
        'threshold': threshold
    }

In [None]:
# Evaluate pun direction at each layer
print("Classification accuracy using pun direction:\n")

results = {}
for name in layers_to_examine.keys():
    result = evaluate_direction(
        pun_activations[name],
        nonpun_activations[name],
        pun_directions[name]['direction']
    )
    results[name] = result
    
    print(f"{name:8s}: {result['accuracy']*100:.1f}% accuracy")
    print(f"          Pun scores: mean={np.mean(result['pun_scores']):.3f}, std={np.std(result['pun_scores']):.3f}")
    print(f"          Non-pun scores: mean={np.mean(result['nonpun_scores']):.3f}, std={np.std(result['nonpun_scores']):.3f}")
    print()

In [None]:
# Visualize score distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for i, name in enumerate(layers_to_examine.keys()):
    ax = axes[i]
    result = results[name]
    
    # Histogram of scores
    ax.hist(result['pun_scores'], bins=15, alpha=0.6, color='coral', 
            label=f'Puns (mean={np.mean(result["pun_scores"]):.2f})', density=True)
    ax.hist(result['nonpun_scores'], bins=15, alpha=0.6, color='steelblue',
            label=f'Non-puns (mean={np.mean(result["nonpun_scores"]):.2f})', density=True)
    
    ax.axvline(result['threshold'], color='black', linestyle='--', 
               label=f'Threshold (acc={result["accuracy"]*100:.0f}%)')
    
    ax.set_xlabel('Projection onto Pun Direction')
    ax.set_ylabel('Density')
    ax.set_title(f'{name.capitalize()} Layer')
    ax.legend(fontsize=8)

plt.suptitle('Score Distributions: Puns vs Non-Puns Projected onto Pun Direction', y=1.02)
plt.tight_layout()
plt.show()

## Part 5: Linear Separability Analysis

Beyond the mean-difference direction, let's see how well a logistic regression classifier can separate puns from non-puns. This tests the full linear separability of the representations.

In [None]:
def test_linear_separability(pun_acts, nonpun_acts, test_size=0.3):
    """
    Train a logistic regression classifier and return train/test accuracy.
    """
    X = np.vstack([pun_acts, nonpun_acts])
    y = np.array([1] * len(pun_acts) + [0] * len(nonpun_acts))
    
    # Split
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=42, stratify=y
    )
    
    # Train logistic regression
    clf = LogisticRegression(max_iter=1000, random_state=42)
    clf.fit(X_train, y_train)
    
    train_acc = accuracy_score(y_train, clf.predict(X_train))
    test_acc = accuracy_score(y_test, clf.predict(X_test))
    
    return {
        'train_accuracy': train_acc,
        'test_accuracy': test_acc,
        'classifier': clf
    }

# Test linear separability at each layer
print("Linear separability (Logistic Regression):\n")
print(f"{'Layer':<10} {'Train Acc':<12} {'Test Acc':<12}")
print("-" * 34)

separability_results = {}
for name in layers_to_examine.keys():
    result = test_linear_separability(
        pun_activations[name],
        nonpun_activations[name]
    )
    separability_results[name] = result
    print(f"{name:<10} {result['train_accuracy']*100:>8.1f}%    {result['test_accuracy']*100:>8.1f}%")

## Part 6: Visualize Pun Direction in PCA Space

Let's visualize the pun direction overlaid on our PCA plots. If the direction aligns with the cluster separation, it confirms our linear representation hypothesis.

In [None]:
def plot_with_direction(pun_acts, nonpun_acts, direction, mean_pun, mean_nonpun, layer_name):
    """
    PCA plot with the pun direction visualized as an arrow.
    """
    # Combine for joint PCA
    all_acts = np.vstack([pun_acts, nonpun_acts])
    
    # Fit PCA
    pca = PCA(n_components=2)
    projected = pca.fit_transform(all_acts)
    
    # Split back
    pun_proj = projected[:len(pun_acts)]
    nonpun_proj = projected[len(pun_acts):]
    
    # Project means and direction
    mean_pun_proj = pca.transform(mean_pun.reshape(1, -1))[0]
    mean_nonpun_proj = pca.transform(mean_nonpun.reshape(1, -1))[0]
    
    # Plot
    fig, ax = plt.subplots(figsize=(10, 8))
    
    ax.scatter(pun_proj[:, 0], pun_proj[:, 1], c='coral', label='Puns', 
               alpha=0.6, s=100, edgecolors='darkred', linewidth=0.5)
    ax.scatter(nonpun_proj[:, 0], nonpun_proj[:, 1], c='steelblue', label='Non-puns',
               alpha=0.6, s=100, edgecolors='darkblue', linewidth=0.5)
    
    # Plot means
    ax.scatter([mean_pun_proj[0]], [mean_pun_proj[1]], c='red', s=300, 
               marker='*', edgecolors='black', linewidth=2, label='Pun mean', zorder=5)
    ax.scatter([mean_nonpun_proj[0]], [mean_nonpun_proj[1]], c='blue', s=300,
               marker='*', edgecolors='black', linewidth=2, label='Non-pun mean', zorder=5)
    
    # Draw arrow from non-pun mean to pun mean (the pun direction)
    arrow_dx = mean_pun_proj[0] - mean_nonpun_proj[0]
    arrow_dy = mean_pun_proj[1] - mean_nonpun_proj[1]
    ax.annotate('', xy=(mean_pun_proj[0], mean_pun_proj[1]),
                xytext=(mean_nonpun_proj[0], mean_nonpun_proj[1]),
                arrowprops=dict(arrowstyle='->', color='green', lw=3))
    
    # Label the arrow
    mid_x = (mean_pun_proj[0] + mean_nonpun_proj[0]) / 2
    mid_y = (mean_pun_proj[1] + mean_nonpun_proj[1]) / 2
    ax.annotate('Pun Direction', xy=(mid_x, mid_y), fontsize=12, 
                color='darkgreen', fontweight='bold',
                xytext=(10, 10), textcoords='offset points')
    
    ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]*100:.1f}%)')
    ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]*100:.1f}%)')
    ax.set_title(f'{layer_name}: PCA with Pun Direction')
    ax.legend(loc='upper right')
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

# Plot for the best layer (based on separability)
best_layer = max(separability_results.keys(), 
                 key=lambda k: separability_results[k]['test_accuracy'])
print(f"Best layer for visualization: {best_layer}")

plot_with_direction(
    pun_activations[best_layer],
    nonpun_activations[best_layer],
    pun_directions[best_layer]['direction'],
    pun_directions[best_layer]['mean_pun'],
    pun_directions[best_layer]['mean_nonpun'],
    best_layer.capitalize()
)

## Part 7: Generalization Test

The key test from Geometry of Truth: does the pun direction generalize to new, unseen puns?

In [None]:
# Held-out test examples (not used for computing the direction)
test_puns = [
    "Why do programmers prefer dark mode? Because light attracts bugs.",
    "What do you call a lazy kangaroo? A pouch potato.",
    "Why did the stadium get hot? All the fans left.",
]

test_nonpuns = [
    "Why do programmers use version control? To track code changes.",
    "What do you call a baby kangaroo? A joey.",
    "Why did the stadium close? For maintenance.",
]

# Get activations for test examples at the best layer
best_layer_idx = layers_to_examine[best_layer]

print(f"Testing generalization at {best_layer} layer (L{best_layer_idx})...\n")

test_pun_acts = collect_activations(test_puns, model, best_layer_idx, remote=REMOTE)
test_nonpun_acts = collect_activations(test_nonpuns, model, best_layer_idx, remote=REMOTE)

# Score on pun direction
direction = pun_directions[best_layer]['direction']

print("\nTest pun scores (should be positive):")
for pun, score in zip(test_puns, score_on_direction(test_pun_acts, direction)):
    print(f"  {score:+.3f}: {pun[:50]}...")

print("\nTest non-pun scores (should be negative):")
for nonpun, score in zip(test_nonpuns, score_on_direction(test_nonpun_acts, direction)):
    print(f"  {score:+.3f}: {nonpun[:50]}...")

## Exercise 1: Position Analysis

We've been using the final token position. How does the representation change at different positions (e.g., middle of punchline, end of setup)?

In [None]:
# TODO: Compare activations at different positions
# 1. Final token (current approach)
# 2. Middle of the sentence
# 3. End of the setup (before punchline)
#
# Question: At which position is the pun/non-pun distinction clearest?

pass

## Exercise 2: Fine-Grained Layer Analysis

We sampled only 4 layers. Create a heatmap of linear separability across all layers.

In [None]:
# TODO: Create heatmap of separability vs layer
# For each layer (or every 4th layer):
#   1. Collect activations
#   2. Train logistic regression
#   3. Record test accuracy
# Plot: x-axis = layer, y-axis = accuracy
#
# Question: At which layer does pun understanding "emerge"?

pass

## Exercise 3: Pun Subtypes

Are different types of puns (homograph, homophone, compound) represented differently?

In [None]:
# TODO: Label puns by type and visualize separately
# Categories:
#   - Homograph: same spelling, different meanings (current, interest)
#   - Homophone: same sound, different spellings (knight/night)
#   - Compound: play on phrases (time flies)
#
# Question: Do different pun types cluster together or separately?
# Do they have different "directions"?

pass

## Exercise 4: 3D Visualization

Create an interactive 3D PCA plot to see more structure.

In [None]:
# TODO: Create 3D PCA visualization
# Use plotly for interactive 3D scatter:
#   pip install plotly
#   import plotly.express as px
#
# This can reveal structure hidden in 2D projections

pass

## Summary

In this notebook, we applied the "Geometry of Truth" methodology to puns:

1. **PCA Visualization:** Revealed whether puns and non-puns cluster separately at different layers

2. **Pun Direction:** Computed a linear direction that captures "pun-ness" using mass mean-difference

3. **Classification:** Tested how well the pun direction separates puns from non-puns

4. **Generalization:** Verified the direction works on held-out examples

### Key Findings

- Puns form a (partially) linearly separable cluster in activation space
- The separation typically improves in middle-to-late layers
- A simple linear direction can classify puns with reasonable accuracy
- The direction generalizes to new pun examples

### Questions for Your Research

- How does your concept compare to puns in terms of linear separability?
- At which layer does your concept "emerge" in the representation?
- Is your concept direction as generalizable as the truth direction from Marks & Tegmark?

### Next Steps

1. Apply this methodology to your research concept
2. Compare geometric structure with causal importance (Week 5)
3. Use the concept direction for steering experiments