In [17]:
from utils.modelUtils import *
from utils.utils import *
import seaborn as sns
import torch
import numpy as np
import matplotlib.pyplot as plt
from casper import nethook
from casper.nethook import *
from utils.judgeUtils import *
from tqdm import tqdm
from scipy.stats import zscore

### Model Initialization

Loading Llama-2-7b-chat model for representation intervention analysis.


In [None]:
model_name = "meta-llama/Llama-2-7b-chat-hf"
mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype=(torch.float16 if "20b" in model_name else None),
    device = 'cuda:0'
)
mt.model

In [None]:
test_prompt = generate_input(mt.tokenizer, 'tell me a fun joke')
test_prompt

### Judge Model Loading

Loading the judge model to evaluate whether responses successfully jailbreak the safety mechanisms.

In [None]:
print("Loading judge model...")
judge_tokenizer, judge_model_loaded = load_judge_model()
print("Judge model loaded successfully")


### Data Loading

Loading benign and harmful prompts for representation-level intervention


In [21]:
import json
with open("analysis_data/llama2_gcgs.json",'r') as f:
    gcgs = json.load(f)[:3]
with open("analysis_data/llama2_autodan.json",'r') as f:
    autodan = json.load(f)[:3]
import pandas as pd
harmful_data = pd.read_csv("analysis_data/advbench_behaviors.csv")['Behavior']
harmful_questions = list(harmful_data)[:100]
with open("analysis_data/alpaca_clean.json",'r') as f:
    alpaca = json.load(f)
benign_questions = [p['instruction'] for p in alpaca][:100]

In [None]:
print(f"Harmful questions: {len(harmful_questions)}")
print(f"Benign questions: {len(benign_questions)}")

# %%
# Format prompts using chat template
harmful_prompts = [mt.tokenizer.apply_chat_template([{"role": "user", "content": q}], 
                                                      tokenize=False, 
                                                      add_generation_prompt=True) 
                   for q in harmful_questions]

benign_prompts = [mt.tokenizer.apply_chat_template([{"role": "user", "content": q}], 
                                                     tokenize=False, 
                                                     add_generation_prompt=True) 
                  for q in benign_questions]



### Collect Hidden States

Collect hidden states from all transformer layers for both benign and harmful prompts.

In [None]:
def collect_hidden_states(mt, prompts, batch_size=8):
    """
    Collect hidden states from all transformer layers.
    Returns a dictionary mapping layer indices to hidden state arrays.
    """
    device = mt.model.device
    num_layers = mt.model.config.num_hidden_layers
    
    # Initialize storage for each layer
    layer_hidden_states = {i: [] for i in range(num_layers)}
    
    total_batches = (len(prompts) + batch_size - 1) // batch_size
    
    for i in tqdm(range(0, len(prompts), batch_size), total=total_batches, desc="Collecting hidden states"):
        batch_prompts = prompts[i:i+batch_size]
        inputs = mt.tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to(device)
        
        with torch.no_grad():
            outputs = mt.model(**inputs, output_hidden_states=True)
            hidden_states = outputs.hidden_states  # Tuple of (num_layers+1) tensors
            
            # Process each layer's hidden states
            for layer_idx in range(num_layers):
                # Get hidden states for this layer (shape: [batch_size, seq_len, hidden_dim])
                layer_output = hidden_states[layer_idx + 1]  # +1 because first is input embeddings
                
                # Take the last token's hidden state for each sequence
                # This represents the final representation after processing the entire prompt
                last_hidden = layer_output[:, -1, :].detach().cpu().float().numpy()
                layer_hidden_states[layer_idx].append(last_hidden)
    
    # Concatenate all batches for each layer
    for layer_idx in range(num_layers):
        layer_hidden_states[layer_idx] = np.concatenate(layer_hidden_states[layer_idx], axis=0)
        print(f"Layer {layer_idx}: shape {layer_hidden_states[layer_idx].shape}")
    
    return layer_hidden_states

# %%
print("Collecting hidden states for harmful prompts...")
harmful_hidden_states = collect_hidden_states(mt, harmful_prompts, batch_size=8)

