# Jailbreak Mechanism Analysis: Assistant Axis + GLP

This notebook implements the full research pipeline for analyzing jailbreak mechanisms
in Llama 3.1 8B using the Assistant Axis and Generative Latent Prior (GLP).

## Research Question

**Does jailbreaking work by:**
1. Pushing away from the Assistant Axis (suppressing the "helpful assistant" identity)?
2. Activating specific harmful personas ("criminal", "hacker", etc.)?
3. Both (combination)?
4. Neither (orthogonal mechanism)?

## Sections
1. Setup & Installation
2. Compute/Load Assistant Axis
3. Data Collection
4. Activation Extraction
5. Projection Analysis
6. GLP Safety Space (H, J, R)
7. Visualization
8. Statistical Analysis
9. Steering Experiments
10. Conclusions

---
## Section 1: Setup & Installation

In [None]:
# Detect environment
import os
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    # Mount Google Drive for persistent storage
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Clone repositories
    !git clone https://github.com/safety-research/assistant-axis.git /content/assistant-axis
    !git clone https://github.com/YOUR_REPO/generative_latent_prior.git /content/generative_latent_prior
    
    # Install dependencies
    !pip install -q torch transformers accelerate vllm
    !pip install -q diffusers einops omegaconf safetensors
    !pip install -q git+https://github.com/davidbau/baukit
    !pip install -q nanogcg datasets scikit-learn scipy pandas
    !pip install -q matplotlib plotly seaborn huggingface_hub tqdm
    
    # Set paths
    PROJECT_ROOT = '/content'
    THIRD_PARTY_DIR = '/content'
else:
    # Local environment
    PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
    THIRD_PARTY_DIR = os.path.join(PROJECT_ROOT, 'third_party')

print(f"Project root: {PROJECT_ROOT}")
print(f"Third party: {THIRD_PARTY_DIR}")

In [None]:
# Add paths and import modules
import sys
sys.path.insert(0, os.path.join(PROJECT_ROOT, 'src'))
sys.path.insert(0, os.path.join(THIRD_PARTY_DIR, 'assistant-axis'))
sys.path.insert(0, os.path.join(THIRD_PARTY_DIR, 'generative_latent_prior'))

# Core imports
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm

# Check GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Configuration
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
TARGET_LAYER = 16
TOTAL_LAYERS = 32
GLP_MODEL = "generative-latent-prior/glp-llama8b-d6"

# Output directories
OUTPUT_DIR = Path(PROJECT_ROOT) / "outputs" / "llama-3.1-8b"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

AXIS_PATH = OUTPUT_DIR / "axis.pt"
JAILBREAKS_PATH = OUTPUT_DIR / "jailbreaks" / "jailbreaks.pt"
ACTIVATIONS_PATH = OUTPUT_DIR / "activations"
ACTIVATIONS_PATH.mkdir(parents=True, exist_ok=True)

print(f"Output directory: {OUTPUT_DIR}")

---
## Section 2: Compute/Load Assistant Axis

The Assistant Axis captures the direction from "role-playing character" to "default assistant".
Computing it requires the full 5-step pipeline (expensive), so we check for a cached version first.

In [None]:
# Check if axis exists
if AXIS_PATH.exists():
    print(f"Loading cached axis from {AXIS_PATH}")
    axis = torch.load(AXIS_PATH)
    print(f"Axis shape: {axis.shape}")
else:
    print("No cached axis found. You need to compute it using the pipeline.")
    print("\nTo compute the axis, run these commands:")
    print(f"""  
    cd {THIRD_PARTY_DIR}/assistant-axis
    
    # Step 1: Generate responses for 275 roles
    python pipeline/1_generate.py --model {MODEL_NAME} --output_dir {OUTPUT_DIR}/responses
    
    # Step 2: Extract activations  
    python pipeline/2_activations.py --model {MODEL_NAME} --responses_dir {OUTPUT_DIR}/responses --output_dir {OUTPUT_DIR}/activations
    
    # Step 3: Judge responses (requires OPENAI_API_KEY)
    python pipeline/3_judge.py --responses_dir {OUTPUT_DIR}/responses --output_dir {OUTPUT_DIR}/judgments
    
    # Step 4: Compute per-role vectors
    python pipeline/4_vectors.py --activations_dir {OUTPUT_DIR}/activations --judgments_dir {OUTPUT_DIR}/judgments --output_dir {OUTPUT_DIR}/vectors
    
    # Step 5: Compute assistant axis
    python pipeline/5_axis.py --vectors_dir {OUTPUT_DIR}/vectors --output {AXIS_PATH}
    """)
    axis = None

