# Experiment 035B: Context-Controlled Excitation

**AKIRA Project - Oscar Goldman - Shogu Research Group @ Datamutant.ai**

---

## What This Experiment Tests

In 035A, we showed that different discrimination types produce different activation patterns.

Now we test something stronger: **the same tokens produce different activations depending on context**.

### The Setup

We use a fixed query phrase like "The answer is" and vary the context:

```
MATH CONTEXT:    "2 + 2 = ? The answer is"
TRIVIA CONTEXT:  "What is the capital of France? The answer is"
SENTIMENT:       "How do you feel today? The answer is"
```

The tokens "The answer is" are identical, but the activations should differ based on what action is required.

### Hypothesis

If AQ theory is correct:
- Same surface tokens + different context = different activation patterns
- Context determines which AQ excite, not the tokens themselves
- Activations for "The answer is" should cluster by CONTEXT TYPE, not by token identity

### Why This Matters

This directly tests the claim that context SELECTS which AQ excite from the weight field.

If activations were purely token-based, "The answer is" would always produce similar patterns.
If activations are context-controlled (AQ theory), the same tokens produce different patterns based on what came before.

---

## 1. Setup and Installation

In [None]:
# Install dependencies (uncomment for Colab)
!pip install transformers torch numpy scikit-learn matplotlib seaborn -q

In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import numpy as np
from sklearn.decomposition import PCA
from sklearn.metrics import silhouette_score, pairwise_distances
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass, field
import warnings

warnings.filterwarnings('ignore')

# Check device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
print(f"PyTorch version: {torch.__version__}")

## 2. Configuration

In [None]:
@dataclass
class ExperimentConfig:
    """Configuration for Context-Controlled Excitation experiment."""
    model_name: str = "gpt2"
    layers_to_probe: List[int] = field(default_factory=list)
    random_seed: int = 42
    
    def __post_init__(self) -> None:
        if not self.layers_to_probe:
            if "gpt2" in self.model_name.lower():
                self.layers_to_probe = [0, 3, 6, 9, 11]
            elif "pythia" in self.model_name.lower():
                self.layers_to_probe = [0, 2, 4, 5]
            else:
                self.layers_to_probe = [0, 3, 6, 9, 11]
        
        np.random.seed(self.random_seed)
        torch.manual_seed(self.random_seed)


config = ExperimentConfig()
print(f"Model: {config.model_name}")
print(f"Layers to probe: {config.layers_to_probe}")

## 3. Context-Controlled Probes

Each probe has:
- A **context** that sets up a specific discrimination type
- A **fixed query** ("The answer is") that we measure activations for

The key: the query tokens are IDENTICAL across all probes, only context differs.

In [None]:
# The fixed query phrase we'll measure activations for
QUERY_PHRASE = "The answer is"

# Context + Query probes organized by context type
CONTEXT_PROBES: Dict[str, List[str]] = {
    'math': [
        "What is 2 + 2? The answer is",
        "Calculate 5 times 3. The answer is",
        "What is 10 divided by 2? The answer is",
        "Solve: 7 - 4 = ? The answer is",
        "What is 8 + 1? The answer is",
        "Compute 6 times 2. The answer is",
    ],
    'geography': [
        "What is the capital of France? The answer is",
        "Name the largest country by area. The answer is",
        "What continent is Brazil in? The answer is",
        "What is the capital of Japan? The answer is",
        "Name the longest river in Africa. The answer is",
        "What ocean is between Europe and America? The answer is",
    ],
    'science': [
        "What is the chemical symbol for water? The answer is",
        "At what temperature does water boil? The answer is",
        "What planet is closest to the sun? The answer is",
        "What gas do plants produce? The answer is",
        "What is the speed of light? The answer is",
        "How many bones in the human body? The answer is",
    ],
    'sentiment': [
        "How are you feeling today? The answer is",
        "What is your mood right now? The answer is",
        "Are you happy or sad? The answer is",
        "How would you describe your emotions? The answer is",
        "What is your emotional state? The answer is",
        "Do you feel good or bad? The answer is",
    ],
    'yesno': [
        "Is the sky blue? The answer is",
        "Can fish swim? The answer is",
        "Is fire cold? The answer is",
        "Do birds fly? The answer is",
        "Is water wet? The answer is",
        "Can humans breathe underwater? The answer is",
    ],
}

# Colors for visualization
CONTEXT_COLORS: Dict[str, str] = {
    'math': '#3498db',        # blue
    'geography': '#9b59b6',   # purple
    'science': '#f39c12',     # orange
    'sentiment': '#2ecc71',   # green
    'yesno': '#e74c3c',       # red
}