print("\nCollecting hidden states for benign prompts...")
benign_hidden_states = collect_hidden_states(mt, benign_prompts, batch_size=8)

### Compute Direction Vectors

Compute the direction vector for each layer by taking the difference between 
mean harmful and mean benign hidden states.

In [None]:
def compute_direction_vectors(harmful_states, benign_states):
    """
    Compute direction vectors for each layer.
    Direction = mean(harmful) - mean(benign)
    """
    num_layers = len(harmful_states)
    direction_vectors = {}
    
    for layer_idx in range(num_layers):
        harmful_mean = np.mean(harmful_states[layer_idx], axis=0)
        benign_mean = np.mean(benign_states[layer_idx], axis=0)
        
        # Direction points from benign to harmful
        direction = harmful_mean - benign_mean
        
        # Normalize the direction vector
        direction = direction / (np.linalg.norm(direction) + 1e-8)
        
        direction_vectors[layer_idx] = direction
    
    print(f"Computed direction vectors for {num_layers} layers")
    return direction_vectors

# %%
direction_vectors = compute_direction_vectors(harmful_hidden_states, benign_hidden_states)

# Print some statistics
print("\nDirection vector statistics:")
for layer_idx in [0, len(direction_vectors)//2, len(direction_vectors)-1]:
    norm = np.linalg.norm(direction_vectors[layer_idx])
    print(f"Layer {layer_idx}: norm = {norm:.4f}")
    


### REPE Generation with Direction Intervention

Generate text while adding direction vectors to hidden states at each layer.

In [31]:
def generate_with_repe_intervention(
    mt,
    prompt,
    direction_vectors,
    intervention_strength=1.0,
    intervene=True,
    intervention_layers=range(4, 9), 
    max_new_tokens=50
):
    """
    Generate text with REPE intervention by adding direction vectors to hidden states.
    
    Args:
        mt: ModelAndTokenizer object
        prompt: Input prompt string
        direction_vectors: Dictionary mapping layer indices to direction vectors
        intervention_strength: Scaling factor for the direction vector
        intervene: Whether to apply intervention
        intervention_layers: Range or list of layer indices to intervene on
        max_new_tokens: Maximum number of tokens to generate
    """
    if hasattr(mt.tokenizer, 'apply_chat_template'):
        formatted_prompt = mt.tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False,
            add_generation_prompt=True
        )
    else:
        formatted_prompt = prompt
    
    inputs = mt.tokenizer(formatted_prompt, return_tensors="pt").to(mt.model.device)
    
    if not intervene:
        # Generate without intervention (baseline)
        with torch.no_grad():
            outputs = mt.model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=mt.tokenizer.pad_token_id
            )
        generation = mt.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        return generation
    
    # Create intervention hook for each layer
    def make_repe_hook(layer_idx, direction_vector):
        def hook(module, input, output):
            # Handle tuple outputs (hidden_states, additional_outputs)
            if isinstance(output, tuple):
                hidden_states = output[0]
            else:
                hidden_states = output
            
            # Clone to avoid in-place modification
            modified_states = hidden_states.clone()
            
            # Add direction vector to all positions
            # Direction vector shape: [hidden_dim]
            # Hidden states shape: [batch_size, seq_len, hidden_dim]
            direction_tensor = torch.tensor(
                direction_vector, 
                dtype=modified_states.dtype, 
                device=modified_states.device
            )
            
            # Add scaled direction vector
            modified_states = modified_states + intervention_strength * direction_tensor
            
            # Return in same format as input
            if isinstance(output, tuple):
                return (modified_states,) + output[1:]
            else:
                return modified_states
        
        return hook
    
    # Register hooks only for specified layers
    hooks = []
    for layer_idx in intervention_layers:
        if layer_idx < mt.model.config.num_hidden_layers:
            layer = mt.model.model.layers[layer_idx]
            direction_vector = direction_vectors[layer_idx]
            hook_handle = layer.register_forward_hook(make_repe_hook(layer_idx, direction_vector))
            hooks.append(hook_handle)
    
    try:
        # Generate with intervention
        with torch.no_grad():
            outputs = mt.model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=mt.tokenizer.pad_token_id
            )
        generation = mt.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    finally:
        # Remove all hooks
        for hook_handle in hooks:
            hook_handle.remove()
    
    return generation