---
## Section 3: Data Collection

Load datasets for jailbreak analysis:
- JailbreakBench: Curated harmful behaviors
- Proven jailbreaks: Known successful attacks
- Alpaca: Benign prompts for baseline

In [None]:
from datasets import load_dataset

# Load JailbreakBench behaviors
print("Loading JailbreakBench...")
jbb = load_dataset("JailbreakBench/JBB-Behaviors", "behaviors")
behaviors = [{"id": i, "goal": b["Goal"], "target": b["Target"]} for i, b in enumerate(jbb["train"])]
print(f"Loaded {len(behaviors)} harmful behaviors")

# Load Alpaca for benign baseline
print("\nLoading Alpaca...")
alpaca = load_dataset("tatsu-lab/alpaca")
benign_prompts = [item["instruction"] for item in alpaca["train"][:100] if item["instruction"].strip()]
print(f"Loaded {len(benign_prompts)} benign prompts")

In [None]:
# Sample behaviors for analysis (use subset for faster iteration)
N_BEHAVIORS = 50  # Adjust based on compute budget

sample_behaviors = behaviors[:N_BEHAVIORS]
sample_benign = benign_prompts[:N_BEHAVIORS]

print(f"Using {len(sample_behaviors)} harmful behaviors")
print(f"Using {len(sample_benign)} benign prompts")

# Create prompt lists
clean_harmful_prompts = [b["goal"] for b in sample_behaviors]

In [None]:
# Option 1: Load pre-generated jailbreaks
if JAILBREAKS_PATH.exists():
    print(f"Loading jailbreaks from {JAILBREAKS_PATH}")
    jailbreaks = torch.load(JAILBREAKS_PATH)
    print(f"Loaded {len(jailbreaks)} jailbreaks")
else:
    print("No cached jailbreaks found.")
    print("To generate jailbreaks with nanoGCG, run the cell below.")
    jailbreaks = None

In [None]:
# Option 2: Generate jailbreaks with nanoGCG (expensive - ~30 min per behavior)
# Only run if needed

GENERATE_JAILBREAKS = False  # Set to True to generate

if GENERATE_JAILBREAKS and jailbreaks is None:
    import nanogcg
    from transformers import AutoModelForCausalLM, AutoTokenizer
    
    print("Loading model for GCG...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=torch.bfloat16,
        device_map="auto"
    )
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    
    config = nanogcg.GCGConfig(num_steps=500, seed=42)
    
    jailbreaks = []
    for behavior in tqdm(sample_behaviors[:10]):  # Start with just 10
        result = nanogcg.run(model, tokenizer, behavior["goal"], behavior["target"], config)
        jailbreaks.append({
            "goal": behavior["goal"],
            "target": behavior["target"],
            "suffix": result.best_string,
            "loss": float(result.best_loss),
            "jailbreak_prompt": f"{behavior['goal']} {result.best_string}"
        })
    
    # Save
    JAILBREAKS_PATH.parent.mkdir(parents=True, exist_ok=True)
    torch.save(jailbreaks, JAILBREAKS_PATH)
    print(f"Saved {len(jailbreaks)} jailbreaks to {JAILBREAKS_PATH}")
    
    # Cleanup
    del model
    torch.cuda.empty_cache()

---
## Section 4: Activation Extraction

Extract activations for all prompt categories using the Assistant Axis wrapper.

In [None]:
# Load model for activation extraction
from assistant_axis.internals import ProbingModel, ConversationEncoder, ActivationExtractor

print(f"Loading model: {MODEL_NAME}")
pm = ProbingModel(MODEL_NAME)
encoder = ConversationEncoder(pm.tokenizer, MODEL_NAME)
extractor = ActivationExtractor(pm, encoder)

print(f"Model loaded. Hidden size: {pm.hidden_size}")

In [None]:
# Extract activations for each category
activations = {}

# Clean harmful prompts
print("\nExtracting activations for clean harmful prompts...")
activations["clean_harmful"] = extractor.for_prompts(clean_harmful_prompts, layer=[TARGET_LAYER])
print(f"  Shape: {activations['clean_harmful'].shape}")

