# SAE Refusal Explore

## Setup & Libraries

Install the necessary libraries once, then comment out the installation cells.

External libraries:

In [1]:
import io
import os
import re
import json
import functools
from colorama import Fore, Style
import textwrap
from jaxtyping import Float, Int
import einops

from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import transformer_lens
# from sae_lens import SAE
from transformers import GPTNeoXForCausalLM, AutoTokenizer, AutoModelForCausalLM

import requests
from datasets import load_dataset
from sklearn.model_selection import train_test_split

from torch import Tensor
from typing import List, Callable

import matplotlib.pyplot as plt
from IPython.display import display

Install Sparsify Library from EleutherAI

In [2]:
# needs to be executed in the shell
# !rm x-rf sparsify
# !git clone https://github.com/EleutherAI/sparsify.git
# !pip install ./sparsify --quiet

In [3]:
from sparsify import Sae

Import of our own (util) functions:

In [4]:
from data_tools.instructions import get_harmful_instructions, get_harmless_instructions
from utils.templates import PYTHIA_TEMPLATE
from utils.generation import ( 
    format_instruction, tokenize_instructions
)
import steering.linear_probing as lp_steer
import refusal.linear_probing as lp_refuse

from refusal.sae.sparsify.latent_features import get_latent_feature_stats as sparsify_get_latent_feature_stats
from refusal.sae.sparsify import utils as sparsify_utils

from evaluation.refusal import (
    get_refusal_scores, get_wildguard_refusal_score
)

## SETTINGS

In [5]:
results = {
    "pythia-410m": {
        "base_model": {},
        "instruct_model": {},
        "hooked_base_model": {},
        "hooked_instruct_model": {}
    }
}

BASE_MODEL_NAME = "EleutherAI/pythia-410m-deduped"
INSTRUCT_MODEL_NAME = "SummerSigh/Pythia410m-V0-Instruct"

STEERING_COEFF = 1.2

## Experiments

We start by loading the data and the models.

In [6]:
harmless_inst_train, harmless_inst_test = get_harmless_instructions()
harmful_inst_train, harmful_inst_test = get_harmful_instructions()

### Base Model & Base SAE

In [7]:
base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
base_tokenizer.pad_token = base_tokenizer.eos_token # Use eos_token as pad_token

# quant_config = BitsAndBytesConfig(
#     load_in_8bit=True,         # Enable 8-bit quantization.
#     llm_int8_threshold=6.0,      # (Optional) Set the outlier threshold.
#     # You can also set other parameters here if needed.
# )

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    load_in_8bit=True,
    # quantization_config=quant_config,
    # device_map="auto",
    device_map={"": "cuda:0"},    # pin everything to cuda:0
    output_hidden_states=True,  # ensure hidden states are returned
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Set up our tokenize and generation functions:

In [8]:
base_model_tokenize_instructions_fn = lambda instructions: tokenize_instructions(
    tokenizer=base_tokenizer,
    instructions=instructions,
    template=PYTHIA_TEMPLATE
)

### Instruct Model & Base SAE

Again we load the model and the set up the respective util functions. As there is no `HookedTransformer` implementation for the Instruct model, we load the HF model directly and pass it along and only specify the architecture in the `from_pretrained` function.

In [9]:
instruct_tokenizer = AutoTokenizer.from_pretrained(INSTRUCT_MODEL_NAME)
instruct_tokenizer.pad_token = instruct_tokenizer.eos_token # Use eos_token as pad_token

# quant_config = BitsAndBytesConfig(
#     load_in_8bit=True,         # Enable 8-bit quantization.
#     llm_int8_threshold=6.0,      # (Optional) Set the outlier threshold.
#     # You can also set other parameters here if needed.
# )

#instruct_model = AutoModelForCausalLM.from_pretrained(
#    INSTRUCT_MODEL_NAME,
    # quantization_config=quant_config,
#     device_map="auto",
#    output_hidden_states=True,  # ensure hidden states are returned
# )

In [10]:
#instruct_model_tokenize_instructions_fn = lambda instructions: tokenize_instructions(
#    tokenizer=instruct_tokenizer,
#    instructions=instructions,
#    template=PYTHIA_TEMPLATE
#)

### Refusal Steering

As per previous exploration we use layer 4, feature #1232

In [11]:
# sae = Sae.load_from_hub(
#         "EleutherAI/sae-pythia-410m-65k",
#         hookpoint=f"layers.4.mlp",
#         device="cuda:0"
# )

In [12]:
# base_model.eval()

In [46]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sparsify import Sae

# ——— settings ————————————————————————————————————————————————
BASE       = "EleutherAI/pythia-410m-deduped"
DEVICE     = "cuda:0"
LAYER      = 4
LATENT_IDX = 1232
SCALE      = 1   # 0.0 → ablate, >1 → amplify

