# Step1. Generate the examples from LLMs.

The interpretation model aims to explain how LLMs generate tokens. To understand the behavior of each LLM, we first need to collect the corpus reflecting how the model interprets each question, and then train transcoders to extract the corresponding interpretive structure.

In [1]:
import config_env
import os
        
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

import json
from datasets import load_dataset
import re
from model_load import load_model, model_name_func

    
name = 'Qwen3-0.6B'
dataset_name = 'maw'
model_name =  model_name_func(name)# 'meta-llama/Llama-3.2-1B'# "Qwen/Qwen2.5-1.5B"
dataset = load_dataset("mwpt5/MAWPS")
test = dataset["train"]
print(dataset['train'][0])
print(len(test))
model = load_model(name)
print(model.cfg.n_layers)

`torch_dtype` is deprecated! Use `dtype` instead!


{'Question': 'Mary is baking a cake . The recipe wants N_00 cups of flour . She already put in N_01 cups . How many cups does she need to add ?', 'Equation': 'N_00 - N_01', 'Answer': 6.0, 'Numbers': '8.0 2.0'}
1772
Loaded pretrained model Qwen/Qwen3-0.6B into HookedTransformer
28


In [2]:
from generate_supposed_answers import generate_MAWPS
text = []
for idx,data in enumerate(test):
    question, prompt, ans, max_token = generate_MAWPS(data)
    generated_text = model.generate(
        prompt,
        max_new_tokens=max_token,  
        temperature=0,     
    )
    print(generated_text, ans)
    text.append({'question': question, 'gold_ans': ans, 'ans':generated_text})
    print(idx)
    break
# in jupyter note book we only run one example, to run the whole dataset, please refer generate_supposed_answer.py

  0%|          | 0/200 [00:00<?, ?it/s]

Question: Mary is baking a cake . The recipe wants 8.0 cups of flour . She already put in 2.0 cups . How many cups does she need to add ?
Let's think step by step
Answer:
To find out how many cups Mary needs to add, we start with the total amount of flour required by the recipe, which is 8.0 cups. She has already put in 2.0 cups, so we subtract that from the total. 

8.0 cups (recipe) - 2.0 cups (already added) = 6.0 cups. 

Therefore, Mary needs to add 6.0 cups of flour.
Answer: 6.0

Step-by-Step Explanation:
1. The recipe requires 8.0 cups of flour.
2. Mary has already added 2.0 cups.
3. To find the remaining amount needed, subtract the already added from the required total: 8.0 - 2.0 = 6.0.
4. The answer is 6.0 cups.
Answer: 6.0
The answer is 6.0 cups.
Answer: 6.0
The answer is  6.0
0


# Step2. Train transcoder

### Step2.1 set the backbone model and datasets

In [1]:
import sys
from tqdm import tqdm
import torch
from torch.utils.data import IterableDataset
import json
from model_load import load_model, model_name_func

class MyIterableDataset(IterableDataset):
    def __init__(self, texts, tokenizer):
        self.examples = []
        self.token_count = 0
        for text in tqdm(texts, total=len(texts)):
            ids = tokenizer.encode(text, add_special_tokens=False)
            self.examples.append(torch.tensor(ids).long())
            self.token_count += len(ids)

    def __iter__(self):
        for item in self.examples:
            yield {"tokens": item} 
    def __len__(self):
        return len(self.examples)
    
device = 'cuda:0'
acts_func = 'relu'
train_dataset_name = 'maw'
from_pretrained = False
name = 'Qwen3-0.6B'
model_name = model_name_func(name)
model = load_model(name, device=device)
tokenizer = model.tokenizer


save_json = f'./data_{train_dataset_name}/{name}_answer.json'
with open(save_json,'r') as f:
    test_data = json.load(f)
    
training_corpus = []
for data in test_data:
    training_corpus.append(data['ans'])
valid_corpus = training_corpus[:10]


train_dataset = MyIterableDataset(training_corpus, tokenizer)
valid_dataset = MyIterableDataset(valid_corpus, tokenizer)

transcoder_save_path = f'/egr/research-dselab/shared/transcoder_model/{acts_func}_{name}_{train_dataset_name}'
activate_cache_dir = '/egr/research-dselab/shared/daixinna/huggingface_path/activations_cache'