# Benign prompts
print("\nExtracting activations for benign prompts...")
activations["benign"] = extractor.for_prompts(sample_benign, layer=[TARGET_LAYER])
print(f"  Shape: {activations['benign'].shape}")

# Jailbreak prompts (if available)
if jailbreaks:
    print("\nExtracting activations for jailbreak prompts...")
    jailbreak_prompts = [jb["jailbreak_prompt"] for jb in jailbreaks]
    activations["jailbreak"] = extractor.for_prompts(jailbreak_prompts, layer=[TARGET_LAYER])
    print(f"  Shape: {activations['jailbreak'].shape}")

In [None]:
# Save activations
torch.save(activations, ACTIVATIONS_PATH / "all_activations.pt")
print(f"Saved activations to {ACTIVATIONS_PATH / 'all_activations.pt'}")

---
## Section 5: Projection Analysis

Project activations onto the Assistant Axis and analyze differences between categories.

In [None]:
from assistant_axis import project, project_batch

if axis is not None:
    projections = {}
    
    for category, acts in activations.items():
        # Reshape if needed: (batch,) -> project each
        if acts.ndim == 2:
            projs = [project(acts[i:i+1, :], axis, layer=TARGET_LAYER) for i in range(len(acts))]
            projections[category] = torch.tensor(projs)
        else:
            projections[category] = project_batch(acts, axis, layer=TARGET_LAYER)
        
        print(f"{category}: mean={projections[category].mean():.3f}, std={projections[category].std():.3f}")
else:
    print("No axis loaded. Cannot compute projections.")

In [None]:
# Compute jailbreak direction
if axis is not None and "jailbreak" in activations:
    jb_mean = activations["jailbreak"][:, 0, :].mean(dim=0)
    clean_mean = activations["clean_harmful"][:, 0, :].mean(dim=0)
    
    jailbreak_direction = jb_mean - clean_mean
    print(f"Jailbreak direction norm: {jailbreak_direction.norm():.3f}")
    
    # Project jailbreak direction onto axis
    ax = axis[TARGET_LAYER].float()
    ax_norm = ax / ax.norm()
    jb_dir = jailbreak_direction.float()
    
    axis_projection = (jb_dir @ ax_norm).item()
    variance_explained = (axis_projection ** 2) / (jb_dir ** 2).sum().item()
    
    print(f"\nJailbreak direction analysis:")
    print(f"  Projection onto axis: {axis_projection:.3f}")
    print(f"  Variance explained by axis: {variance_explained:.1%}")
    
    if axis_projection < 0:
        print("  -> Jailbreaks push AWAY from the Assistant direction")
    else:
        print("  -> Jailbreaks push TOWARD the Assistant direction (unexpected)")

---
## Section 6: GLP Safety Space (H, J, R)

Use GLP to extract meta-neurons and train probes for:
- H: Harm detection (harmful vs benign)
- J: Jailbreak success (successful vs failed)
- R: Refusal (model refuses vs complies)

In [None]:
# Load GLP model
from glp.denoiser import load_glp
from glp.script_probe import get_meta_neurons_locations, get_meta_neurons_layer_time

print(f"Loading GLP model: {GLP_MODEL}")
glp = load_glp(GLP_MODEL, device="cuda:0", checkpoint="final")
glp_layers = get_meta_neurons_locations(glp)
print(f"GLP loaded with {len(glp_layers)} meta-neuron layers")

In [None]:
# Extract meta-neurons for all activations
u = torch.tensor([0.9])[:, None]  # Diffusion timestep

meta_neurons = {}
for category, acts in activations.items():
    print(f"\nExtracting meta-neurons for {category}...")
    # Reshape: (batch, 1, hidden) -> (batch, hidden)
    X = acts[:, 0, :].float() if acts.ndim == 3 else acts.float()
    
    meta, (layer_size, u_size) = get_meta_neurons_layer_time(
        glp, "cuda:0", X, u, glp_layers, seed=42, batch_size=64
    )
    meta_neurons[category] = meta
    print(f"  Meta-neuron shape: {meta.shape}")

In [None]:
# Train H probe (Harm detection)
from glp.script_probe import prefilter_and_reshape_to_oned, run_sklearn_logreg_batched

