In [154]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from sae_lens import SAE
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from functools import partial
import string
import bitsandbytes
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from scipy.stats import trim_mean

In [2]:
from huggingface_hub import login
login(token="REDACTED")

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",  
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_enable_fp32_cpu_offload=True
)

In [4]:
device = utils.get_device()
model_name = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'
# model = HookedTransformer.from_pretrained_no_processing(model_name, device=device, dtype='float16')
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config
)

Loading checkpoint shards: 100%|██████████| 2/2 [00:16<00:00,  8.29s/it]


In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [6]:
release = "llama_scope_r1_distill"
sae_id = "l25r_400m_slimpajama_400m_openr1_math"
sae, cfg_dict, sparsity = SAE.from_pretrained_with_cfg_and_sparsity(release, sae_id)
sae = sae.to(device)

In [7]:
def format_prompt_aqua(query, reasoning=True):
    question, options = query['question'], query['options']
    joined_options = "\n".join(options)
    COT_PROMPT = r'Please reason step by step, and put your final answer within \boxed{}.'
    if reasoning:
        return f'<s>[INST] {question}\n{COT_PROMPT} [/INST] \n<think>\n'
    else:
        return f'<s>[INST] {question}\n [/INST] \n'  

In [8]:
def format_prompt_tqa(question, reasoning=True):
    COT_PROMPT = r'Please reason step by step, and put your final answer within \boxed{}.'
    if reasoning:
        return f'<s>[INST] {question}\n{COT_PROMPT} [/INST] \n<think>\n'
    else:
        return f'<s>[INST] {question}\n [/INST] \n'  

In [9]:
def get_cot(model, tokenizer, query):
    """
    Function to get the response from a model to an MCQ - both the prediction and the reasoning (CoT)
    """

    MCQ_ANSWER_EXTRACTION_PROMPT = 'The correct answer is ('

    # Set up prompt for CoT case 
    prompt = format_prompt(query)

    # Tokenize prompt
    prompt_token_ids = tokenizer(
        [prompt], 
        padding=True,
        truncation=True, 
        max_length=8192,
        return_attention_mask=True,
        return_tensors='pt'
    ).to(model.device)

    # Generate model response based on input
    # No sampling, just deterministic output to analyze faithfulness
    with torch.inference_mode():
        out_tok_ids = model.generate(
            input_ids=prompt_token_ids['input_ids'], 
            attention_mask=prompt_token_ids['attention_mask'],
            pad_token_id=tokenizer.eos_token_id,
            max_new_tokens=1000,
            do_sample=True,
            temperature=0.6,
            top_p=0.95
        )[0]

    # Decode output tokens into English and parse output to remove question + options
    output = tokenizer.decode(out_tok_ids.cpu(), skip_special_tokens=False)
    output = "".join(output.split('[/INST]')[1:])
    
    # Tokenize answer extraction prompt for prompting model again for answer post CoT
    answer_extraction_prompt = tokenizer(
        MCQ_ANSWER_EXTRACTION_PROMPT, 
        truncation=True, 
        return_tensors='pt')['input_ids'].to(model.device)

    if out_tok_ids[-1] == tokenizer.eos_token_id:
        out_tok_ids = out_tok_ids[:-1]
    out_tok_ids_w_answer_extraction = torch.concat([out_tok_ids, answer_extraction_prompt[0]])
    out = model(out_tok_ids_w_answer_extraction.unsqueeze(0))

    # Getting valid options from ASCII and converting to integer representation 
    options = query['options']
    letters = list(string.ascii_uppercase)[:len(options)]
    valid_option_ids = [tokenizer.convert_tokens_to_ids(o) for o in letters]

    pred_idx = torch.argmax(out.logits[0, -1, valid_option_ids]).cpu()
    pred = tokenizer.convert_ids_to_tokens(valid_option_ids[pred_idx])

    return pred, output

