In [18]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from sae_lens import SAE
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from functools import partial
import string
import bitsandbytes
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from scipy.stats import trim_mean
from torch.utils.data import DataLoader, Dataset
import gc
import time

In [2]:
from huggingface_hub import login
login(token="REDACTED")

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",  
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_enable_fp32_cpu_offload=True
)

In [4]:
device = utils.get_device()
model_name = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config
)

Fetching 2 files: 100%|██████████| 2/2 [02:35<00:00, 77.85s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:19<00:00,  9.98s/it]


In [5]:
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((409

In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [7]:
release = "llama_scope_r1_distill"
sae_id = "l25r_400m_slimpajama_400m_openr1_math"
sae, cfg_dict, sparsity = SAE.from_pretrained_with_cfg_and_sparsity(release, sae_id)
sae = sae.to(device)

In [8]:
def format_prompt_aqua(question, reasoning=True, include_options=True):
    # question, options = query['question'], query['options']
    joined_options = "\n".join(options) if include_options else ""
    if reasoning:
        return f'<s>[INST] {question}{joined_options}\n{COT_PROMPT} [/INST] \n<think>\n'
    else:
        return f'<s>[INST] {question}{joined_options}\n [/INST] \n'

In [9]:
def format_prompt_tqa(question, reasoning=True):
    COT_PROMPT = r'Please reason step by step, and put your final answer within \boxed{}.'
    if reasoning:
        return f'<s>[INST] {question}\n{COT_PROMPT} [/INST] \n<think>\n'
    else:
        return f'<s>[INST] {question}\n [/INST] \n'  

In [10]:
def get_cot(model, tokenizer, query):
    """
    Function to get the response from a model to an MCQ - both the prediction and the reasoning (CoT)
    """

    MCQ_ANSWER_EXTRACTION_PROMPT = 'The correct answer is ('

    # Set up prompt for CoT case 
    prompt = format_prompt(query)

    # Tokenize prompt
    prompt_token_ids = tokenizer(
        [prompt], 
        padding=True,
        truncation=True, 
        max_length=8192,
        return_attention_mask=True,
        return_tensors='pt'
    ).to(model.device)

    # Generate model response based on input
    # No sampling, just deterministic output to analyze faithfulness
    with torch.inference_mode():
        out_tok_ids = model.generate(
            input_ids=prompt_token_ids['input_ids'], 
            attention_mask=prompt_token_ids['attention_mask'],
            pad_token_id=tokenizer.eos_token_id,
            max_new_tokens=1000,
            do_sample=True,
            temperature=0.6,
            top_p=0.95
        )[0]

    # Decode output tokens into English and parse output to remove question + options
    output = tokenizer.decode(out_tok_ids.cpu(), skip_special_tokens=False)
    output = "".join(output.split('[/INST]')[1:])
    
    # Tokenize answer extraction prompt for prompting model again for answer post CoT
    answer_extraction_prompt = tokenizer(
        MCQ_ANSWER_EXTRACTION_PROMPT, 
        truncation=True, 
        return_tensors='pt')['input_ids'].to(model.device)

    if out_tok_ids[-1] == tokenizer.eos_token_id:
        out_tok_ids = out_tok_ids[:-1]
    out_tok_ids_w_answer_extraction = torch.concat([out_tok_ids, answer_extraction_prompt[0]])
    out = model(out_tok_ids_w_answer_extraction.unsqueeze(0))

    # Getting valid options from ASCII and converting to integer representation 
    options = query['options']
    letters = list(string.ascii_uppercase)[:len(options)]
    valid_option_ids = [tokenizer.convert_tokens_to_ids(o) for o in letters]

    pred_idx = torch.argmax(out.logits[0, -1, valid_option_ids]).cpu()
    pred = tokenizer.convert_ids_to_tokens(valid_option_ids[pred_idx])

    return pred, output

In [11]:
query = {"question": "In the coordinate plane, points (x, 1) and (5, y) are on line k. If line k passes through the origin and has slope 1/5, then what are the values of x and y respectively?", "options": ["A)4 and 1", "B)1 and 5", "C)5 and 1", "D)3 and 5", "E)5 and 3"], "rationale": "Line k passes through the origin and has slope 1/5 means that its equation is y=1/5*x.\nThus: (x, 1)=(5, 1) and (5, y) = (5,1) -->x=5 and y=1\nAnswer: C", "correct": "C"}


In [148]:
get_cot(model, tokenizer, query)

('C',
 ' \n\n<think>\nAlright, let\'s try to tackle this problem step by step. So, we have two points on line k: (x, 1) and (5, y). The line k passes through the origin and has a slope of 1/5. We need to find the values of x and y from the given options.\n\nFirst, let me recall the equation of a line in slope-intercept form. It\'s usually written as y = mx + b, where m is the slope and b is the y-intercept. However, since the line passes through the origin, the y-intercept b should be zero. So, the equation simplifies to y = mx.\n\nGiven the slope is 1/5, the equation of line k becomes y = (1/5)x. That seems straightforward.\n\nNow, since both points (x, 1) and (5, y) lie on this line, they must satisfy the equation y = (1/5)x.\n\nLet\'s plug in the first point (x, 1). Substituting into the equation, we get:\n1 = (1/5)x.\n\nHmm, solving for x, I can multiply both sides by 5:\n5 * 1 = x\nSo, x = 5.\n\nWait, that seems too straightforward. Let me check if I did that correctly. If x is 5,

What are we trying to achieve here?

We have a model and a pretrained SAE for that model. We want to try and find reasoning-related components through the use of the SAE. What features reliably fire for questions that require reasoning, vs "regular" questions. What features fire when the CoT prompt is present vs absent? When we run faithfulness experiments, is there any feature that corresponds to hint verbalization? Is there a difference in feature activations between models that have the same answer prior to CoT and those that use the CoT to get to the answer?

So first, let's try and find reasoning-related SAE features. To do that, we can try to find features that consistently activate strongly when the model is given questions that require reasoning, but don't activate when given other random input (e.g John went to the market today).

In [11]:
aqua_ds = load_dataset('aqua_rat', 'raw', split='test')

Downloading readme: 100%|██████████| 5.89k/5.89k [00:00<00:00, 17.3kB/s]
Downloading data: 100%|██████████| 25.4M/25.4M [00:02<00:00, 11.0MB/s]
Downloading data: 100%|██████████| 74.0k/74.0k [00:00<00:00, 213kB/s]
Downloading data: 100%|██████████| 76.1k/76.1k [00:00<00:00, 249kB/s]
Generating train split: 100%|██████████| 97467/97467 [00:00<00:00, 424546.28 examples/s]
Generating test split: 100%|██████████| 254/254 [00:00<00:00, 88572.76 examples/s]
Generating validation split: 100%|██████████| 254/254 [00:00<00:00, 93427.45 examples/s]


In [12]:
aqua_ds

Dataset({
    features: ['question', 'options', 'rationale', 'correct'],
    num_rows: 254
})

In [13]:
def get_sae_acts(input_batch, layer, agg='mean'):

    activation_dict = {}

    def hook_fn(module, input, output):
        activation_dict["hidden"] = output

    hook = model.model.layers[layer].register_forward_hook(hook_fn)

    model.eval()

    with torch.no_grad():
        _ = model(**input_batch)

    hook.remove()
    
    hidden_states = activation_dict['hidden']
    raw_feats = sae.encode(hidden_states)
    
    if agg == 'mean':
        result= raw_feats.mean(dim=1) 
    elif agg == 'last':
        result = raw_feats[:, -1]

    del hidden_states, raw_feats
    
    return result

In [46]:
class IndexedPromptDataset(Dataset):
    def __init__(self, num_examples):
        self.indices = list(range(num_examples))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.indices[idx]

def collate_tokenized(batch_indices, tokenized):
    return {k: v[batch_indices] for k, v in tokenized.items()}

In [128]:
def get_ds_saes(sae, layer, prompts, model, collate_fn, batch_size=8, agg='mean'):
    
    dataset = IndexedPromptDataset(len(prompts))
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

    num_feats = sae.cfg.d_sae
    sae_mat = torch.zeros(len(prompts), num_feats)

    with torch.no_grad():
        for i, batch_inputs in enumerate(tqdm(dataloader)):
            batch_inputs = {k: v.to(model.device) for k, v in batch_inputs.items()}
            batch_feats = get_sae_acts(batch_inputs, layer=layer, agg=agg)
            start = i * batch_size
            end = start + batch_feats.shape[0]
            sae_mat[start:end] = batch_feats.cpu()

    # torch.cuda.empty_cache()
    # gc.collect()

    return sae_mat

In [149]:
def format_prompt_aqua(query, reasoning=True, include_options=True):
    question, options = query['question'], query['options']
    joined_options = "\n".join(options) if include_options else ""
    if reasoning:
        return f'<s>[INST] {question}{joined_options}\n{COT_PROMPT} [/INST] \n<think>\n'
    else:
        return f'<s>[INST] {question}{joined_options}\n [/INST] \n'

In [129]:
aq_prompts = [format_prompt_aqua(q, reasoning=False, include_options=False) for q in aqua_ds['question']]
aq_tokenized = tokenizer(aq_prompts, return_tensors='pt', padding=True, truncation=True)

In [130]:
aq_collate_fn = partial(collate_tokenized, tokenized=aq_tokenized)

In [133]:
means_r = get_ds_saes(sae, 25, aq_prompts, model, collate_fn=aq_collate_fn, agg='mean')

100%|██████████| 32/32 [00:09<00:00,  3.51it/s]


In [134]:
lasts_r = get_ds_saes(sae, 25, aq_prompts, model, collate_fn=aq_collate_fn, agg='last')

100%|██████████| 32/32 [00:09<00:00,  3.43it/s]


In [39]:
nr_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split='test')

Downloading readme: 100%|██████████| 26.7k/26.7k [00:00<00:00, 120kB/s]
Downloading data: 100%|██████████| 26/26 [02:43<00:00,  6.30s/files]
Downloading data: 100%|██████████| 327M/327M [00:06<00:00, 48.6MB/s] 
Downloading data: 100%|██████████| 296M/296M [00:04<00:00, 67.7MB/s] 
Downloading data: 100%|██████████| 184M/184M [00:02<00:00, 63.1MB/s] 
Downloading data: 100%|██████████| 129M/129M [00:02<00:00, 61.4MB/s] 
Downloading data: 100%|██████████| 307M/307M [00:04<00:00, 71.5MB/s] 
Downloading data: 100%|██████████| 288M/288M [00:05<00:00, 53.5MB/s] 
Downloading data: 100%|██████████| 171M/171M [00:03<00:00, 46.9MB/s] 
Downloading data: 100%|██████████| 128M/128M [00:02<00:00, 61.1MB/s] 
Generating train split: 100%|██████████| 138384/138384 [00:49<00:00, 2776.21 examples/s]
Generating validation split: 100%|██████████| 17944/17944 [00:06<00:00, 2832.55 examples/s]
Generating test split: 100%|██████████| 17210/17210 [00:06<00:00, 2683.29 examples/s]


In [135]:
tr_prompts = [format_prompt_tqa(q, reasoning=False) for q in nr_ds[:250]['question']]
tr_tokenized = tokenizer(tr_prompts, return_tensors='pt', padding=True, truncation=True)

In [136]:
tr_collate_fn = partial(collate_tokenized, tokenized=tr_tokenized)

In [138]:
means_nr = get_ds_saes(sae, 25, tr_prompts, model, collate_fn=tr_collate_fn, agg='mean')

100%|██████████| 32/32 [00:05<00:00,  5.65it/s]


In [139]:
lasts_nr = get_ds_saes(sae, 25, tr_prompts, model, collate_fn=tr_collate_fn, agg='last')

100%|██████████| 32/32 [00:05<00:00,  5.78it/s]


In [140]:
mean_aqua = trim_mean(means_r.detach(), proportiontocut=0.05, axis=0)

In [141]:
mean_trivia = trim_mean(means_nr.detach(), proportiontocut=0.05, axis=0)

In [142]:
epsilon = 1e-6
percentage_increase = 100 * (mean_aqua - mean_trivia) / (mean_trivia + epsilon)

In [143]:
valid = (mean_trivia > 0.01) & (mean_aqua > 0.1)

In [144]:
filtered_percentage_increase = percentage_increase[valid]

In [145]:
valid_indices = np.where(valid)[0]  

In [146]:
ranked_order = np.argsort(-filtered_percentage_increase)
ranked_feature_indices = valid_indices[ranked_order] 
reasoning_feats = []

top_k = 10
for i in range(top_k):
    idx = ranked_feature_indices[i]
    reasoning_feats.append(idx)
    print(f"Feature {idx}: Aqua mean = {mean_aqua[idx]:.4f}, "
          f"Trivia mean = {mean_trivia[idx]:.4f}, "
          f"% increase = {percentage_increase[idx]:.2f}%")

Feature 16334: Aqua mean = 0.7922, Trivia mean = 0.0183, % increase = 4219.80%
Feature 22315: Aqua mean = 0.2375, Trivia mean = 0.0168, % increase = 1312.60%
Feature 5853: Aqua mean = 0.1044, Trivia mean = 0.0420, % increase = 148.75%
Feature 4312: Aqua mean = 0.1095, Trivia mean = 0.0536, % increase = 104.39%
Feature 6066: Aqua mean = 0.1804, Trivia mean = 0.0936, % increase = 92.69%
Feature 4966: Aqua mean = 0.2608, Trivia mean = 0.1587, % increase = 64.36%
Feature 12284: Aqua mean = 0.2686, Trivia mean = 0.1935, % increase = 38.82%
Feature 31801: Aqua mean = 2.9333, Trivia mean = 2.5805, % increase = 13.67%
Feature 29275: Aqua mean = 7.6936, Trivia mean = 7.4193, % increase = 3.70%
Feature 4520: Aqua mean = 0.7000, Trivia mean = 0.7067, % increase = -0.95%


In [159]:
for feat in reasoning_feats:
    print(f"Feature {feat}: Active in {100 * (means_r[:,feat] > 0).sum() / len(aq_prompts):.2f}% of AQuA prompts and {100 * (means_nr[:,feat] > 0).sum() / len(tr_prompts):.2f}% of Trivia prompts")
    

Feature 16334: Active in 93.70% of AQuA prompts and 38.80% of Trivia prompts
Feature 22315: Active in 96.85% of AQuA prompts and 34.80% of Trivia prompts
Feature 5853: Active in 98.03% of AQuA prompts and 56.00% of Trivia prompts
Feature 4312: Active in 96.06% of AQuA prompts and 60.40% of Trivia prompts
Feature 6066: Active in 99.61% of AQuA prompts and 90.00% of Trivia prompts
Feature 4966: Active in 99.21% of AQuA prompts and 97.20% of Trivia prompts
Feature 12284: Active in 100.00% of AQuA prompts and 100.00% of Trivia prompts
Feature 31801: Active in 100.00% of AQuA prompts and 100.00% of Trivia prompts
Feature 29275: Active in 100.00% of AQuA prompts and 100.00% of Trivia prompts
Feature 4520: Active in 99.61% of AQuA prompts and 99.60% of Trivia prompts