print(f"Query phrase: '{QUERY_PHRASE}'")
print(f"Number of context types: {len(CONTEXT_PROBES)}")
print(f"Total probes: {sum(len(v) for v in CONTEXT_PROBES.values())}")
for ctx, prompts in CONTEXT_PROBES.items():
    print(f"  {ctx}: {len(prompts)} probes")

## 4. Activation Capture Class

In [None]:
class ActivationCapture:
    """Captures activations from specified layers using forward hooks."""
    
    def __init__(self, model: nn.Module, layer_indices: List[int]) -> None:
        assert len(layer_indices) > 0, "Must specify at least one layer to probe"
        
        self.activations: Dict[int, torch.Tensor] = {}
        self.hooks: List[torch.utils.hooks.RemovableHandle] = []
        self.layer_indices = layer_indices
        
        # Get the transformer blocks
        if hasattr(model, 'transformer'):
            layers = model.transformer.h
        elif hasattr(model, 'gpt_neox'):
            layers = model.gpt_neox.layers
        else:
            raise ValueError(f"Unknown model architecture: {type(model)}")
        
        assert len(layers) > max(layer_indices), \
            f"Model has {len(layers)} layers but requested layer {max(layer_indices)}"
        
        for idx in layer_indices:
            layer = layers[idx]
            hook = layer.register_forward_hook(self._make_hook(idx))
            self.hooks.append(hook)
        
        print(f"Registered hooks on layers: {layer_indices}")
    
    def _make_hook(self, layer_idx: int):
        def hook(module: nn.Module, input: Tuple, output: Tuple) -> None:
            if isinstance(output, tuple):
                hidden_states = output[0]
            else:
                hidden_states = output
            self.activations[layer_idx] = hidden_states.detach()
        return hook
    
    def clear(self) -> None:
        self.activations = {}
    
    def remove_hooks(self) -> None:
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        print("Removed all hooks")
    
    def get_activation_at_positions(self, layer_idx: int, positions: List[int]) -> np.ndarray:
        """Get activations for specific token positions.
        
        Args:
            layer_idx: Which layer
            positions: List of token positions to extract
            
        Returns:
            Array of shape [len(positions), hidden_dim]
        """
        assert layer_idx in self.activations, f"Layer {layer_idx} not captured"
        act = self.activations[layer_idx]
        
        # Extract specified positions
        extracted = act[0, positions, :].cpu().numpy()
        return extracted
    
    def get_mean_activation(self, layer_idx: int, positions: List[int]) -> np.ndarray:
        """Get mean activation across specified positions."""
        extracted = self.get_activation_at_positions(layer_idx, positions)
        return extracted.mean(axis=0)

## 5. Analysis Functions

In [None]:
def compute_category_distances(
    activations: np.ndarray,
    labels: List[str]
) -> Tuple[float, float, float]:
    """Compute within-category and between-category distances."""
    within_distances = []
    between_distances = []
    
    distances = pairwise_distances(activations, metric='euclidean')
    
    for i in range(len(labels)):
        for j in range(i + 1, len(labels)):
            if labels[i] == labels[j]:
                within_distances.append(distances[i, j])
            else:
                between_distances.append(distances[i, j])
    
    within_mean = np.mean(within_distances) if within_distances else 0
    between_mean = np.mean(between_distances) if between_distances else 0
    ratio = between_mean / within_mean if within_mean > 0 else float('inf')
    
    return within_mean, between_mean, ratio


def compute_silhouette(activations: np.ndarray, labels: List[str]) -> float:
    """Compute silhouette score."""
    unique_labels = list(set(labels))
    label_to_int = {label: i for i, label in enumerate(unique_labels)}
    int_labels = [label_to_int[label] for label in labels]
    
    if len(set(int_labels)) < 2:
        return 0.0
    
    return silhouette_score(activations, int_labels)


def run_pca_analysis(
    activations: np.ndarray,
    n_components: int = 2,
    verbose: bool = True
) -> Tuple[np.ndarray, PCA]:
    """Apply PCA."""
    scaler = StandardScaler()
    activations_scaled = scaler.fit_transform(activations)
    
    pca = PCA(n_components=n_components, random_state=42)
    reduced = pca.fit_transform(activations_scaled)
    
    if verbose:
        explained_var = sum(pca.explained_variance_ratio_) * 100
        print(f"PCA: {n_components} components explain {explained_var:.1f}% variance")
    
    return reduced, pca

## 6. Visualization Functions

