# Synthetic Steering Demo

This notebook demonstrates steering methods using synthetically generated contrastive pairs, similar to the CLI workflow.

In [1]:
import sys
import torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

from wisent_guard.core.steering_methods.dac import DAC
from wisent_guard.core.steering_methods.caa import CAA
from wisent_guard.core.steering_methods.k_steering import KSteering
from wisent_guard.core.contrastive_pairs.contrastive_pair_set import ContrastivePairSet
from wisent_guard.core.contrastive_pairs.generate_synthetically import SyntheticContrastivePairGenerator
from wisent_guard.core.model import Model

MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
LAYER_INDEX = 15
STEERING_STRENGTH = 1.0
MAX_LENGTH = 15
NUM_PAIRS = 5

# Get optimal device
device = torch.device("mps" if torch.backends.mps.is_available() else 
                     "cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

# Load model and tokenizer
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
hf_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16 if device.type == 'cuda' else torch.float32).to(device)
tokenizer.pad_token = tokenizer.eos_token

# Create Model wrapper for synthetic generator
model = Model(name=MODEL_NAME, hf_model=hf_model)
print("✓ Model loaded successfully")

TEST_PROMPTS = [
    "Tell me about cats",
    "How do I learn programming?", 
    "What's the weather like?"
]

Using device: mps
Loading model and tokenizer...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✓ Model loaded successfully


In [None]:
# Initialize synthetic generator using the same approach as CLI
print("Initializing synthetic generator...")
generator = SyntheticContrastivePairGenerator(model)

# Define trait descriptions (just like in CLI)
TRAITS = {
    "sarcastic": "sarcastic and witty responses with subtle mockery and irony",
    "helpful": "extremely helpful, supportive, and eager to assist responses"
}

print(f"Will generate {NUM_PAIRS} synthetic pairs for each trait:")
for name, description in TRAITS.items():
    print(f"- {name}: {description}")

# Generate synthetic contrastive pairs for each trait
pair_sets = {}

for trait_name, trait_description in TRAITS.items():
    print(f"\nGenerating {trait_name} behavior pairs...")
    pair_set = generator.generate_contrastive_pair_set(
        trait_description=trait_description,
        num_pairs=NUM_PAIRS,
        name=trait_name
    )
    pair_sets[trait_name] = pair_set
    print(f"✓ Generated {len(pair_set.pairs)} {trait_name} pairs")

print("\n✅ All synthetic pairs generated successfully")

In [3]:
# Show examples of generated pairs
for trait_name, pair_set in pair_sets.items():
    print(f"\n=== Example {trait_name.upper()} pairs ===")
    for i, pair in enumerate(pair_set.pairs[:2]):
        print(f"\nPair {i+1}:")
        print(f"Prompt: {pair.prompt}")
        print(f"Positive: {pair.positive_response.text}")
        print(f"Negative: {pair.negative_response.text}")

# Extract activations for all pairs  
def extract_activations(text, layer_idx):
    inputs = tokenizer(text, return_tensors="pt").to(device)
    activations = []
    def hook(module, input, output):
        activations.append(output[0][:, -1, :].clone())
    handle = hf_model.model.layers[layer_idx].register_forward_hook(hook)
    with torch.no_grad():
        hf_model(**inputs)
    handle.remove()
    return activations[0].squeeze(0)

print("\nExtracting activations for all pairs...")
for trait_name, pair_set in pair_sets.items():
    for pair in pair_set.pairs:
        pair.positive_response.activations = extract_activations(pair.positive_response.text, LAYER_INDEX)
        pair.negative_response.activations = extract_activations(pair.negative_response.text, LAYER_INDEX)
    print(f"✓ Extracted activations for {trait_name} pairs")

print("✅ All activations extracted")


=== Example SARCASTIC pairs ===

Pair 1:
Prompt: How can you deliver a backhanded compliment without actually saying anything nice?
Positive: The art of delivering a backhanded
Negative: The art of backhanded compliments!

=== Example HELPFUL pairs ===

Pair 1:
Prompt: *What's the best way to approach this problem/task, and can you walk me through the steps?**
Positive: I'd be happy to help you tackle
Negative: I'd be happy to help you approach

Pair 2:
Prompt: *I'm feeling a bit stuck - can you offer some guidance or suggestions on how to get started?**
Positive: I'd be delighted to help! Let
Negative: You're feeling stuck, huh? Don

Extracting activations for all pairs...
✓ Extracted activations for sarcastic pairs
✓ Extracted activations for helpful pairs
✅ All activations extracted


In [4]:
print("Training steering methods...")

# Individual DAC models for each trait
dac_models = {}
for trait_name, pair_set in pair_sets.items():
    print(f"  Training {trait_name} DAC...")
    dac = DAC(device=device)
    dac.set_model_reference(hf_model)
    dac.train(pair_set, LAYER_INDEX)
    dac_models[trait_name] = dac

