初始化

In [1]:
import os
import sys
import time
import pytorch_lightning as pl
import torch
from model import *
import torch.utils.data as tud
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
from tqdm.notebook import tqdm
from utils.my_utils import *
import torch.nn.functional as F
import random
import regex as re
from dataset import *
import ipywidgets as widgets
from IPython.display import display
from typing import Union, List
pl.seed_everything(42)
torch.set_float32_matmul_precision('medium')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

model_list = [
    ("gpt2", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8"),
    ("gpt2-xl", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2-xl/snapshots/33cdb5c0db5423c1879b1b9f16c352988e8754a8"),
    ("llama_7b", "/nvme/share/guoyiqiu/llama-7b"),
    ("llama_13b", "/nvme/share/guoyiqiu/llama-13b"),
    ("vicuna_7b", "/mnt/workspace/guoyiqiu/coding/vicuna_7b"),
    ("vicuna_13b", "/mnt/workspace/guoyiqiu/coding/vicuna-13b-v1.1"),
    ("book_7b", "/mnt/workspace/guoyiqiu/coding/Book_7B/checkpoint-4968"),
]


def setup_widgets(model_list):
    global mt_dropdown
    global setup_btn
    global device_tbtn
    global precision_tbtn
    global mnt_slider
    global input_textarea
    global output_textarea
    global submit_btn

    def setup_llm(btn):
        global mt
        global vis
        time_st = time.time()
        btn.description = "Loading model..."
        mt = LLM.from_pretrained(model_name=mt_dropdown.value, fp16=(precision_tbtn.value == "half"),)
        btn.description = "setup FlowVisualizer..."
        vis = FlowVisualizer(mt)
        btn.description = "Everything is ready."
        device_tbtn.value = 'cpu'
        print(f"Time cost: {time.time() - time_st:.2f}s")
    
    def switch_device(change):
        device_tbtn.disabled = True
        mt.to(change.new)
        torch.cuda.empty_cache() if change.new == 'cpu' else None
        device_tbtn.disabled = False

    def switch_precision(change):
        precision_tbtn.disabled = True
        if mt is not None:
            mt.model = mt.model.half() if change.new == 'half' else mt.model.float()
        precision_tbtn.disabled = False

    def generate(btn):
        btn.disabled = True
        submit_btn.description = "Generating..."
        gen_kwargs = {
            "input_texts":input_textarea.value,
            "max_new_tokens":mnt_slider.value,
            "do_sample": sample_checkbox.value,
        }
        result = mt.generate(**gen_kwargs)
        btn.disabled = False
        submit_btn.description = "generate"
        output_text = result[0]
        output_textarea.value = output_text

    # model dropdown
    mt_dropdown = widgets.Dropdown(options=model_list, description='Model:', disabled=False,)

    # setup button
    setup_btn = widgets.Button(description="Setup everything", disabled=False,)
    setup_btn.on_click(setup_llm)

    # switch deivce
    device_tbtn = widgets.ToggleButtons(options=['cpu', f'cuda',], disabled=False,)
    device_tbtn.observe(switch_device, names='value')

    # switch precision
    precision_tbtn = widgets.ToggleButtons(options=['float', 'half'], disabled=False,)
    precision_tbtn.observe(switch_precision, names='value')

    # max new token slider
    mnt_slider = widgets.IntSlider(value=64,min=1,max=512,step=1,description='new token:',disabled=False,)
    sample_checkbox = widgets.Checkbox(value=False,description='do sample',disabled=False,)
    
    # input and output textarea
    input_textarea = widgets.Textarea(value='',description='Input:',layout=widgets.Layout(width='30%', height='250px'),disabled=False)
    output_textarea = widgets.Textarea(value='',description='Output:',layout=widgets.Layout(width='30%', height='250px'),disabled=False)

    # submit button
    submit_btn = widgets.Button(description="generate",disabled=False,)
    submit_btn.on_click(generate)

    # pannel layout
    control_panel = widgets.HBox([mt_dropdown, setup_btn, precision_tbtn, device_tbtn])
    talk_panel = widgets.HBox([input_textarea, widgets.VBox([mnt_slider, sample_checkbox, submit_btn]), output_textarea])
    all_panel = widgets.VBox([control_panel, talk_panel])
    display(all_panel)

setup_widgets(model_list)

Global seed set to 42


VBox(children=(HBox(children=(Dropdown(description='Model:', options=(('gpt2', '/mnt/workspace/guoyiqiu/coding…

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Time cost: 61.05s


In [2]:
outputs = []
# attribute = "body temperature"
# attribute = "human body temperature"
# attribute = "diastolic pressure"
attribute = "diastolic blood pressure"
# attribute = "cycle threshold value in nucleic acid test of COVID-19"
# attribute = "Ct value in nucleic acid test of COVID-19"

# unit = "degree celsius"
unit = "mmHg"
# unit = ""
# unit = "HU"

# num_func = lambda x: 33+x*0.1
num_func = lambda x: x
mnt_slider.value = 16
for i in range(0,100,1):
    num = num_func(i)
    prompt = f"A patient's {attribute} is {num} {unit}, which is considerd to be"
    input_textarea.value = prompt
    submit_btn.click()
    outputs.append(output_textarea.value)
for o in outputs:
    print(o)

<s> A patient's diastolic blood pressure is 0 mmHg, which is considerd to be low. What is the systolic blood pressure?

The diast
<s> A patient's diastolic blood pressure is 1 mmHg, which is considerd to be a small change.
A patient's systolic blood pressure is 
<s> A patient's diastolic blood pressure is 2 mmHg, which is considerd to be low. What is the patient's systolic blood pressure?


<s> A patient's diastolic blood pressure is 3 mmHg, which is considerd to be low. What is the patient's systolic blood pressure?


<s> A patient's diastolic blood pressure is 4 mmHg, which is considerd to be low. What is the patient's systolic blood pressure?


<s> A patient's diastolic blood pressure is 5 mmHg, which is considerd to be low. What is the patient's systolic blood pressure?


<s> A patient's diastolic blood pressure is 6 mmHg, which is considerd to be low. What is the patient's systolic blood pressure?


<s> A patient's diastolic blood pressure is 7 mmHg, which is considerd to be low. 

In [24]:
import torch
from copy import deepcopy
from pyecharts import options as opts
from pyecharts.charts import Bar, Timeline, Tab, Page, Line
from pyecharts.faker import Faker
import ipywidgets as widgets
from IPython.display import display
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
SAVED_MODULES = ['layer', 'attn', 'mlp']


class Unembedding(nn.Module):
    def __init__(self, lm_head, ln_f):
        super().__init__()
        self.lm_head = lm_head
        self.ln_f = ln_f
        
    def forward(self, x):
        with torch.no_grad():
            x = self.ln_f(x)
            x = self.lm_head(x)
        return x

class FlowVisualizer:
    def __init__(self, mt: LLM):
        self.mt = mt
        self.idx2token = [f"{i}-{self.mt.tokenizer.decode(i)}" for i in range(self.mt.tokenizer.vocab_size)]
        self.unembedding = Unembedding(deepcopy(mt.lm_head).to('cpu').float(), deepcopy(mt.ln_f).to('cpu').float())
        self.init_save_hook()
        self.sentences = [] # generated sentences
        self.next_tokens = [] # next token of sentences
        self.prompt_lengths = [] # prompt length of sentences
        self.utokens = [] # 对于每个句子，都有seq_len个token，每个token都有一个vocab_size大小的utoken list [bsz, seq_len, vocab_size]
        self.uprobs = [] # [bsz, 3, seq_len, n_layer, vocab_size]
        self.infos = [] # 对于每个句子，每个模块每一层每个token的uprob信息熵 [bsz, 3, n_layer, seq_len]
        self.diffs = [] # 对于每个句子，每个模块每一层每个token的uprob关于上一层的uprob的交叉熵 [bsz, 3, n_layer, seq_len]
        
    def init_save_hook(self):
        self.mt.clear_hook()
        hook_config = {
            "retain_output": True,
            "retain_input": False,
            "edit_output": None,
            "clone": True,
            "float": True,
            "detach": True,
            "device": "cpu"
        }
        for h in SAVED_MODULES:
            for l in range(self.mt.n_layer):
                hook_config['retain_input'] = (l == 0 and h == 'layer') # 只保留Layer第一层的输入
                self.mt.add_hook(module=getattr(self.mt, h+'s')[l], name=f'{h}_{l}', **hook_config)

    def get_sentence_matrix(self, sidx):
        '''return out_matrix and x0 of sentence sidx with shape of [3, n_layer, seq_len, hidden_size] and [3, 1, seq_len, hidden_size]'''
        out_matrix = torch.stack([torch.cat([self.mt.hooks[f'{h}_{l}'].outputs[sidx] for l in range(self.mt.n_layer)], dim=0) for h in SAVED_MODULES])
        x0 = self.mt.hooks['layer_0'].inputs[sidx]
        return out_matrix, x0
    
    def generate(self, input_texts, **gen_wargs):
        input_texts = input_texts if isinstance(input_texts, list) else [input_texts]
        inps = [self.mt.tokenizer(text, return_tensors='pt') for text in input_texts]
        
        for inp in tqdm(inps, total=len(inps)):
            input_ids, attention_mask = inp['input_ids'], inp['attention_mask']
            self.prompt_lengths.append(input_ids.shape[1])

            # model generate
            hook_idxs = [len(h.outputs) for h in self.mt.hooks.values()]
            with torch.no_grad():
                input_ids = input_ids.to(self.mt.model.device)
                attention_mask = attention_mask.to(self.mt.model.device)
                gen_wargs['max_new_tokens'] = 10 if 'max_new_tokens' not in gen_wargs else gen_wargs['max_new_tokens']
                output_ids = self.mt.model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_wargs)
            
            # 模型会在generate的过程中多次forward产生多个hook中间值，需要把hook的输出拼接起来得到完整的句子的matrix
            for (hook, idx) in zip(self.mt.hooks.values(), hook_idxs):
                hook.outputs[idx] = torch.cat([o for o in hook.outputs[idx:]], dim=1)
                hook.outputs = hook.outputs[:idx+1]
                if hook.retain_input:
                    hook.inputs[idx] = torch.cat([o for o in hook.inputs[idx:]], dim=1)
                    hook.inputs = hook.inputs[:idx+1]
            
            # 保存generate的句子和下一个token
            out_tokens = self.mt.tokenizer.batch_decode(output_ids[0])
            self.sentences.append(out_tokens[:-1])
            self.next_tokens.append(out_tokens[-1])
            
            self.post_process(-1)
            
    def post_process(self, sidx):
        # 获取当前句子的关于每一层，每一个模块合并后的完整matrix [3, n_layer, seq_len, hidden_size]
        out_matrix, x0 = self.get_sentence_matrix(sidx)
        seq_len = out_matrix.shape[2]
        
        # 将activation映射到vocabulary词表空间，计算所有unbedding token的概率
        cur_logits = self.unembedding(out_matrix) # [3, n_layer, seq_len, vocab_size]
        cur_prob = torch.softmax(cur_logits, dim=-1)  # [3, n_layer, seq_len, vocab_size]

        # 计算层信息熵
        cur_info = -torch.sum(cur_prob * torch.log(cur_prob), dim=-1) # [3, n_layer, seq_len]
        self.infos.append(cur_info)

        # 计算层概率差
        # with torch.no_grad():
        #     x0 = self.mt.embedding(output_ids[0,:-1]).unsqueeze(0).repeat(3, 1, 1).unsqueeze(1).cpu() # [3, 1, seq_len, hidden_size]
        logits0 = self.unembedding(x0) # [3, 1, seq_len, vocab_size]
        cur_logits_extended = torch.cat([logits0, cur_logits], dim=1) # [3, n_layer+1, seq_len, vocab_size]
        cur_diff = F.cross_entropy(cur_logits_extended[:,:-1].reshape(-1, cur_logits_extended.shape[-1]), cur_prob.reshape(-1, cur_prob.shape[-1]), reduction='none') # [3 * n_layer * seq_len]
        cur_diff = cur_diff.reshape(3, self.mt.n_layer, seq_len) # [3, n_layer, seq_len]
        self.diffs.append(cur_diff)
        
        # 对generate的句子的每一个token对应的uprob，依据uprob在3个模块中的变化大小之和，对utoken从大到小排序
        cur_utokens = [] # [seq_len, vocab_size]
        cur_uprobs = [] # [seq_len, 3, n_layer, vocab_size]
        for j in range(seq_len):
            cur_token_prob = cur_prob[:,:,j,:] # [3, n_layer, vocab_size]
            # 计算token在3个模块中的概率变化之和
            cur_token_prob_diff = (cur_token_prob[1:] - cur_token_prob[:-1]).abs().sum(dim=0).sum(dim=0) # [vocab_size]
            # 按照变化之和从大到小排序
            cur_token_udiff, cur_token_uids = torch.sort(cur_token_prob_diff, descending=True)
            cur_token_utokens = [self.idx2token[idx] for idx in cur_token_uids]
            cur_utokens.append(cur_token_utokens)
            cur_token_uprobs = cur_token_prob[:, :, cur_token_uids] # [3, n_layer, vocab_size]
            cur_uprobs.append(cur_token_uprobs)
        
        # 保存utokens和uprobs
        self.utokens.append(cur_utokens)
        cur_uprobs = torch.stack(cur_uprobs).transpose(0, 1) # [3, seq_len, n_layer, vocab_size]
        self.uprobs.append(cur_uprobs)

    
    def visualize_utokens(self, sidx=-1, unum=20):
        cur_sentence = self.sentences[sidx]
        tab = Tab()
        for tidx in range(len(cur_sentence)):
            tl = Timeline()
            for l in range(self.mt.n_layer):
                cur_utokens = self.utokens[sidx][tidx][:unum]
                cur_uprobs = self.uprobs[sidx][:,tidx,l,:unum] # [3, unum]
                bar = (
                    Bar()
                    .add_xaxis(cur_utokens)
                    .add_yaxis('layer', cur_uprobs[0].numpy().tolist(), label_opts=opts.LabelOpts(is_show=False))
                    # .add_yaxis('attn', cur_uprobs[1].numpy().tolist(), label_opts=opts.LabelOpts(is_show=False))
                    # .add_yaxis('mlp', cur_uprobs[2].numpy().tolist(), label_opts=opts.LabelOpts(is_show=False))
                    .reversal_axis()
                    .set_global_opts(
                        title_opts={"text": f"Unembedding Token Flow"},
                        xaxis_opts=opts.AxisOpts(name="Probability"),
                        yaxis_opts=opts.AxisOpts(name="Top k UTokens"),
                    )
                )
                tl.add(bar, f"{l+1}")
            tab.add(tl, cur_sentence[tidx])
        return tab
    
    def visualize_info(self, sidx=-1, show_modules=['layer', 'attn', 'mlp'],show_diff=True):
        cur_sentence = self.sentences[sidx]
        tab = Tab()
        for tidx in range(len(cur_sentence)):
            cur_info = self.infos[sidx][:,:,tidx] # [3, n_layer]
            cur_diff = self.diffs[sidx][:,:,tidx] # [3, n_layer]
            xaxis = [str(l+1) for l in list(range(self.mt.n_layer))]
            c = (
                Line()
                .add_xaxis(xaxis)
                .extend_axis(
                    yaxis=opts.AxisOpts(
                        name="Cross Entropy",
                        type_="value",
                        position="right",
                    )
                )
                .extend_axis(
                    yaxis=opts.AxisOpts(
                        name="Infomation Entropy",
                        type_="value",
                        position="left",
                    )
                )
                .add_yaxis("layer info", cur_info[0].numpy().tolist(), yaxis_index=0, label_opts=opts.LabelOpts(is_show=False))
                # .add_yaxis("attn info", cur_info[1].numpy().tolist(), yaxis_index=0, label_opts=opts.LabelOpts(is_show=False))
                # .add_yaxis("mlp info", cur_info[2].numpy().tolist(), yaxis_index=0, label_opts=opts.LabelOpts(is_show=False))
                .add_yaxis("layer diff", cur_diff[0].numpy().tolist(), yaxis_index=1, label_opts=opts.LabelOpts(is_show=False))
                # .add_yaxis("attn diff", cur_diff[1].numpy().tolist(), yaxis_index=1, label_opts=opts.LabelOpts(is_show=False))
                # .add_yaxis("mlp diff", cur_diff[2].numpy().tolist(), yaxis_index=1, label_opts=opts.LabelOpts(is_show=False))
                .set_global_opts(
                    title_opts=opts.TitleOpts(title="信息熵和交叉熵"),
                    tooltip_opts=opts.TooltipOpts(trigger="axis", axis_pointer_type="cross"),)
            )
            tab.add(c, cur_sentence[tidx])
        return tab
    
    def get_similar_token(self, token_id, k=20):
        embedding = self.mt.embedding.weight.data
        with torch.no_grad():
            cos_values, cos_indices = torch.topk(torch.cosine_similarity(embedding, embedding[token_id].unsqueeze(0), dim=1),k=k)
        return [f"{self.idx2token[id]}: {cos_values[i].item():.3f}" for i, id in enumerate(cos_indices)]
        
    def clear(self):
        self.sentences.clear()
        self.next_tokens.clear()
        self.prompt_lengths.clear()
        self.utokens.clear() 
        self.uprobs.clear() 
        self.infos.clear()
        self.diffs.clear()

In [40]:
vis = FlowVisualizer(mt)
prompt_template = "A human's {attr} is {num}{unit}, which is considered to be"
attrs = ["cycle threshold value in nucleic acid test of COVID-19", "body temperature", "diastolic blood pressure"]
units = ["", "°C", "mmHg"]# degree celsius °C
num_range = [0,100,10]
num_funcs = [lambda x: x, lambda x: 35 + x*0.06, lambda x: x]

input_texts = [
    prompt_template.format(attr=attrs[j], num=num_func(i), unit=" "+units[j] if units[j] else "") for j in range(len(attrs)) for i in range(*num_range)
]
vis.clear()
res = vis.generate(input_texts,max_new_tokens=10)

100%|██████████| 30/30 [02:54<00:00,  5.82s/it]


In [45]:
for s in vis.sentences:
    print(s)

['<s>', 'A', 'human', "'", 's', 'cycle', 'threshold', 'value', 'in', 'nucle', 'ic', 'acid', 'test', 'of', 'COVID', '-', '1', '9', 'is', '', '3', '5', '.', '0', ',', 'which', 'is', 'considered', 'to', 'be', 'the', 'threshold', 'of', 'detection', '.', '\n', '\n', 'The', 'threshold']
['<s>', 'A', 'human', "'", 's', 'cycle', 'threshold', 'value', 'in', 'nucle', 'ic', 'acid', 'test', 'of', 'COVID', '-', '1', '9', 'is', '', '3', '8', '.', '0', ',', 'which', 'is', 'considered', 'to', 'be', 'the', 'threshold', 'of', 'detection', 'for', 'the', 'virus', '.', '\n']
['<s>', 'A', 'human', "'", 's', 'cycle', 'threshold', 'value', 'in', 'nucle', 'ic', 'acid', 'test', 'of', 'COVID', '-', '1', '9', 'is', '', '4', '1', '.', '0', ',', 'which', 'is', 'considered', 'to', 'be', 'the', 'threshold', 'of', 'detection', 'for', 'COVID', '-', '1', '9']
['<s>', 'A', 'human', "'", 's', 'cycle', 'threshold', 'value', 'in', 'nucle', 'ic', 'acid', 'test', 'of', 'COVID', '-', '1', '9', 'is', '', '4', '4', '.', '0', ','

In [51]:
vis.visualize_utokens(15,unum=20).render_notebook()

In [61]:
vis.visualize_info(12).render_notebook()

数据集设置

In [None]:
bsz = 8
size = 100
num_workers = 20
knows_data_path = "/mnt/workspace/guoyiqiu/coding/datasets/rome_datasets/known_1000.json"
medqa_data_path = "/nvme/guoyiqiu/coding/datasets/MedQA/data_clean/questions/US/dev.jsonl"

dst = Knowns(knows_data_path, mt.tokenizer)
# dst = MedQA(medqa_data_path, mt.tokenizer, size=size)
# without gt token 
# dst.input_ids = [i[:,:-1] for i in dst.input_ids]
# dst.attention_mask = [i[:,:-1] for i in dst.attention_mask]
# dst.labels = [i[:,:-1] for i in dst.labels]
dl = DataLoader(dst, batch_size=bsz, collate_fn=dst.collate_fn,num_workers=num_workers)


In [None]:
for d in dl:
    print(d)
    break

批量推理

In [None]:
def my_predict_step(self, batch, batch_idx, dataloader_idx=0):
    input_ids, attention_mask = batch[0], batch[1]
    hook_idxs = [len(h.outputs) for h in self.hooks.values()]
    output_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=10)
    output_ids = output_ids[:, input_ids.shape[-1]:]
    for hook, idx in zip(self.hooks.values(), hook_idxs):
        hook.outputs[idx] = torch.cat([o for o in hook.outputs[idx:]], dim=1)
        hook.outputs = hook.outputs[:idx+1]
    return output_ids
mt.set_func('predict_step', my_predict_step)
mt.reset_hook()

trainer_config = {
    "precision": "16-mixed",
    "accelerator": "auto",
    "devices": [4],
}
trainer = pl.Trainer(**trainer_config)
res = trainer.predict(mt, dst)

分析Residual的平均norm值

In [None]:
layers_mean_output = [[] for i in range(mt.n_layer)]
for idx, (input_ids, attention_mask, labels) in enumerate(dst):
    seq_len = attention_mask.sum(dim=1).unsqueeze(-1).repeat(1,mt.model.config.hidden_size)
    attention_mask = attention_mask.unsqueeze(-1).repeat(1,1,mt.model.config.hidden_size)
    for i in range(mt.n_layer):
        output_i_idx = mt.hooks[f"layer_{i}"].outputs[idx][:,:input_ids.shape[-1],:]
        output_i_idx = output_i_idx * attention_mask.float()
        output_i_idx = output_i_idx.sum(dim=1) / seq_len # [bsz, hidden_size] # compute mean
        # output_i_idx = output_i_idx[:,-1,:] # [bsz, hidden_size] # use last
        layers_mean_output[i].append(output_i_idx)
layers_mean_output = [torch.vstack(b).mean(0) for b in layers_mean_output]
plotly_bar([torch.norm(b).item() for b in layers_mean_output],"Avg norm of layer output", )

分析Residual的Unembedding Token

In [None]:
def unembed(batch_residual, k):
    batch_residual = batch_residual.unsqueeze(0) if len(batch_residual.shape) == 1 else batch_residual
    with torch.no_grad():
        batch_residual = mt.ln_f(batch_residual)
        batch_logits = mt.lm_head(batch_residual)
        batch_prob = F.softmax(batch_logits, dim=1)
        batch_prob, batch_indices = torch.topk(batch_prob, k, dim=1)
        next_tokens = []
        for i in range(batch_residual.shape[0]):
            prob = batch_prob[i]
            indices = batch_indices[i]
            tokens = mt.tokenizer.batch_decode(indices)
            next_token = {f"{t}({i})": p.item() for t, p, i in zip(tokens, prob, indices)}
            next_token = dict(sorted(next_token.items(), key=lambda x: x[1], reverse=True))
            next_tokens.append(next_token)
        return next_tokens if len(next_tokens) > 1 else next_tokens[0]


def embed(token):
    with torch.no_grad():
        return mt.embedding(mt.tokenizer(token, return_tensors='pt').input_ids[0][1])


def all_layer_unembed(idx, k):
    bi = idx // bsz  # the idx of the batch
    ii = idx - (idx // bsz) * bsz  # the idx of the sample in the batch

    input_ids = list(dst)[bi][0][ii]
    attention_mask = list(dst)[bi][1][ii]
    new_ids = res[bi][ii]
    
    input_ids = torch.cat([input_ids, new_ids], dim=0)
    attention_mask = torch.cat([attention_mask, torch.ones_like(new_ids, dtype=torch.long)], dim=0)
    bos_idx = attention_mask.shape[0] - attention_mask.sum()
    input_ids = input_ids[bos_idx:-1] # remove the final token generated by the model

    prompt = mt.tokenizer.decode(input_ids)  
    print(f"prompt: {prompt}")
    
    input_tokens = [f"{i}{mt.tokenizer.decode(d)}" for i, d in enumerate(input_ids)]
    print(f"input_tokens: {input_tokens}")
    
    gt_token = mt.tokenizer.decode(dst.labels[idx][:,-1])
    print(f"gt token: {gt_token}")

    table = torch.zeros((len(input_ids), mt.n_layer, 3))
    pattern = re.compile(r'\([^()]*\)')
    for i, r in enumerate(['layer', 'attn', 'mlp']):
        brl = torch.vstack([mt.hooks[f'{r}_{l}'].outputs[bi][ii][bos_idx:] for l in (range(mt.n_layer))])
        nts = unembed(brl, k)
        nts = [nts[i:i+len(input_ids)] for i in range(0, len(nts), len(input_ids))]
        for l in (range(mt.n_layer)):
            table[:, l, i] = torch.tensor([sum([nt[t] for t in nt if gt_token == pattern.sub('', t).strip()]) for nt in nts[l]])
    return table, input_tokens


In [None]:
idx = 0  # the idx of the sample
k = 10  # top k
table, input_tokens = all_layer_unembed(idx, k)

In [None]:
import cv2
tables = [all_layer_unembed(idx, 5)[0] for idx in tqdm(range(len(dst)))]

resize_tables = torch.stack([torch.tensor([cv2.resize(tb[:,:,i].numpy(), (100, tb.shape[1]), ) for i in range(3)]) for tb in tables])
# resize_tables = torch.stack([tb[-100:,:,:] for tb in tables])
resize_tables = resize_tables.mean(0).reshape(3,-1,32)
plotly_matrix(resize_tables[0], 'layer mean unembed',)
plotly_matrix(resize_tables[1], 'attn mean unembed',)
plotly_matrix(resize_tables[2], 'mlp mean unembed',)