`torch_dtype` is deprecated! Use `dtype` instead!


Loaded pretrained model Qwen/Qwen3-0.6B into HookedTransformer


100%|██████████| 1772/1772 [00:00<00:00, 3459.25it/s]
100%|██████████| 10/10 [00:00<00:00, 3239.09it/s]


### Step2.2 set the parameters for transcoder training

In [2]:

sys.path.append('../')
from circuit_tracer.configs import Configs
from circuit_tracer.transcoder.activation_functions import JumpReLU, Relu, TopK
from circuit_tracer.transcoder.single_layer_transcoder import SingleLayerTranscoder
from circuit_tracer.transcoder.activations_store import ActivationsStore
from circuit_tracer.transcoder_training import train_sae_on_language_model


train_dataset_name = 'maw'
l1_co = 0.00005
max_length = 200
total_training_tokens = 5_000_000

n_layer = model.cfg.n_layers
d_model = model.cfg.d_model

acts_func = 'relu'
transcoder_config = {}
transcoder_config['model_name'] = model_name
transcoder_config['train_dataset_name'] = train_dataset_name
transcoder_config['dataset_path'] = None
transcoder_config['pattern_name'] = 'maw'
transcoder_config['method'] = 'real'
transcoder_config['is_tokenized'] = True
transcoder_config['max_length'] = max_length
transcoder_config['total_train_num'] = total_training_tokens
transcoder_config['acts_func'] = acts_func

# turn epoch if the training set is small while MSE still high
transcoder_config['epoch'] = 1
transcoder_config['batch_size'] = 1024
transcoder_config['train_batch_size'] = 1024
transcoder_config['resample_batches'] = 1024
# turn this parameters when out of memory (very efficient)
transcoder_config['n_batches_in_buffer'] = 32
transcoder_config['store_batch_size'] = 16
# turn this parameter to balance the MSE and L1 loss
transcoder_config['l1_coefficient'] = l1_co

# decompose the features to sparse space (can turn the parameters)
transcoder_config['d_transcoder'] = d_model * 2

# large dead feature window is efficient when large vocab size
transcoder_config['dead_feature_window'] = d_model
transcoder_config['dead_feature_threshold'] = 1e-8
# do not be too small. Better increasing training num instead of large lr
transcoder_config['lr'] = 1e-3

base_model_configs = {}
base_model_configs['d_model'] = d_model
base_model_configs['max_length'] = max_length

### Step2.3 Train Transcoders.

In [3]:
for target_layer in range(0, n_layer):
    configs = Configs.init_setup(target_layer, transcoder_config, base_model_configs, transcoder_save_path, device, activate_cache_dir)
    model = model.to(configs.act_store_device)
    configs.tokenizer_name = None
    
    activation_stores = ActivationsStore(configs, model, dataset=train_dataset, tokenizer=tokenizer)

    if acts_func == 'relu':
        transcoder = SingleLayerTranscoder(configs, Relu())# JumpReLU(0.0, 0.1))

    record_scores = train_sae_on_language_model(configs, model, transcoder, activation_stores, use_eval=False)
    print(f"Target Layer: {target_layer}")
    break
# Note, in jupyter file we only run 1 layer. For the whole training please use train_transcoder.py file



