In [1]:
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ['MODELSCOPE_CACHE']="/data0/modelscope/qwen2.5"

In [2]:
from IPython import get_ipython
%load_ext autoreload
%autoreload 2

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'  #'last', 'last_expr'

In [3]:
import sys
sys.path.insert(0, '/home/hushengchun/python_library/TransformerLens/')

import transformer_lens
from argparse import ArgumentParser
from functools import partial 
from pathlib import Path

import pandas as pd
import numpy as np
import torch
from transformer_lens import HookedTransformer
import matplotlib.pyplot as plt
from tqdm import tqdm

from eap.graph import Graph
from eap.attribute import attribute, _plot_attn, tokenize_plus
from eap.evaluate import evaluate_graph, evaluate_baseline

from dataset import EAPDataset
from metrics import get_metric
from transformers import BitsAndBytesConfig
from modelscope import AutoModelForCausalLM, AutoTokenizer

In [4]:
model_name = "Qwen/Qwen2.5-7B-Instruct"
model_base = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    cache_dir="/data0/modelscope/qwen2.5",
    local_files_only=True, low_cpu_mem_usage=True,
)

Downloading Model to directory: /data0/modelscope/qwen2.5/models/Qwen/Qwen2.5-7B-Instruct


2025-04-03 04:01:56,913 - modelscope - INFO - Target directory already exists, skipping creation.


In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/config.json
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/config.json


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading Model to directory: /data0/modelscope/qwen2.5/models/Qwen/Qwen2.5-7B-Instruct


2025-04-03 04:02:02,551 - modelscope - INFO - Target directory already exists, skipping creation.


