In [1]:
import torch
from transformer_lens import HookedTransformer
from functools import partial
import torch.nn.functional as F
from eap.metrics import logit_diff, direct_logit
import transformer_lens.utils as utils
from eap.graph import Graph
from eap.dataset import EAPDataset
from eap.attribute import attribute
import time
from rich import print as rprint
import pandas as pd
from eap.evaluate import evaluate_graph, evaluate_baseline,get_circuit_logits

  from .autonotebook import tqdm as notebook_tqdm
  warn(


In [2]:
LLAMA_2_7B_CHAT_PATH = "meta-llama/Llama-2-7b-chat-hf"
from transformers import LlamaForCausalLM
model = HookedTransformer.from_pretrained(LLAMA_2_7B_CHAT_PATH, device="cuda", fold_ln=False, center_writing_weights=False, center_unembed=False)
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

Loading checkpoint shards: 100%|██████████| 2/2 [00:10<00:00,  5.06s/it]


Loaded pretrained model meta-llama/Llama-2-7b-chat-hf into HookedTransformer


In [3]:
clean_subject = 'Eiffel Tower'
corrupted_subject = 'the Great Walls'
clean = f'The official currency of the country where {clean_subject} is loacted in is the'
corrupted = f'The official currency of the country where {corrupted_subject} is loacted in is the'
assert len(model.to_str_tokens(clean.format(clean_subject))) == len(model.to_str_tokens(corrupted.format(corrupted_subject)))
labels = ['Euro','Chinese']
country_idx = model.tokenizer(labels[0],add_special_tokens=False).input_ids[0]
corrupted_country_idx = model.tokenizer(labels[1],add_special_tokens=False).input_ids[0]
# dataset = {k:[] for k in ['clean','country_idx', 'corrupted',  'corrupted_country_idx']}
# for k, v in zip(['clean', 'country_idx', 'corrupted', 'corrupted_country_idx'], [clean, country_idx, corrupted, corrupted_country_idx]):
#     dataset[k].append(v)
# df2 = pd.DataFrame.from_dict(dataset)
# df2.to_csv(f'capital_city.csv', index=False)

In [4]:
label = [[country_idx, corrupted_country_idx]]
label = torch.tensor(label)
data = ([clean],[corrupted],label)

In [None]:
# ds = EAPDataset(filename='capital_city.csv',task='fact-retrieval')
# dataloader = ds.to_dataloader(1)

In [6]:
g = Graph.from_model(model)
start_time = time.time()
# Attribute using the model, graph, clean / corrupted data and labels, as well as a metric
attribute(model, g, data, partial(logit_diff, loss=True, mean=True), method='EAP-IG-case', ig_steps=100)
# attribute(model, g, data, partial(direct_logit, loss=True, mean=True), method='EAP-IG-case', ig_steps=30)
# attribute(model, g, dataloader, partial(logit_diff, loss=True, mean=True), method='EAP-IG', ig_steps=30)
g.apply_topn(5000, absolute=True)
g.prune_dead_nodes()

g.to_json('graph.json')

# gz = g.to_graphviz()
# gz.draw(f'graph.png', prog='dot')

end_time = time.time()
execution_time = end_time - start_time
print(f"程序执行时间：{execution_time}秒")

程序执行时间：65.61587119102478秒


In [8]:
def get_component_logits(logits, model, answer_token, top_k=10):
    logits = utils.remove_batch_dim(logits)
    # print(heads_out[head_name].shape)
    probs = logits.softmax(dim=-1)
    token_probs = probs[-1]
    answer_str_token = model.to_string(answer_token)
    sorted_token_probs, sorted_token_values = token_probs.sort(descending=True)
    # Janky way to get the index of the token in the sorted list - I couldn't find a better way?
    correct_rank = torch.arange(len(sorted_token_values))[
        (sorted_token_values == answer_token).cpu()
    ].item()
    # answer_ranks = []
    # answer_ranks.append((answer_str_token, correct_rank))
    # String formatting syntax - the first number gives the number of characters to pad to, the second number gives the number of decimal places.
    # rprint gives rich text printing
    print(
        f"Performance on answer token: Rank: {correct_rank: <8} Logit: {logits[-1, answer_token].item():5.2f} Prob: {token_probs[answer_token].item():6.2%} Token: |{answer_str_token}|"
    )
    for i in range(top_k):
        print(
            f"Top {i}th token. Logit: {logits[-1, sorted_token_values[i]].item():5.2f} Prob: {sorted_token_probs[i].item():6.2%} Token: |{model.to_string(sorted_token_values[i])}|"
        )
    # rprint(f"[b]Ranks of the answer tokens:[/b] {answer_ranks}")

In [9]:
logits = get_circuit_logits(model, g, data)
get_component_logits(logits, model, answer_token=model.to_tokens('Euro',prepend_bos=False)[0], top_k=5)

Performance on answer token: Rank: 0        Logit: 16.94 Prob: 56.56% Token: |Euro|
Top 0th token. Logit: 16.94 Prob: 56.56% Token: |Euro|
Top 1th token. Logit: 15.96 Prob: 21.39% Token: |French|
Top 2th token. Logit: 14.06 Prob:  3.18% Token: |_|
Top 3th token. Logit: 13.95 Prob:  2.85% Token: |euro|
Top 4th token. Logit: 13.91 Prob:  2.74% Token: |Eu|


In [None]:
baseline = evaluate_baseline(model, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()
results = evaluate_graph(model, g, dataloader, partial(logit_diff, loss=False, mean=False)).mean().item()
print(f"Original performance was {baseline}; the circuit's performance is {results}")

In [19]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

In [20]:
model_path = "/hpc2hdd/home/hchen763/jhaidata/local_model/DeepSeek-R1-Distill-Qwen-1.5B"

# 1. 加载分词器和模型
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to("cuda")

In [52]:
messages = [
    {"role": "system", "content": "You are a helpful assistant. You should generate answer as short as possible."},
    {"role": "user", "content": "Natalia sold clips to 50 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"}
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", padding=True).to("cuda")
attention_mask = (inputs != tokenizer.pad_token_id).long()   # 通常如此，如果有 pad_token
outputs = model.generate(
    inputs,
    attention_mask=attention_mask,
    max_new_tokens=100,
    do_sample=False,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,
)

response_ids = outputs[0]
response = tokenizer.decode(response_ids)
print(response)
# print(response_ids)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


<｜begin▁of▁sentence｜>You are a helpful assistant. You should generate answer as short as possible.<｜User｜>Natalia sold clips to 50 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? Please reason step by step, and put the answer as short as possible.
</think>

Natalia sold 50 clips in April and half as many in May, which is 25 clips. In total, she sold 50 + 25 = 75 clips.  
**Answer:** 75<｜end▁of▁sentence｜>


In [53]:
prompt = """
    <｜begin▁of▁sentence｜>You are a helpful assistant. You should generate answer as short as possible.
    <｜User｜>Natalia sold clips to 50 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?
"""

# 编码，推理
inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.generate(
    **inputs,
    max_new_tokens=100,
    do_sample=False,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
print(response)


The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


</think>

Natalia sold clips in April to 50 friends and half as many in May. Therefore, she sold 25 clips in May. In total, she sold 50 + 25 = 75 clips in April and May.

Answer: \boxed{75}


In [18]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from tqdm import tqdm

# 显卡1
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = "cuda" if torch.cuda.is_available() else "cpu"

# 本地模型路径
model_path = "/hpc2hdd/home/hchen763/jhaidata/local_model/DeepSeek-R1-Distill-Qwen-1.5B"

# 载入分词器和模型
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, trust_remote_code=True).to(device)
model.eval()

# 载入GSM8K dev集
dataset = load_dataset("gsm8k", "main", split="test")

def make_chat(question):
    messages = [
        {"role": "system", "content": "You are a helpful assistant. You should generate answer as short as possible."},
        {"role": "user", "content": question.strip()}
    ]
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

def extract_number(ans):
    import re
    numbers = re.findall(r"[-+]?\d*\.?\d+", ans)
    return numbers[-1] if numbers else ans.strip()

correct = 0
total = 0

for sample in tqdm(dataset.select(range(100)), desc="Evaluating"):
    q = sample["question"]
    gt = sample["answer"]
    prompt = make_chat(q)
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=False,
            temperature=0.0,
            pad_token_id=tokenizer.eos_token_id
        )
    gen = tokenizer.decode(output[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    pred_ans = extract_number(gen)
    gt_ans = extract_number(gt)
    if pred_ans == gt_ans:
        correct += 1
    total += 1

acc = correct / total
print(f"Accuracy on first 100 GSM8K dev samples: {acc:.2%}")


Evaluating:   0%|          | 0/100 [00:00<?, ?it/s]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating:   1%|          | 1/100 [00:05<08:47,  5.33s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating:   2%|▏         | 2/100 [00:10<08:40,  5.31s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating:   3%|▎         | 3/100 [00:16<08:39,  5.35s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
Evaluating:   4%|▍         | 4/100 [00:19<07:26,  4.65s/it]The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details

Accuracy on first 100 GSM8K dev samples: 25.00%