# Multi-behavior CAA
print("  Training Multi-behavior CAA...")
multi_caa = CAA(device=device)
multi_caa.train_multi_behavior(pair_sets, LAYER_INDEX, normalize_across_behaviors=True)

# K-Steering with all pairs combined
print("  Training K-Steering...")
combined_pair_set = ContrastivePairSet(name="combined")
for pair_set in pair_sets.values():
    combined_pair_set.pairs.extend(pair_set.pairs)

k_steering = KSteering(device=device, num_labels=len(TRAITS), classifier_epochs=20)
k_steering.train(combined_pair_set, LAYER_INDEX)

print("✅ All steering methods trained")

Training steering methods...
  Training sarcastic DAC...
  Training helpful DAC...
  Training Multi-behavior CAA...
  Training K-Steering...
Classifier training epoch 0, loss: 0.6955
✅ All steering methods trained


In [5]:
# Generation functions
def generate_unsteered(prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = hf_model.generate(
            **inputs,
            max_length=inputs["input_ids"].shape[1] + MAX_LENGTH,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response[len(prompt):].strip()

def generate_with_steering(prompt, steering_method, strength):
    def steering_hook(module, input, output):
        hidden_states = output[0]
        last_token = hidden_states[:, -1:, :]
        steered = steering_method.apply_steering(last_token, strength)
        hidden_states[:, -1:, :] = steered
        return (hidden_states,) + output[1:]
    
    handle = hf_model.model.layers[LAYER_INDEX].register_forward_hook(steering_hook)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = hf_model.generate(
            **inputs,
            max_length=inputs["input_ids"].shape[1] + MAX_LENGTH,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    handle.remove()
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response[len(prompt):].strip()

def generate_with_multi_caa_steering(prompt, caa_method, trait_strengths):
    def multi_caa_hook(module, input, output):
        hidden_states = output[0]
        last_token = hidden_states[:, -1:, :]
        combined_diff = torch.zeros_like(last_token)
        
        for trait_name, strength in trait_strengths.items():
            if strength != 0:
                steered = caa_method.apply_steering(last_token, strength, behavior_name=trait_name)
                combined_diff += (steered - last_token)
        
        hidden_states[:, -1:, :] = last_token + combined_diff
        return (hidden_states,) + output[1:]
    
    handle = hf_model.model.layers[LAYER_INDEX].register_forward_hook(multi_caa_hook)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = hf_model.generate(
            **inputs,
            max_length=inputs["input_ids"].shape[1] + MAX_LENGTH,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )
    handle.remove()
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return response[len(prompt):].strip()

print("✅ Generation functions defined")

✅ Generation functions defined


In [6]:
# Generate responses for all test prompts
print("=" * 80)
print("SYNTHETIC STEERING DEMONSTRATION")
print("=" * 80)

for prompt in TEST_PROMPTS:
    print(f"\nPrompt: {prompt}")
    print("-" * 50)
    
    # Unsteered response
    unsteered = generate_unsteered(prompt)
    print(f"Unsteered: {unsteered}")
    
    # Individual trait steering with DAC
    for trait_name, dac_model in dac_models.items():
        steered = generate_with_steering(prompt, dac_model, STEERING_STRENGTH)
        print(f"{trait_name.capitalize()} (DAC): {steered}")
    
    # Multi-property steering with CAA
    multi_caa_response = generate_with_multi_caa_steering(prompt, multi_caa, 
                                                         {name: STEERING_STRENGTH for name in TRAITS.keys()})
    print(f"Multi-property (CAA): {multi_caa_response}")
    
    # K-Steering response
    k_response = generate_with_steering(prompt, k_steering, STEERING_STRENGTH)
    print(f"Multi-property (K-Steering): {k_response}")

print("\n" + "=" * 80)
print("✅ Synthetic steering demonstration completed!")

SYNTHETIC STEERING DEMONSTRATION

Prompt: Tell me about cats
--------------------------------------------------
Unsteered: Cats are domesticated mammals that are known for their agility, play
[DEBUG] Applying default steering: alpha=1.0000, vector norm=12.6540
[DEBUG] Applied to last token, shape torch.Size([1, 1, 4096]), norm change: 13.1478 -> 17.8352
[DEBUG] Applying default steering: alpha=1.0000, vector norm=12.6540
[DEBUG] Applied to last token, shape torch.Size([1, 1, 4096]), norm change: 11.4094 -> 17.1604
[DEBUG] Applying default steering: alpha=1.0000, vector norm=12.6540
[DEBUG] Applied to last token, shape torch.Size([1, 1, 4096]), norm change: 10.9575 -> 16.6962
[DEBUG] Applying default steering: alpha=1.0000, vector norm=12.6540
[DEBUG] Applied to last token, shape torch.Size([1, 1, 4096]), norm change: 10.9327 -> 16.1843
[DEBUG] Applying default steering: alpha=1.0000, vector norm=12.6540
[DEBUG] Applied to last token, shape torch.Size([1, 1, 4096]), norm change: 9.8540 