# Model Test and Activation Capture with NNsight

This notebook downloads Mistral-7B-Instruct, saves it locally, and demonstrates proper residual stream activation capture using NNsight.

In [1]:
# 1. Install dependencies
import sys
import subprocess
import os

def install_package(package):
    try:
        __import__(package.split('==')[0])
        print(f"✓ {package} already installed")
    except ImportError:
        print(f"Installing {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])

packages = ["torch", "transformers", "numpy", "tqdm", "nnsight", "python-dotenv"]
for package in packages:
    install_package(package)

print("\n✓ All dependencies installed")

✓ torch already installed
✓ transformers already installed
✓ numpy already installed
✓ tqdm already installed
✓ nnsight already installed
Installing python-dotenv...

✓ All dependencies installed


In [2]:
# 2. Import libraries and setup device
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer
from nnsight import LanguageModel
from typing import Dict, List, Optional
import json
from dotenv import load_dotenv

# Load environment variables
load_dotenv()

# Check device capabilities
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("✓ Using Apple Silicon MPS")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("✓ Using CUDA")
else:
    device = torch.device("cpu")
    print("⚠️ Using CPU (will be slow)")

print(f"Device: {device}")

✓ Using Apple Silicon MPS
Device: mps


In [3]:
# 3. Model configuration and directory setup
model_name = "mistralai/Mistral-7B-Instruct-v0.1"
model_dir = "./model"

# Create model directory
os.makedirs(model_dir, exist_ok=True)
print(f"Model directory: {model_dir}")

Model directory: ./model


In [4]:


# 4. Download and save tokenizer
print("Downloading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = 0

# Save tokenizer locally
tokenizer_dir = os.path.join(model_dir, "tokenizer")
tokenizer.save_pretrained(tokenizer_dir)
print(f"✓ Tokenizer saved to {tokenizer_dir}")

Downloading tokenizer...
✓ Tokenizer saved to ./model/tokenizer


In [None]:
# 5. Download and save model
print("Downloading model (this may take several minutes)...")
model = AutoModelForCausalLM.from_pretrained(
    model_name, 
    torch_dtype=torch.float16,
    device_map="auto" if device.type != "mps" else None
)

# Move to device if using MPS
if device.type == "mps":
    model = model.to(device)

# Save model locally
model_dir_path = os.path.join(model_dir, "model")
model.save_pretrained(model_dir_path)
print(f"✓ Model saved to {model_dir_path}")
print(f"Model device: {next(model.parameters()).device}")

Downloading model (this may take several minutes)...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
# 6. Test basic inference
print("Testing basic inference...")

test_prompt = "[INST] Hello, how are you today? [/INST]"
inputs = tokenizer(test_prompt, return_tensors="pt").to(model.device)

print(f"Input prompt: {test_prompt}")
print(f"Input shape: {inputs.input_ids.shape}")
print(f"Input tokens: {inputs.input_ids[0].tolist()[:10]}...")  # Show first 10 tokens

# Generate response
try:
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=50,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )

    # Decode response
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    response_only = generated_text[len(test_prompt):].strip()

    print(f"\nGenerated response: {response_only}")
    print(f"Full generation shape: {outputs.shape}")
    print("✅ Basic inference working!")
    
except Exception as e:
    print(f"❌ Inference error: {e}")
    import traceback
    traceback.print_exc()

Testing basic inference...
Input prompt: [INST] Hello, how are you today? [/INST]
Input shape: torch.Size([1, 15])
Input tokens: [1, 733, 16289, 28793, 22557, 28725, 910, 460, 368, 3154]...

Generated response: I'm an AI language model, so I don't have feelings or physical sensations. But thank you for asking! How can I help you today?
Full generation shape: torch.Size([1, 49])
✅ Basic inference working!


In [21]:
# 7. Load model with NNsight
print("Loading model with NNsight for activation capture...")

# Load from HuggingFace (local loading has tokenizer issues)
nnsight_model = LanguageModel(model_name, torch_dtype=torch.float16)
print("✓ Loaded model with NNsight")
print(f"NNsight model device: {nnsight_model.device}")
print(f"Total layers: {len(nnsight_model.model.layers)}")

Loading model with NNsight for activation capture...
✓ Loaded model with NNsight
NNsight model device: meta
Total layers: 32


In [22]:
# 8. Define activation capture functions
def capture_residual_activations(
    model: LanguageModel,
    prompt: str,
    layer_idx: int,
    token_position: int = -1
) -> torch.Tensor:
    """
    Capture residual stream activations from a specific layer.
    
    Args:
        model: NNsight LanguageModel
        prompt: Raw text prompt (nnsight handles tokenization)
        layer_idx: Which transformer layer to capture from (0-indexed)
        token_position: Which token position to capture (-1 for last)
    
    Returns:
        Activation tensor from the residual stream
    """
    with model.trace(prompt) as tracer:
        # IMPORTANT: layer.output returns a tuple of length two:
        # - [0] = positional arguments (the actual hidden states tensor)
        # - [1] = keyword arguments
        # This is a recent change in nnsight to fix output access issues
        
        # Get the hidden states tensor (first element of output tuple)
        hidden_states = model.model.layers[layer_idx].output[0]
        
        # Debug: Check the actual shape of hidden_states
        print(f"Debug: hidden_states shape = {hidden_states.shape}")
        
        # Handle different tensor dimensions
        if len(hidden_states.shape) == 3:
            # Standard 3D case: [batch_size, seq_len, hidden_dim]
            if token_position == -1:
                activation = hidden_states[:, -1, :].save()
            else:
                activation = hidden_states[:, token_position, :].save()
        elif len(hidden_states.shape) == 2:
            # 2D case: [seq_len, hidden_dim] - batch size of 1 was squeezed
            if token_position == -1:
                activation = hidden_states[-1, :].save()
            else:
                activation = hidden_states[token_position, :].save()
        else:
            raise ValueError(f"Unexpected hidden_states shape: {hidden_states.shape}")
    
    return activation

def test_activation_capture(
    model: LanguageModel,
    prompt: str,
    test_layers: List[int] = None
) -> Dict[int, torch.Tensor]:
    """
    Test activation capture from multiple layers.
    """
    if test_layers is None:
        total_layers = len(model.model.layers)
        test_layers = [0, total_layers//4, total_layers//2, 3*total_layers//4, total_layers-1]
    
    print(f"Testing activation capture from layers: {test_layers}")
    print(f"Prompt: '{prompt}'")
    print("-" * 60)
    
    activations = {}
    
    for layer_idx in test_layers:
        try:
            activation = capture_residual_activations(
                model, prompt, layer_idx, token_position=-1
            )
            activations[layer_idx] = activation
            
            # Print activation info
            magnitude = torch.norm(activation).item()
            mean_val = torch.mean(activation).item()
            std_val = torch.std(activation).item()
            print(f"✅ Layer {layer_idx:2d}: shape={activation.shape}, "
                  f"magnitude={magnitude:.4f}, mean={mean_val:.4f}, std={std_val:.4f}")
            
        except Exception as e:
            print(f"❌ Failed to capture layer {layer_idx}: {e}")
            import traceback
            traceback.print_exc()
    
    return activations

def debug_layer_structure(
    model: LanguageModel,
    prompt: str,
    layer_idx: int = 0
) -> None:
    """
    Debug function to understand layer output structure.
    """
    print(f"Debugging layer {layer_idx} structure...")
    
    # Initialize variables OUTSIDE the trace context
    output_tuple = None
    pos_args = None
    kw_args = None
    
    with model.trace(prompt) as tracer:
        layer = model.model.layers[layer_idx]
        
        # The output is a tuple: (positional_args, keyword_args)
        output_tuple = layer.output.save()
        
        # Access the positional args (hidden states)
        pos_args = layer.output[0].save()
        
        # Try to access keyword args (may not exist for all layers)
        try:
            kw_args = layer.output[1].save()
            print("✅ Successfully accessed keyword args")
        except Exception as e:
            print(f"❌ No keyword args available: {e}")
            kw_args = None
        
    # Print information about the structures (outside trace context)
    print(f"\nLayer output analysis:")
    print(f"Full output type: {type(output_tuple)}")
    
    if hasattr(output_tuple, '__len__'):
        try:
            print(f"Output tuple length: {len(output_tuple)}")
        except Exception as e:
            print(f"Could not determine output tuple length: {e}")
    
    print(f"Positional args [0]: type={type(pos_args)}, shape={getattr(pos_args, 'shape', 'No shape')}")
    
    if kw_args is not None:
        print(f"Keyword args [1]: type={type(kw_args)}")
    else:
        print("Keyword args [1]: None")

print("✓ Fixed activation capture functions with dynamic tensor dimension handling")

✓ Fixed activation capture functions with dynamic tensor dimension handling


In [23]:
# 9. Test activation capture with proper tuple structure
print("Testing activation capture with corrected tuple structure...")
print("=" * 70)

test_prompt = "[INST] What is the capital of France? [/INST]"

try:
    # First, debug the layer structure to confirm our understanding
    print("🔍 Debugging layer structure first...")
    debug_layer_structure(nnsight_model, test_prompt, layer_idx=0)
    print("\n" + "=" * 40 + "\n")
    
    # Test activation capture from sample layers
    captured_activations = test_activation_capture(nnsight_model, test_prompt)

    if captured_activations:
        print(f"\n🎉 SUCCESS! Captured activations from {len(captured_activations)} layers")
        print("\nActivation Summary:")
        print("-" * 50)
        
        total_params = sum(torch.numel(act) for act in captured_activations.values())
        print(f"Total activation parameters captured: {total_params:,}")
        
        # Test different token positions on middle layer
        middle_layer = len(nnsight_model.model.layers) // 2
        if middle_layer in captured_activations:
            print(f"\nTesting different token positions on layer {middle_layer}...")
            for pos in [-1, 0, 5]:  # last, first, middle
                try:
                    activation = capture_residual_activations(
                        nnsight_model, test_prompt, middle_layer, token_position=pos
                    )
                    magnitude = torch.norm(activation).item()
                    mean_val = torch.mean(activation).item()
                    print(f"  Position {pos:2d}: shape={activation.shape}, magnitude={magnitude:.4f}, mean={mean_val:.4f}")
                except Exception as e:
                    print(f"  Position {pos:2d}: ❌ Error: {e}")
    else:
        print("❌ No activations captured")

except Exception as e:
    print(f"❌ Error during activation capture: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "=" * 70)

Testing activation capture with corrected tuple structure...
🔍 Debugging layer structure first...
Debugging layer 0 structure...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

❌ No keyword args available: index 1 is out of bounds for dimension 0 with size 1

Layer output analysis:
Full output type: <class 'torch.Tensor'>
Output tuple length: 1
Positional args [0]: type=<class 'torch.Tensor'>, shape=torch.Size([15, 4096])
Keyword args [1]: None


Testing activation capture from layers: [0, 8, 16, 24, 31]
Prompt: '[INST] What is the capital of France? [/INST]'
------------------------------------------------------------
Debug: hidden_states shape = torch.Size([15, 4096])
✅ Layer  0: shape=torch.Size([4096]), magnitude=0.3250, mean=-0.0001, std=0.0051
Debug: hidden_states shape = torch.Size([15, 4096])
✅ Layer  8: shape=torch.Size([4096]), magnitude=3.7188, mean=-0.0006, std=0.0581
Debug: hidden_states shape = torch.Size([15, 4096])
✅ Layer 16: shape=torch.Size([4096]), magnitude=10.6875, mean=-0.0023, std=0.1670
Debug: hidden_states shape = torch.Size([15, 4096])
✅ Layer 24: shape=torch.Size([4096]), magnitude=24.6719, mean=0.0040, std=0.3855
Debug: hidden_sta

In [24]:
print(captured_activations)

{0: tensor([-0.0003, -0.0024,  0.0074,  ...,  0.0048, -0.0008,  0.0100],
       dtype=torch.float16, grad_fn=<SliceBackward0>), 8: tensor([-0.0224, -0.0137, -0.0555,  ..., -0.0221,  0.0429,  0.0952],
       dtype=torch.float16, grad_fn=<SliceBackward0>), 16: tensor([-0.2195, -0.4102, -0.1387,  ..., -0.1802, -0.3159, -0.0692],
       dtype=torch.float16, grad_fn=<SliceBackward0>), 24: tensor([-0.1135, -0.1949, -0.7153,  ..., -0.5386,  0.4922, -1.0566],
       dtype=torch.float16, grad_fn=<SliceBackward0>), 31: tensor([ 0.0294,  0.0435, -0.7559,  ..., -0.5977,  0.6938, -0.9199],
       dtype=torch.float16, grad_fn=<SliceBackward0>)}


In [26]:
# 10. Final working example
print("Running final working example...")

# Simple single-layer activation capture
try:
    activation = capture_residual_activations(
        nnsight_model, 
        "[INST] The capital of France is [/INST]", 
        layer_idx=16, 
        token_position=-1
    )
    
    print(f"✅ SUCCESS: Captured activation from layer 16")
    print(f"   Shape: {activation.shape}")
    print(f"   Magnitude: {torch.norm(activation).item():.4f}")
    print(f"   Data type: {activation.dtype}")
    print(f"   Device: {activation.device}")
    
    # Test with different prompts to verify consistency
    print(f"\nTesting with different prompts...")
    for prompt in ["[INST] Hello world! [/INST]", "[INST] What is 2+2? [/INST]"]:
        act = capture_residual_activations(nnsight_model, prompt, 16, -1)
        mag = torch.norm(act).item()
        print(f"  '{prompt[:20]}...': magnitude={mag:.4f}")
    
except Exception as e:
    print(f"❌ Error: {e}")
    import traceback
    traceback.print_exc()
    
print("\n" + "="*50)
print("🎉 Activation capture is now working correctly!")
print("   Ready for SAE training and other experiments.")

Running final working example...
Debug: hidden_states shape = torch.Size([13, 4096])
✅ SUCCESS: Captured activation from layer 16
   Shape: torch.Size([4096])
   Magnitude: 9.6250
   Data type: torch.float16
   Device: cpu

Testing with different prompts...
Debug: hidden_states shape = torch.Size([11, 4096])
  '[INST] Hello world! ...': magnitude=10.5469
Debug: hidden_states shape = torch.Size([15, 4096])
  '[INST] What is 2+2? ...': magnitude=11.3672

🎉 Activation capture is now working correctly!
   Ready for SAE training and other experiments.


In [27]:
print(activation)

tensor([-0.1765, -0.0891, -0.1075,  ...,  0.0314, -0.3169,  0.0094],
       dtype=torch.float16, grad_fn=<SliceBackward0>)


# 🎯 Summary

## What This Notebook Accomplishes:

1. ✅ **Downloads and saves** Mistral-7B-Instruct locally
2. ✅ **Tests basic inference** to verify model functionality 
3. ✅ **Loads model with NNsight** for activation capture
4. ✅ **Captures residual stream activations** from any layer and token position
5. ✅ **Validates activation capture** across multiple layers and prompts

## Key Function:

```python
# Capture activation from layer 16, last token
activation = capture_residual_activations(
    nnsight_model, 
    "[INST] Your prompt here [/INST]", 
    layer_idx=16, 
    token_position=-1
)
```

## Ready for Advanced Experiments:

- **Sparse Autoencoders (SAE)** training on residual stream
- **Activation patching** experiments  
- **Feature visualization** and analysis
- **Intervention studies** on model behavior
- **Mechanistic interpretability** research

The activation capture now works reliably and can be used as a foundation for interpretability research.