In [None]:
def plot_activation_scatter(
    activations_2d: np.ndarray,
    labels: List[str],
    layer_idx: int,
    title_suffix: str = ""
) -> None:
    """Create 2D scatter plot."""
    plt.figure(figsize=(10, 8))
    
    for category in CONTEXT_COLORS:
        mask = [l == category for l in labels]
        if any(mask):
            indices = [i for i, m in enumerate(mask) if m]
            plt.scatter(
                activations_2d[indices, 0],
                activations_2d[indices, 1],
                c=CONTEXT_COLORS[category],
                label=category,
                alpha=0.7,
                s=100,
                edgecolors='white',
                linewidth=0.5
            )
    
    plt.xlabel('PC1', fontsize=12)
    plt.ylabel('PC2', fontsize=12)
    plt.title(f'Context-Controlled Activations - Layer {layer_idx}{title_suffix}', fontsize=14)
    plt.legend(loc='best', fontsize=10)
    plt.grid(True, alpha=0.3)
    plt.show()


def plot_similarity_matrix(
    activations: np.ndarray,
    labels: List[str],
    layer_idx: int
) -> None:
    """Create similarity heatmap."""
    norms = np.linalg.norm(activations, axis=1, keepdims=True)
    normalized = activations / (norms + 1e-8)
    similarity = normalized @ normalized.T
    
    sorted_indices = sorted(range(len(labels)), key=lambda i: labels[i])
    similarity_sorted = similarity[sorted_indices][:, sorted_indices]
    labels_sorted = [labels[i] for i in sorted_indices]
    
    plt.figure(figsize=(12, 10))
    
    sns.heatmap(
        similarity_sorted,
        cmap='RdYlBu_r',
        vmin=-1,
        vmax=1,
        square=True,
        cbar_kws={'label': 'Cosine Similarity'}
    )
    
    # Add category boundaries
    unique_labels = []
    boundaries = [0]
    for i, label in enumerate(labels_sorted):
        if label not in unique_labels:
            unique_labels.append(label)
            if i > 0:
                boundaries.append(i)
    boundaries.append(len(labels_sorted))
    
    for b in boundaries[1:-1]:
        plt.axhline(y=b, color='black', linewidth=2)
        plt.axvline(x=b, color='black', linewidth=2)
    
    plt.title(f'Activation Similarity for "{QUERY_PHRASE}" - Layer {layer_idx}', fontsize=14)
    plt.xlabel('Probe Index (sorted by context type)')
    plt.ylabel('Probe Index (sorted by context type)')
    plt.show()


