In [1]:
from utils.modelUtils import *
from utils.utils import *
import seaborn as sns
import torch
import numpy as np
import matplotlib.pyplot as plt
from casper import nethook
from casper.nethook import *
from utils.judgeUtils import *
from tqdm import tqdm
from scipy.stats import zscore

### Model Initialization

Loading Llama-2-7b-chat model for neuron intervention analysis.


In [None]:
model_name = "meta-llama/Llama-2-7b-chat-hf"
mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype=(torch.float16 if "20b" in model_name else None),
    device = 'cuda:0'
)
mt.model

In [None]:
test_prompt = generate_input(mt.tokenizer, 'tell me a fun joke')
test_prompt

### Judge Model Loading

Loading the judge model to evaluate whether responses successfully jailbreak the safety mechanisms.

In [None]:
print("Loading judge model...")
judge_tokenizer, judge_model_loaded = load_judge_model()
print("Judge model loaded successfully")


### Data Loading

Loading benign and harmful prompts for neuron-level intervention


In [None]:
import json
with open("analysis_data/llama2_gcgs.json",'r') as f:
    gcgs = json.load(f)[:3]
with open("analysis_data/llama2_autodan.json",'r') as f:
    autodan = json.load(f)[:3]
import pandas as pd
harmful_data = pd.read_csv("analysis_data/advbench_behaviors.csv")['Behavior']
harmful_questions = list(harmful_data)[:100]
with open("analysis_data/alpaca_clean.json",'r') as f:
    alpaca = json.load(f)
safe_questions = [p['instruction'] for p in alpaca][:100]
questions = harmful_questions + safe_questions
labels = [1] * len(harmful_questions) + [0] * len(safe_questions)  # 1=harmful, 0=safe

print(f"Total questions: {len(questions)}")
print(f"Harmful: {sum(labels)}, Safe: {len(labels) - sum(labels)}")

In [6]:
prompts = [mt.tokenizer.apply_chat_template([{"role": "user", "content": q}], 
                                             tokenize=False, 
                                             add_generation_prompt=True)  for q in questions]

### Activation Collection Functions

In [7]:
activations = {}
device = 'cuda'
def collect_activations_nethook(mt, prompts, layer_names, batch_size=8):
    """
    Collect activations using nethook.TraceDict.
    Returns a dictionary mapping layer names to activation arrays.
    """
    all_activations = {name: [] for name in layer_names}
    
    total_batches = (len(prompts) + batch_size - 1) // batch_size
    
    for i in tqdm(range(0, len(prompts), batch_size), total=total_batches, desc="Collecting activations"):
        batch_prompts = prompts[i:i+batch_size]
        inputs = mt.tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to(device)
        
        with torch.no_grad():
            with nethook.TraceDict(mt.model, layer_names, retain_output=True) as traces:
                _ = mt.model(**inputs)
                
                # Extract activations for each layer
                for layer_name in layer_names:
                    output = traces[layer_name].output
                    
                    # Handle tuple outputs (some layers return tuples)
                    if isinstance(output, tuple):
                        output = output[0]
                    
                    # Take max over sequence dimension and move to CPU
                    act = output.max(dim=1)[0].detach().cpu().float().numpy()
                    all_activations[layer_name].append(act)
    
    # Concatenate all batches
    for layer_name in layer_names:
        all_activations[layer_name] = np.concatenate(all_activations[layer_name], axis=0)
        print(f"Layer {layer_name}: shape {all_activations[layer_name].shape}")
    
    return all_activations

In [None]:
# Collect activations using nethook
target_layers = []
for name, module in mt.model.named_modules():
    if any(keyword in name.lower() for keyword in ["gate", "up"]):
        target_layers.append(name)

print(f"Found {len(target_layers)} target layers")
print(f"First 5 layers: {target_layers[:5]}")

activations = collect_activations_nethook(
    mt,
    prompts, 
    target_layers, 
    batch_size=8
) 

### Compute Toxic Neurons

In [9]:
def safety_probe(activations_tensor, labels_tensor, device='cuda:0', num_runs=1, lr=0.01, epochs=100):
    """
    Train a linear probe to identify safety neurons.
    Returns averaged weights across multiple runs.
    """
    hidden_size = activations_tensor.shape[1]
    all_weights = []
    
    for run in range(num_runs):
        probe = torch.nn.Linear(hidden_size, 1).to(device)
        optimizer = torch.optim.Adam(probe.parameters(), lr=lr)
        criterion = torch.nn.BCEWithLogitsLoss()
        
        for epoch in range(epochs):
            optimizer.zero_grad()
            outputs = probe(activations_tensor)
            loss = criterion(outputs, labels_tensor)
            loss.backward()
            optimizer.step()
        
        weights = probe.weight.detach().cpu().numpy().flatten()
        all_weights.append(weights)
    
    # Average weights across runs
    avg_weights = np.mean(all_weights, axis=0)
    return avg_weights

