# Reasoning Direction Analysis - Example Notebook

This notebook demonstrates the basic workflow for analyzing reasoning directions in language models.

In [None]:
import sys
sys.path.append('..')

import torch
from pipeline_Benchmark.config import get_default_config
from pipeline_Benchmark.model_utils import load_model_and_tokenizer, collect_activations, compute_contrastive_directions
from pipeline_Benchmark.utils import load_dataset, create_control_dataset, prepare_prompts

## 1. Load Configuration and Model

In [None]:
# Load configuration
config = get_default_config()

# For quick testing, use smaller sample sizes
config.dataset.gsm8k_sample_size = 10
config.dataset.math_sample_size = 10

In [None]:
# Load model (this will take a while for large models)
print(f"Loading model: {config.model.rl_model_name}")

model_wrapper = load_model_and_tokenizer(
    model_name=config.model.rl_model_name,
    device_map="auto",
    torch_dtype="float16"
)

print(f"Model loaded: {model_wrapper.num_layers} layers")

## 2. Load Datasets

In [None]:
# Load reasoning dataset (GSM8K)
gsm8k_data = load_dataset(
    config.dataset.gsm8k_path,
    split="test",
    sample_size=10
)

print(f"Loaded {len(gsm8k_data)} GSM8K examples")
print("\nExample:")
print(gsm8k_data[0])

In [None]:
# Create control dataset
control_data = create_control_dataset(size=10, task_type="simple_qa")

print(f"Created {len(control_data)} control examples")
print("\nExample:")
print(control_data[0])

## 3. Prepare Prompts

In [None]:
reasoning_prompts = prepare_prompts(gsm8k_data, dataset_type="gsm8k", include_cot_prompt=True)
control_prompts = prepare_prompts(control_data, dataset_type="control", include_cot_prompt=False)

print("Reasoning prompt example:")
print(reasoning_prompts[0])
print("\nControl prompt example:")
print(control_prompts[0])

## 4. Collect Activations

In [None]:
# Collect activations on reasoning tasks
print("Collecting activations on reasoning tasks...")
reasoning_activations = collect_activations(
    model=model_wrapper.model,
    tokenizer=model_wrapper.tokenizer,
    texts=reasoning_prompts,
    batch_size=2,
    device=model_wrapper.device
)

print(f"Collected activations for {len(reasoning_activations)} layers")

In [None]:
# Collect activations on control tasks
print("Collecting activations on control tasks...")
control_activations = collect_activations(
    model=model_wrapper.model,
    tokenizer=model_wrapper.tokenizer,
    texts=control_prompts,
    batch_size=2,
    device=model_wrapper.device
)

print(f"Collected activations for {len(control_activations)} layers")

## 5. Compute Reasoning Directions

In [None]:
# Compute contrastive directions
directions = compute_contrastive_directions(
    reasoning_activations=reasoning_activations,
    control_activations=control_activations,
    normalize=True
)

print(f"Computed reasoning directions for {len(directions)} layers")
print(f"Direction vector shape: {directions[0].shape}")

## 6. Test Interventions

In [None]:
from pipeline_Benchmark.model_utils import apply_direction_intervention

test_prompt = reasoning_prompts[0]
print(f"Test prompt: {test_prompt}\n")

# Test different intervention strengths
for strength in [-1.0, 0.0, 1.0]:
    output = apply_direction_intervention(
        model=model_wrapper.model,
        tokenizer=model_wrapper.tokenizer,
        prompt=test_prompt,
        directions=directions,
        intervention_strength=strength,
        max_new_tokens=100
    )
    
    print(f"\nIntervention strength {strength}:")
    print(output)
    print("-" * 80)

## 7. Visualize Results

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Plot direction magnitudes across layers
layers = sorted(directions.keys())
magnitudes = [directions[l].norm().item() for l in layers]

plt.figure(figsize=(12, 6))
plt.plot(layers, magnitudes, marker='o')
plt.xlabel('Layer Index')
plt.ylabel('Direction Magnitude')
plt.title('Reasoning Direction Magnitude Across Layers')
plt.grid(True, alpha=0.3)
plt.show()

## 8. Save Results

In [None]:
# Save directions
output_path = "../results/directions/example_directions.pt"
torch.save(directions, output_path)
print(f"Saved directions to {output_path}")