In [10]:
query = {"question": "In the coordinate plane, points (x, 1) and (5, y) are on line k. If line k passes through the origin and has slope 1/5, then what are the values of x and y respectively?", "options": ["A)4 and 1", "B)1 and 5", "C)5 and 1", "D)3 and 5", "E)5 and 3"], "rationale": "Line k passes through the origin and has slope 1/5 means that its equation is y=1/5*x.\nThus: (x, 1)=(5, 1) and (5, y) = (5,1) -->x=5 and y=1\nAnswer: C", "correct": "C"}


In [148]:
get_cot(model, tokenizer, query)

('C',
 ' \n\n<think>\nAlright, let\'s try to tackle this problem step by step. So, we have two points on line k: (x, 1) and (5, y). The line k passes through the origin and has a slope of 1/5. We need to find the values of x and y from the given options.\n\nFirst, let me recall the equation of a line in slope-intercept form. It\'s usually written as y = mx + b, where m is the slope and b is the y-intercept. However, since the line passes through the origin, the y-intercept b should be zero. So, the equation simplifies to y = mx.\n\nGiven the slope is 1/5, the equation of line k becomes y = (1/5)x. That seems straightforward.\n\nNow, since both points (x, 1) and (5, y) lie on this line, they must satisfy the equation y = (1/5)x.\n\nLet\'s plug in the first point (x, 1). Substituting into the equation, we get:\n1 = (1/5)x.\n\nHmm, solving for x, I can multiply both sides by 5:\n5 * 1 = x\nSo, x = 5.\n\nWait, that seems too straightforward. Let me check if I did that correctly. If x is 5,

What are we trying to achieve here?

We have a model and a pretrained SAE for that model. We want to try and find reasoning-related components through the use of the SAE. What features reliably fire for questions that require reasoning, vs "regular" questions. What features fire when the CoT prompt is present vs absent? When we run faithfulness experiments, is there any feature that corresponds to hint verbalization? Is there a difference in feature activations between models that have the same answer prior to CoT and those that use the CoT to get to the answer?

So first, let's try and find reasoning-related SAE features. To do that, we can try to find features that consistently activate strongly when the model is given questions that require reasoning, but don't activate when given other random input (e.g John went to the market today).

In [11]:
aqua_ds = load_dataset('aqua_rat', 'raw', split='test')

In [12]:
aqua_ds

Dataset({
    features: ['question', 'options', 'rationale', 'correct'],
    num_rows: 254
})

In [13]:
def get_sae_acts(text, layer, agg='mean'):
    model.eval()
    with torch.no_grad():
        outputs = model(**text, output_hidden_states=True, return_dict=True)
    hidden_states = outputs.hidden_states
    raw_feats = sae.encode(hidden_states[layer])
    if agg == 'mean':
        return raw_feats[0].mean(axis=0) 
    elif agg == 'last':
        return raw_feats[0, -1]

In [14]:
means_r = torch.zeros((aqua_ds.num_rows, sae.cfg.d_sae))

In [15]:
for index, question in enumerate(tqdm(aqua_ds)):
    prompt = format_prompt_aqua(question, reasoning=False)
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    means_r[index, :] = get_sae_acts(inputs, 25)

100%|██████████| 254/254 [00:23<00:00, 10.61it/s]


In [16]:
lasts_r = torch.zeros((aqua_ds.num_rows, sae.cfg.d_sae))

In [17]:
for index, question in enumerate(tqdm(aqua_ds)):
    prompt = format_prompt_aqua(question, reasoning=False)
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    lasts_r[index, :] = get_sae_acts(inputs, 25, agg='last')

100%|██████████| 254/254 [00:23<00:00, 10.64it/s]


In [18]:
nr_ds = load_dataset("mandarjoshi/trivia_qa", "rc", split='test')

In [19]:
means_nr = torch.zeros((250, sae.cfg.d_sae))

In [20]:
for index, question in enumerate(tqdm(nr_ds[:250]['question'])):
    prompt = format_prompt_tqa(question, reasoning=False)
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    means_nr[index, :] = get_sae_acts(inputs, 25)

100%|██████████| 250/250 [00:19<00:00, 12.72it/s]


In [21]:
lasts_nr = torch.zeros((250, sae.cfg.d_sae))

In [22]:
for index, question in enumerate(tqdm(nr_ds[:250]['question'])):
    prompt = format_prompt_tqa(question, reasoning=False)
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    lasts_nr[index, :] = get_sae_acts(inputs, 25, agg='last')

100%|██████████| 250/250 [00:19<00:00, 12.66it/s]


In [156]:
# mean_aqua = means_r.mean(axis=0)
mean_aqua = trim_mean(means_r.detach(), proportiontocut=0.05, axis=0)

In [157]:
# mean_trivia = means_nr.mean(axis=0) 
mean_trivia = trim_mean(means_nr.detach(), proportiontocut=0.05, axis=0)

In [159]:
epsilon = 1e-6
percentage_increase = 100 * (mean_aqua - mean_trivia) / (mean_trivia + epsilon)

In [160]:
valid = (mean_trivia > 0.01) & (mean_aqua > 0.1)

In [161]:
filtered_percentage_increase = percentage_increase[valid]

In [162]:
valid_indices = np.where(valid)[0]  

In [165]:
ranked_order = np.argsort(-filtered_percentage_increase)
ranked_feature_indices = valid_indices[ranked_order] 
reasoning_feats = []

top_k = 10
for i in range(top_k):
    idx = ranked_feature_indices[i]
    reasoning_feats.append(idx)
    print(f"Feature {idx}: Aqua mean = {mean_aqua[idx]:.4f}, "
          f"Trivia mean = {mean_trivia[idx]:.4f}, "
          f"% increase = {percentage_increase[idx]:.2f}%")

Feature 16334: Aqua mean = 1.3904, Trivia mean = 0.0126, % increase = 10952.28%
Feature 22315: Aqua mean = 0.4217, Trivia mean = 0.0142, % increase = 2864.22%
Feature 9971: Aqua mean = 0.1110, Trivia mean = 0.0293, % increase = 278.54%
Feature 5853: Aqua mean = 0.1697, Trivia mean = 0.0533, % increase = 218.07%
Feature 4312: Aqua mean = 0.1986, Trivia mean = 0.0863, % increase = 130.04%
Feature 6066: Aqua mean = 0.3660, Trivia mean = 0.1667, % increase = 119.53%
Feature 4966: Aqua mean = 0.5258, Trivia mean = 0.3160, % increase = 66.37%
Feature 31789: Aqua mean = 0.1309, Trivia mean = 0.1006, % increase = 30.07%
Feature 12284: Aqua mean = 0.5369, Trivia mean = 0.4195, % increase = 27.98%
Feature 31801: Aqua mean = 5.8519, Trivia mean = 5.2715, % increase = 11.01%


In [166]:
for feat in reasoning_feats:
    print(f"Feature {feat}: Aqua active = { 100 * (means_r[:, feat] > 0).sum() / 254:.4f}, Trivia active = {100 * (means_nr[:, feat] > 0).sum() / 250:.4f}")
    

Feature 16334: Aqua active = 92.1260, Trivia active = 19.2000
Feature 22315: Aqua active = 95.6693, Trivia active = 20.8000
Feature 9971: Aqua active = 79.9213, Trivia active = 31.2000
Feature 5853: Aqua active = 94.4882, Trivia active = 43.2000
Feature 4312: Aqua active = 93.3071, Trivia active = 57.6000
Feature 6066: Aqua active = 99.2126, Trivia active = 80.8000
Feature 4966: Aqua active = 99.6063, Trivia active = 97.6000
Feature 31789: Aqua active = 100.0000, Trivia active = 100.0000
Feature 12284: Aqua active = 100.0000, Trivia active = 100.0000
Feature 31801: Aqua active = 100.0000, Trivia active = 100.0000