In [None]:
safety_neuron_threshold = 3
toxic_neurons_dict = {}
weights_dict = {}

labels_tensor = torch.tensor(labels, dtype=torch.float32).unsqueeze(1).to(device)
print(f"Labels tensor shape: {labels_tensor.shape}")

print("\nComputing toxic neurons for each layer...")
for layer_name, act_matrix in tqdm(activations.items(), desc="Processing layers"):
    
    activations_tensor = torch.tensor(act_matrix, dtype=torch.float32).to(device)
    
    # Verify shapes match
    assert activations_tensor.shape[0] == labels_tensor.shape[0], \
        f"Shape mismatch: activations {activations_tensor.shape[0]} vs labels {labels_tensor.shape[0]}"
    
    # Train safety probe
    weights = safety_probe(activations_tensor, labels_tensor, device=device, num_runs=1)
    weights_dict[layer_name] = weights
    
    # Identify toxic neurons using z-score threshold
    z_scores = zscore(weights)
    toxic_neurons = np.where((np.abs(z_scores) > safety_neuron_threshold) & (weights > 0))[0]
    toxic_neurons_dict[layer_name] = toxic_neurons
    

total_toxic = sum(len(neurons) for neurons in toxic_neurons_dict.values())
print(f"\nTotal toxic neurons across all layers: {total_toxic}")

### Toxic Neuron Intervention 


In [11]:
def generate_with_neuron_intervention(
    mt,
    prompt,
    toxic_neurons_dict,
    intervene=True,
    max_new_tokens=50
):
    """
    Generate text with toxic neurons zeroed out using forward hooks.
    """
    if hasattr(mt.tokenizer, 'apply_chat_template'):
        formatted_prompt = mt.tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt}],
            tokenize=False,
            add_generation_prompt=True
        )
    else:
        formatted_prompt = prompt
    
    inputs = mt.tokenizer(formatted_prompt, return_tensors="pt").to(mt.model.device)
    
    if not intervene:
        # Generate without intervention (baseline)
        with torch.no_grad():
            outputs = mt.model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=mt.tokenizer.pad_token_id
            )
        generation = mt.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
        return generation
    
    # Create intervention hook
    def make_intervention_hook(layer_name):
        def hook(module, input, output):
            if layer_name not in toxic_neurons_dict:
                return output
            
            # Handle both tensor and tuple outputs
            if isinstance(output, tuple):
                result = list(output)
                result[0] = result[0].clone()
            else:
                result = output.clone()
            
            toxic_neurons = toxic_neurons_dict[layer_name]
            
            if len(toxic_neurons) > 0:
                if isinstance(result, list):
                    result[0][..., toxic_neurons] = 0
                    return tuple(result)
                else:
                    result[..., toxic_neurons] = 0
                    return result
            
            return output
        
        return hook
    
    # Register hooks
    hooks = []
    for layer_name in toxic_neurons_dict.keys():
        # Get the module
        module = mt.model.get_submodule(layer_name)
        hook_handle = module.register_forward_hook(make_intervention_hook(layer_name))
        hooks.append(hook_handle)
    
    try:
        # Generate with intervention
        with torch.no_grad():
            outputs = mt.model.generate(
                inputs['input_ids'],
                attention_mask=inputs['attention_mask'],
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=mt.tokenizer.pad_token_id
            )
        generation = mt.tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    finally:
        # Remove all hooks
        for hook_handle in hooks:
            hook_handle.remove()
    
    return generation

In [None]:
test_harmful_data = pd.read_csv("analysis_data/advbench_behaviors.csv")['Behavior']
test_prompts = list(test_harmful_data)[100:150]  # Use different prompts for testing

print(f"Testing on {len(test_prompts)} harmful prompts...")

results = {
    'baseline': {'harmful': 0, 'total': 0},
    'intervened': {'harmful': 0, 'total': 0},
    'detailed_results': []
}

