初始化

In [7]:
import os
import sys
import time
import pytorch_lightning as pl
import torch
from model import *
import torch.utils.data as tud
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
from tqdm.notebook import tqdm
from utils.my_utils import *
import torch.nn.functional as F
import random
import regex as re
from dataset import *
import ipywidgets as widgets
from IPython.display import display
from typing import Union, List
pl.seed_everything(42)
torch.set_float32_matmul_precision('medium')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

model_list = [
    ("gpt2", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8"),
    ("gpt2-xl", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2-xl/snapshots/33cdb5c0db5423c1879b1b9f16c352988e8754a8"),
    ("llama_7b", "/nvme/share/guoyiqiu/llama-7b"),
    ("llama_13b", "/nvme/share/guoyiqiu/llama-13b"),
    ("vicuna_7b", "/mnt/workspace/guoyiqiu/coding/vicuna_7b"),
    ("vicuna_13b", "/nvme/share/guoyiqiu/vicuna-13b-v1.1"),
    ("book_7b", "/mnt/workspace/guoyiqiu/coding/Book_7B/checkpoint-4968"),
]


def setup_widgets(model_list):
    global mt_dropdown
    global setup_btn
    global device_tbtn
    global precision_tbtn
    global mnt_slider
    global input_textarea
    global output_textarea
    global submit_btn

    def setup_llm(btn):
        global mt
        global vis
        time_st = time.time()
        btn.description = "Loading model..."
        mt = LLM(model_name=mt_dropdown.value, fp16=precision_tbtn.value == "half",)
        btn.description = "setup FlowVisualizer..."
        vis = FlowVisualizer(mt)
        btn.description = "Everything is ready."
        device_tbtn.value = 'cpu'
        print(f"Time cost: {time.time() - time_st:.2f}s")
    
    def switch_device(change):
        device_tbtn.disabled = True
        mt.to(change.new)
        torch.cuda.empty_cache() if change.new == 'cpu' else None
        device_tbtn.disabled = False

    def switch_precision(change):
        precision_tbtn.disabled = True
        if mt is not None:
            mt.model = mt.model.half() if change.new == 'half' else mt.model.float()
        precision_tbtn.disabled = False

    def generate(btn):
        btn.disabled = True
        submit_btn.description = "Generating..."
        gen_kwargs = {
            "input_texts":input_textarea.value,
            "max_new_tokens":mnt_slider.value,
            "do_sample": sample_checkbox.value,
        }
        result = mt.generate(**gen_kwargs)
        btn.disabled = False
        submit_btn.description = "generate"
        output_text = result[0]
        output_textarea.value = output_text

    # model dropdown
    mt_dropdown = widgets.Dropdown(options=model_list, description='Model:', disabled=False,)

    # setup button
    setup_btn = widgets.Button(description="Setup everything", disabled=False,)
    setup_btn.on_click(setup_llm)

    # switch deivce
    device_tbtn = widgets.ToggleButtons(options=['cpu', f'cuda:2',], disabled=False,)
    device_tbtn.observe(switch_device, names='value')

    # switch precision
    precision_tbtn = widgets.ToggleButtons(options=['float', 'half'], disabled=False,)
    precision_tbtn.observe(switch_precision, names='value')

    # max new token slider
    mnt_slider = widgets.IntSlider(value=64,min=1,max=512,step=1,description='new token:',disabled=False,)
    sample_checkbox = widgets.Checkbox(value=False,description='do sample',disabled=False,)
    
    # input and output textarea
    input_textarea = widgets.Textarea(value='',description='Input:',layout=widgets.Layout(width='30%', height='250px'),disabled=False)
    output_textarea = widgets.Textarea(value='',description='Output:',layout=widgets.Layout(width='30%', height='250px'),disabled=False)

    # submit button
    submit_btn = widgets.Button(description="generate",disabled=False,)
    submit_btn.on_click(generate)

    # pannel layout
    control_panel = widgets.HBox([mt_dropdown, setup_btn, precision_tbtn, device_tbtn])
    talk_panel = widgets.HBox([input_textarea, widgets.VBox([mnt_slider, sample_checkbox, submit_btn]), output_textarea])
    all_panel = widgets.VBox([control_panel, talk_panel])
    display(all_panel)

setup_widgets(model_list)

Global seed set to 42


VBox(children=(HBox(children=(Dropdown(description='Model:', options=(('gpt2', '/mnt/workspace/guoyiqiu/coding…

In [None]:
outputs = []
attribute = "body temperature"
# attribute = "human body temperature"
# attribute = "diastolic pressure"
# attribute = "diastolic blood pressure"
# attribute = "cycle threshold value in RT-PCR test of SARS-CoV-2"
# attribute = "cycle threshold value in RT-PCR test of COVID-19"
# attribute = "cycle threshold value in nucleic acid test of COVID-19"
# attribute = "cycle threshold value in nucleic acid test of SARS-CoV-2"
# attribute = "Ct value in nucleic acid test of COVID-19"
# attribute = "Ct value in RT-PCR test of SARS-CoV-2"

unit = "degree celsius"
# unit = "mmHg"
# unit = ""
# unit = "HU"

# num_func = lambda x: 33+x*0.1
num_func = lambda x: x
mnt_slider.value = 32
for i in range(0,100,1):
    num = num_func(i)
    prompt = f"A patient's {attribute} is {num} {unit}, which is considerd to be"
    input_textarea.value = prompt
    submit_btn.click()
    outputs.append(output_textarea.value)
for o in outputs:
    print(o)

FlowVisualizer

In [8]:
input_texts = [
    "The patient's body temperature is",
    'The Space Needle is in downtown',
]
vis.clear()
res = vis.generate(input_texts)

In [9]:
vis.visualize_utokens(sidx=-1,unum=20).render_notebook()

In [11]:
vis.visualize_info(sidx=-1).render_notebook()

数据集设置

In [None]:
bsz = 8
size = 100
num_workers = 20
knows_data_path = "/mnt/workspace/guoyiqiu/coding/datasets/rome_datasets/known_1000.json"
medqa_data_path = "/nvme/guoyiqiu/coding/datasets/MedQA/data_clean/questions/US/dev.jsonl"

dst = Knowns(knows_data_path, mt.tokenizer)
# dst = MedQA(medqa_data_path, mt.tokenizer, size=size)
# without gt token 
# dst.input_ids = [i[:,:-1] for i in dst.input_ids]
# dst.attention_mask = [i[:,:-1] for i in dst.attention_mask]
# dst.labels = [i[:,:-1] for i in dst.labels]
dl = DataLoader(dst, batch_size=bsz, collate_fn=dst.collate_fn,num_workers=num_workers)


In [None]:
for d in dl:
    print(d)
    break

批量推理

In [None]:
def my_predict_step(self, batch, batch_idx, dataloader_idx=0):
    input_ids, attention_mask = batch[0], batch[1]
    hook_idxs = [len(h.outputs) for h in self.hooks.values()]
    output_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=10)
    output_ids = output_ids[:, input_ids.shape[-1]:]
    for hook, idx in zip(self.hooks.values(), hook_idxs):
        hook.outputs[idx] = torch.cat([o for o in hook.outputs[idx:]], dim=1)
        hook.outputs = hook.outputs[:idx+1]
    return output_ids
mt.set_func('predict_step', my_predict_step)
mt.reset_hook()

trainer_config = {
    "precision": "16-mixed",
    "accelerator": "auto",
    "devices": [4],
}
trainer = pl.Trainer(**trainer_config)
res = trainer.predict(mt, dst)

分析Residual的平均norm值

In [None]:
layers_mean_output = [[] for i in range(mt.n_layer)]
for idx, (input_ids, attention_mask, labels) in enumerate(dst):
    seq_len = attention_mask.sum(dim=1).unsqueeze(-1).repeat(1,mt.model.config.hidden_size)
    attention_mask = attention_mask.unsqueeze(-1).repeat(1,1,mt.model.config.hidden_size)
    for i in range(mt.n_layer):
        output_i_idx = mt.hooks[f"layer_{i}"].outputs[idx][:,:input_ids.shape[-1],:]
        output_i_idx = output_i_idx * attention_mask.float()
        output_i_idx = output_i_idx.sum(dim=1) / seq_len # [bsz, hidden_size] # compute mean
        # output_i_idx = output_i_idx[:,-1,:] # [bsz, hidden_size] # use last
        layers_mean_output[i].append(output_i_idx)
layers_mean_output = [torch.vstack(b).mean(0) for b in layers_mean_output]
plotly_bar([torch.norm(b).item() for b in layers_mean_output],"Avg norm of layer output", )

分析Residual的Unembedding Token

In [None]:
def unembed(batch_residual, k):
    batch_residual = batch_residual.unsqueeze(0) if len(batch_residual.shape) == 1 else batch_residual
    with torch.no_grad():
        batch_residual = mt.ln_f(batch_residual)
        batch_logits = mt.lm_head(batch_residual)
        batch_prob = F.softmax(batch_logits, dim=1)
        batch_prob, batch_indices = torch.topk(batch_prob, k, dim=1)
        next_tokens = []
        for i in range(batch_residual.shape[0]):
            prob = batch_prob[i]
            indices = batch_indices[i]
            tokens = mt.tokenizer.batch_decode(indices)
            next_token = {f"{t}({i})": p.item() for t, p, i in zip(tokens, prob, indices)}
            next_token = dict(sorted(next_token.items(), key=lambda x: x[1], reverse=True))
            next_tokens.append(next_token)
        return next_tokens if len(next_tokens) > 1 else next_tokens[0]


def embed(token):
    with torch.no_grad():
        return mt.embedding(mt.tokenizer(token, return_tensors='pt').input_ids[0][1])


def all_layer_unembed(idx, k):
    bi = idx // bsz  # the idx of the batch
    ii = idx - (idx // bsz) * bsz  # the idx of the sample in the batch

    input_ids = list(dst)[bi][0][ii]
    attention_mask = list(dst)[bi][1][ii]
    new_ids = res[bi][ii]
    
    input_ids = torch.cat([input_ids, new_ids], dim=0)
    attention_mask = torch.cat([attention_mask, torch.ones_like(new_ids, dtype=torch.long)], dim=0)
    bos_idx = attention_mask.shape[0] - attention_mask.sum()
    input_ids = input_ids[bos_idx:-1] # remove the final token generated by the model

    prompt = mt.tokenizer.decode(input_ids)  
    print(f"prompt: {prompt}")
    
    input_tokens = [f"{i}{mt.tokenizer.decode(d)}" for i, d in enumerate(input_ids)]
    print(f"input_tokens: {input_tokens}")
    
    gt_token = mt.tokenizer.decode(dst.labels[idx][:,-1])
    print(f"gt token: {gt_token}")

    table = torch.zeros((len(input_ids), mt.n_layer, 3))
    pattern = re.compile(r'\([^()]*\)')
    for i, r in enumerate(['layer', 'attn', 'mlp']):
        brl = torch.vstack([mt.hooks[f'{r}_{l}'].outputs[bi][ii][bos_idx:] for l in (range(mt.n_layer))])
        nts = unembed(brl, k)
        nts = [nts[i:i+len(input_ids)] for i in range(0, len(nts), len(input_ids))]
        for l in (range(mt.n_layer)):
            table[:, l, i] = torch.tensor([sum([nt[t] for t in nt if gt_token == pattern.sub('', t).strip()]) for nt in nts[l]])
    return table, input_tokens


In [None]:
idx = 0  # the idx of the sample
k = 10  # top k
table, input_tokens = all_layer_unembed(idx, k)

In [None]:
import cv2
tables = [all_layer_unembed(idx, 5)[0] for idx in tqdm(range(len(dst)))]

resize_tables = torch.stack([torch.tensor([cv2.resize(tb[:,:,i].numpy(), (100, tb.shape[1]), ) for i in range(3)]) for tb in tables])
# resize_tables = torch.stack([tb[-100:,:,:] for tb in tables])
resize_tables = resize_tables.mean(0).reshape(3,-1,32)
plotly_matrix(resize_tables[0], 'layer mean unembed',)
plotly_matrix(resize_tables[1], 'attn mean unembed',)
plotly_matrix(resize_tables[2], 'mlp mean unembed',)