## Load Packages

In [16]:
from datasets import load_dataset
import os
from transformers import AutoTokenizer, AutoModelForCausalLM
from tqdm import tqdm
from transformers import GenerationConfig
from collections import defaultdict
from functools import partial
import accelerate
import torch


## Load Model

In [4]:
device = "cuda:1"
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-8B", 
    trust_remote_code=True,
    # device_map="auto",   # Hugging Face accelerates across multiple GPUs
    dtype="auto"   # or torch.float16 / bfloat16
).to(device)

Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [19]:

def load_data():
    harm_ds = load_dataset("walledai/AdvBench")
    alpaca_dataset = load_dataset("yahma/alpaca-cleaned")
    malicious_ds = load_dataset("walledai/MaliciousInstruct")

    harm = harm_ds["train"].select_columns(["prompt"]).select(range(150))
    clean = alpaca_dataset["train"].select_columns(["instruction", "input"]).select(range(200))
    malicious = malicious_ds["train"].select_columns(["prompt"]).select(range(50))

    harm_ls = [p['prompt'] for p in harm]
    clean_ls = [f"{p['instruction']} {p['input']}".strip() for p in clean]
    malicious_ls = [p['prompt'] for p in malicious]
    
    return (harm_ls + malicious_ls, clean_ls)

harm_test, clean_test = load_data()


## Functions to get activations

In [20]:
for module in model.modules():
    module._forward_hooks.clear()
    module._forward_pre_hooks.clear()
    module._backward_hooks.clear()

def get_full_block_activations(insts, tokenizer_kwargs={}, gen_kwargs={"max_new_tokens":4}):
    temp_hidden_state_dict = defaultdict(lambda: defaultdict(list))
    step_counter = [-1]  # list so it’s mutable and shared in closures

    def make_hook_fn(temp_hidden_state_dict, step_counter, trailing_tokens=3):
        def hook_fn(module, input, output, key):
            activations = output[0].detach().cpu()  # (batch, seq_len, hidden_dim)
            seq_len = activations.shape[0]
    
            if seq_len > 1:
                n_keep = min(trailing_tokens, seq_len)
                for i in range(n_keep):
                    step_label = 0 if i == 0 else -i
                    temp_hidden_state_dict[key][step_label].append(activations[-(i+1), :])
            else:
                # Generation: only 1 token
                temp_hidden_state_dict[key][step_counter[0]].append(activations[0])
        return hook_fn


    orig_forward = model.forward
    def forward_with_step_tracking(*args, **kwargs):
        step_counter[0] += 1
        return orig_forward(*args, **kwargs)

    model.forward = forward_with_step_tracking

    # Register hooks
    hook_dict = {}
    for i, block in enumerate(model.model.layers):
        hook_dict[f"layer_{i}"] = block.register_forward_hook(
            partial(make_hook_fn(temp_hidden_state_dict, step_counter, trailing_tokens=5), key=f"layer_{i}")
        )

    for inst in tqdm(insts):
        step_counter[0] = -1  # reset per prompt
        prompt = tokenizer.apply_chat_template(
            [{"role": "user", "content": inst}],
            add_generation_prompt=True,
            tokenize=False,
            
        )
        inputs = tokenizer(prompt, return_tensors="pt", **tokenizer_kwargs)
        inputs = {k: v.to(model.device) for k, v in inputs.items()}

        with torch.no_grad():
            _ = model.generate(**inputs, **gen_kwargs)

    # Remove hooks & restore forward
    for k in hook_dict:
        hook_dict[k].remove()
    model.forward = orig_forward

    return temp_hidden_state_dict


In [21]:
harmful_vectors = get_full_block_activations(harm_test)
harmless_vectors = get_full_block_activations(clean_test)

100%|██████████| 200/200 [00:45<00:00,  4.39it/s]
100%|██████████| 200/200 [00:46<00:00,  4.33it/s]


## Subtracting the Means

In [23]:
mean_harm = defaultdict(dict)
mean_clean = defaultdict(dict)
refusal_vectors = defaultdict(dict)

for layer in tqdm(harmful_vectors.keys()):
    for step in harmful_vectors[layer].keys():
        # Stack once per step per layer
        harm_stack = torch.stack(harmful_vectors[layer][step], dim=0)  # shape (N, ...)
        clean_stack = torch.stack(harmless_vectors[layer][step], dim=0)

        # Compute mean once
        mean_harm_vec = harm_stack.mean(dim=0)
        mean_clean_vec = clean_stack.mean(dim=0)

        mean_harm[layer][step] = mean_harm_vec
        mean_clean[layer][step] = mean_clean_vec
        refusal_vectors[layer][step] = mean_harm_vec - mean_clean_vec

100%|██████████| 36/36 [00:01<00:00, 21.70it/s]


## Saving The Vectors

In [24]:
def save_candiate_vectors(filename, temp_hidden_state_dict):
    # Convert nested defaultdicts to normal dicts recursively
    def convert(d):
        if isinstance(d, defaultdict):
            d = {k: convert(v) for k, v in d.items()}
        elif isinstance(d, list):
            # List of tensors: convert tensors to CPU if needed
            d = [t.cpu() if isinstance(t, torch.Tensor) else t for t in d]
        return d
    
    normal_dict = convert(temp_hidden_state_dict)
    torch.save(normal_dict, filename)

save_candiate_vectors("refusal_vectors.pt", refusal_vectors)