# Compute Assistant Axis for Llama 3.1 8B

This notebook computes the Assistant Axis for Llama 3.1 8B.

**Storage-optimized:** Only saves the final axis and lightweight role vectors. Heavy intermediate data (responses, activations) is processed in memory and discarded.

**Saved outputs (~150MB total):**
- `axis.pt` - The final Assistant Axis (~500KB)
- `role_vectors.pt` - Per-role mean vectors (~140MB)

**NOT saved (processed in memory):**
- Raw responses (~GB)
- Full activations (~100GB+)

**Runtime:** ~4-8 hours on A100

**Requirements:**
- GPU: A100 40GB recommended
- OpenAI API key for judging responses

---
## Section 1: Setup

In [None]:
import os
import sys

IN_COLAB = 'google.colab' in sys.modules
print(f"Running in Colab: {IN_COLAB}")

if IN_COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Output directory in Drive (only for final outputs)
    OUTPUT_DIR = '/content/drive/MyDrive/JB_mech_outputs/llama-3.1-8b'
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    print(f"Final outputs will be saved to: {OUTPUT_DIR}")

In [None]:
# Clone repos
if IN_COLAB:
    # Clone assistant-axis for role data and utilities
    !git clone --depth 1 https://github.com/safety-research/assistant-axis.git /content/assistant-axis
    ASSISTANT_AXIS_DIR = '/content/assistant-axis'
else:
    PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))
    ASSISTANT_AXIS_DIR = os.path.join(PROJECT_ROOT, 'third_party', 'assistant-axis')
    OUTPUT_DIR = os.path.join(PROJECT_ROOT, 'outputs', 'llama-3.1-8b')
    os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
# Install dependencies
if IN_COLAB:
    !pip install -q torch transformers accelerate vllm
    !pip install -q openai python-dotenv jsonlines tqdm
    print("Dependencies installed!")

In [None]:
import torch
import json
import gc
from pathlib import Path
from tqdm import tqdm

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Configuration
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
TARGET_LAYER = 16
TOTAL_LAYERS = 32
HIDDEN_SIZE = 4096

# Paths
AXIS_PATH = os.path.join(OUTPUT_DIR, 'axis.pt')
ROLE_VECTORS_PATH = os.path.join(OUTPUT_DIR, 'role_vectors.pt')
CHECKPOINT_PATH = os.path.join(OUTPUT_DIR, 'checkpoint.pt')  # For resuming

# Pipeline settings (reduce for faster testing)
N_QUESTIONS = 50  # Full: 240, Quick test: 20
N_SYSTEM_PROMPTS = 3  # Full: 5, Quick: 2

print(f"Output: {OUTPUT_DIR}")
print(f"Settings: {N_QUESTIONS} questions Ã— {N_SYSTEM_PROMPTS} system prompts per role")

In [None]:
# Set OpenAI API key (paste your key here or use environment variable)
OPENAI_API_KEY = ""  # <-- Paste your key here if not using env var

if OPENAI_API_KEY:
    os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
elif 'OPENAI_API_KEY' not in os.environ:
    from getpass import getpass
    api_key = getpass("Enter OpenAI API key: ")
    os.environ['OPENAI_API_KEY'] = api_key

print("OpenAI API key configured.")

---
## Section 2: Load Role Data

In [None]:
import jsonlines

# Load roles
ROLES_DIR = Path(ASSISTANT_AXIS_DIR) / 'data' / 'roles' / 'instructions'
QUESTIONS_FILE = Path(ASSISTANT_AXIS_DIR) / 'data' / 'extraction_questions.jsonl'

roles = {}
for role_file in ROLES_DIR.glob('*.json'):
    with open(role_file) as f:
        roles[role_file.stem] = json.load(f)

# Load questions
with jsonlines.open(QUESTIONS_FILE) as reader:
    questions = list(reader)[:N_QUESTIONS]

