# Compute Assistant Axis for Llama 3.1 8B

This notebook computes the Assistant Axis for Llama 3.1 8B.

**Storage-optimized:** Only saves the final axis and lightweight role vectors. Heavy intermediate data (responses, activations) is processed in memory and discarded.

**Saved outputs (~150MB total):**
- `axis.pt` - The final Assistant Axis (~500KB)
- `role_vectors.pt` - Per-role mean vectors (~140MB)

**NOT saved (processed in memory):**
- Raw responses (~GB)
- Full activations (~100GB+)

**Runtime:** ~4-8 hours on A100

**Requirements:**
- GPU: A100 40GB recommended
- OpenAI API key for judging responses

---
## Section 1: Setup

In [None]:
import os
import sys

IN_COLAB = 'google.colab' in sys.modules
print(f"Running in Colab: {IN_COLAB}")

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Output directory in Drive (only for final outputs)
    OUTPUT_DIR = '/content/drive/MyDrive/JB_mech_outputs/llama-3.1-8b'
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print(f"Final outputs will be saved to: {OUTPUT_DIR}")

In [None]:
# Clone repos
if IN_COLAB:
    # Clone assistant-axis for role data and utilities
    !git clone --depth 1 https://github.com/safety-research/assistant-axis.git /content/assistant-axis
    ASSISTANT_AXIS_DIR = '/content/assistant-axis'
else:
    PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
    ASSISTANT_AXIS_DIR = os.path.join(PROJECT_ROOT, 'third_party', 'assistant-axis')
    OUTPUT_DIR = os.path.join(PROJECT_ROOT, 'outputs', 'llama-3.1-8b')
    os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# Install dependencies
if IN_COLAB:
    !pip install -q torch transformers accelerate vllm
    !pip install -q openai python-dotenv jsonlines tqdm
    print("Dependencies installed!")

In [None]:
import torch
import json
import gc
from pathlib import Path
from tqdm import tqdm

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Configuration
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
TARGET_LAYER = 16
TOTAL_LAYERS = 32
HIDDEN_SIZE = 4096

# Paths
AXIS_PATH = os.path.join(OUTPUT_DIR, 'axis.pt')
ROLE_VECTORS_PATH = os.path.join(OUTPUT_DIR, 'role_vectors.pt')

# Pipeline settings (reduce for faster testing)
N_QUESTIONS = 50  # Full: 240, Quick test: 20
N_SYSTEM_PROMPTS = 3  # Full: 5, Quick: 2

print(f"Output: {OUTPUT_DIR}")
print(f"Settings: {N_QUESTIONS} questions × {N_SYSTEM_PROMPTS} system prompts per role")

In [None]:
# Set OpenAI API key (paste your key here or use environment variable)
OPENAI_API_KEY = ""  # <-- Paste your key here if not using env var

if OPENAI_API_KEY:
    os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
elif 'OPENAI_API_KEY' not in os.environ:
    from getpass import getpass
    api_key = getpass("Enter OpenAI API key: ")
    os.environ['OPENAI_API_KEY'] = api_key

print("OpenAI API key configured.")

---
## Section 2: Load Role Data

In [None]:
import jsonlines

# Load roles
ROLES_DIR = Path(ASSISTANT_AXIS_DIR) / 'data' / 'roles' / 'instructions'
QUESTIONS_FILE = Path(ASSISTANT_AXIS_DIR) / 'data' / 'extraction_questions.jsonl'

roles = {}
for role_file in ROLES_DIR.glob('*.json'):
    with open(role_file) as f:
        roles[role_file.stem] = json.load(f)

# Load questions
with jsonlines.open(QUESTIONS_FILE) as reader:
    questions = list(reader)[:N_QUESTIONS]

print(f"Loaded {len(roles)} roles, {len(questions)} questions")

In [None]:
# Initialize (checkpoint not needed with new sequential flow)
completed_roles = set()
print(f"Ready to process {len(roles)} roles")

---
## Section 3: Generate All Responses (vLLM)

Generate responses for ALL roles at once, then unload vLLM to free GPU memory.

In [None]:
# Load vLLM for generation
from vllm import LLM, SamplingParams

print(f"Loading vLLM: {MODEL_NAME}")
llm = LLM(
    model=MODEL_NAME,
    tensor_parallel_size=1,
    gpu_memory_utilization=0.9,  # Can use more since HF model loads later
    dtype="bfloat16",
    max_model_len=1024,
)

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=200,
)

print("vLLM loaded!")

In [None]:
# Generate ALL responses for ALL roles
all_responses = {}  # role_name -> list of {system_prompt, question, response}

roles_to_generate = [r for r in roles.keys() if r not in completed_roles]
print(f"Generating responses for {len(roles_to_generate)} roles...")
print(f"Each role: {N_SYSTEM_PROMPTS} prompts × {N_QUESTIONS} questions = {N_SYSTEM_PROMPTS * N_QUESTIONS} responses")
print("="*60)