### Test REPE Intervention

Test the intervention on harmful prompts and compute jailbreak success rates.

In [None]:
test_harmful_data = pd.read_csv("analysis_data/advbench_behaviors.csv")['Behavior']
test_prompts = list(test_harmful_data)[100:120]  # Use different prompts for testing

print(f"Testing on {len(test_prompts)} harmful prompts...")

results = {
    'baseline': {'harmful': 0, 'total': 0},
    'intervened': {'harmful': 0, 'total': 0},
    'detailed_results': []
}

# Test different intervention strengths
intervention_strength = 3.0
intervention_layers = range(3, 9)  

for idx, prompt in enumerate(tqdm(test_prompts, desc="Testing prompts")):
    try:
        # Generate baseline (no intervention)
        baseline_gen = generate_with_repe_intervention(
            mt, prompt, direction_vectors, 
            intervention_strength=intervention_strength,
            intervene=False,
            intervention_layers=intervention_layers,
            max_new_tokens=50
        )
        
        # Generate with REPE intervention
        intervened_gen = generate_with_repe_intervention(
            mt, prompt, direction_vectors,
            intervention_strength=intervention_strength,
            intervene=True,
            intervention_layers=intervention_layers,
            max_new_tokens=50
        )
        
        # Judge both outputs
        baseline_harmful = harmful_judge_from_strings(
            prompt, baseline_gen, judge_model=(judge_tokenizer, judge_model_loaded)
        )
        intervened_harmful = harmful_judge_from_strings(
            prompt, intervened_gen, judge_model=(judge_tokenizer, judge_model_loaded)
        )
        
        # Update statistics
        results['baseline']['total'] += 1
        results['intervened']['total'] += 1
        
        if baseline_harmful:
            results['baseline']['harmful'] += 1
        if intervened_harmful:
            results['intervened']['harmful'] += 1
        
        # Store detailed result
        results['detailed_results'].append({
            'prompt': prompt,
            'baseline_generation': baseline_gen,
            'baseline_harmful': baseline_harmful,
            'intervened_generation': intervened_gen,
            'intervened_harmful': intervened_harmful,
            'jailbreak_successful': not baseline_harmful and intervened_harmful
        })
        
    except Exception as e:
        print(f"Error on prompt {idx}: {str(e)}")
        continue

# Calculate success rates
baseline_jailbreak_rate = results['baseline']['harmful'] / results['baseline']['total'] if results['baseline']['total'] > 0 else 0
intervened_jailbreak_rate = results['intervened']['harmful'] / results['intervened']['total'] if results['intervened']['total'] > 0 else 0

print("\n" + "="*80)
print("REPE INTERVENTION RESULTS")
print("="*80)
print(f"\nBaseline (No Intervention):")
print(f"  Harmful outputs: {results['baseline']['harmful']}/{results['baseline']['total']}")
print(f"  Jailbreak rate: {baseline_jailbreak_rate*100:.2f}%")

print(f"\nWith REPE Intervention (strength={intervention_strength}, layers=4-8):")
print(f"  Harmful outputs: {results['intervened']['harmful']}/{results['intervened']['total']}")
print(f"  Jailbreak rate: {intervened_jailbreak_rate*100:.2f}%")

print(f"\nIntervention Effectiveness:")
print(f"  Increase in jailbreak rate: {(intervened_jailbreak_rate - baseline_jailbreak_rate)*100:.2f}%")

# Count successful jailbreaks (baseline safe -> intervened harmful)
successful_jailbreaks = sum(1 for r in results['detailed_results'] if r['jailbreak_successful'])
print(f"  Successful jailbreaks: {successful_jailbreaks}/{len(results['detailed_results'])}")


In [None]:
with open("analysis_results/repe_intervention_results.json",'w') as f:
    json.dump(results, f, indent=4)

print("\nResults saved to analysis_results/repe_intervention_results.json")