print(f"Loaded {len(roles)} roles, {len(questions)} questions")

In [None]:
# Check for existing checkpoint (for resuming)
if os.path.exists(CHECKPOINT_PATH):
    checkpoint = torch.load(CHECKPOINT_PATH)
    completed_roles = set(checkpoint.get('completed_roles', []))
    role_vectors = checkpoint.get('role_vectors', {})
    print(f"Resuming from checkpoint: {len(completed_roles)} roles completed")
else:
    completed_roles = set()
    role_vectors = {}
    print("Starting fresh (no checkpoint found)")

---
## Section 3: Process Roles (Memory-Efficient Pipeline)

For each role:
1. Generate responses with vLLM
2. Extract activations (keep in memory)
3. Judge responses with GPT-4
4. Compute mean vector from score=3 responses
5. Discard raw data, keep only the vector

This processes one role at a time to minimize memory usage.

In [None]:
# Load vLLM for generation
from vllm import LLM, SamplingParams

print(f"Loading vLLM: {MODEL_NAME}")
llm = LLM(
    model=MODEL_NAME,
    tensor_parallel_size=1,
    gpu_memory_utilization=0.7,  # Leave room for activation extraction
    dtype="bfloat16",
    max_model_len=1024,
)

sampling_params = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=200,
)

print("vLLM loaded!")

In [None]:
# Setup OpenAI for judging
import openai
import time
import re

client = openai.OpenAI()

# Judge model config
JUDGE_MODEL = "gpt-4o-mini"  # Cheapest, ~$0.15/1M input tokens
JUDGE_MAX_TOKENS = 10

# Rate limiting for Tier 1: 500 RPM = ~8 RPS, we use 5 RPS to be safe
REQUESTS_PER_SECOND = 5
MIN_REQUEST_INTERVAL = 1.0 / REQUESTS_PER_SECOND  # 0.2 seconds

# Retry config
MAX_RETRIES = 3
RETRY_BACKOFF = 2.0  # Exponential backoff multiplier

last_request_time = 0

def judge_response(role_name, system_prompt, question, response):
    """
    Score how well a response embodies the role (0-3).
    
    Scoring rubric (from assistant-axis):
    - 0: Model refused to answer
    - 1: Model says it can't be the role but can help
    - 2: Model identifies as AI/LLM but has some role attributes  
    - 3: Model fully plays the role
    """
    global last_request_time
    
    prompt = f"""Score how well this AI response embodies the assigned role (0-3):
- 0: Refuses to engage
- 1: Acknowledges can't be role but helps
- 2: Shows some role traits but breaks character
- 3: Fully embodies the role

Role: {role_name}
System: {system_prompt[:300]}
User: {question[:200]}
Response: {response[:500]}

Score (just the number):"""
    
    for attempt in range(MAX_RETRIES):
        try:
            # Rate limiting
            elapsed = time.time() - last_request_time
            if elapsed < MIN_REQUEST_INTERVAL:
                time.sleep(MIN_REQUEST_INTERVAL - elapsed)
            
            last_request_time = time.time()
            
            result = client.chat.completions.create(
                model=JUDGE_MODEL,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=JUDGE_MAX_TOKENS,
                temperature=0,
            )
            
            # Parse score (handle various formats like "3", "Score: 3", etc.)
            content = result.choices[0].message.content.strip()
            match = re.search(r'[0-3]', content)
            if match:
                return int(match.group())
            return 0
            
        except openai.RateLimitError as e:
            wait_time = RETRY_BACKOFF ** (attempt + 1)
            print(f"  Rate limited, waiting {wait_time}s...")
            time.sleep(wait_time)
            
        except openai.APIError as e:
            if attempt < MAX_RETRIES - 1:
                wait_time = RETRY_BACKOFF ** attempt
                time.sleep(wait_time)
            else:
                print(f"  API error after {MAX_RETRIES} retries: {e}")
                return 0
                
        except Exception as e:
            print(f"  Unexpected error: {e}")
            return 0
    
    return 0

