# Extract Token Activations Through Layers

**Date:** October 29, 2025

**Goal:** Capture per-token, per-layer activations to study how token representations evolve geometrically through the model.

**Method:**
1. Load one text from Wikipedia Simple dataset (~512 tokens)
2. Run through model with `output_hidden_states=True`
3. Extract activations at all 36 layers
4. Save as [n_layers, n_tokens, hidden_dim] tensor in bfloat16

**Output:** `data/results/token_activations_sample.pt` containing:
- `activations`: [36, n_tokens, 2560] tensor (bfloat16)
- `tokens`: List of token IDs
- `text`: Original input text
- `metadata`: Model name, sequence length, extraction date

**Note:** This is DATA COLLECTION only. Analysis happens in 08.2.

## Configuration

In [22]:
# Paths
DATASET_PATH = '../data/wikipedia_top20_texts.json'
OUTPUT_PATH = '../data/results/token_activations_sample.pt'

# Model
MODEL_NAME = 'Qwen/Qwen3-4B-Instruct-2507'
DEVICE = 'auto'  # 'auto', 'cuda', 'mps', or 'cpu'

# Sampling
RANDOM_SEED = 42  # Fixed seed for reproducibility
TARGET_TOKENS = 512  # Desired sequence length (will truncate if longer)

print(f"Configuration:")
print(f"  Model: {MODEL_NAME}")
print(f"  Device: {DEVICE}")
print(f"  Target tokens: {TARGET_TOKENS}")
print(f"  Random seed: {RANDOM_SEED}")

Configuration:
  Model: Qwen/Qwen3-4B-Instruct-2507
  Device: auto
  Target tokens: 512
  Random seed: 42


## Load Dataset and Select Sample

In [23]:
import json
import random

print("Loading dataset...")
with open(DATASET_PATH, 'r') as f:
    dataset = json.load(f)

print(f"✓ Loaded {len(dataset)} text pairs")

# Extract all English texts
simple_texts = [pair['high_complexity'] for pair in dataset]

# Select one at random (fixed seed)
random.seed(RANDOM_SEED)
selected_text = random.choice(simple_texts)

print(f"\nSelected text (first 500 chars):")
print(selected_text[:500])
print(f"\nText length: {len(selected_text)} characters")

Loading dataset...
✓ Loaded 20 text pairs

Selected text (first 500 chars):
A political party is an organization that coordinates candidates to compete in elections and participate in governance. It is common for the members of a party to hold similar ideas about politics, and parties may promote specific ideological or policy goals.
Political parties have become a major part of the politics of almost every country, as modern party organizations developed and spread around the world over the last few centuries. Although some countries have no political parties, this is 

Text length: 43702 characters


## Load Model and Tokenizer

In [24]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

print("Loading model and tokenizer...")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map=DEVICE,
)
model.eval()

print(f"✓ Model loaded on device: {model.device}")
print(f"✓ Model dtype: {model.dtype}")
print(f"✓ Number of layers: {model.config.num_hidden_layers}")
print(f"✓ Hidden dimension: {model.config.hidden_size}")

Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

✓ Model loaded on device: mps:0
✓ Model dtype: torch.bfloat16
✓ Number of layers: 36
✓ Hidden dimension: 2560


## Tokenize and Truncate

In [25]:
print("Tokenizing text...")

# Tokenize
inputs = tokenizer(
    selected_text,
    return_tensors='pt',
    truncation=True,
    max_length=TARGET_TOKENS,
    padding=False,
)

# Move to same device as model
inputs = {k: v.to(model.device) for k, v in inputs.items()}

n_tokens = inputs['input_ids'].shape[1]

print(f"✓ Tokenized to {n_tokens} tokens")
if n_tokens < TARGET_TOKENS:
    print(f"  (Text was shorter than target {TARGET_TOKENS})")
else:
    print(f"  (Truncated from longer text)")

Tokenizing text...
✓ Tokenized to 512 tokens
  (Truncated from longer text)


## Extract Activations

In [26]:
print("Running forward pass to extract activations...")

with torch.no_grad():
    outputs = model(
        **inputs,
        output_hidden_states=True,
        return_dict=True,
    )

# Extract hidden states
# outputs.hidden_states is a tuple of length (n_layers + 1)
# hidden_states[0] is embedding layer, hidden_states[1:] are transformer layers
hidden_states = outputs.hidden_states[1:]  # Skip embedding layer

# Stack into single tensor: [n_layers, batch_size, seq_len, hidden_dim]
# Since batch_size=1, we squeeze it: [n_layers, seq_len, hidden_dim]
activations = torch.stack(hidden_states).squeeze(1)

print(f"✓ Extracted activations with shape: {activations.shape}")
print(f"  [n_layers={activations.shape[0]}, n_tokens={activations.shape[1]}, hidden_dim={activations.shape[2]}]")
print(f"  Dtype: {activations.dtype}")
print(f"  Device: {activations.device}")

Running forward pass to extract activations...
✓ Extracted activations with shape: torch.Size([36, 512, 2560])
  [n_layers=36, n_tokens=512, hidden_dim=2560]
  Dtype: torch.bfloat16
  Device: mps:0


## Save to Disk

In [27]:
from datetime import datetime

print("Saving activations to disk...")

# Prepare data structure
data = {
    'activations': activations.cpu(),  # Move to CPU for storage
    'tokens': inputs['input_ids'].cpu().squeeze().tolist(),  # Token IDs as list
    'text': selected_text,
    'metadata': {
        'model': MODEL_NAME,
        'n_layers': activations.shape[0],
        'n_tokens': activations.shape[1],
        'hidden_dim': activations.shape[2],
        'dtype': str(activations.dtype),
        'extraction_date': datetime.now().isoformat(),
        'random_seed': RANDOM_SEED,
    }
}

torch.save(data, OUTPUT_PATH)

print(f"✓ Saved to: {OUTPUT_PATH}")

# Compute file size
import os
file_size_mb = os.path.getsize(OUTPUT_PATH) / (1024 ** 2)
print(f"  File size: {file_size_mb:.1f} MB")

Saving activations to disk...
✓ Saved to: ../data/results/token_activations_sample.pt
  File size: 90.0 MB


## Verification

In [28]:
print("Verifying saved data...")

# Reload and check
loaded = torch.load(OUTPUT_PATH, weights_only=False)

print(f"✓ Activations shape: {loaded['activations'].shape}")
print(f"✓ Number of tokens: {len(loaded['tokens'])}")
print(f"✓ Text length: {len(loaded['text'])} chars")
print(f"✓ Metadata:")
for key, value in loaded['metadata'].items():
    print(f"    {key}: {value}")

print("\n" + "="*70)
print("DATA EXTRACTION COMPLETE")
print("="*70)

Verifying saved data...
✓ Activations shape: torch.Size([36, 512, 2560])
✓ Number of tokens: 512
✓ Text length: 43702 chars
✓ Metadata:
    model: Qwen/Qwen3-4B-Instruct-2507
    n_layers: 36
    n_tokens: 512
    hidden_dim: 2560
    dtype: torch.bfloat16
    extraction_date: 2025-10-29T17:53:58.397520
    random_seed: 42

DATA EXTRACTION COMPLETE
