In [1]:







import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import json
import re
import yaml
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
from tqdm import tqdm, trange
from matplotlib import pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
plt.rcParams['font.sans-serif'] = ['Noto Sans CJK JP']  # 设置字体为 SimHei
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题
cmap = LinearSegmentedColormap.from_list("prob_cmap", ["red", "yellow", "green"])

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
results_path = "/home/nfs02/laizj/experiment/uncertainty_analysis/analysis_unknown/results/mgsm8k_inference_0.7_0.75_5.json"
source_path = "/home/nfs02/laizj/experiment/uncertainty_analysis/analysis_unknown/data/mgsm8k_test.json"
model_path = "/home/nfs03/laizj/model/models--kevinpro--MistralMathOctopus-7B/snapshots/de76d14562b1da1ba0e52d5b245cf9c0859e2af3"

In [8]:
correct_data = []
half_data = []
wrong_data = []

def extract_last_num(text: str) -> float:
    text = re.sub(r"(\d),(\d)", r"\g<1>\g<2>", text)  # 处理形如 123,456
    res = re.findall(r"(\d+(\.\d+)?)", text)  # 匹配 123456.789
    if len(res) > 0:
        num_str = res[-1][0]
        return float(num_str)
    else:
        return 0.0

source = None
results = None
with open(source_path, 'r') as src_f, open(results_path, 'r') as f:
    source = json.load(src_f)
    results = json.load(f)
    source = results
    
for i in range(len(source)):
    count = 0
    results[i]['correct_generated_texts'] = []
    results[i]['wrong_generated_texts'] = []
    for generated_text in results[i]["generated_texts"]:
        if abs(extract_last_num(source[i]["answer"]) - extract_last_num(generated_text)) < 1e-2:
            count += 1
            results[i]['correct_generated_texts'].append(generated_text)
        else:
            results[i]['wrong_generated_texts'].append(generated_text)
    if count == 0:
        wrong_data.append(results[i])
    elif count == len(results[i]["generated_texts"]):
        correct_data.append(results[i])
    else:
        half_data.append(results[i])


print(len(correct_data))
print(len(half_data))
print(len(wrong_data))

927
997
576


In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map='auto')
        

Loading checkpoint shards: 100%|██████████| 6/6 [00:13<00:00,  2.30s/it]


In [79]:
model.eval()
results = []
with torch.no_grad():
    # for dataset in [correct_data, half_data, wrong_data]:
    for dataset in [wrong_data]:
        for i in trange(len(dataset)):
            if dataset[i]['lang'] != 'Chinese': continue
            inputs = tokenizer(dataset[i]['prompt'] + dataset[i]['wrong_generated_texts'][3], return_tensors="pt").to(model.device)
            outputs = model(**inputs, output_hidden_states=True)
            results.append({
                'prompt': dataset[i]['prompt'],
                'answer': dataset[i]['answer'],
                'generated_text': dataset[i]['prompt'] + dataset[i]['wrong_generated_texts'][3],
                'outputs': outputs, 
                })
            inputs = tokenizer(dataset[i]['prompt'] + dataset[i]['wrong_generated_texts'][1], return_tensors="pt").to(model.device)
            outputs = model(**inputs, output_hidden_states=True)
            results.append({
                'prompt': dataset[i]['prompt'],
                'answer': dataset[i]['answer'],
                'generated_text': dataset[i]['prompt'] + dataset[i]['wrong_generated_texts'][1],
                'outputs': outputs, 
                })
            break

  0%|          | 1/576 [00:00<04:02,  2.37it/s]


