In [None]:
import logging
from tqdm import tqdm
from utils.process_data import get_model_generate
from data_loader.base_loader import BaseLoader
# from data_processor.base_processor import BaseProcessor
from data_loader.cot_loader import CotLoader
from utils.load_config import load_config
import argparse
from utils.load_model import load_model_tokenizer
import data_loader
# import data_processor
from utils.meter import AverageMeter
from utils.process_data import *
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
# total_entropy_token = AverageMeter()
def process_data_token_level(model,model_generate,soft_max=True,avg_head=True):
    res = model_generate['generate']
    encoder = get_encoder_k(model,-1)
    hidden_state = get_hidden_state_k(res,-2)
    # 手动计算最后一层的attention weight
    attentions = get_attention_matrix(encoder,hidden_state,soft_max=soft_max).to(torch.float32).cpu() # shape = (bs_size,#heads,len,len)
    res_entropy = get_attention_entropy(attentions,avg_head=avg_head,soft_max=soft_max) # shape = (bs_size,len)
    mean_entropy = res_entropy[:,:,1:].mean()
    return mean_entropy.item()

In [None]:
# total_entropy = AverageMeter()
def process_data_sen_level(model, tokenizer,data, model_generate,split_words=None,soft_max=True,avg_head=True):
    # logging.info(f"{self.name} process data")
    res = model_generate['generate']
    input_ids = model_generate['input_ids']
    encoder = get_encoder_k(model,-1)

    # 划分输入        
    if split_words:
        split_tokens = split_sentence(tokenizer=tokenizer,question=data,input_ids=input_ids,split_words=split_words)
    else:
        split_tokens = split_sentence(tokenizer=tokenizer,question=data,input_ids=input_ids)
        
    # 根据句子切分attention矩阵 weight权重，token_ids 权重对应token下标
    weights,token_ids = split_attn_matrix(model,res,split_tokens,soft_max=soft_max)
    
    # 加权计算embedding得到hidden_states
    hidden_states = weighted_hidden_states(weights,token_ids,res)        

    # 计算attention矩阵
    attn_matrix = get_attention_matrix(encoder,hidden_states,soft_max=soft_max).to(torch.float32)
    
    # 计算entropy
    with torch.no_grad():
        sentence_entropy = get_attention_entropy(attn_matrix.cpu(),soft_max=soft_max,avg_head=avg_head)
        mean_sentence_entropy = torch.mean(sentence_entropy,dim=0).squeeze()
    return mean_sentence_entropy.mean()

In [None]:
import re

def replace_with_dots(text,tokenizer):

    # 使用正则表达式找到 A: 和 The answer is 之间的部分
    pattern = re.compile(r'(?<=A: ).*?(?=The answer is)', re.DOTALL)
    matches = pattern.findall(text)

    for match in matches:
        # 使用 tokenizer 获取 tokens
        tokens = tokenizer(match, padding=False, return_tensors='pt')['input_ids']
        token_count = tokens.size(1)  # 获取 token 的数量

        # 替换为等数量的省略号，每个省略号之间有空格
        replacement = '... ' * token_count
        replacement = replacement.strip()  # 移除最后一个多余的空格
        text = text.replace(match, replacement)

    return text

In [None]:
def plot_entropy_scatter(data_lists, labels, x_label, y_label, title, save_path):
    """
    绘制散点图并保存

    参数：
    - data_lists: 数据列表的列表，每个子列表对应一组数据
    - labels: 数据标签的列表
    - x_label: 横坐标标签
    - y_label: 纵坐标标签
    - title: 图表标题
    - save_path: 保存路径
    """
    plt.figure(figsize=(10, 6))
    for data_list, label in zip(data_lists, labels):
        plt.scatter(list(range(len(data_list))), data_list, label=label)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    plt.legend()
    plt.savefig(save_path)
    plt.close()

# 创建路径并绘制图表
def create_and_plot(model_name, soft_max, avg_head, nocot_list, dotcot_list, cot_list, level):
    path = f'./debug_vis/soft_max_{soft_max}__avg_head_{avg_head}/{level}_level'
    if not os.path.exists(path):
        os.makedirs(path)
    
    data_lists = [nocot_list, dotcot_list, cot_list]
    labels = [f'nocot_{level}', f'dotcot_{level}', f'cot_{level}']
    x_label = 'Index Example'
    y_labels = {
        'token': 'Token Entropy',
        'sen': 'Sentence Entropy',
        'rio': 'Entropy Rio'
    }
    title = f'{y_labels[level]} Scatter of {model_name}'
    save_path = f'{path}/{model_name}_{level}_entropy.png'

    plot_entropy_scatter(data_lists, labels, x_label, y_labels[level], title, save_path)