def plot_layer_comparison(metrics_by_layer: Dict[int, Dict[str, float]]) -> None:
    """Plot metrics across layers."""
    layers = sorted(metrics_by_layer.keys())
    
    silhouettes = [metrics_by_layer[l]['silhouette'] for l in layers]
    ratios = [metrics_by_layer[l]['distance_ratio'] for l in layers]
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    axes[0].plot(layers, silhouettes, 'o-', color='#3498db', linewidth=2, markersize=10)
    axes[0].set_xlabel('Layer Index', fontsize=12)
    axes[0].set_ylabel('Silhouette Score', fontsize=12)
    axes[0].set_title('Context Clustering Quality by Layer', fontsize=14)
    axes[0].grid(True, alpha=0.3)
    axes[0].set_ylim(-0.2, 1.0)
    
    axes[1].plot(layers, ratios, 'o-', color='#e74c3c', linewidth=2, markersize=10)
    axes[1].set_xlabel('Layer Index', fontsize=12)
    axes[1].set_ylabel('Between/Within Distance Ratio', fontsize=12)
    axes[1].set_title('Context Separation by Layer', fontsize=14)
    axes[1].grid(True, alpha=0.3)
    axes[1].axhline(y=1.0, color='gray', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.show()

## 7. Load Model

In [None]:
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
model = AutoModelForCausalLM.from_pretrained(config.model_name)
model = model.to(DEVICE)
model.eval()

print(f"Model loaded on {DEVICE}")
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 8. Find Query Token Positions

We need to identify which token positions correspond to "The answer is" in each probe.

In [None]:
def find_query_positions(full_text: str, query: str, tokenizer) -> List[int]:
    """Find token positions of the query phrase within the full text.
    
    Returns positions of query tokens in the tokenized full text.
    """
    # Tokenize full text
    full_tokens = tokenizer.encode(full_text, add_special_tokens=False)
    
    # Tokenize query (with space prefix to match how it appears in context)
    query_tokens = tokenizer.encode(" " + query, add_special_tokens=False)
    
    # Find where query tokens appear in full tokens
    query_len = len(query_tokens)
    for i in range(len(full_tokens) - query_len + 1):
        if full_tokens[i:i+query_len] == query_tokens:
            return list(range(i, i + query_len))
    
    # Fallback: try without space prefix
    query_tokens = tokenizer.encode(query, add_special_tokens=False)
    query_len = len(query_tokens)
    for i in range(len(full_tokens) - query_len + 1):
        if full_tokens[i:i+query_len] == query_tokens:
            return list(range(i, i + query_len))
    
    # If still not found, return last N positions (where N = query token count)
    print(f"Warning: Could not find exact query match in '{full_text[:50]}...'")
    return list(range(len(full_tokens) - query_len, len(full_tokens)))


# Test on first probe
test_probe = list(CONTEXT_PROBES.values())[0][0]
test_positions = find_query_positions(test_probe, QUERY_PHRASE, tokenizer)
print(f"Test probe: '{test_probe}'")
print(f"Query positions: {test_positions}")

# Show what tokens those positions correspond to
full_tokens = tokenizer.encode(test_probe, add_special_tokens=False)
query_token_ids = [full_tokens[p] for p in test_positions]
query_decoded = tokenizer.decode(query_token_ids)
print(f"Tokens at those positions: '{query_decoded}'")

## 9. Run Probes and Capture Activations

For each probe, we capture the activations ONLY at the positions corresponding to "The answer is".

In [None]:
# Set up activation capture
capture = ActivationCapture(model, config.layers_to_probe)

# Storage for results
all_activations: Dict[int, List[np.ndarray]] = {
    layer: [] for layer in config.layers_to_probe
}
all_labels: List[str] = []
all_texts: List[str] = []

print(f"Running context-controlled probes...")
print(f"Extracting activations for tokens: '{QUERY_PHRASE}'\n")

for context_type, prompts in CONTEXT_PROBES.items():
    for prompt in prompts:
        # Find positions of query phrase
        query_positions = find_query_positions(prompt, QUERY_PHRASE, tokenizer)
        
        # Tokenize and run
        inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
        
        capture.clear()
        with torch.no_grad():
            outputs = model(**inputs)
        
        # Extract activations at query positions (mean across positions)
        for layer_idx in config.layers_to_probe:
            act = capture.get_mean_activation(layer_idx, query_positions)
            all_activations[layer_idx].append(act)
        
        all_labels.append(context_type)
        all_texts.append(prompt)

# Convert to numpy arrays
for layer_idx in config.layers_to_probe:
    all_activations[layer_idx] = np.array(all_activations[layer_idx])

print(f"\nCollected activations for {len(all_labels)} probes")
print(f"Activation shape per layer: {all_activations[config.layers_to_probe[0]].shape}")

## 10. Layer-by-Layer Analysis

In [None]:
print("=" * 70)
print("LAYER-BY-LAYER ANALYSIS")
print("=" * 70)
print(f"\nMeasuring activations for IDENTICAL tokens '{QUERY_PHRASE}'")
print("If context controls AQ, same tokens should cluster by context type\n")

metrics_by_layer: Dict[int, Dict[str, float]] = {}
pca_results_by_layer: Dict[int, np.ndarray] = {}

for layer_idx in config.layers_to_probe:
    print(f"\n--- Layer {layer_idx} ---")
    
    activations = all_activations[layer_idx]
    
    within_dist, between_dist, ratio = compute_category_distances(activations, all_labels)
    print(f"Within-context distance:  {within_dist:.4f}")
    print(f"Between-context distance: {between_dist:.4f}")
    print(f"Distance ratio: {ratio:.4f}")
    
    silhouette = compute_silhouette(activations, all_labels)
    print(f"Silhouette score: {silhouette:.4f}")
    
    metrics_by_layer[layer_idx] = {
        'within_distance': within_dist,
        'between_distance': between_dist,
        'distance_ratio': ratio,
        'silhouette': silhouette
    }
    
    pca_2d, _ = run_pca_analysis(activations, n_components=2)
    pca_results_by_layer[layer_idx] = pca_2d

## 11. Visualizations

### 11.1 Scatter Plots

Each point represents the activation pattern for "The answer is" - colored by what CONTEXT preceded it.

In [None]:
for layer_idx in config.layers_to_probe:
    plot_activation_scatter(
        pca_results_by_layer[layer_idx],
        all_labels,
        layer_idx,
        f"\n(Same tokens '{QUERY_PHRASE}' - different contexts)"
    )

### 11.2 Similarity Matrix

Shows pairwise similarity between activations for "The answer is" across all probes.

In [None]:
final_layer = config.layers_to_probe[-1]
plot_similarity_matrix(all_activations[final_layer], all_labels, final_layer)

### 11.3 Layer Comparison

In [None]:
plot_layer_comparison(metrics_by_layer)

## 12. Summary and Assessment

In [None]:
print("=" * 70)
print("SUMMARY: CONTEXT-CONTROLLED EXCITATION")
print("=" * 70)

final_layer = config.layers_to_probe[-1]
final_silhouette = metrics_by_layer[final_layer]['silhouette']
final_ratio = metrics_by_layer[final_layer]['distance_ratio']

silhouettes = [metrics_by_layer[l]['silhouette'] for l in config.layers_to_probe]
ratios = [metrics_by_layer[l]['distance_ratio'] for l in config.layers_to_probe]

print(f"\nKEY QUESTION: Do IDENTICAL tokens ('{QUERY_PHRASE}') produce")
print(f"             DIFFERENT activations based on context?")

print("\n" + "-" * 50)
print("RESULTS:")
print("-" * 50)

print(f"\n1. Do same-context probes cluster together?")
if final_silhouette > 0.1:
    print(f"   YES - Silhouette score {final_silhouette:.3f} > 0.1")
    print(f"   The tokens '{QUERY_PHRASE}' cluster by CONTEXT, not token identity")
else:
    print(f"   UNCLEAR - Silhouette score {final_silhouette:.3f} is low")

print(f"\n2. Do different contexts produce different patterns?")
if final_ratio > 1.2:
    print(f"   YES - Distance ratio {final_ratio:.3f} > 1.2")
    print(f"   Context determines activation pattern, not surface tokens")
else:
    print(f"   UNCLEAR - Distance ratio {final_ratio:.3f} is low")

print(f"\n3. Does context control strengthen with depth?")
if silhouettes[-1] > silhouettes[0]:
    print(f"   YES - Silhouette increases: {silhouettes[0]:.3f} -> {silhouettes[-1]:.3f}")
else:
    print(f"   NO - Silhouette does not increase with depth")

print("\n" + "-" * 50)
print("OVERALL ASSESSMENT:")
print("-" * 50)

evidence_score = 0
if final_silhouette > 0.1:
    evidence_score += 1
if final_ratio > 1.2:
    evidence_score += 1
if silhouettes[-1] > silhouettes[0]:
    evidence_score += 1

if evidence_score >= 2:
    print(f"\nSTRONG EVIDENCE that context controls AQ excitation")
    print(f"The same tokens produce different activations based on context.")
    print(f"This supports the AQ theory: context SELECTS which patterns excite.")
elif evidence_score == 1:
    print(f"\nWEAK EVIDENCE for context-controlled excitation")
else:
    print(f"\nNO CLEAR EVIDENCE for context-controlled excitation")

print(f"\n(Evidence score: {evidence_score}/3)")

## 13. Final Metrics Table

In [None]:
print("\nFINAL METRICS BY LAYER:")
print("-" * 60)
print(f"{'Layer':>6} | {'Silhouette':>12} | {'Distance Ratio':>15}")
print("-" * 60)
for layer_idx, metrics in metrics_by_layer.items():
    print(f"{layer_idx:>6} | {metrics['silhouette']:>12.3f} | {metrics['distance_ratio']:>15.3f}")

## 14. Cleanup

In [None]:
capture.remove_hooks()

results = {
    'config': config,
    'activations': all_activations,
    'labels': all_labels,
    'texts': all_texts,
    'metrics_by_layer': metrics_by_layer,
    'pca_results': pca_results_by_layer,
    'evidence_score': evidence_score
}

print("\n" + "=" * 70)
print("EXPERIMENT 035B COMPLETE")
print("=" * 70)

---

## Interpretation Guide

### What This Experiment Shows

We took the SAME tokens ("The answer is") and measured their activations in different contexts:
- After a math question
- After a geography question
- After a science question
- After a sentiment question
- After a yes/no question

**If tokens determine activations**: All points would cluster together (same tokens = same pattern)

**If context determines activations (AQ theory)**: Points cluster by context type (same tokens = different patterns based on what precedes them)

### Why This Matters

This directly tests the core AQ claim: **context selects which AQ excite from the weight field**.

The weights contain many possible excitation patterns. The context (what comes before) determines which pattern manifests for any given token. This is why the same word can mean/do different things in different contexts.

### Relation to 035A

- **035A** showed: Different discrimination types produce different patterns
- **035B** shows: Context controls which pattern activates, even for identical tokens

Together, they support the AQ theory that weights store crystallized patterns and context selects which ones excite.

---

**AKIRA Project - Experiment 035B**  
Oscar Goldman - Shogu Research Group @ Datamutant.ai