In [34]:
def merge_tokens_2_word(tokens, token_probs):
    merged = []
    merged_probs = []
    
    current_chars = []   # 当前累积的字符
    current_probs = []   # 当前累积的概率
    current_bytes = []   # 当前累积的字节
    current_bytes_probs = []  # 字节标记对应的概率

    def decode_bytes():
        """处理字节标记的合并"""
        nonlocal current_bytes, current_bytes_probs
        if not current_bytes:
            return
        
        try:
            # 解码字节序列
            decoded = b''.join(current_bytes).decode('utf-8')
            if current_chars:
                # 处理已有字符
                merged.append(''.join(current_chars))
                merged_probs.append(current_probs.copy())
                current_chars.clear()
                current_probs.clear()
            merged.append(decoded)
            merged_probs.append(current_bytes_probs.copy())
        except UnicodeDecodeError:
            # 解码失败时保留原始标记
            merged.extend([f'<0x{b.hex()}>' for b in current_bytes])
            merged_probs.extend([[p] for p in current_bytes_probs])
        
        current_bytes.clear()
        current_bytes_probs.clear()

    for token, prob in zip(tokens, token_probs):
        # 匹配字节标记 <0xXX>
        byte_match = re.match(r'^<0x([0-9A-Fa-f]+)>$', token)
        if byte_match:
            # 累积字节和概率
            current_bytes.append(bytes.fromhex(byte_match.group(1)))
            current_bytes_probs.append(prob)
        else:
            # 遇到非字节标记先处理已累积的字节
            decode_bytes()

            # 处理特殊前缀 ▁
            if token == '▁':
                if current_chars:
                    # 提交当前累积的词
                    merged.append(''.join(current_chars))
                    merged_probs.append(current_probs.copy())
                    current_chars.clear()
                    current_probs.clear()
                current_chars.append(token)
                current_probs.append(prob)
            else:
                current_chars.append(token)
                current_probs.append(prob)
    
    # 处理最后剩余的标记
    decode_bytes()
    if current_chars:
        merged.append(''.join(current_chars))
        merged_probs.append(current_probs)

    return merged, merged_probs

In [52]:

def analysis_per_sample(data, fix_prob=0.3):
    prompt_ids = tokenizer(data['prompt'], return_tensors="pt")['input_ids'][0]
    input_ids = tokenizer(data['generated_text'], return_tensors="pt")['input_ids'][0][prompt_ids.size(-1):]

    logits = data['outputs'].logits[0, prompt_ids.size(-1) - 1: -1, :].cpu()
    probs = F.softmax(logits, dim=-1)

    token_probs = probs.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

    tokens = tokenizer.convert_ids_to_tokens(input_ids.tolist())
    # tokens, token_probs = merge_tokens_2_word(tokens, token_probs)
    # print(tokens)
    

    # filtered_tokens = []
    # filtered_probs = []
    # for token, prob in zip(tokens, token_probs):
    #     if prob < fix_prob: 
    #         filtered_tokens.append(token)
    #         filtered_probs.append(prob.item())

    # print(filtered_tokens)

    # # 绘制图表
    # plt.figure(figsize=(12, 6))
    # plt.bar(range(len(filtered_tokens)), filtered_probs, tick_label=filtered_tokens)
    # plt.xlabel('Tokens')
    # plt.ylabel('Probability')
    # plt.title(f'Token Generation Probabilities (Prob < {fix_prob})')
    # # plt.xticks(rotation=90)
    # plt.tight_layout()
    # plt.show()
    
    colored_text = ""
    for token, prob in zip(tokens, token_probs):
        prob_value = torch.tensor(prob).min().item()  # 取最大概率
        # 将概率映射到颜色
        color = cmap(prob_value)  # 获取颜色
        color_hex = '#%02x%02x%02x' % (int(color[0]*255), int(color[1]*255), int(color[2]*255))  # 转换为十六进制
        # 将token和颜色添加到文本中
        colored_text += f'<span style="color: {color_hex};">{token}</span> '
    
    # 打印带颜色的文本
    from IPython.display import display, HTML
    display(HTML(colored_text))


In [80]:
print(results[0]['generated_text'])

analysis_per_sample(results[0], fix_prob=0.8)

Below is an instruction that describes a task. Write a response that appropriately completes the request in Chinese. Please answer in Chinese.

### Instruction:
卡拉正在下载一个 200 GB 的文件。正常情况下，她每分钟可以下载 2 GB，但在下载了 40% 的时候，Windows 强制重新启动以安装更新，而这个过程需要 20 分钟。然后卡拉不得不从头重新下载。下载这份文件需要多长时间？

