In [None]:
import transformers
transformers.__version__

In [2]:
from utils.modelUtils import *
from utils.utils import *
import seaborn as sns
import torch
import numpy as np
import matplotlib.pyplot as plt
from casper import nethook
from utils.judgeUtils import *

### Model Initialization

Loading Llama-2-7b-chat model for layer_intervention analysis.


In [None]:
model_name = "meta-llama/Llama-2-7b-chat-hf"
mt = ModelAndTokenizer(
    model_name,
    low_cpu_mem_usage=True,
    torch_dtype=(torch.float16 if "20b" in model_name else None),
    device = 'cuda:0'
)
mt.model

In [None]:
test_prompt = generate_input(mt.tokenizer, 'tell me a fun joke')
test_prompt

### Judge Model Loading

Loading the judge model to evaluate whether responses successfully jailbreak the safety mechanisms.

In [None]:
print("Loading judge model...")
judge_tokenizer, judge_model_loaded = load_judge_model()
print("Judge model loaded successfully")


### Data Loading

Loading jailbreak prompts from GCG and AutoDAN attack methods, along with harmful behaviors dataset.

In [6]:
import json
with open("analysis_data/llama2_gcgs.json",'r') as f:
    gcgs = json.load(f)[:3]
with open("analysis_data/llama2_autodan.json",'r') as f:
    autodan = json.load(f)[:3]
import pandas as pd
harmful_data = pd.read_csv("analysis_data/advbench_behaviors.csv")['Behavior']
harmful_data = list(harmful_data)


### Layer Patching Functions
 
Functions for tracing and generating with layer interventions.

In [7]:
def trace_with_patch_layer(
    model,  # The model
    inp,  # A set of inputs
    states_to_patch,  # A list of (token index, layername) triples to restore
    answers_t,  # Answer probabilities to collect  
):
    prng = np.random.RandomState(1)  # For reproducibility, use pseudorandom noise
    layers = [states_to_patch[0], states_to_patch[1]]

    # Create dictionary to store intermediate results
    inter_results = {}

    def untuple(x):
        return x[0] if isinstance(x, tuple) else x

    # Define the model-patching rule.
    def patch_rep(x, layer):
        if layer not in layers:
            return x

        if layer == layers[0]:
            inter_results["hidden_states"] = x[0].cpu()
            
            return x
        elif layer == layers[1]:
            short_cut_1 = inter_results["hidden_states"].cuda()
            
            short_cut = (short_cut_1,)
            return short_cut
            
    with torch.no_grad(), nethook.TraceDict(
        model,
        layers,
        edit_output=patch_rep,
    ) as td:
        outputs_exp = model(**inp)

    probs = torch.softmax(outputs_exp.logits[1:, -1, :], dim=1).mean(dim=0)[answers_t]

    return probs

In [8]:
def generate_with_patch_layer(
    model,  # The model
    inp,  # A set of inputs
    states_to_patch,  # A list of (token index, layername) triples to restore
):
    prng = np.random.RandomState(1)  # For reproducibility, use pseudorandom noise
    layers = [states_to_patch[0], states_to_patch[1]]

    # Create dictionary to store intermediate results
    inter_results = {}

    def untuple(x):
        return x[0] if isinstance(x, tuple) else x

    # Define the model-patching rule.
    def patch_rep(x, layer):
        if layer not in layers:
            return x

        if layer == layers[0]:
            inter_results["hidden_states"] = x[0].cpu()
            
            return x
        elif layer == layers[1]:
            short_cut_1 = inter_results["hidden_states"].cuda()
            
            short_cut = (short_cut_1,)
            return short_cut
            
    with torch.no_grad(), nethook.TraceDict(
        model,
        layers,
        edit_output=patch_rep,
    ) as td:


        num_input_tokens = inp['input_ids'].shape[1]
        # print(tokenizer.decode(input_ids['input_ids'].squeeze(0)))
        outputs = model.generate(inp['input_ids'], attention_mask=inp['attention_mask'].half(),
                                        max_new_tokens=30, do_sample=False, pad_token_id=mt.tokenizer.pad_token_id)
        generation = mt.tokenizer.batch_decode(outputs[:, num_input_tokens:], skip_special_tokens=True)
    # print(generation)
    return generation

In [9]:
def generate_based_on_multiple_layer(prompt, layer_combinations):
    """
    Generate outputs based on multiple layer combinations
    
    Args:
        prompt: Input prompt text
        layer_combinations: 2D array, each element is [start_layer, end_layer]
    
    Returns:
        result_generations: List of generation results for each layer combination
    """
    inp = make_inputs(mt.tokenizer, [prompt])
    
    model = mt.model
    result_generations = []
    
    for start_layer, end_layer in layer_combinations:
        layers = [layername(model, start_layer), layername(model, end_layer)]
        print(f"Processing layers: {layers}")
        
        generation = generate_with_patch_layer(model, inp, layers)
        result_generations.append({
            'layers': [start_layer, end_layer],
            'generation': generation
        })
    
    return result_generations

### Layer Pruning Experiment

Performing layer pruning by intervening different layer groups (up to 1/4 of total layers).
We test all possible consecutive layer interventions and evaluate harmful content generation.


In [10]:
from collections import defaultdict
from itertools import combinations