print(f"Judging with {JUDGE_MODEL} @ {REQUESTS_PER_SECOND} RPS (Tier 1 safe)")

In [None]:
# Helper: extract activation from a single text
from transformers import AutoModelForCausalLM, AutoTokenizer

# We'll load HF model when needed (after vLLM generation for each role)
hf_model = None
hf_tokenizer = None

def load_hf_model():
    global hf_model, hf_tokenizer
    if hf_model is None:
        print("Loading HuggingFace model for activation extraction...")
        hf_model = AutoModelForCausalLM.from_pretrained(
            MODEL_NAME,
            torch_dtype=torch.bfloat16,
            device_map="auto",
        )
        hf_model.eval()
        hf_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
        hf_tokenizer.pad_token = hf_tokenizer.eos_token
        print("HF model loaded!")

def extract_activation(text):
    """Extract mean activation at all layers for a single text."""
    inputs = hf_tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    inputs = {k: v.to(hf_model.device) for k, v in inputs.items()}
    
    activations = []
    
    def hook(module, input, output):
        h = output[0] if isinstance(output, tuple) else output
        activations.append(h.mean(dim=1).squeeze(0).detach().cpu())
    
    handles = [hf_model.model.layers[i].register_forward_hook(hook) for i in range(TOTAL_LAYERS)]
    
    with torch.no_grad():
        hf_model(**inputs)
    
    for h in handles:
        h.remove()
    
    # Clear GPU memory after each extraction
    del inputs
    torch.cuda.empty_cache()
    
    return torch.stack(activations)  # (n_layers, hidden_dim)

In [None]:
def process_role(role_name, role_data):
    """
    Process a single role end-to-end:
    1. Generate responses
    2. Judge them
    3. Extract activations for score=3 only
    4. Return mean activation vector
    """
    # Get system prompts
    if isinstance(role_data, dict) and 'system_prompts' in role_data:
        system_prompts = role_data['system_prompts'][:N_SYSTEM_PROMPTS]
    elif isinstance(role_data, list):
        system_prompts = role_data[:N_SYSTEM_PROMPTS]
    else:
        system_prompts = [str(role_data)]
    
    # Generate responses
    responses_data = []
    for sys_prompt in system_prompts:
        prompts = []
        for q in questions:
            q_text = q['question'] if isinstance(q, dict) else str(q)
            chat = f"<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{sys_prompt}<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n{q_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
            prompts.append(chat)
        
        outputs = llm.generate(prompts, sampling_params)
        
        for i, output in enumerate(outputs):
            responses_data.append({
                'system_prompt': sys_prompt,
                'question': questions[i],
                'response': output.outputs[0].text,
            })
    
    # Judge responses (only need score=3)
    # Rate limiting is handled inside judge_response()
    score_3_data = []
    for resp in responses_data:
        q_text = resp['question']['question'] if isinstance(resp['question'], dict) else str(resp['question'])
        score = judge_response(role_name, resp['system_prompt'], q_text, resp['response'])
        if score == 3:
            score_3_data.append(resp)
    
    if len(score_3_data) == 0:
        print(f"  Warning: No score=3 responses for {role_name}")
        return None
    
    # Extract activations only for score=3 responses
    load_hf_model()
    
    activations = []
    for i, resp in enumerate(score_3_data[:20]):  # Limit to 20 to save time
        q_text = resp['question']['question'] if isinstance(resp['question'], dict) else str(resp['question'])
        full_text = f"{resp['system_prompt']}\n\nUser: {q_text}\n\nAssistant: {resp['response']}"
        act = extract_activation(full_text)
        activations.append(act)
        
        # Periodic cleanup every 5 extractions
        if (i + 1) % 5 == 0:
            gc.collect()
            torch.cuda.empty_cache()
    
    # Compute mean vector
    mean_vector = torch.stack(activations).mean(dim=0)  # (n_layers, hidden_dim)
    
    # Cleanup after processing role
    del activations
    gc.collect()
    torch.cuda.empty_cache()
    
    print(f"  {role_name}: {len(responses_data)} responses, {len(score_3_data)} score=3, vector computed")
    return mean_vector