Moving model to device:  cuda:0


  tokens = torch.tensor(


5000000
Epoch 1/1


4882| MSE Loss 0.001 | L1 0.000: : 5000192it [01:24, 58987.09it/s]                           

Saved model to /egr/research-dselab/shared/transcoder_model/relu_Qwen3-0.6B_maw/0.pt
Model saved to /egr/research-dselab/shared/transcoder_model/relu_Qwen3-0.6B_maw/0.pt
Target Layer: 0





# Step3. collect the attribution graphs from each samples

(This step will cost a lot of computation resource due to the reasoning chain could be very long)

Basic idea:

We know that an LLM ultimately produces a final answer, while the intermediate tokens form the thinking chain. Therefore, we trace backward from the final token to examine the thinking chain and identify which tokens meaningfully contributed to the final prediction.

The generated tokens always depend on the given question. Thus, the backward tracing process stops once the contributing tokens originate from the input question.

### Step3.1 setup the model

In [6]:
import config_env
import os
import gc

os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch
import sys
import json

import random
import numpy as np
import torch
import re
from model_load import load_model, model_name_func
import networkx as nx
sys.path.append('../')
from circuit_tracer.replacement_model import ReplacementModel
from circuit_tracer.configs import Configs
from circuit_tracer.graph import Graph, prune_graph
from circuit_tracer import attribute
from circuit_tracer.utils.create_graph_files import create_nodes, create_used_nodes_and_edges

# controlling the density of attribution graph
edge_ratio = 0.5
node_ratio = 0.4

name = 'Qwen3-0.6B'
cache_root = "/egr/research-dselab/daixinna/shared/huggingface_path"
# use this to setup multi gpu usasge
cuda_range = [1]
cuda_list = [f'cuda:{i}' for i in cuda_range]

model_name = model_name_func(name)
model = load_model(name, device=f'cuda:{cuda_range[0]}')

tokenizer = model.tokenizer
n_layer = model.cfg.n_layers
d_model = model.cfg.d_model
split_layer = int(n_layer/len(cuda_range))

plan = []
for s, cuda_id in enumerate(cuda_range):
    end = (s + 1)*split_layer
    if end>n_layer:
        end = n_layer
    plan.append((torch.device(f'cuda:{cuda_id}'),range(s*split_layer,end)))

device_map = {}
for dev, layers in plan:
    for L in layers:
        device_map[L] = dev

acts_func = 'relu'
train_dataset_name = 'maw'

transcoder_save_path = f'/egr/research-dselab/shared/transcoder_model/{acts_func}_{name}_{train_dataset_name}'
transcoder_model_save_path = f'/egr/research-dselab/shared/transcoder_model/{acts_func}_{name}_{train_dataset_name}'

transcoder_config_path = os.path.join(transcoder_save_path,'config.json')
configs = Configs.init_load(transcoder_config_path)
configs.set_device_map(device_map)

del model
model = ReplacementModel.from_self_pretrained_and_transcoders(cfg=configs,model_name=name, model_path = cache_root, transcoders_path = transcoder_model_save_path)
model.forward = model.forward_sharded
model.set_device_map(device_map)
model.eval() 

Loaded pretrained model Qwen/Qwen3-0.6B into HookedTransformer
Loaded pretrained model Qwen/Qwen3-0.6B into HookedTransformer


ReplacementModel(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-27): 28 x TransformerBlock(
      (ln1): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (q_norm): RMSNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (k_norm): RMSNorm(
          (hook_scale): HookPoint()
          (hook_normalized): HookPoint()
        )
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): ReplacementMLP(
        (old_mlp): GatedMLP(
          (hook_pre): HookPoint()


### Step3.2. Setup dataset

In [7]:
save_json = f'./data_{configs.pattern_name}/{name}_answer.json'
with open(save_json,'r') as f:
    test_data = json.load(f)


match_collect = {}
for idx, data in enumerate(test_data):
    text = str(data['gold_ans'])
    match_collect[idx] = text

### Step3.3 Setup functions

In [9]:
def translate_node_ids(graph, node_id, given_type):
    node_id = node_id.split('_')
    if given_type == 'embedding':
        vocab_id = int(node_id[1])
        pos = int(node_id[2])
        id = tokenizer.decode(vocab_id)
        layer = 'Emb'
    if given_type == "mlp reconstruction error":
        _, layer, pos = int(node_id[0]), int(node_id[1]), int(node_id[2])
        id = 'block_inner'
    if given_type == "logit":
        layer, vocab_id, pos = int(node_id[0]), int(node_id[1]), int(node_id[2])
        id = tokenizer.decode(vocab_id)
    if given_type == "cross layer transcoder":
        layer, feat_idx, pos = int(node_id[0]), int(node_id[1]), int(node_id[2])
        id = tokenizer.decode(graph.input_tokens[pos])
    return id, pos, layer

def nodes_at_n_hop(G, source, n, filter = ''):
    current_layer = set([source])
    all_traversed = set([source])
    for _ in range(n):
        next_layer = set()
        for node in current_layer:
            if node not in G:
                continue
            next_layer.update(G.successors(node))
            all_traversed.update(G.successors(node))
        current_layer = next_layer
    if filter != '':
        all_traversed = {node for node in all_traversed if node.split('_')[1] == filter}
    return all_traversed

def implicit_route(prompt_str, ans_pos):
    graph = attribute(
            prompt=prompt_str,
            model=model,
            # input_ids=input_ids,
            input_ids=None,
            max_n_logits=max_n_logits,
            desired_logit_prob=desired_logit_prob,
            batch_size=batch_size,
            max_feature_nodes=max_feature_nodes,
            offload=offload,
            verbose=verbose,
            print_log=False,
        )

        
    node_threshold = node_ratio
    edge_threshold = edge_ratio# 0.99


    node_mask, edge_mask, cumulative_scores = (
        el.cpu() for el in prune_graph(graph, node_threshold, edge_threshold)
    )
    scan = graph.scan
    nodes = create_nodes(graph, node_mask, tokenizer, cumulative_scores, scan)
    used_nodes, used_edges = create_used_nodes_and_edges(graph, nodes, edge_mask)
    nodes_dicts = {}
    for n in used_nodes:
        nodes_dicts[n.node_id] = {}
        nodes_dicts[n.node_id]['features'] = n.feature_type
        nodes_dicts[n.node_id]['clerp'] = n.clerp
    
    logit_nodes = []
    result_graph = {}
    for e in used_edges:
        if nodes_dicts[e['source']]['features'] == "mlp reconstruction error": continue
        if nodes_dicts[e['target']]['features'] == "mlp reconstruction error": continue
        if nodes_dicts[e['target']]['features'] == "logit" and nodes_dicts[e['source']]['features'] == 'embedding': continue
        # if nodes_dicts[e['source']]['features'] == "cross layer transcoder":
        weights = e['weight']
        if weights<0. :continue
        source_node, source_pos, source_layer = translate_node_ids(graph, e['source'], nodes_dicts[e['source']]['features'])
        target_node, target_pos, target_layer = translate_node_ids(graph, e['target'], nodes_dicts[e['target']]['features'])
        # print(f"Edge from {source_node} {e['source']} ({nodes_dicts[e['source']]['features']}, pos: {source_pos}, layer: {source_layer}) to {target_node} {e['target']} ({nodes_dicts[e['target']]['features']}, pos: {target_pos}, layer: {target_layer}) with weight {weights}")
        
        start_node = f'{source_node}_{source_layer}_{source_pos}'
        end_node = f'{target_node}_{target_layer}_{target_pos}'
        if end_node not in result_graph:
            result_graph[end_node] = {}
        if start_node not in result_graph[end_node]:
            result_graph[end_node][start_node] = 0
        result_graph[end_node][start_node] += weights
        
        if nodes_dicts[e['target']]['features'] == 'logit':
            logit_nodes.append(end_node)
    
    result_graph = {outer_k: {inner_k: inner_v 
            for inner_k, inner_v in outer_v.items() if inner_v > 0} 
    for outer_k, outer_v in result_graph.items()}
    result_graph = {k: v for k, v in result_graph.items() if v}
    result_graph_G = nx.DiGraph()
    
    for key in result_graph.keys():
        for inner_key in result_graph[key].keys():
            result_graph_G.add_edge(key, inner_key, weight=result_graph[key][inner_key])
    
    marked_pos = []
    repeat_check = set()
    for target_node in list(set(logit_nodes)):    
        target_layers = nodes_at_n_hop(result_graph_G, target_node, n_layer + 2)
        for t in target_layers:
            # TODO some text may include '_' may change a method to get features in the future version
            if len(t.split('_')) != 3: continue
            word, layer, pos = t.split('_')
            if int(pos) > ans_pos: 
                tmp_t = f'{word}_{layer}_{pos}_1'
            else:tmp_t = f'{word}_{layer}_{pos}_0'
            if tmp_t in repeat_check:continue
            repeat_check.add(tmp_t)
            if int(pos) > ans_pos and int(pos) not in marked_pos:
                marked_pos.append(int(pos))
    
    del graph
    gc.collect() 
    torch.cuda.empty_cache()
    # print(marked_pos)
    return marked_pos, repeat_check, result_graph_G

### Step3.4. collect the attribution graphs

In [None]:
max_n_logits = 10   
desired_logit_prob = 0.95  # Attribution will attribute from the minimum number of logits needed to reach this probability mass (or max_n_logits, whichever is lower)
max_feature_nodes = 2048  # Only attribute from this number of feature nodes, max. Lower is faster, but you will lose more of the graph. None means no limit.
batch_size=128
offload = 'cpu'# 'disk' if IN_COLAB else 'cpu' # Offload various parts of the model during attribution to save memory. Can be 'disk', 'cpu', or None (keep on GPU)
verbose = True 

count = 0

print(len(test_data))

# ⚠️ Warning: change you own idx range in test set
# Here we show generate attribution graph for onw sample.
# Also remember to save the graphs as suggested in collecting_attribution_graph.py
start = 0
end = 1
test_data = test_data[start:end]

tokenizer_Answer = tokenizer('Answer')['input_ids'][0]

for idx, data in enumerate(test_data):
    # files = [int(f.split('.')[0]) for f in os.listdir(save_path) if os.path.isfile(os.path.join(save_path, f))]
    real_id = start + idx

    for node_thre in [node_ratio]:
        overall_acc = []
        layer_counts = {}
        counts = 0
        
        prompt = "Question: " + data['question'] + "\nLet's think step by step\nAnswer:\n"
        ans = data['ans']
        
        try:
            tokens = tokenizer(ans, return_offsets_mapping=True)
            offsets = tokens['offset_mapping']
            ids = tokens['input_ids']
            ans_position = ids.index(tokenizer_Answer)
            position_lookup = {}
            if real_id not in match_collect: continue
            last_pos_ans = ans.rfind(match_collect[real_id])
            ans_token = tokenizer(ans[:last_pos_ans])['input_ids']
            dicts = {}
            print(f"    current_idx:{idx}")
            marked_pos, repeat_check, graph = implicit_route(ans[:last_pos_ans], ans_position)
            dicts['**ans**'] = match_collect[real_id]
            dicts[ans[last_pos_ans:last_pos_ans+1]] = {}
            dicts[ans[last_pos_ans:last_pos_ans+1]]['last'] = repeat_check
            graph_list = [graph]

            i = 0
            while i < len(marked_pos):
                m = marked_pos[i]
                words = ans[offsets[m][0]:offsets[m][1]]
                
                if words not in dicts:
                    dicts[words] = {}
                new_marked_pos, new_repeat_check, new_graph = implicit_route(ans[:offsets[m][0]], ans_position)
                dicts[words][m] = new_repeat_check
                graph_list.append(new_graph)
                added = 0
                for new_m in new_marked_pos:
                    if new_m not in marked_pos:
                        marked_pos.append(new_m)
                        added += 1  
                
                i += 1
                marked_pos_lens = len(marked_pos)
                print(f"interpret token idx:{i}, marked pos lenths:{marked_pos_lens}")
                
            print(f"idx:{idx}, marked:{len(marked_pos)}")

        except Exception as e:
            import traceback
            traceback.print_exc()
            print("Error message:", e)           
    count += 1

100
    current_idx:0
interpret token idx:1, marked pos lenths:9
interpret token idx:2, marked pos lenths:12
interpret token idx:3, marked pos lenths:13
interpret token idx:4, marked pos lenths:13
interpret token idx:5, marked pos lenths:16
interpret token idx:6, marked pos lenths:17
interpret token idx:7, marked pos lenths:18
interpret token idx:8, marked pos lenths:19
interpret token idx:9, marked pos lenths:20
interpret token idx:10, marked pos lenths:22
interpret token idx:11, marked pos lenths:24
interpret token idx:12, marked pos lenths:25
interpret token idx:13, marked pos lenths:26
interpret token idx:14, marked pos lenths:27
interpret token idx:15, marked pos lenths:28
interpret token idx:16, marked pos lenths:31
interpret token idx:17, marked pos lenths:34
interpret token idx:18, marked pos lenths:35
interpret token idx:19, marked pos lenths:36
interpret token idx:20, marked pos lenths:38
interpret token idx:21, marked pos lenths:42
interpret token idx:22, marked pos lenths:4