for idx, prompt in enumerate(tqdm(test_prompts, desc="Testing prompts")):
    try:
        # Generate baseline (no intervention)
        baseline_gen = generate_with_neuron_intervention(
            mt, prompt, toxic_neurons_dict, intervene=False, max_new_tokens=50
        )
        
        # Generate with intervention
        intervened_gen = generate_with_neuron_intervention(
            mt, prompt, toxic_neurons_dict, intervene=True, max_new_tokens=50
        )
        
        # Judge both outputs
        baseline_harmful = harmful_judge_from_strings(
            prompt, baseline_gen, judge_model=(judge_tokenizer, judge_model_loaded)
        )
        intervened_harmful = harmful_judge_from_strings(
            prompt, intervened_gen, judge_model=(judge_tokenizer, judge_model_loaded)
        )
        
        # Update statistics
        results['baseline']['total'] += 1
        results['intervened']['total'] += 1
        
        if baseline_harmful:
            results['baseline']['harmful'] += 1
        if intervened_harmful:
            results['intervened']['harmful'] += 1
        
        # Store detailed result
        results['detailed_results'].append({
            'prompt': prompt,
            'baseline_generation': baseline_gen,
            'baseline_harmful': baseline_harmful,
            'intervened_generation': intervened_gen,
            'intervened_harmful': intervened_harmful,
            'intervention_successful': baseline_harmful and not intervened_harmful
        })
        
    except Exception as e:
        print(f"Error on prompt {idx}: {str(e)}")
        continue

# Calculate success rates
baseline_jailbreak_rate = results['baseline']['harmful'] / results['baseline']['total']
intervened_jailbreak_rate = results['intervened']['harmful'] / results['intervened']['total']

print("\n" + "="*80)
print("NEURON INTERVENTION RESULTS")
print("="*80)
print(f"\nBaseline (No Intervention):")
print(f"  Harmful outputs: {results['baseline']['harmful']}/{results['baseline']['total']}")
print(f"  Jailbreak rate: {baseline_jailbreak_rate*100:.2f}%")

print(f"\nWith Neuron Intervention:")
print(f"  Harmful outputs: {results['intervened']['harmful']}/{results['intervened']['total']}")
print(f"  Jailbreak rate: {intervened_jailbreak_rate*100:.2f}%")

print(f"\nIntervention Effectiveness:")
print(f"  Reduction in jailbreak rate: {(baseline_jailbreak_rate - intervened_jailbreak_rate)*100:.2f}%")
print(f"  Total toxic neurons intervened: {total_toxic}")

In [13]:
with open("analysis_results/neuron_intervention_results.json",'w') as f:
    json.dump(results, f, indent=4)

### Visualize Toxic Neuron Positions

In [14]:
def visualize_toxic_neuron_positions(toxic_neurons_dict):
    """Visualize toxic neuron positions across layers."""
    layer_data = []
    for layer_name, neuron_indices in toxic_neurons_dict.items():
        # Extract layer number
        layer_num = None
        parts = layer_name.split('.')
        for i, part in enumerate(parts):
            if part.isdigit():
                layer_num = int(part)
                break
        
        if layer_num is not None:
            for neuron_idx in neuron_indices:
                layer_data.append({
                    'layer': layer_num,
                    'neuron': neuron_idx,
                    'layer_name': layer_name
                })
    
    df = pd.DataFrame(layer_data)
    
    # Create visualizations
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 10))
    
    # Scatter plot of neuron positions
    ax1.scatter(df['layer'], df['neuron'], alpha=0.6, s=20, c='red')
    ax1.set_xlabel('Layer Index', fontsize=12)
    ax1.set_ylabel('Neuron Index', fontsize=12)
    ax1.set_title('Toxic Neuron Positions Across Layers', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    # Histogram of toxic neurons per layer
    layer_counts = df.groupby('layer').size()
    ax2.bar(layer_counts.index, layer_counts.values, color='steelblue', alpha=0.7)
    ax2.set_xlabel('Layer Index', fontsize=12)
    ax2.set_ylabel('Number of Toxic Neurons', fontsize=12)
    ax2.set_title('Distribution of Toxic Neurons Across Layers', fontsize=14, fontweight='bold')
    ax2.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    plt.savefig("toxic_neuron_positions.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print summary
    print("\n" + "="*80)
    print("TOXIC NEURON POSITION SUMMARY")
    print("="*80)
    print(f"\nTotal toxic neurons: {len(df)}")
    print(f"Layers with toxic neurons: {df['layer'].nunique()}")
    print(f"Average toxic neurons per layer: {len(df) / df['layer'].nunique():.2f}")
    
    print(f"\nTop 5 layers with most toxic neurons:")
    for layer, count in layer_counts.nlargest(5).items():
        print(f"  Layer {layer}: {count} toxic neurons")
    
    return df

In [None]:
toxic_neuron_df = visualize_toxic_neuron_positions(toxic_neurons_dict)
