# Extract refusal

This notebook shows how to run harmful and harmless instructions on an LLM and extract refusal embeddings

- Created by [@maximelabonne](https://twitter.com/maximelabonne).
- Adapted by [Fabian Hildebrandt](https://huggingface.co/FabianHildebrandt).

## Load dependencies

In [None]:
import torch
import functools
import gc
import yaml
import pickle
import os

from datasets import load_dataset
from tqdm import tqdm
from torch import Tensor
from typing import List
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint
from transformers import AutoModelForCausalLM, AutoTokenizer
from jaxtyping import Float, Int
from collections import defaultdict



## Load parameters from config 


In [None]:
with open('./config.yaml', 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

print(config)

MODEL_ID = config['extraction']['MODEL_ID']
MODEL_TYPE = config['extraction']['MODEL_TYPE']
MODEL_NAME = config['extraction']['MODEL_NAME']
HARMFUL_DATA = config['extraction']['HARMFUL_DATA']
HARMLESS_DATA = config['extraction']['HARMLESS_DATA']
N_SAMPLES = config['extraction']['N_SAMPLES']
TOKEN = config['extraction']['TOKEN']

## Load datasets

In [None]:
# Turn automatic differentiation off to save GPU memory (credit: Undi95)
torch.set_grad_enabled(False)

def reformat_texts(texts):
    return [[{"role": "user", "content": text}] for text in texts]

# Get harmful and harmless datasets
def get_harmful_instructions():
    dataset = load_dataset(HARMFUL_DATA)
    return reformat_texts(dataset['train']['text']), reformat_texts(dataset['test']['text'])

def get_harmless_instructions():
    dataset = load_dataset(HARMLESS_DATA)
    return reformat_texts(dataset['train']['text']), reformat_texts(dataset['test']['text'])

harmful_inst_train, harmful_inst_test = get_harmful_instructions()
harmless_inst_train, harmless_inst_test = get_harmless_instructions()

print('-----Overview Train & Test Data------')
print(f"Train data: Harmful instructions: {len(harmful_inst_train)}, Harmless instructions: {len(harmless_inst_train)}")
print(f"Test data: Harmful instructions: {len(harmful_inst_test)}, Harmless instructions: {len(harmless_inst_test)}")

## Load model and tokenizer

In [None]:
# load model & tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_ID,
    device_map='auto', # make sure, that model will be load to the GPU
    token = TOKEN
)
model = HookedTransformer.from_pretrained_no_processing(
    MODEL_ID,
    dtype=torch.bfloat16,
    default_padding_side='left',
    device_map='auto', # make sure, that model will be load to the GPU
    use_auth_token = TOKEN
)
tokenizer.padding_side = 'left'
tokenizer.pad_token = tokenizer.eos_token

## Display information about the model

In [None]:
def analyze_model(model):
    """Analyze a HookedTransformer model and return its key information"""
    model_info = {
        'Name': model.cfg.model_name,  # Use cfg instead of config
        'Number of layers': model.cfg.n_layers,
        'Attention heads / layer': model.cfg.n_heads,
        'Hidden layer size': model.cfg.d_model,
        'Number of different tokens': model.cfg.d_vocab,
        'Context Size': model.cfg.n_ctx
    }

    # Calculate total parameters
    total_params = sum(p.numel() for p in model.parameters())
    model_info['total_parameters'] = total_params

    return model_info

# Example usage
model_info = analyze_model(model)
print('--------------')
print('Model Information:')
print('--------------')
for key, value in model_info.items():
    print(f"{key}: {value}")

## Tokenize instructions

In [None]:
def tokenize_instructions(tokenizer, instructions):
    #tokenizer.chat_template = "<|im_start|>user\n{message}<|im_end|>\n<|im_start|>assistant\n"
    if 'bloom' in MODEL_NAME:
      tokenizer.chat_template = """{% for message in messages %}
      {% if message['role'] == 'user' %}
      User: {{ message['content'] }}
      {% elif message['role'] == 'assistant' %}
      Assistant: {{ message['content'] }}
      {% endif %}
      {% endfor %}
      """
    return tokenizer.apply_chat_template(
        instructions,
        padding=True,
        truncation=False,
        return_tensors="pt",
        return_dict=True,
        add_generation_prompt=True,
    ).input_ids

n_inst_train = min(N_SAMPLES, len(harmful_inst_train), len(harmless_inst_train))

# Tokenize datasets
harmful_tokens = tokenize_instructions(
    tokenizer,
    instructions=harmful_inst_train[:n_inst_train],
)
harmless_tokens = tokenize_instructions(
    tokenizer,
    instructions=harmless_inst_train[:n_inst_train],
)
print(f'Successfully tokenized {len(harmless_tokens)} harmless and {len(harmful_tokens)} harmful instructions.')

## Run generations and cache the activations

In [None]:
# Define batch size based on available VRAM
batch_size = 16

# Initialize defaultdicts to store activations
harmful = defaultdict(list)
harmless = defaultdict(list)

# Process the training data in batches
num_batches = (n_inst_train + batch_size - 1) // batch_size
for i in tqdm(range(num_batches)):
    print(i)
    start_idx = i * batch_size
    end_idx = min(n_inst_train, start_idx + batch_size)

    # Run models on harmful and harmless prompts, cache activations
    harmful_logits, harmful_cache = model.run_with_cache(
        harmful_tokens[start_idx:end_idx],
        names_filter=lambda hook_name: 'resid' in hook_name,
        device='cpu',
        reset_hooks_end=True
    )
    harmless_logits, harmless_cache = model.run_with_cache(
        harmless_tokens[start_idx:end_idx],
        names_filter=lambda hook_name: 'resid' in hook_name,
        device='cpu',
        reset_hooks_end=True
    )

    # Collect and store the activations
    for key in harmful_cache:
        harmful[key].append(harmful_cache[key][:,-1,:])
        harmless[key].append(harmless_cache[key][:,-1,:])

    # Flush RAM and VRAM
    del harmful_logits, harmless_logits, harmful_cache, harmless_cache
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
harmful = {k: torch.cat(v) for k, v in harmful.items()}
harmless = {k: torch.cat(v) for k, v in harmless.items()}
print(f'Now we have the activations for all {model.cfg.n_layers} layers: {harmful.keys()}')
print(f'Examplary shape of the harmful activations: {harmful["blocks.0.hook_resid_pre"].shape}')

## Calculate difference-in-means

In [21]:
# Helper function to get activation index
def get_act_idx(cache_dict, act_name, layer):
    key = (act_name, layer)
    return cache_dict[utils.get_act_name(*key)]

# Compute difference of means between harmful and harmless activations at intermediate layers
activation_layers = ["resid_mid", "resid_post"]
activation_refusals = defaultdict(list)

for layer_num in range(1, model.cfg.n_layers):

    for layer in activation_layers:
        harmful_mean_act = get_act_idx(harmful, layer, layer_num)[:, :].mean(dim=0)
        harmless_mean_act = get_act_idx(harmless, layer, layer_num)[:, :].mean(
            dim=0
        )

        refusal_dir = harmful_mean_act - harmless_mean_act
        refusal_dir = refusal_dir / refusal_dir.norm()
        activation_refusals[layer].append(refusal_dir)

## Store activations in a pickle file

In [None]:
destination = f"./data"
os.makedirs(destination, exist_ok=True)
activations = {
    'harmful': harmful,
    'harmless': harmless,
    'activation_refusals': activation_refusals
}

fpath = os.path.join(destination, f'{MODEL_NAME}_activations.pkl')
with open(fpath, 'wb') as f:
    pickle.dump(activations, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f'Successfully loaded to {fpath}.')

# free up the ressources
del harmful, harmless, activation_refusals
gc.collect()
torch.cuda.empty_cache()