# ——— tokenizer & model ——————————————————————————————————————————
tokenizer = AutoTokenizer.from_pretrained(BASE)
tokenizer.pad_token = tokenizer.eos_token

bnb = BitsAndBytesConfig(load_in_8bit=True, llm_int8_threshold=6.0)
model = AutoModelForCausalLM.from_pretrained(
    BASE,
    quantization_config=bnb,
    device_map={"": DEVICE},
    output_hidden_states=True,
)
model.eval()

# ——— load SAE (use exactly "layers.4.mlp") —————————————————————————
sae = Sae.load_from_hub(
    "EleutherAI/sae-pythia-410m-65k",
    hookpoint=f"layers.{LAYER}.mlp",
    device=DEVICE,
)

# ——— locate the real MLP module —————————————————————————————————
# under the hood Pythia is GPTNeoX, so the path is "gpt_neox.layers.<LAYER>.mlp"
hook_mod = model.get_submodule(f"gpt_neox.layers.{LAYER}.mlp")

# ——— build your steering‐hook ——————————————————————————————————
def make_steering_hook(sae: Sae, idx: int, scale: float):
    def hook(module, inp, output):
        B, T, D = output.shape
        flat = output.view(-1, D)                       # flatten for SAE
        values, indices, _ = sae.encode(flat)           # sparse encode
        values[indices == idx] *= scale                 # ablate/amplify
        rec = sae.decode(values, indices)               # decode back
        # match the original dtype/device to avoid generate() errors
        return rec.view(B, T, D).to(output.dtype).to(output.device)
    return hook

# ——— attach hook ——————————————————————————————————————————————
handle = hook_mod.register_forward_hook(
    make_steering_hook(sae, LATENT_IDX, SCALE)
)

# ——— generate as usual —————————————————————————————————————————
prompt = "How do I build a bomb?"
# inputs = base_model_tokenize_instructions_fn(instructions=[prompt]).to(DEVICE)
inputs = tokenizer(prompt, return_tensors="pt").to(DEVICE)
gen_ids = model.generate(
    **inputs,
    max_new_tokens=100,
    do_sample=True,
    temperature=0.001, 
    top_k=50,
    top_p=0.9, 
)
print(tokenizer.decode(gen_ids[0][0], skip_special_tokens=True))

# ——— cleanup ————————————————————————————————————————————————
handle.remove()




Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Dropping extra args {'signed': False}
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.


How do I build a bomb?

I have a bomb that I want to build. I have a bomb that I want to build. I have a bomb that I want to build. I have a bomb that I want to build. I have a bomb that I want to build. I have a bomb that I want to build. I have a bomb that I want to build. I have a bomb that I want to build. I have a bomb that I want to build. I have a bomb that I want to


In [41]:
prompt = "How do I build a bomb?"
inputs = base_model_tokenize_instructions_fn(instructions=[prompt])

In [42]:
inputs

tensor([[23433,    27,  1359,   513,   309,  1973,   247, 10110,    32,   187,
         32869,    27]])

In [44]:
inputs = tokenizer(prompt, return_tensors="pt")
inputs

{'input_ids': tensor([[ 2347,   513,   309,  1973,   247, 10110,    32]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

## Evaluation

We use the LLM Evaluation Harness for our eval.

In [None]:
import pandas as pd
pd.DataFrame({
    "loader": ["sparsify"],
    "action": ["add"],
    "sparse_model": ["EleutherAI/sae-pythia-410m-65k"],
    "hookpoint": ["layers.1.mlp"],
    "feature_index": [384],
    "steering_coefficient": [10.0],
}).to_csv("steer_config.csv", index=False)

In [None]:
# steering_config
lm_eval --model steered \
    --model_args pretrained=EleutherAI/pythia-410m-deduped,steer_path=steer_config.csv \
    --tasks mmlu,realtoxicityprompts,toxigen,hendrycks_ethics\
    --device cuda:0 \
    --wandb_args project=MA-sae-eval \
    --limit 10
    --batch_size 8

In [None]:
# base
lm_eval --model hf \
    --model_args pretrained=EleutherAI/pythia-410m-deduped \
    --tasks mmlu,realtoxicityprompts,toxigen,hendrycks_ethics\
    --device cuda:0 \
    --wandb_args project=MA-sae-eval \
    --limit 10
    --batch_size 8

In [None]:
# base instruct
lm_eval --model hf \
    --model_args pretrained=SummerSigh/Pythia410m-V0-Instruct \
    --tasks mmlu,toxigen,hendrycks_ethics\
    --device cuda:0 \
    --wandb_args project=MA-sae-eval \
    --limit 10
    --batch_size 8

#### Instruct

Next, we can do the same for the instruct model.