In [None]:
class Pipeline:
    def __init__(self, model, tokenizer, model_config, data_loaders:list[BaseLoader], soft_max=True, avg_head=True):
        logging.info("Init Pipeline")
        self.model = model
        self.tokenizer = tokenizer 
        self.model_config = model_config
        self.data_loaders = data_loaders
        # self.data_processors = data_processors
        self.min_input_token = 100
        self.max_input_token = 2000
        self.max_sample = 21

        self.soft_max = soft_max
        self.avg_head = avg_head

    def run(self):
        logging.info("Pipeline start")
        # data_loaders
        for data_loader in self.data_loaders:
            logging.info(f"Data loader {data_loader.name}")
            load_data = data_loader.load_data()
            split_words = data_loader.split_words()
            # init processor
            # for data_processor in self.data_processors:
            #     data_processor.set(data_loader.name)
            index = 0
            nocot_token_list = []
            dotcot_token_list = []
            cot_token_list = []
            nocot_sen_list = []
            dotcot_sen_list = []
            cot_sen_list = []
            nocot_rio_list = []
            dotcot_rio_list = []
            cot_rio_list = []
            # data samples
            for data in load_data:
                if (index+1)%3==2:
                    data = replace_with_dots(data, self.tokenizer)
                inputs = self.tokenizer(data, padding=False, return_tensors='pt')
                num_input_token = inputs['input_ids'].shape[1]
                # if num_input_token < self.min_input_token or num_input_token > self.max_input_token:
                #     logging.info(f"num_input_token {num_input_token} less than min_input_token {self.min_input_token} or greater than max_input_token {self.max_input_token}")
                #     continue
                # pre process
                model_generate = get_model_generate(self.tokenizer,self.model,data,max_new_tokens=1,max_input_token=400,split_words=split_words)
                index += 1

                total_entropy = process_data_token_level(self.model,model_generate,soft_max=self.soft_max,avg_head=self.avg_head)
                sentence_entropy = process_data_sen_level(self.model,self.tokenizer,data,model_generate,split_words=split_words,soft_max=self.soft_max,avg_head=self.avg_head)
                if (index)%3==1:
                    nocot_token_list.append(total_entropy)
                    nocot_sen_list.append(sentence_entropy)
                    nocot_rio_list.append(total_entropy/sentence_entropy)
                    # print(f"{index // 3+1} entropy_rio for nocot:",total_entropy/sentence_entropy)
                elif (index)%3==2:
                    dotcot_token_list.append(total_entropy)
                    dotcot_sen_list.append(sentence_entropy)
                    dotcot_rio_list.append(total_entropy/sentence_entropy)
                    # print(f"{index // 3+1} entropy_rio for nocot:",total_entropy/sentence_entropy)
                else:
                    cot_token_list.append(total_entropy)
                    cot_sen_list.append(sentence_entropy)
                    cot_rio_list.append(total_entropy/sentence_entropy)
                    # print(f"{index // 3} entropy_rio for nocot:",total_entropy/sentence_entropy)                                      
                # process_data_sen_level(self.model,self.tokenizer,data,model_generate,split_words=split_words)

                if index >= self.max_sample:
                    break
            # 绘制散点图
            create_and_plot(self.model_config[0], self.soft_max, self.avg_head, nocot_token_list, dotcot_token_list, cot_token_list, 'token')
            create_and_plot(self.model_config[0], self.soft_max, self.avg_head, nocot_sen_list, dotcot_sen_list, cot_sen_list, 'sen')
            create_and_plot(self.model_config[0], self.soft_max, self.avg_head, nocot_rio_list, dotcot_rio_list, cot_rio_list, 'rio')


In [None]:
parser = argparse.ArgumentParser()
parser.add_argument("--cfg", default="./config/llama2.yaml", help="config file path")
parser.add_argument("--start", default = 5, type=int, help="config file path")
parser.add_argument("--model_cfg", default="./config/models_pz.yaml", help="model config file path")
# args = parser.parse_args()
args =parser.parse_known_args()[0]

# log_f = '%(asctime)s | %(filename)s[line:%(lineno)d] | %(levelname)s | %(message)s'
# logging.basicConfig(level="DEBUG", format=log_f)
# logging.basicConfig(level="INFO", format=log_f)

# load config 
config = load_config(args.cfg)
model_cfg = load_config(args.model_cfg)

model_familys = config['model_familys']
model_configs = []
for key in model_familys:
    model_configs += model_cfg[f"paths_{key}"]

# models
for model_config in model_configs[args.start:args.start+1]:
    print("model_name:",model_config[0])
    model, tokenizer = load_model_tokenizer(model_config=model_config)

    # data loaders + data processors
    # data_loaders = [getattr(data_loader,loader_name)() for loader_name in config['data_loaders']]
    # data_processors = [getattr(data_processor,processor_name)(model, tokenizer, model_config) for processor_name in config['data_processors']]
    # init pipeline
    data_loaders = [getattr(data_loader,config['data_loaders'][0])()]
    pipeline = Pipeline(model,tokenizer,model_config,data_loaders,soft_max=True, avg_head=True)
    # run
    pipeline.run()

    pipeline = Pipeline(model,tokenizer,model_config,data_loaders,soft_max=True, avg_head=False)
    # run
    pipeline.run()

    pipeline = Pipeline(model,tokenizer,model_config,data_loaders,soft_max=False, avg_head=True)
    # run
    pipeline.run()

    pipeline = Pipeline(model,tokenizer,model_config,data_loaders,soft_max=False, avg_head=False)
    # run
    pipeline.run()