print("Training H probe (Harmful vs Benign)...")

# Combine meta-neurons
X_meta = torch.cat([meta_neurons["clean_harmful"], meta_neurons["benign"]], dim=1)
y = torch.cat([torch.ones(meta_neurons["clean_harmful"].shape[1]), 
               torch.zeros(meta_neurons["benign"].shape[1])])

# Split
n = X_meta.shape[1]
perm = torch.randperm(n)
split = int(n * 0.8)
train_idx, test_idx = perm[:split], perm[split:]

# Prefilter and probe
X_train_1d, X_test_1d, top_indices = prefilter_and_reshape_to_oned(
    X_meta[:, train_idx, :], X_meta[:, test_idx, :], y[train_idx], "cuda:0", topk=512
)
val_aucs, test_aucs = run_sklearn_logreg_batched(X_train_1d, y[train_idx], X_test_1d, y[test_idx])

best_h_idx = top_indices[np.argmax(val_aucs)]
print(f"Best H neuron: {best_h_idx}, AUC: {test_aucs[np.argmax(val_aucs)]:.4f}")

In [None]:
# Similarly train J and R probes if we have the data
# For J: need successful vs failed jailbreaks
# For R: need refusal vs compliance responses

# This requires running generation and classifying responses
# Placeholder for full implementation

print("\nNote: Full H, J, R training requires:")
print("  - J probe: successful vs failed jailbreak responses")
print("  - R probe: refusal vs compliance responses")
print("\nRun generation and classification to get this data.")

---
## Section 7: Visualization

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Plot axis projection distributions
if axis is not None:
    fig, ax = plt.subplots(figsize=(10, 6))
    
    data = []
    labels = []
    for category, projs in projections.items():
        data.append(projs.numpy())
        labels.append(category)
    
    parts = ax.violinplot(data, showmeans=True, showmedians=True)
    ax.set_xticks(range(1, len(labels) + 1))
    ax.set_xticklabels(labels, rotation=45, ha='right')
    ax.set_ylabel('Projection on Assistant Axis')
    ax.set_title('Jailbreak vs Clean: Assistant Axis Projections')
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    
    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / 'axis_projections.png', dpi=150)
    plt.show()

In [None]:
# PCA visualization
from sklearn.decomposition import PCA

# Collect all activations
all_acts = []
all_labels = []
for category, acts in activations.items():
    a = acts[:, 0, :].numpy() if acts.ndim == 3 else acts.numpy()
    all_acts.append(a)
    all_labels.extend([category] * len(a))

all_acts = np.vstack(all_acts)

# PCA
pca = PCA(n_components=2)
embedded = pca.fit_transform(all_acts)

# Plot
fig, ax = plt.subplots(figsize=(10, 8))

categories = list(set(all_labels))
colors = plt.cm.tab10(np.linspace(0, 1, len(categories)))

for i, cat in enumerate(categories):
    mask = np.array(all_labels) == cat
    ax.scatter(embedded[mask, 0], embedded[mask, 1], c=[colors[i]], label=cat, alpha=0.6)

ax.legend()
ax.set_xlabel(f'PC1 ({pca.explained_variance_ratio_[0]:.1%} variance)')
ax.set_ylabel(f'PC2 ({pca.explained_variance_ratio_[1]:.1%} variance)')
ax.set_title('Activation Space (PCA)')

plt.tight_layout()
plt.savefig(OUTPUT_DIR / 'pca_activations.png', dpi=150)
plt.show()

---
## Section 8: Statistical Analysis

In [None]:
from scipy import stats