In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/tokenizer_config.json
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/tokenizer_config.json
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/vocab.json
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/merges.txt
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/tokenizer.json
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
model = HookedTransformer.from_pretrained("Qwen/Qwen2.5-7B-Instruct",center_writing_weights=False,
    center_unembed=False,
    fold_ln=False,
    fold_value_biases=False,
    device='cuda',
    hf_model=model_base,
    tokenizer=tokenizer,
    dtype="float16",
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/tokenizer_config.json
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/tokenizer_config.json
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/vocab.json
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/merges.txt
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct /data0/modelscope/qwen2.5/models/Qwen/Qwen2___5-7B-Instruct/tokenizer.json
In cached_file: path_or_repo_id, resolved_file = /data0/modelscope/qwen2.5

In [7]:
model.cfg.use_split_qkv_input = True
model.cfg.use_attn_result = True
model.cfg.use_hook_mlp_in = True

In [8]:
task = 'ioi'

In [9]:
task_metric_name = 'prob_diff'
ds = EAPDataset(task, model_name)
# ds.df['clean'] = ds.df['clean'].str.replace('\\\\n', '\n', regex=True) # hsc: 字符串\n替换
batch_size = 20
dataloader = ds.to_dataloader(batch_size)
task_metric = get_metric(task_metric_name, task, model=model)

In [10]:
file_name = "/home/hushengchun/project/eap-ig-faithfulness/qwen_7b-ioi.json"
g = Graph.from_json(file_name)

In [11]:
import re
for node_name, node in g.nodes.items():
    if '.h' in node_name:
        node.qkv_inputs.clear()
        for letter in 'qkv':
            match = re.search(r'a(\d+)', node_name)
            if match:
                layer = match.group(1)
            node.qkv_inputs.append(f'blocks.{layer}.hook_{letter}_input')

In [12]:
n_edges = []
results = []
steps = list(range(5100, 7101, 500))
with tqdm(total=len(steps)) as pbar:
    for i in steps:
        n_edge = []
        result = []
        g.apply_greedy(i, absolute=False)
        g.prune_dead_nodes(prune_childless=True, prune_parentless=True)
        n = g.count_included_edges()
        r = evaluate_graph(model, g, dataloader, partial(task_metric, mean=False, loss=False), quiet=True)
        print(r)
        n_edge.append(n)
        result.append(r.mean().item())
        pbar.update(1)
        n_edges.append(n_edge)
        results.append(result)

n_edges = np.array(n_edges)
results = np.array(results)

  0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([813, 2381])


  0%|          | 0/5 [00:03<?, ?it/s]

torch.Size([30, 28])
HookPoint() torch.Size([20, 20, 28, 3584])





RuntimeError: einsum(): the number of subscripts in the equation (2) does not match the number of dimensions (1) for operand 1 and no ellipsis was given

In [13]:
model

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (blocks): ModuleList(
    (0-27): 28 x TransformerBlock(
      (ln1): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): RMSNorm(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): GroupedQueryAttention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
        (hook_rot_k): HookPoint()
        (hook_rot_q): HookPoint()
      )
      (mlp): GatedMLP(
        (hook_pre): HookPoint()
        (hook_pre_linear): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out)

In [None]:
model.run_with_cache

In [None]:
from eap.attribute import tokenize_plus
import re

for clean, corrupt, label in dataloader:
    token_ids, attention_mask, input_lengths, n_pos = tokenize_plus(model, clean)
    clean_logits, cache = model.run_with_cache(token_ids, attention_mask = attention_mask)
    for node_name, node in g.nodes.items():
        if node.in_graph == True and '.h' in node_name:
            pattern = r'a(\d+)\.h(\d+)'
            match = re.search(pattern, node_name)
            if match:
                layer = int(match.group(1))
                head = int(match.group(2))

                for i in range(batch_size):
                    tokens = tokenizer.convert_ids_to_tokens(token_ids[i])
                    io = 1 if tokens[1] != ',' else 2
                    s1 = io + 2
                    s1_1 = s1 + 1
                    for j, word in enumerate(tokens[::-1]):
                        if word == 'Ġto':
                            end = -1 - j
                            break
                    for j, word in enumerate(tokens[::-1]):
                        if word == tokens[io] or word == tokens[s1]:
                            s2 = -1 - j
                            break
                    attn = cache[f'blocks.{layer}.attn.hook_pattern'][i][head]
                    node.attn[0]['value'] += attn[end][io]
                    node.attn[1]['value'] += attn[end][s2]
                    node.attn[2]['value'] += attn[s2][s1]
                    node.attn[3]['value'] += attn[s2][s1_1]
                    node.attn[4]['value'] += attn[s1_1][s1]

In [None]:
for node_name, node in g.nodes.items():
    if '.h' in node_name:
        for attn in node.attn:
            attn['value'] /= 100

In [None]:
import math

head_info = ['NMH', 'SIH', 'DTH', 'IH', 'PTH', 'NNMH']

for node_name, node in g.nodes.items():
    if '.h' in node_name:
        sorted_attn = sorted(node.attn, key=lambda x: abs(x['value']), reverse=True)
        print(sorted_attn)
        if sorted_attn[0]['value'] >= 0.05 and (math.log10(sorted_attn[0]['value']) - math.log10(sorted_attn[0]['value'])) >= 2:
            if sorted_attn[0]['name'] == 'end_io':
                if sorted_attn[0]['value'] < 0:
                    node.head_info = head_info[-1]
                else:
                    node.head_info = head_info[0]
            elif sorted_attn[0]['name'] == 'end_s2':
                node.head_info = head_info[1]
            elif sorted_attn[0]['name'] == 's2_s1':
                node.head_info = head_info[2]
            elif sorted_attn[0]['name'] == 's2_s1_1':
                node.head_info = head_info[3]
            else:
                node.head_info = head_info[4]

gs = g.to_graphviz()
gs.draw(f'qwen2_5-7b_4bit-ioi-split_qkv_false.png', prog='dot')

In [None]:
head_info = ['NMH', 'SIH', 'DTH', 'IH', 'PTH', 'NNMH']

for node_name, node in g.nodes.items():
    if '.h' in node_name:
        max_attn = max(node.attn, key=lambda x: abs(x['value']))
        if max_attn['name'] == 'end_io':
            if max_attn['value'] < 0:
                node.head_info = head_info[-1]
            else:
                node.head_info = head_info[0]
        elif max_attn['name'] == 'end_s2':
            node.head_info = head_info[1]
        elif max_attn['name'] == 's2_s1':
            node.head_info = head_info[2]
        elif max_attn['name'] == 's2_s1_1':
            node.head_info = head_info[3]
        else:
            node.head_info = head_info[4]