# Week 6 Exercise: Probes and Masks

In this exercise, you'll gain hands-on experience with:
- Training linear and MLP probes
- Implementing control tasks
- Testing for overfitting and underfitting
- Training learned masks with sparse regularization
- Comparing probe/mask findings with causal interventions
- Interpreting agreements and disagreements between methods

## Setup

Install required libraries:

In [None]:
!pip install transformers torch numpy matplotlib scikit-learn -q

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import warnings
warnings.filterwarnings('ignore')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load GPT-2 small
model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
model = model.to(device)
model.eval()

print(f"\nModel: {model_name}")
print(f"Number of layers: {model.config.n_layer}")
print(f"Hidden size: {model.config.n_embd}")

## Part 1: Creating a Probe Dataset

First, let's create a balanced dataset for probing.

In [None]:
# Example: Probing for sentiment (positive vs negative)
positive_examples = [
    "This movie was amazing and wonderful",
    "I love this product, it's fantastic",
    "The weather is beautiful today",
    "What a great experience at the restaurant",
    "The book was inspiring and uplifting",
    "I'm so happy with my purchase",
    "The service was excellent and friendly",
    "This is the best thing ever",
    "I enjoyed every moment of it",
    "Outstanding quality and performance",
    "Absolutely delighted with the results",
    "A truly remarkable achievement",
    "The team did an excellent job",
    "I highly recommend this to everyone",
    "Perfect in every way possible"
]

negative_examples = [
    "This movie was terrible and boring",
    "I hate this product, it's awful",
    "The weather is dreadful today",
    "What a horrible experience at the restaurant",
    "The book was depressing and disappointing",
    "I'm so unhappy with my purchase",
    "The service was terrible and rude",
    "This is the worst thing ever",
    "I regret every moment of it",
    "Poor quality and terrible performance",
    "Absolutely disappointed with the results",
    "A truly disastrous failure",
    "The team did an awful job",
    "I strongly advise against this",
    "Flawed in every way possible"
]

# Create dataset
texts = positive_examples + negative_examples
labels = [1] * len(positive_examples) + [0] * len(negative_examples)

print(f"Dataset size: {len(texts)} examples")
print(f"Positive: {sum(labels)}, Negative: {len(labels) - sum(labels)}")
print(f"\nExample positive: {positive_examples[0]}")
print(f"Example negative: {negative_examples[0]}")

## Part 2: Extracting Hidden States

Extract representations from all layers.

In [None]:
def extract_hidden_states(texts, layer_idx=-1, position=-1):
    """
    Extract hidden states from specified layer and position.
    
    Args:
        texts: List of text strings
        layer_idx: Which layer (-1 for last layer)
        position: Which token position (-1 for last token)
    
    Returns:
        hidden_states: [num_examples, hidden_size] numpy array
    """
    all_hidden_states = []
    
    for text in texts:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128).to(device)
        
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)
            # Get hidden states from specified layer
            hidden_state = outputs.hidden_states[layer_idx][0, position, :]
            all_hidden_states.append(hidden_state.cpu().numpy())
    
    return np.array(all_hidden_states)


# Extract from last layer as an example
print("Extracting hidden states from last layer...")
hidden_states = extract_hidden_states(texts, layer_idx=-1, position=-1)
print(f"Hidden states shape: {hidden_states.shape}")
print(f"  {hidden_states.shape[0]} examples × {hidden_states.shape[1]} dimensions")

## Part 3: Training Linear Probes

Train linear probes at every layer.

In [None]:
def train_linear_probe(X, y, test_size=0.2, random_state=42):
    """
    Train a linear probe (logistic regression).
    
    Args:
        X: Features [num_examples, hidden_size]
        y: Labels [num_examples]
        test_size: Fraction for test set
    
    Returns:
        probe: Trained model
        accuracy: Test accuracy
    """
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state, stratify=y
    )
    
    # Train linear probe
    probe = LogisticRegression(max_iter=1000, random_state=random_state)
    probe.fit(X_train, y_train)
    
    # Evaluate
    y_pred = probe.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    
    return probe, accuracy, (X_train, X_test, y_train, y_test)


# Train probes at all layers
print("Training linear probes at all layers...\n")
layer_accuracies = []

for layer_idx in range(model.config.n_layer + 1):  # +1 for embedding layer
    # Extract hidden states from this layer
    X = extract_hidden_states(texts, layer_idx=layer_idx, position=-1)
    
    # Train probe
    probe, accuracy, splits = train_linear_probe(X, labels)
    layer_accuracies.append(accuracy)
    
    print(f"Layer {layer_idx}: {accuracy:.4f}")