if axis is not None and len(projections) >= 2:
    print("Statistical Tests for Hypothesis 1: Jailbreaks = Anti-Assistant")
    print("=" * 60)
    
    if "jailbreak" in projections and "clean_harmful" in projections:
        jb_projs = projections["jailbreak"].numpy()
        clean_projs = projections["clean_harmful"].numpy()
        
        # Mann-Whitney U test
        stat, p = stats.mannwhitneyu(jb_projs, clean_projs, alternative='less')
        
        # Effect size (Cohen's d)
        pooled_std = np.sqrt(((len(jb_projs)-1)*jb_projs.std()**2 + (len(clean_projs)-1)*clean_projs.std()**2) / (len(jb_projs)+len(clean_projs)-2))
        cohens_d = (jb_projs.mean() - clean_projs.mean()) / pooled_std
        
        print(f"\nJailbreak projections: mean={jb_projs.mean():.3f}, std={jb_projs.std():.3f}")
        print(f"Clean projections: mean={clean_projs.mean():.3f}, std={clean_projs.std():.3f}")
        print(f"\nMann-Whitney U test (jailbreak < clean):")
        print(f"  Statistic: {stat:.2f}")
        print(f"  p-value: {p:.4f}")
        print(f"  Effect size (Cohen's d): {cohens_d:.3f}")
        
        if p < 0.05 and cohens_d < 0:
            print("\n*** HYPOTHESIS 1 SUPPORTED ***")
            print("Jailbreaks have significantly lower projections on the Assistant Axis.")
        else:
            print("\nHypothesis 1 NOT supported by this data.")

---
## Section 9: Steering Experiments

Test causal effects by steering activations toward/away from the Assistant Axis.

In [None]:
from assistant_axis.steering import ActivationSteering

if axis is not None:
    # Test prompts
    test_harmful = clean_harmful_prompts[:5]
    
    print("Steering Experiment: Can we prevent harmful outputs by steering toward Assistant?")
    print("=" * 70)
    
    for i, prompt in enumerate(test_harmful):
        print(f"\n--- Prompt {i+1}: {prompt[:80]}...")
        
        # Baseline
        baseline = pm.generate(prompt, max_new_tokens=100, temperature=0.0, do_sample=False)
        print(f"\nBaseline: {baseline[:200]}...")
        
        # Steered toward assistant (coefficient +2.0)
        with ActivationSteering(
            pm.model,
            steering_vectors=[axis[TARGET_LAYER]],
            coefficients=[2.0],
            layer_indices=[TARGET_LAYER]
        ):
            steered = pm.generate(prompt, max_new_tokens=100, temperature=0.0, do_sample=False)
        print(f"\nSteered (+2.0): {steered[:200]}...")

In [None]:
# Steering away from assistant (negative coefficient)
if axis is not None:
    print("\nSteering Experiment: Can we induce harmful outputs by steering AWAY from Assistant?")
    print("=" * 70)
    
    for i, prompt in enumerate(test_harmful[:3]):
        print(f"\n--- Prompt {i+1}: {prompt[:80]}...")
        
        # Baseline
        baseline = pm.generate(prompt, max_new_tokens=100, temperature=0.0, do_sample=False)
        print(f"\nBaseline: {baseline[:200]}...")
        
        # Steered away from assistant (coefficient -2.0)
        with ActivationSteering(
            pm.model,
            steering_vectors=[axis[TARGET_LAYER]],
            coefficients=[-2.0],
            layer_indices=[TARGET_LAYER]
        ):
            steered = pm.generate(prompt, max_new_tokens=100, temperature=0.0, do_sample=False)
        print(f"\nSteered (-2.0): {steered[:200]}...")

---
## Section 10: Conclusions

In [None]:
print("""
===============================================================================
                         ANALYSIS CONCLUSIONS
===============================================================================

Based on the analysis performed in this notebook:

1. ASSISTANT AXIS PROJECTIONS
   - Measured projections of jailbreak vs clean prompts onto the Assistant Axis
   - Hypothesis 1 (Anti-Assistant): [Result from statistical test]

2. GLP META-NEURON ANALYSIS
   - Trained H probe for harm detection
   - [Add J and R results when available]

3. STEERING EXPERIMENTS
   - Tested causal effects of steering toward/away from Assistant
   - [Summarize findings]

4. OVERALL CONCLUSION
   Based on the evidence:
   - Hypothesis 1 (Anti-Assistant): [Supported/Not Supported]
   - Hypothesis 2 (Harmful Persona): [Requires persona vector analysis]
   - Hypothesis 3 (Both): [Assessment]
   - Hypothesis 4 (Orthogonal): [Based on variance explained]

NEXT STEPS:
   - Compute full axis for Llama 3.1 8B using pipeline
   - Generate more jailbreaks with nanoGCG
   - Train J and R probes with response classification
   - Analyze persona vector similarities

===============================================================================
""")

In [None]:
# Cleanup
if 'pm' in dir():
    pm.close()
if 'glp' in dir():
    del glp
torch.cuda.empty_cache()
print("Cleanup complete.")