### Response:在 Windows 强制重新启动之前，卡拉已经下载了 200 GB * 0.4 = <<200*0.4=80>>80 GB。
由于卡拉每分钟下载 2 GB，她需要 80 GB / 2 GB/分钟 = <<80/2=40>>40 分钟才能下载完剩下的 80 GB。
由于卡拉需要重新下载 200 GB - 80 GB = <<200-80=120>>120 GB，她需要 120 GB / 2 GB/分钟 = <<120/2=60>>60 分钟来下载完这份文件。
总共，卡拉需要 40 分钟 + 20 分钟 + 60 分钟 = <<40+20+60=120>>120 分钟来下载这份文件。


  prob_value = torch.tensor(prob).min().item()  # 取最大概率


In [76]:
print(results[1]['generated_text'])

analysis_per_sample(results[1], fix_prob=0.8)

Below is an instruction that describes a task. Write a response that appropriately completes the request in Chinese. Please answer in Chinese.

### Instruction:
杰瑞正在摇一颗六面骰子。他摇出一个大于 3 的数字的可能性比连续摇出两个偶数的可能性大多少（以百分比表达）？

### Response:首先，我们需要确定杰瑞摇出一个大于 3 的数字的概率。骰子有 6 面，所以有 3 面是大于 3 的数字。因此，杰瑞摇出一个大于 3 的数字的概率是 3/6 = 1/2 = 50%。
接下来，我们需要确定杰瑞连续摇出两个偶数的概率。骰子有 6 面，所以有 3 面是偶数。因此，杰瑞摇出一个偶数的概率是 3/6 = 1/2 = 50%。
由于杰瑞连续摇出两个偶数的概率是 50% * 50% = 25%，所以杰瑞摇出一个大于 3 的数字的概率比连续摇出两个偶数的概率大 50% - 25% = 25%。


  prob_value = torch.tensor(prob).min().item()  # 取最大概率


In [46]:
inputs = tokenizer(correct_data[0]['generated_texts'], return_tensors="pt")
source = tokenizer(correct_data[0]['prompt'], return_tensors="pt")

In [48]:
tokenizer.convert_ids_to_tokens(inputs['input_ids'][0, source['input_ids'].size(-1)].item())
tokenizer.convert_ids_to_tokens(inputs['input_ids'][0, :source['input_ids'].size(-1)])

['<s>',
 '▁Below',
 '▁is',
 '▁an',
 '▁instruction',
 '▁that',
 '▁describes',
 '▁a',
 '▁task',
 '.',
 '▁Write',
 '▁a',
 '▁response',
 '▁that',
 '▁appropri',
 'ately',
 '▁complet',
 'es',
 '▁the',
 '▁request',
 '▁in',
 '▁Sw',
 'ah',
 'ili',
 '.',
 '▁Please',
 '▁answer',
 '▁in',
 '▁Sw',
 'ah',
 'ili',
 '.',
 '<0x0A>',
 '<0x0A>',
 '###',
 '▁Inst',
 'ruction',
 ':',
 '<0x0A>',
 'J',
 'ill',
 '▁anal',
 'ip',
 'wa',
 '▁$',
 '2',
 '0',
 '▁k',
 'ila',
 '▁sa',
 'a',
 '▁k',
 'uf',
 'un',
 'za',
 '▁na',
 '▁$',
 '3',
 '0',
 '▁k',
 'u',
 'wa',
 '▁k',
 'och',
 'a',
 '▁wa',
 '▁m',
 'ash',
 'ab',
 'iki',
 '.',
 '▁I',
 'w',
 'ap',
 'o',
 '▁h',
 'u',
 'wa',
 '▁an',
 'af',
 'anya',
 '▁k',
 'azi',
 '▁w',
 'iki',
 '▁',
 '5',
 '0',
 '▁k',
 'wa',
 '▁m',
 'w',
 'aka',
 ',',
 '▁sa',
 'a',
 '▁',
 '3',
 '5',
 '▁k',
 'wa',
 '▁w',
 'iki',
 '▁k',
 'ama',
 '▁m',
 'wal',
 'im',
 'u',
 '▁na',
 '▁sa',
 'a',
 '▁',
 '1',
 '5',
 '▁k',
 'ama',
 '▁k',
 'och',
 'a',
 ',',
 '▁m',
 'sh',
 'ah',
 'ara',
 '▁wake',
 '▁wa',
 '▁k',