print(f"\nBest layer: {np.argmax(layer_accuracies)} (accuracy: {max(layer_accuracies):.4f})")

In [None]:
# Visualize probe accuracy across layers
plt.figure(figsize=(10, 5))
plt.plot(range(len(layer_accuracies)), layer_accuracies, marker='o')
plt.axhline(y=0.5, color='r', linestyle='--', label='Random baseline')
plt.xlabel('Layer')
plt.ylabel('Probe Accuracy')
plt.title('Linear Probe Accuracy Across Layers')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print("\nInterpretation: Higher accuracy indicates the concept is more linearly accessible.")

## Part 4: Training MLP Probes

Compare linear probes with nonlinear (MLP) probes.

In [None]:
class MLPProbe(nn.Module):
    """Simple MLP probe with one hidden layer."""
    
    def __init__(self, input_size, hidden_size=128, num_classes=2):
        super().__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


def train_mlp_probe(X, y, hidden_size=128, epochs=50, lr=0.001):
    """
    Train an MLP probe.
    """
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    # Convert to tensors
    X_train_t = torch.FloatTensor(X_train).to(device)
    X_test_t = torch.FloatTensor(X_test).to(device)
    y_train_t = torch.LongTensor(y_train).to(device)
    y_test_t = torch.LongTensor(y_test).to(device)
    
    # Create probe
    probe = MLPProbe(X.shape[1], hidden_size).to(device)
    optimizer = optim.Adam(probe.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Train
    probe.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        outputs = probe(X_train_t)
        loss = criterion(outputs, y_train_t)
        loss.backward()
        optimizer.step()
    
    # Evaluate
    probe.eval()
    with torch.no_grad():
        outputs = probe(X_test_t)
        _, predicted = torch.max(outputs, 1)
        accuracy = (predicted == y_test_t).float().mean().item()
    
    return probe, accuracy


# Compare linear vs MLP probes at each layer
print("Training MLP probes at all layers...\n")
mlp_accuracies = []

for layer_idx in range(model.config.n_layer + 1):
    X = extract_hidden_states(texts, layer_idx=layer_idx, position=-1)
    probe, accuracy = train_mlp_probe(X, labels, epochs=50)
    mlp_accuracies.append(accuracy)
    print(f"Layer {layer_idx}: {accuracy:.4f}")

print(f"\nBest MLP layer: {np.argmax(mlp_accuracies)} (accuracy: {max(mlp_accuracies):.4f})")

In [None]:
# Compare linear vs MLP probes
plt.figure(figsize=(10, 5))
plt.plot(range(len(layer_accuracies)), layer_accuracies, marker='o', label='Linear Probe')
plt.plot(range(len(mlp_accuracies)), mlp_accuracies, marker='s', label='MLP Probe')
plt.axhline(y=0.5, color='r', linestyle='--', label='Random baseline')
plt.xlabel('Layer')
plt.ylabel('Probe Accuracy')
plt.title('Linear vs MLP Probe Comparison')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

# Calculate gap
gaps = np.array(mlp_accuracies) - np.array(layer_accuracies)
print(f"\nAverage MLP advantage: {gaps.mean():.4f}")
print(f"Max gap at layer: {np.argmax(gaps)} (gap: {gaps.max():.4f})")
print("\nInterpretation: Large gap suggests information is present but not linearly accessible.")

## Part 5: Control Tasks

Validate probes with control tasks.

In [None]:
# Selectivity control: Test if sentiment probe responds to tense (should not)
past_tense = [
    "The cat walked to the store yesterday",
    "She finished her homework last night",
    "They played soccer in the park",
    "He cooked dinner for everyone",
    "We watched a movie together"
]

present_tense = [
    "The cat walks to the store today",
    "She finishes her homework tonight",
    "They play soccer in the park",
    "He cooks dinner for everyone",
    "We watch a movie together"
]

control_texts = past_tense + present_tense
control_labels = [0] * len(past_tense) + [1] * len(present_tense)  # 0=past, 1=present

# Extract hidden states for control task
best_layer = np.argmax(layer_accuracies)
print(f"Testing selectivity on best layer: {best_layer}\n")

X_control = extract_hidden_states(control_texts, layer_idx=best_layer, position=-1)

# Train probe on original task (sentiment)
X_sentiment = extract_hidden_states(texts, layer_idx=best_layer, position=-1)
sentiment_probe, sentiment_acc, _ = train_linear_probe(X_sentiment, labels)

# Test on control task (tense)
y_control_pred = sentiment_probe.predict(X_control)
control_acc = accuracy_score(control_labels, y_control_pred)

print(f"Sentiment probe on sentiment task: {sentiment_acc:.4f}")
print(f"Sentiment probe on tense task (control): {control_acc:.4f}")
print(f"Random baseline: 0.5000")

if abs(control_acc - 0.5) < 0.1:
    print("\n✓ PASS: Probe is selective (random on control task)")
else:
    print("\n✗ FAIL: Probe may be picking up confounds")

In [None]:
# Random label test: Check for overfitting
print("Testing for overfitting with random labels...\n")

# Randomize labels
random_labels = np.random.permutation(labels)

# Train probe on random labels
random_probe, random_acc, _ = train_linear_probe(X_sentiment, random_labels)

print(f"Probe on real labels: {sentiment_acc:.4f}")
print(f"Probe on random labels: {random_acc:.4f}")
print(f"Random baseline: 0.5000")

if random_acc < 0.6:
    print("\n✓ PASS: Probe is not overfitting (can't learn random labels)")
else:
    print("\n✗ WARNING: Probe may be overfitting (learns random labels too well)")

## Part 6: Testing for Underfitting

Check if linear probe is too simple.

In [None]:
def diagnose_fitting(linear_acc, mlp_acc, threshold=0.15):
    """
    Diagnose overfitting, underfitting, or good fit.
    """
    gap = mlp_acc - linear_acc
    
    print(f"Linear probe accuracy: {linear_acc:.4f}")
    print(f"MLP probe accuracy: {mlp_acc:.4f}")
    print(f"Gap: {gap:.4f}\n")
    
    if linear_acc < 0.6 and mlp_acc < 0.6:
        print("Diagnosis: INFORMATION ABSENT or WRONG LAYER")
        print("  - Concept may not be encoded here")
        print("  - Try different layers")
        print("  - Verify with interventions")
    
    elif linear_acc > 0.8 and gap < threshold:
        print("Diagnosis: GOOD FIT (linear)")
        print("  - Information is linearly accessible")
        print("  - Linear probe is sufficient")
        print("  - Representation is relatively simple")
    
    elif linear_acc < 0.7 and mlp_acc > 0.8:
        print("Diagnosis: UNDERFITTING (linear probe too simple)")
        print("  - Information is present but nonlinear")
        print("  - Linear probe cannot extract it")
        print("  - MLP probe succeeds")
        print("  - Representation is complex/distributed")
    
    elif linear_acc > mlp_acc:
        print("Diagnosis: POSSIBLE OVERFITTING (MLP)")
        print("  - MLP may be overfitting to noise")
        print("  - Trust linear probe more")
        print("  - Try regularization on MLP")
    
    else:
        print("Diagnosis: MODERATE NONLINEARITY")
        print("  - Some nonlinear structure present")
        print("  - Both probes partially successful")


# Diagnose each layer
print("Fitting diagnosis for each layer:\n")
print("=" * 60)

for layer_idx in [0, best_layer, model.config.n_layer]:
    print(f"\nLayer {layer_idx}:")
    diagnose_fitting(layer_accuracies[layer_idx], mlp_accuracies[layer_idx])

## Part 7: Learned Masks with Regularization

Train masks to identify important components.

In [None]:
class MaskedModel(nn.Module):
    """Model with learnable component masks."""
    
    def __init__(self, num_components, num_classes=2):
        super().__init__()
        # Mask parameters (one per component)
        self.mask_logits = nn.Parameter(torch.zeros(num_components))
        
        # Classifier
        self.classifier = nn.Linear(num_components, num_classes)
    
    def forward(self, x):
        # Apply sigmoid to get masks in [0, 1]
        masks = torch.sigmoid(self.mask_logits)
        
        # Mask the input
        x_masked = x * masks
        
        # Classify
        return self.classifier(x_masked), masks


def train_masked_model(X, y, l1_lambda=0.01, epochs=100, lr=0.01):
    """
    Train a model with learned masks and L1 regularization.
    """
    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    
    # Convert to tensors
    X_train_t = torch.FloatTensor(X_train).to(device)
    X_test_t = torch.FloatTensor(X_test).to(device)
    y_train_t = torch.LongTensor(y_train).to(device)
    y_test_t = torch.LongTensor(y_test).to(device)
    
    # Create model
    masked_model = MaskedModel(X.shape[1]).to(device)
    optimizer = optim.Adam(masked_model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()
    
    # Training loop
    train_losses = []
    test_accs = []
    sparsities = []
    
    for epoch in range(epochs):
        masked_model.train()
        optimizer.zero_grad()
        
        outputs, masks = masked_model(X_train_t)
        
        # Task loss + L1 regularization on masks
        task_loss = criterion(outputs, y_train_t)
        l1_reg = masks.abs().sum()
        loss = task_loss + l1_lambda * l1_reg
        
        loss.backward()
        optimizer.step()
        
        train_losses.append(loss.item())
        
        # Evaluate
        if epoch % 10 == 0:
            masked_model.eval()
            with torch.no_grad():
                outputs, masks = masked_model(X_test_t)
                _, predicted = torch.max(outputs, 1)
                accuracy = (predicted == y_test_t).float().mean().item()
                sparsity = (masks > 0.5).float().mean().item()
                
                test_accs.append(accuracy)
                sparsities.append(sparsity)
    
    # Final evaluation
    masked_model.eval()
    with torch.no_grad():
        outputs, final_masks = masked_model(X_test_t)
        _, predicted = torch.max(outputs, 1)
        final_accuracy = (predicted == y_test_t).float().mean().item()
        final_sparsity = (final_masks > 0.5).float().sum().item()
    
    return masked_model, final_accuracy, final_sparsity, final_masks


# Train with different regularization strengths
print("Training masked models with different L1 regularization...\n")

lambdas = [0.0, 0.001, 0.01, 0.05, 0.1, 0.5]
results = []

for l1_lambda in lambdas:
    masked_model, acc, sparsity, masks = train_masked_model(
        X_sentiment, labels, l1_lambda=l1_lambda, epochs=100
    )
    results.append((l1_lambda, acc, sparsity))
    print(f"λ={l1_lambda:.3f}: Accuracy={acc:.4f}, Active components={int(sparsity)}/{X_sentiment.shape[1]}")

In [None]:
# Visualize sparsity-performance tradeoff
lambdas_list, accs, sparsities = zip(*results)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Accuracy vs lambda
ax1.plot(lambdas_list, accs, marker='o')
ax1.set_xlabel('L1 Regularization (λ)')
ax1.set_ylabel('Test Accuracy')
ax1.set_title('Accuracy vs Regularization Strength')
ax1.grid(True, alpha=0.3)

# Sparsity vs accuracy
ax2.plot(sparsities, accs, marker='o')
ax2.set_xlabel('Number of Active Components')
ax2.set_ylabel('Test Accuracy')
ax2.set_title('Sparsity-Performance Tradeoff')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Find best tradeoff (highest accuracy with fewest components)
efficiency = np.array(accs) / (np.array(sparsities) + 1)
best_idx = np.argmax(efficiency)
print(f"\nBest tradeoff: λ={lambdas_list[best_idx]:.3f}")
print(f"  Accuracy: {accs[best_idx]:.4f}")
print(f"  Active components: {int(sparsities[best_idx])}/{X_sentiment.shape[1]}")

In [None]:
# Visualize learned masks
plt.figure(figsize=(12, 4))
mask_values = masks.cpu().numpy()
plt.bar(range(len(mask_values)), mask_values)
plt.axhline(y=0.5, color='r', linestyle='--', label='Threshold (0.5)')
plt.xlabel('Component (Dimension)')
plt.ylabel('Mask Weight')
plt.title('Learned Mask Weights (Best λ)')
plt.legend()
plt.show()

# Top components
top_k = 10
top_indices = np.argsort(mask_values)[-top_k:][::-1]
print(f"\nTop {top_k} components by mask weight:")
for i, idx in enumerate(top_indices):
    print(f"  {i+1}. Dimension {idx}: {mask_values[idx]:.4f}")

## Part 8: Comparing Probes with Interventions

Validate probe findings with causal interventions.

In [None]:
def ablation_test(model, text, layer_idx, top_components, position=-1):
    """
    Test if ablating top components affects output.
    
    This is a simplified version - full implementation would require hooks.
    """
    inputs = tokenizer(text, return_tensors="pt").to(device)
    
    # Baseline
    with torch.no_grad():
        outputs_baseline = model(**inputs)
        logits_baseline = outputs_baseline.logits[0, -1, :]
        probs_baseline = torch.softmax(logits_baseline, dim=-1)
    
    # For simplicity, we'll use a proxy: check if the components
    # identified by masks correlate with model behavior
    # In a full implementation, you would actually patch these components
    
    return probs_baseline


# Compare probe and mask findings
print("Comparing probe and mask findings:\n")
print("=" * 60)

print(f"\nProbe analysis (Layer {best_layer}):")
print(f"  Linear probe accuracy: {layer_accuracies[best_layer]:.4f}")
print(f"  MLP probe accuracy: {mlp_accuracies[best_layer]:.4f}")
print(f"  → Sentiment information is {'linearly' if layer_accuracies[best_layer] > 0.8 else 'nonlinearly'} accessible")

print(f"\nMask analysis:")
print(f"  Best λ: {lambdas_list[best_idx]:.3f}")
print(f"  Active components: {int(sparsities[best_idx])}/{X_sentiment.shape[1]}")
print(f"  → Only {100 * sparsities[best_idx] / X_sentiment.shape[1]:.1f}% of components needed")

print("\nInterpretation:")
if layer_accuracies[best_layer] > 0.8 and sparsities[best_idx] < X_sentiment.shape[1] * 0.2:
    print("  ✓ Information is linearly accessible")
    print("  ✓ Representation is sparse (few components matter)")
    print("  → Strong candidate for causal role")
    print("  → Validate with intervention experiments")
elif layer_accuracies[best_layer] > 0.8:
    print("  ✓ Information is linearly accessible")
    print("  ⚠ Representation is distributed (many components)")
    print("  → May be present but not causally used")
    print("  → Intervention critical to validate")
else:
    print("  ⚠ Information not linearly accessible")
    print("  → May be encoded nonlinearly or not present")
    print("  → Check other layers or use MLP probes")

## Part 9: Analyzing Agreement and Disagreement

When do probes and interventions give different answers?

In [None]:
# Create summary table
print("Summary: Probes, Masks, and Interventions\n")
print("=" * 70)

summary_data = {
    'Method': ['Linear Probe', 'MLP Probe', 'Learned Masks', 'Intervention (Week 4)'],
    'Question Answered': [
        'Is info linearly accessible?',
        'Is info computationally accessible?',
        'Which components are sufficient?',
        'Is info causally used?'
    ],
    'Speed': ['Fast', 'Fast', 'Medium', 'Slow'],
    'Causal Claim': ['No', 'No', 'Weak', 'Yes'],
    'Best For': [
        'Quick exploration',
        'Nonlinear patterns',
        'Component selection',
        'Validation'
    ]
}

import pandas as pd
df = pd.DataFrame(summary_data)
print(df.to_string(index=False))

print("\n" + "=" * 70)
print("\nBest Practice Workflow:")
print("1. Start with linear probes (fast exploration)")
print("2. Try MLP probes if linear fails (check nonlinearity)")
print("3. Use learned masks (identify important components)")
print("4. Validate with interventions (test causal role)")
print("5. Investigate disagreements (reveal structure)")

## Part 10: Your Project Template

Apply these methods to your concept.

In [None]:
print("Week 6 Project Template: Probing and Masking Your Concept\n")

print("1. Create your dataset")
MY_CONCEPT = "[Your concept here]"
my_positive_examples = [
    # Examples with your concept
]
my_negative_examples = [
    # Examples without your concept
]

print("\n2. Train linear probes at all layers")
# Use extract_hidden_states and train_linear_probe functions

print("\n3. Train MLP probes for comparison")
# Use train_mlp_probe function

print("\n4. Implement control tasks")
# Create selectivity and random label tests

print("\n5. Train learned masks with regularization")
# Try different λ values, find best tradeoff

print("\n6. Compare with Week 4 intervention results")
# Do probe/mask findings match intervention findings?

print("\n7. Analyze disagreements")
# What do they reveal about your concept's representation?

print("\n8. Document findings")
# Create visualizations and write report

## Summary

In this exercise, you've learned:
- How to train and interpret linear probes
- When to use MLP probes (nonlinear patterns)
- How to validate probes with control tasks
- How to diagnose overfitting and underfitting
- How to train learned masks with sparse regularization
- How to compare auxiliary models with causal interventions
- What agreements and disagreements reveal

Key takeaways:
- **Probes show what could be extracted, not what is used**
- **Always validate with interventions**
- **Disagreements are informative**
- **Use multiple methods for triangulation**

For your project:
1. Systematically probe all layers
2. Use control tasks to validate findings
3. Train masks to identify important components
4. Compare with Week 4/5 intervention results
5. Interpret what the comparison reveals about representation structure