def layer_pruning_experiment(harmful_data, max_samples=None):
    """
    Perform layer pruning experiment on harmful prompts
    
    Args:
        harmful_data: List of harmful behavior prompts
        max_samples: Maximum number of samples to test (None for all)
    
    Returns:
        results_dict: Dictionary containing all experimental results
    """
    model = mt.model
    total_layers = model.config.num_hidden_layers
    max_intervention_layers = total_layers // 4  # Maximum 1/4 of layers
    
    # Initialize result tracking
    results_dict = {
        'prompt_has_harmful': defaultdict(bool),  # Track if any intervention causes harmful output for each prompt
        'layer_frequency': defaultdict(int),  # Track frequency of each layer appearing in harmful interventions
        'detailed_results': []  # Store all detailed results
    }
    
    # Limit samples if specified
    test_data = harmful_data[:max_samples] if max_samples else harmful_data
    
    print(f"Total layers: {total_layers}")
    print(f"Max intervention layers: {max_intervention_layers}")
    print(f"Testing {len(test_data)} prompts")
    
    # Iterate through each harmful prompt
    for prompt_idx, prompt in enumerate(test_data):
        print(f"\n{'='*80}")
        print(f"Testing prompt {prompt_idx + 1}/{len(test_data)}: {prompt[:50]}...")
        
        # Test different intervention sizes (2 to max_intervention_layers)
        for n_layers in range(2, max_intervention_layers + 1):
            print(f"\n  Intervention size: {n_layers} layers")
            
            # Generate all possible consecutive layer combinations
            layer_combinations = []
            for start_layer in range(total_layers - n_layers + 1):
                end_layer = start_layer + n_layers - 1
                layer_combinations.append([start_layer, end_layer])
            
            print(f"  Testing {len(layer_combinations)} layer combinations")
            
            # Test each layer combination
            for combo_idx, (start_layer, end_layer) in enumerate(layer_combinations):
                try:
                    # Generate input
                    inp = make_inputs(mt.tokenizer, [prompt])
                    
                    # Get layer names
                    layers = [layername(model, start_layer), layername(model, end_layer)]
                    
                    # Generate with layer intervention
                    generation = generate_with_patch_layer(model, inp, layers)
                    generation_text = generation[0] if isinstance(generation, list) else generation
                    
                    # Judge if output is harmful
                    is_harmful = harmful_judge_from_strings(
                                        prompt, 
                                        generation_text,
                                        judge_model=(judge_tokenizer, judge_model_loaded)
                                    )
                    
                    # Store detailed result
                    result_entry = {
                        'prompt_idx': prompt_idx,
                        'prompt': prompt,
                        'n_layers': n_layers,
                        'start_layer': start_layer,
                        'end_layer': end_layer,
                        'intervened_layers': list(range(start_layer, end_layer + 1)),
                        'generation': generation_text,
                        'is_harmful': is_harmful
                    }
                    results_dict['detailed_results'].append(result_entry)
                    
                    # Update tracking if harmful
                    if is_harmful:
                        results_dict['prompt_has_harmful'][prompt_idx] = True
                        
                        # Update layer frequency for all intervened layers
                        for layer in range(start_layer, end_layer + 1):
                            results_dict['layer_frequency'][layer] += 1
                        
                        print(f"    ✓ Combination [{start_layer}-{end_layer}]: HARMFUL")
                    else:
                        print(f"    ✗ Combination [{start_layer}-{end_layer}]: Safe")
                    
                except Exception as e:
                    print(f"    Error processing combination [{start_layer}-{end_layer}]: {str(e)}")
                    continue
    
    return results_dict


In [None]:
print("Starting layer pruning experiment...")
results = layer_pruning_experiment(harmful_data, max_samples=1)  

In [12]:
with open("analysis_results/layer_intervention_results.json",'w') as f:
    json.dump(results, f, indent=4)

### Results Analysis

Analyzing the experimental results to understand:
1. Which prompts have at least one intervention that produces harmful output
2. Which layers appear most frequently in harmful interventions

In [None]:
print("\n" + "="*80)
print("EXPERIMENT RESULTS SUMMARY")
print("="*80)

# Statistics for prompts with harmful outputs
total_prompts = len(set([r['prompt_idx'] for r in results['detailed_results']]))
prompts_with_harmful = sum(results['prompt_has_harmful'].values())

print(f"\n1. PROMPT-LEVEL STATISTICS:")
print(f"   Total prompts tested: {total_prompts}")
print(f"   Prompts with at least one harmful intervention: {prompts_with_harmful}")
print(f"   Percentage: {prompts_with_harmful/total_prompts*100:.2f}%")

# Layer frequency analysis
print(f"\n2. LAYER FREQUENCY IN HARMFUL INTERVENTIONS:")
if results['layer_frequency']:
    sorted_layers = sorted(results['layer_frequency'].items(), key=lambda x: x[1], reverse=True)
    print(f"   Top 10 most frequent layers:")
    for layer, freq in sorted_layers[:10]:
        print(f"      Layer {layer}: {freq} occurrences")
else:
    print("   No harmful interventions found")

In [None]:
if results['layer_frequency']:
    plt.figure(figsize=(12, 6))
    layers = sorted(results['layer_frequency'].keys())
    frequencies = [results['layer_frequency'][l] for l in layers]
    
    plt.bar(layers, frequencies, color='steelblue', alpha=0.7)
    plt.xlabel('Layer Index', fontsize=12)
    plt.ylabel('Frequency in Harmful Interventions', fontsize=12)
    plt.title('Layer Frequency Distribution in Harmful Interventions', fontsize=14, fontweight='bold')
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()