for role_name in tqdm(roles_to_generate):
    role_data = roles[role_name]
    
    # Get system prompts
    if isinstance(role_data, dict) and 'system_prompts' in role_data:
        system_prompts = role_data['system_prompts'][:N_SYSTEM_PROMPTS]
    elif isinstance(role_data, list):
        system_prompts = role_data[:N_SYSTEM_PROMPTS]
    else:
        system_prompts = [str(role_data)]
    
    # Generate responses for this role
    responses_data = []
    for sys_prompt in system_prompts:
        prompts = []
        for q in questions:
            q_text = q['question'] if isinstance(q, dict) else str(q)
            chat = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{sys_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{q_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
            prompts.append(chat)
        
        outputs = llm.generate(prompts, sampling_params)
        
        for i, output in enumerate(outputs):
            responses_data.append({
                'system_prompt': sys_prompt,
                'question': questions[i],
                'response': output.outputs[0].text,
            })
    
    all_responses[role_name] = responses_data

print(f"\nGenerated {sum(len(v) for v in all_responses.values())} total responses for {len(all_responses)} roles")

In [None]:
# Unload vLLM to free GPU memory
del llm
gc.collect()
torch.cuda.empty_cache()
print(f"vLLM unloaded. GPU memory freed.")

---
## Section 4: Judge All Responses (OpenAI API)

Score all responses using GPT-4o-mini. Rate limited to 5 RPS for Tier 1 API safety.

In [None]:
# Setup OpenAI for judging
import openai
import time
import re

client = openai.OpenAI()

# Judge model config
JUDGE_MODEL = "gpt-4o-mini"  # Cheapest, ~$0.15/1M input tokens
JUDGE_MAX_TOKENS = 10

# Rate limiting for Tier 1: 500 RPM = ~8 RPS, we use 5 RPS to be safe
REQUESTS_PER_SECOND = 5
MIN_REQUEST_INTERVAL = 1.0 / REQUESTS_PER_SECOND  # 0.2 seconds

# Retry config
MAX_RETRIES = 3
RETRY_BACKOFF = 2.0

last_request_time = 0

def judge_response(role_name, system_prompt, question, response):
    """Score how well a response embodies the role (0-3)."""
    global last_request_time
    
    prompt = f"""Score how well this AI response embodies the assigned role (0-3):
- 0: Refuses to engage
- 1: Acknowledges can't be role but helps
- 2: Shows some role traits but breaks character
- 3: Fully embodies the role

Role: {role_name}
System: {system_prompt[:300]}
User: {question[:200]}
Response: {response[:500]}

Score (just the number):"""
    
    for attempt in range(MAX_RETRIES):
        try:
            # Rate limiting
            elapsed = time.time() - last_request_time
            if elapsed < MIN_REQUEST_INTERVAL:
                time.sleep(MIN_REQUEST_INTERVAL - elapsed)
            
            last_request_time = time.time()
            
            result = client.chat.completions.create(
                model=JUDGE_MODEL,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=JUDGE_MAX_TOKENS,
                temperature=0,
            )
            
            content = result.choices[0].message.content.strip()
            match = re.search(r'[0-3]', content)
            if match:
                return int(match.group())
            return 0
            
        except openai.RateLimitError:
            wait_time = RETRY_BACKOFF ** (attempt + 1)
            print(f"  Rate limited, waiting {wait_time}s...")
            time.sleep(wait_time)
        except Exception as e:
            if attempt == MAX_RETRIES - 1:
                return 0
            time.sleep(RETRY_BACKOFF ** attempt)
    
    return 0

print(f"Judging with {JUDGE_MODEL} @ {REQUESTS_PER_SECOND} RPS")

In [None]:
# Judge ALL responses and collect score=3 responses
scored_responses = {}  # role_name -> list of {system_prompt, question, response} with score=3

total_responses = sum(len(v) for v in all_responses.values())
print(f"Judging {total_responses} responses...")
print(f"Estimated time: {total_responses / REQUESTS_PER_SECOND / 60:.1f} minutes")
print("="*60)

total_score_3 = 0
for role_name in tqdm(all_responses.keys()):
    score_3_data = []
    
    for resp in all_responses[role_name]:
        q_text = resp['question']['question'] if isinstance(resp['question'], dict) else str(resp['question'])
        score = judge_response(role_name, resp['system_prompt'], q_text, resp['response'])
        if score == 3:
            score_3_data.append(resp)
    
    scored_responses[role_name] = score_3_data
    total_score_3 += len(score_3_data)

print(f"\nJudging complete: {total_score_3} score=3 responses out of {total_responses}")
print(f"Average score=3 rate: {total_score_3/total_responses*100:.1f}%")

# Free memory - we don't need all_responses anymore
del all_responses
gc.collect()

---
## Section 5: Extract Activations (HuggingFace)

Load HF model and extract activations for all score=3 responses.

In [None]:
# Load HuggingFace model for activation extraction
from transformers import AutoModelForCausalLM, AutoTokenizer

print(f"Loading HuggingFace model: {MODEL_NAME}")
hf_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
hf_model.eval()

hf_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
hf_tokenizer.pad_token = hf_tokenizer.eos_token

print("HF model loaded!")

def extract_activation(text):
    """Extract mean activation at all layers for a single text."""
    inputs = hf_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(hf_model.device) for k, v in inputs.items()}
    
    activations = []
    
    def hook(module, input, output):
        h = output[0] if isinstance(output, tuple) else output
        activations.append(h.mean(dim=1).squeeze(0).detach().cpu())
    
    handles = [hf_model.model.layers[i].register_forward_hook(hook) for i in range(TOTAL_LAYERS)]
    
    with torch.no_grad():
        hf_model(**inputs)
    
    for h in handles:
        h.remove()
    
    # Clear GPU memory
    del inputs
    torch.cuda.empty_cache()
    
    return torch.stack(activations)  # (n_layers, hidden_dim)

In [None]:
# Extract activations for all score=3 responses
role_vectors = {}
MAX_ACTIVATIONS_PER_ROLE = 20  # Limit per role to save time

total_extractions = sum(min(len(v), MAX_ACTIVATIONS_PER_ROLE) for v in scored_responses.values() if len(v) > 0)
print(f"Extracting activations for {total_extractions} responses...")
print("="*60)

extraction_count = 0
for role_name in tqdm(scored_responses.keys()):
    score_3_data = scored_responses[role_name]
    
    if len(score_3_data) == 0:
        print(f"  Skipping {role_name}: no score=3 responses")
        continue
    
    activations = []
    for i, resp in enumerate(score_3_data[:MAX_ACTIVATIONS_PER_ROLE]):
        q_text = resp['question']['question'] if isinstance(resp['question'], dict) else str(resp['question'])
        full_text = f"{resp['system_prompt']}\n\nUser: {q_text}\n\nAssistant: {resp['response']}"
        
        act = extract_activation(full_text)
        activations.append(act)
        extraction_count += 1
        
        # Periodic cleanup every 5 extractions
        if (i + 1) % 5 == 0:
            gc.collect()
            torch.cuda.empty_cache()
    
    # Compute mean vector for this role
    role_vectors[role_name] = torch.stack(activations).mean(dim=0)
    
    # Cleanup
    del activations
    gc.collect()
    torch.cuda.empty_cache()

print(f"\nExtracted {extraction_count} activations for {len(role_vectors)} roles")

# Free scored_responses
del scored_responses
gc.collect()

In [None]:
# Unload HF model
del hf_model
del hf_tokenizer
gc.collect()
torch.cuda.empty_cache()
print("HF model unloaded.")

---
## Section 6: Compute Final Axis

In [None]:
# Save role vectors
torch.save(role_vectors, ROLE_VECTORS_PATH)
print(f"Saved {len(role_vectors)} role vectors to: {ROLE_VECTORS_PATH}")

# Compute axis: default - mean(roles)
DEFAULT_ROLE = 'default'

if DEFAULT_ROLE in role_vectors:
    default_vec = role_vectors[DEFAULT_ROLE]
    character_vecs = {k: v for k, v in role_vectors.items() if k != DEFAULT_ROLE}
else:
    print(f"No '{DEFAULT_ROLE}' role found. Using overall mean as default.")
    all_vecs = torch.stack(list(role_vectors.values()))
    default_vec = all_vecs.mean(dim=0)
    character_vecs = role_vectors

# Compute axis
character_mean = torch.stack(list(character_vecs.values())).mean(dim=0)
axis = default_vec - character_mean

print(f"Axis shape: {axis.shape}")
print(f"Axis norm at target layer ({TARGET_LAYER}): {axis[TARGET_LAYER].norm():.3f}")

# Save axis
torch.save(axis, AXIS_PATH)
print(f"Saved axis to: {AXIS_PATH}")

---
## Done!

In [None]:
print(f"""
{'='*60}
ASSISTANT AXIS COMPUTATION COMPLETE
{'='*60}

Saved files:
  - Axis: {AXIS_PATH} ({os.path.getsize(AXIS_PATH)/1e3:.0f} KB)
  - Role vectors: {ROLE_VECTORS_PATH} ({os.path.getsize(ROLE_VECTORS_PATH)/1e6:.1f} MB)

Axis details:
  - Shape: {axis.shape}
  - Model: {MODEL_NAME}
  - Target layer: {TARGET_LAYER}
  - Roles processed: {len(role_vectors)}

Next: Run jailbreak_analysis.ipynb
{'='*60}
""")