In [None]:
# Process all roles
print(f"\nProcessing {len(roles)} roles...")
print("="*60)

roles_to_process = [r for r in roles.keys() if r not in completed_roles]
print(f"Remaining: {len(roles_to_process)} roles\n")

for role_name in tqdm(roles_to_process):
    role_data = roles[role_name]
    
    try:
        vector = process_role(role_name, role_data)
        if vector is not None:
            role_vectors[role_name] = vector
        
        completed_roles.add(role_name)
        
        # Save checkpoint every 10 roles
        if len(completed_roles) % 10 == 0:
            torch.save({
                'completed_roles': list(completed_roles),
                'role_vectors': role_vectors,
            }, CHECKPOINT_PATH)
            print(f"  [Checkpoint saved: {len(completed_roles)} roles]")
    
    except Exception as e:
        print(f"  Error processing {role_name}: {e}")
        continue

print(f"\nCompleted! {len(role_vectors)} role vectors computed.")

In [None]:
# Cleanup models
del llm
if hf_model is not None:
    del hf_model
    del hf_tokenizer
gc.collect()
torch.cuda.empty_cache()
print("Models unloaded.")

---
## Section 4: Compute Final Axis

In [None]:
# Save role vectors
torch.save(role_vectors, ROLE_VECTORS_PATH)
print(f"Saved {len(role_vectors)} role vectors to: {ROLE_VECTORS_PATH}")
print(f"Size: {os.path.getsize(ROLE_VECTORS_PATH) / 1e6:.1f} MB")

In [None]:
# Compute axis: default - mean(roles)
DEFAULT_ROLE = 'default'

if DEFAULT_ROLE in role_vectors:
    default_vec = role_vectors[DEFAULT_ROLE]
    character_vecs = {k: v for k, v in role_vectors.items() if k != DEFAULT_ROLE}
else:
    print(f"No '{DEFAULT_ROLE}' role found. Using overall mean as default.")
    all_vecs = torch.stack(list(role_vectors.values()))
    default_vec = all_vecs.mean(dim=0)
    character_vecs = role_vectors

# Compute axis
character_mean = torch.stack(list(character_vecs.values())).mean(dim=0)
axis = default_vec - character_mean

print(f"Axis shape: {axis.shape}")
print(f"Axis norm at target layer ({TARGET_LAYER}): {axis[TARGET_LAYER].norm():.3f}")

In [None]:
# Save axis
torch.save(axis, AXIS_PATH)
print(f"Saved axis to: {AXIS_PATH}")
print(f"Size: {os.path.getsize(AXIS_PATH) / 1e3:.1f} KB")

In [None]:
# Cleanup checkpoint
if os.path.exists(CHECKPOINT_PATH):
    os.remove(CHECKPOINT_PATH)
    print("Removed checkpoint file.")

---
## Done!

In [None]:
print(f"""
{'='*60}
ASSISTANT AXIS COMPUTATION COMPLETE
{'='*60}

Saved files:
  - Axis: {AXIS_PATH} ({os.path.getsize(AXIS_PATH)/1e3:.0f} KB)
  - Role vectors: {ROLE_VECTORS_PATH} ({os.path.getsize(ROLE_VECTORS_PATH)/1e6:.1f} MB)

Axis details:
  - Shape: {axis.shape}
  - Model: {MODEL_NAME}
  - Target layer: {TARGET_LAYER}
  - Roles processed: {len(role_vectors)}

Next: Run jailbreak_analysis.ipynb
{'='*60}
""")