In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from sae_lens import SAE
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens import HookedTransformer
from functools import partial
import string
import bitsandbytes
from datasets import load_dataset
from tqdm import tqdm
import numpy as np
from scipy.stats import trim_mean
from torch.utils.data import DataLoader, Dataset
import gc
import time

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
from huggingface_hub import login
login(token="REDACTED")

In [3]:
# torch.cuda.empty_cache()   
# gc.collect()

In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",  
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_enable_fp32_cpu_offload=True
)

In [5]:
device = utils.get_device()
model_name = 'deepseek-ai/DeepSeek-R1-Distill-Llama-8B'
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config
)

Fetching 2 files: 100%|██████████| 2/2 [02:35<00:00, 77.70s/it] 
Loading checkpoint shards: 100%|██████████| 2/2 [00:18<00:00,  9.17s/it]


In [6]:
model.eval()

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((409

In [7]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [8]:
release = "llama_scope_r1_distill"
sae_id = "l25r_400m_slimpajama_400m_openr1_math"
sae, cfg_dict, sparsity = SAE.from_pretrained_with_cfg_and_sparsity(release, sae_id)
sae = sae.to(device)

In [9]:
def format_prompt_aqua(query, reasoning=True, include_options=True):
    COT_PROMPT = r'Please reason step by step, and put your final answer within \boxed{}.'
    question, options = query['question'], query['options']
    joined_options = "\n".join(options) if include_options else ""
    if reasoning:
        return f'<s>[INST] {question}{joined_options}\n{COT_PROMPT} [/INST] \n<think>\n'
    else:
        return f'<s>[INST] {question}{joined_options}\n [/INST] \n'

In [10]:
def format_prompt_tqa(question, reasoning=True):
    COT_PROMPT = r'Please reason step by step, and put your final answer within \boxed{}.'
    if reasoning:
        return f'<s>[INST] {question}\n{COT_PROMPT} [/INST] \n<think>\n'
    else:
        return f'<s>[INST] {question}\n [/INST] \n'  

In [11]:
class TokenizedPromptDataset(Dataset):
    def __init__(self, tokenized_inputs, queries):
        self.input_ids = tokenized_inputs['input_ids']
        self.attn_masks = tokenized_inputs['attention_mask']
        self.queries = queries

    def __len__(self):
        return len(self.queries)

    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attn_masks[idx],
            'query': self.queries[idx]
        }

In [12]:
def collate_tokenized(batch):
    input_ids = torch.stack([item['input_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    queries = [item['query'] for item in batch]
    return {'input_ids': input_ids, 'attention_mask': attention_mask}, queries

In [14]:
def get_cot_batch(ds, batch_size, tokenizer, model, collate_fn):
    
    tokenized = tokenizer(
        [format_prompt_aqua(q, reasoning=True) for q in ds],
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=8192
    )

    dataset = TokenizedPromptDataset(tokenized, ds)
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

    all_preds = []
    all_generations = []
    MCQ_ANSWER_PROMPT = 'The correct answer is ('
    
    for batch_inputs, queries in tqdm(dataloader):
        
        input_ids = batch_inputs['input_ids'].to(model.device)
        attention_mask = batch_inputs['attention_mask'].to(model.device)

        with torch.inference_mode():
            output_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=512,
                do_sample=True,
                temperature=0.6,
                top_p=0.95,
                pad_token_id=tokenizer.eos_token_id
            )
    
        decoded = tokenizer.batch_decode(output_ids, skip_special_tokens=False)
        decoded = ["".join(out.split('[/INST]')[1:]) for out in decoded]
        all_generations.extend(decoded)
    
        answer_prompts = [text + MCQ_ANSWER_PROMPT for text in decoded]
        answer_inputs = tokenizer(answer_prompts, return_tensors='pt', padding=True, truncation=True).to(model.device)
    
        with torch.inference_mode():
            out = model(**answer_inputs)
    
        for i, query in enumerate(queries):
            options = query['options']
            letters = list(string.ascii_uppercase)[:len(options)]
            valid_ids = tokenizer.convert_tokens_to_ids(letters)
            logits = out.logits[i, -1, valid_ids]
            pred_idx = torch.argmax(logits).item()
            all_preds.append(letters[pred_idx])

    return all_preds, all_generations
    

In [18]:
preds, gens = get_cot_batch(aqua_ds, 16, tokenizer, model, collate_tokenized)

100%|██████████| 16/16 [14:41<00:00, 55.10s/it]


In [16]:
aqua_ds = load_dataset('aqua_rat', 'raw', split='test')

Downloading readme: 100%|██████████| 5.89k/5.89k [00:00<00:00, 19.3kB/s]
Downloading data: 100%|██████████| 25.4M/25.4M [00:02<00:00, 12.2MB/s]
Downloading data: 100%|██████████| 74.0k/74.0k [00:00<00:00, 239kB/s]
Downloading data: 100%|██████████| 76.1k/76.1k [00:00<00:00, 247kB/s]
Generating train split: 100%|██████████| 97467/97467 [00:00<00:00, 417344.26 examples/s]
Generating test split: 100%|██████████| 254/254 [00:00<00:00, 83406.66 examples/s]
Generating validation split: 100%|██████████| 254/254 [00:00<00:00, 97203.76 examples/s]


In [17]:
aqua_ds

Dataset({
    features: ['question', 'options', 'rationale', 'correct'],
    num_rows: 254
})

In [21]:
def get_sae_acts(input_batch, layer, agg='mean'):

    activation_dict = {}

    def hook_fn(module, input, output):
        activation_dict["hidden"] = output

    hook = model.model.layers[layer].register_forward_hook(hook_fn)

    model.eval()

    with torch.no_grad():
        _ = model(**input_batch)

    hook.remove()
    
    hidden_states = activation_dict['hidden']
    raw_feats = sae.encode(hidden_states)
    
    if agg == 'mean':
        result= raw_feats.mean(dim=1) 
    elif agg == 'last':
        result = raw_feats[:, -1]

    del hidden_states, raw_feats
    
    return result

In [22]:
class IndexedPromptDataset(Dataset):
    def __init__(self, num_examples):
        self.indices = list(range(num_examples))

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        return self.indices[idx]

def collate_tokenized(batch_indices, tokenized):
    return {k: v[batch_indices] for k, v in tokenized.items()}

In [23]:
def get_ds_saes(sae, layer, prompts, model, collate_fn, batch_size=8, agg='mean'):
    
    dataset = IndexedPromptDataset(len(prompts))
    dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)

    num_feats = sae.cfg.d_sae
    sae_mat = torch.zeros(len(prompts), num_feats)

    with torch.no_grad():
        for i, batch_inputs in enumerate(tqdm(dataloader)):
            batch_inputs = {k: v.to(model.device) for k, v in batch_inputs.items()}
            batch_feats = get_sae_acts(batch_inputs, layer=layer, agg=agg)
            start = i * batch_size
            end = start + batch_feats.shape[0]
            sae_mat[start:end] = batch_feats.cpu()

    return sae_mat

In [24]:
aq_prompts = [format_prompt_aqua(q) for q in aqua_ds]
aq_tokenized = tokenizer(aq_prompts, return_tensors='pt', padding=True, truncation=True)

In [26]:
aq_collate_fn = partial(collate_tokenized, tokenized=aq_tokenized)

In [27]:
means_r = get_ds_saes(sae, 25, aq_prompts, model, collate_fn=aq_collate_fn, agg='mean')

100%|██████████| 32/32 [00:12<00:00,  2.63it/s]


In [30]:
aq_cots_tokenized = tokenizer(gens, return_tensors='pt', padding=True, truncation=True)
aq_cot_collate_fn = partial(collate_tokenized, tokenized=aq_cots_tokenized)

In [31]:
means_cot = get_ds_saes(sae, 25, gens, model, collate_fn=aq_cot_collate_fn, agg='mean')

100%|██████████| 32/32 [00:39<00:00,  1.24s/it]


In [32]:
mean_q = trim_mean(means_r.detach(), proportiontocut=0.05, axis=0)

In [33]:
means_cot = trim_mean(means_cot.detach(), proportiontocut=0.05, axis=0)

In [34]:
epsilon = 1e-6
percentage_increase = 100 * (means_cot - mean_q) / (mean_q + epsilon)

In [52]:
valid = (means_cot > 0.1) & (mean_q > 0.01)

In [53]:
filtered_percentage_increase = percentage_increase[valid]

In [54]:
valid_indices = np.where(valid)[0]  

In [55]:
ranked_order = np.argsort(-filtered_percentage_increase)
ranked_feature_indices = valid_indices[ranked_order] 
reasoning_feats = []

top_k = 10
for i in range(top_k):
    idx = ranked_feature_indices[i]
    reasoning_feats.append(idx)
    print(f"Feature {idx}: Question mean = {mean_q[idx]:.4f}, "
          f"CoT mean = {means_cot[idx]:.4f}, "
          f"% increase = {percentage_increase[idx]:.2f}%")

Feature 10527: Question mean = 0.0141, CoT mean = 0.2704, % increase = 1812.44%
Feature 26077: Question mean = 0.0254, CoT mean = 0.4213, % increase = 1561.91%
Feature 11762: Question mean = 0.0420, CoT mean = 0.5053, % increase = 1101.90%
Feature 28534: Question mean = 0.0165, CoT mean = 0.1310, % increase = 694.80%
Feature 29395: Question mean = 0.0156, CoT mean = 0.1182, % increase = 657.48%
Feature 31564: Question mean = 0.0375, CoT mean = 0.2510, % increase = 568.85%
Feature 18463: Question mean = 0.0330, CoT mean = 0.2022, % increase = 512.06%
Feature 22078: Question mean = 0.0303, CoT mean = 0.1455, % increase = 380.09%
Feature 29825: Question mean = 0.0314, CoT mean = 0.1334, % increase = 324.75%
Feature 20549: Question mean = 0.0560, CoT mean = 0.2096, % increase = 274.50%
