In [None]:
import xxhash
import json
import numpy as np
from functools import lru_cache
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
from transformers import AutoTokenizer
import tqdm


In [None]:
fonts = [f.name for f in fm.fontManager.ttflist]
print(fonts)
print("可用字体数量:", len(fonts))

In [None]:
@staticmethod
@lru_cache(maxsize=10)
def _cache_json(path: Path):
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

def _compute_hash(token_ids: list[int]):
    """
    计算完整 prompt token 序列的哈希值
    """
    h = xxhash.xxh64()
    h.update(np.array(token_ids).tobytes())
    return h.intdigest()

def read_jsonl(path: str | Path) -> list:
    """读取 JSONL 文件并返回 list"""
    file_path = Path(path)
    
    if not file_path.exists():
        print(f"文件不存在: {file_path}")
        raise FileNotFoundError(f"文件不存在: {file_path}")

    data_list = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():  # 跳过空行
                data_list.append(json.loads(line))
    return data_list

def serialize_token_ids(token_ids: list[int] | np.ndarray) -> str:
    """将 token_ids 序列化为 [数值,数值,数值] 格式"""
    if isinstance(token_ids, np.ndarray):
        return str(token_ids.tolist())
    else:
        return str(token_ids)

In [None]:
drafter_path = Path("/root/.cache/modelscope/hub/models/Qwen/Qwen3-0___6B")
tokenizer = AutoTokenizer.from_pretrained(drafter_path)
prompt_list = read_jsonl("/root/nano-vllm/select_question.jsonl")
verify_logits_path = Path("/root/nano-vllm/tmp/Qwen3-0___6B")
target_logits_path = Path("/root/nano-vllm/tmp/Qwen3-4B")

current_path = Path.cwd()
output_dir = current_path / "plt"
if not output_dir.exists():
    output_dir.mkdir(exist_ok=True)


In [None]:
step = 1
topk = 5
for i in tqdm.tqdm(range(0,len(prompt_list),step)):

    prompt = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt_list[i]["turns"][0]}],
            tokenize=True,
            add_generation_prompt=True,
            enable_thinking=True
        )
    hash_id = _compute_hash(prompt)
    
    draft_logits =  read_jsonl(verify_logits_path / f"seq_{hash_id}_draft.jsonl")
    target_data = json.load(open(target_logits_path / f"seq_{hash_id}.json"))

    target_logits = target_data["logits"]
    misalign_logits = len(draft_logits)/len(target_logits)
    print(f"misalign_logits: {misalign_logits} for seq {hash_id}")
        

In [None]:
step = 1
topk = 5
for i in tqdm.tqdm(range(0,len(prompt_list),step)):

    prompt = tokenizer.apply_chat_template(
            [{"role": "user", "content": prompt_list[i]["turns"][0]}],
            tokenize=True,
            add_generation_prompt=True,
            enable_thinking=True
        )
    hash_id = _compute_hash(prompt)
    
    draft_logits =  read_jsonl(logits_path / f"seq_{hash_id}_draft.jsonl")
    target_logits = read_jsonl(logits_path / f"seq_{hash_id}_target.jsonl")

    assert len(draft_logits) == len(target_logits)
    for j in range(len(draft_logits)):
        draft_token_id = np.argmax(draft_logits[j])
        draft_topk_token_ids = np.argpartition(draft_logits[j], -topk)[-topk:]
        target_token_id = np.argmax(target_logits[j])
        target_topk_token_ids = np.argpartition(target_logits[j], -topk)[-topk:]
        
        output_path = output_dir / f"logits_{hash_id}_target_{target_topk_token_ids}_draft_{draft_topk_token_ids}.png"
        if output_path.exists():
            print(f"{output_path} exists")
            continue
        
        # print(f"draft_token_id: {draft_token_id}")
        # print(f"draft_topk_token_ids: {draft_topk_token_ids}")
        # test = [draft_logits[j][p] for p in draft_topk_token_ids]
        # print(f"{test}")
        # assert draft_token_id == draft_topk_token_ids[-1]
        # assert target_token_id == target_topk_token_ids[-1]

        draft_token = tokenizer.decode(draft_topk_token_ids[-1])
        target_token = tokenizer.decode(target_topk_token_ids[-1])
    
        plt.figure(figsize=(10, 8))
        
        plt.style.use('default')
        # https://matplotlib.org/stable/gallery/style_sheets/style_sheets_reference.html
        plt.vlines(range(len(draft_logits[0])), ymin=0, ymax=draft_logits[0],
                    color="C0",lw=0.5,alpha=0.3,
                    label=f"draft"
                    )
        plt.vlines(range(len(target_logits[0])), ymin=0, ymax=target_logits[0],
                    color="C1",lw=0.5,alpha=0.3,
                    label=f"target"
                    )
        plt.xlabel('token id')
        plt.ylabel('logits')
        plt.legend(fontsize=10)
        plt.savefig(output_dir / f"logits_{hash_id}_target_{target_topk_token_ids}_draft_{draft_topk_token_ids}.png")
        # plt.show() 会显示图形并清空当前的图形对象
        # plt.show()
        print(f"{output_path} dumped")

        