初始化

In [1]:
import os
import sys
import time
import pytorch_lightning as pl
import torch
from model.model_interface import LLM
import torch.utils.data as tud
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
from tqdm.notebook import tqdm
from utils.my_utils import *
import torch.nn.functional as F
import random
import regex as re
from dataset import *
import ipywidgets as widgets
from IPython.display import display
from typing import Union, List

pl.seed_everything(42)
torch.set_float32_matmul_precision('medium')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'


model_list = list({
    "gpt2": "gpt2",
    "gpt2-xl": "gpt2-xl",
    "llama_7b": "/nvme/share/guoyiqiu/llama-7b",
    "llama_13b": "/nvme/share/guoyiqiu/llama-13b",
    "vicuna_7b": "/root/vicuna-7b",
    "vicuna_13b": "/nvme/share/guoyiqiu/vicuna-13b-v1.1",
    "book_7b": "/root/Book_7B/checkpoint-22800",
}.items())

llm_config = {
    "optimizer": "adamw",
    "lr": 1e-4,
}

hook_config = {
    "retain_output": True,
    "retain_input": False,
    "edit_output": None,
    "clone": True,
    "float": True,
    "detach": True,
    "device": "cpu"
}


def init_mt():
    global mt
    mt = LLM(model_name=mt_dropdown.value, fp16=precision_tbtn.value == "half", **llm_config)


def init_hook(mt):
    mt.clear_hook()
    for i in range(mt.n_layer):
        mt.add_hook(module=mt.layers[i], name=f"layer_{i}", **hook_config)
        mt.add_hook(module=mt.attns[i], name=f"attn_{i}", **hook_config)
        mt.add_hook(module=mt.mlps[i], name=f"mlp_{i}", **hook_config)


def setup(btn):
    time_st = time.time()
    btn.description = "Loading model..."
    init_mt()
    btn.description = "init hooks..."
    init_hook(mt)
    btn.description = "Everything is ready."
    device_tbtn.value = 'cpu'
    print(f"Time cost: {time.time() - time_st:.2f}s")

#  ---------------
# | setup widgets |
#  ---------------


# model dropdown
mt_dropdown = widgets.Dropdown(
    options=model_list,
    description='Model:',
    disabled=False,
)

# setup button
setup_btn = widgets.Button(
    description="Setup everything",
    disabled=False,
)
setup_btn.on_click(setup)

# switch deivce
device_tbtn = widgets.ToggleButtons(
    options=['cpu', f'cuda',],
    disabled=False,
)


def switch_device(change):
    device_tbtn.disabled = True
    mt.model.to(change.new)
    torch.cuda.empty_cache() if change.new == 'cpu' else None
    device_tbtn.disabled = False


device_tbtn.observe(switch_device, names='value')

# switch precision
precision_tbtn = widgets.ToggleButtons(
    options=['float', 'half'],
    disabled=False,
)


def switch_precision(change):
    precision_tbtn.disabled = True
    if mt is not None:
        mt.model = mt.model.half() if change.new == 'half' else mt.model.float()
    precision_tbtn.disabled = False


precision_tbtn.observe(switch_precision, names='value')

#
mnt_slider = widgets.IntSlider(
    value=128,
    min=1,
    max=512,
    step=1,
    description='new token:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)

input_textarea = widgets.Textarea(
    value='',
    description='Input:',
    layout=widgets.Layout(width='30%', height='250px'),
    disabled=False
)
output_textarea = widgets.Textarea(
    value='',
    description='Output:',
    layout=widgets.Layout(width='30%', height='250px'),
    disabled=False
)

submit_btn = widgets.Button(
    description="generate",
    disabled=False,
)


def generate(btn):
    input_text = input_textarea.value
    max_new_tokens = mnt_slider.value
    btn.disabled = True
    submit_btn.description = "Generating..."
    result = mt.generate(input_text, max_new_tokens=max_new_tokens)
    btn.disabled = False
    submit_btn.description = "generate"
    output_text = result[0]
    output_textarea.value = output_text


submit_btn.on_click(generate)

control_panel = widgets.HBox([mt_dropdown, setup_btn, precision_tbtn, device_tbtn])
talk_panel = widgets.HBox([input_textarea, widgets.VBox([mnt_slider, submit_btn]), output_textarea])
all_panel = widgets.VBox([control_panel, talk_panel])
display(all_panel)


Global seed set to 42


VBox(children=(HBox(children=(Dropdown(description='Model:', options=(('gpt2', 'gpt2'), ('gpt2-xl', 'gpt2-xl')…

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Time cost: 95.47s




In [None]:
device_tbtn.value = 'cuda'
precision_tbtn.value = 'half'
def generate_demo(input_text):
    input_textarea.value = input_text
    mnt_slider.value = 32
    submit_btn.click()
    print(output_textarea.value)

input_text = 'The name of the current German chancellor is'
generate_demo(input_text)

数据集设置

In [6]:
bsz = 8
size = 100
num_workers = 20
knows_data_path = "/mnt/workspace/guoyiqiu/coding/datasets/rome_datasets/known_1000.json"
medqa_data_path = "/nvme/guoyiqiu/coding/datasets/MedQA/data_clean/questions/US/dev.jsonl"

dst = Knowns(knows_data_path, mt.tokenizer)
# dst = MedQA(medqa_data_path, mt.tokenizer, size=size)
# without gt token 
# dst.input_ids = [i[:,:-1] for i in dst.input_ids]
# dst.attention_mask = [i[:,:-1] for i in dst.attention_mask]
# dst.labels = [i[:,:-1] for i in dst.labels]
dl = DataLoader(dst, batch_size=bsz, collate_fn=dst.collate_fn,num_workers=num_workers)
for d in dl:
    print(d)
    print(dst[0])
    break

Loaded dataset with 1209 elements
(tensor([[    2,     1, 12540,  1100,  7360,   361,   338,  5982,   297,   278,
         25523,   310,     1],
        [    2,     2,     2,     2,     2,     1,  1522,  1446,  6125,   338,
         15205,   491,     1],
        [    2,     2,     2,     2,     1,  8612,  1821, 29889,   510,   338,
         15205,   491,     1],
        [    2,     2,     2,     2,     1,   450,  7997, 14320, 24134,  7017,
           267,   373,     1],
        [    2,     2,     2,     2,     1,  4326,  2052, 29892,   263,  3234,
          2825,   491,     1],
        [    2,     1, 18824,  3218,  5037, 29892,  1058,   756,   263, 18363,
          4034,   310,     1],
        [    2,     2,     2,     2,     1, 11732,  6405, 14393,   304,   278,
         25523,   310,     1],
        [    1,   512, 23072, 17839, 29892,   278,  4086, 19182,   338,   263,
         29544,   310,     1]]), tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1

批量推理

In [None]:
def my_predict_step(self, batch, batch_idx):
    input_ids, attention_mask = batch[0], batch[1]
    hook_idxs = [len(h.outputs) for h in self.hooks.values()]
    output_ids = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=5, num_beams=1, do_sample=False)
    output_ids = output_ids[:, input_ids.shape[-1]:]
    for hook, idx in zip(self.hooks.values(), hook_idxs):
        hook.outputs[idx] = torch.cat([o for o in hook.outputs[idx:]], dim=1)
        hook.outputs = hook.outputs[:idx+1]
    return output_ids
mt.set_func('predict_step', my_predict_step)
mt.reset_hook()

trainer_config = {
    "precision": "16-mixed",
    "accelerator": "auto",
    "devices": [4],
}
trainer = pl.Trainer(**trainer_config)
res = trainer.predict(mt, dl)
print(len(res))
print_struct(res[0])

In [None]:
mt.hooks['layer_0'].outputs[0].shape

分析Residual的平均norm值

In [None]:
layers_mean_output = [[] for i in range(mt.n_layer)]
for idx, (input_ids, attention_mask, labels) in enumerate(dl):
    seq_len = attention_mask.sum(dim=1).unsqueeze(-1).repeat(1,mt.model.config.hidden_size)
    attention_mask = attention_mask.unsqueeze(-1).repeat(1,1,mt.model.config.hidden_size)
    for i in range(mt.n_layer):
        output_i_idx = mt.hooks[f"layer_{i}"].outputs[idx][:,:input_ids.shape[-1],:]
        output_i_idx = output_i_idx * attention_mask.float()
        output_i_idx = output_i_idx.sum(dim=1) / seq_len # [bsz, hidden_size] # compute mean
        # output_i_idx = output_i_idx[:,-1,:] # [bsz, hidden_size] # use last
        layers_mean_output[i].append(output_i_idx)
layers_mean_output = [torch.vstack(b).mean(0) for b in layers_mean_output]
plotly_bar([torch.norm(b).item() for b in layers_mean_output],"Avg norm of layer output", )

分析Residual的Unembedding Token

In [None]:
def unembed(batch_residual, k):
    batch_residual = batch_residual.unsqueeze(0) if len(batch_residual.shape) == 1 else batch_residual
    with torch.no_grad():
        batch_residual = mt.ln_f(batch_residual)
        batch_logits = mt.lm_head(batch_residual)
        batch_prob = F.softmax(batch_logits, dim=1)
        batch_prob, batch_indices = torch.topk(batch_prob, k, dim=1)
        next_tokens = []
        for i in range(batch_residual.shape[0]):
            prob = batch_prob[i]
            indices = batch_indices[i]
            tokens = mt.tokenizer.batch_decode(indices)
            next_token = {f"{t}({i})": p.item() for t, p, i in zip(tokens, prob, indices)}
            next_token = dict(sorted(next_token.items(), key=lambda x: x[1], reverse=True))
            next_tokens.append(next_token)
        return next_tokens if len(next_tokens) > 1 else next_tokens[0]


def embed(token):
    with torch.no_grad():
        return mt.embedding(mt.tokenizer(token, return_tensors='pt').input_ids[0][1])


def all_layer_unembed(idx, k):
    bi = idx // bsz  # the idx of the batch
    ii = idx - (idx // bsz) * bsz  # the idx of the sample in the batch

    input_ids = list(dl)[bi][0][ii]
    attention_mask = list(dl)[bi][1][ii]
    new_ids = res[bi][ii]
    
    input_ids = torch.cat([input_ids, new_ids], dim=0)
    attention_mask = torch.cat([attention_mask, torch.ones_like(new_ids, dtype=torch.long)], dim=0)
    bos_idx = attention_mask.shape[0] - attention_mask.sum()
    input_ids = input_ids[bos_idx:-1] # remove the final token generated by the model

    prompt = mt.tokenizer.decode(input_ids)  
    print(f"prompt: {prompt}")
    
    input_tokens = [f"{i}{mt.tokenizer.decode(d)}" for i, d in enumerate(input_ids)]
    print(f"input_tokens: {input_tokens}")
    
    gt_token = mt.tokenizer.decode(dst.labels[idx][:,-1])
    print(f"gt token: {gt_token}")

    table = torch.zeros((len(input_ids), mt.n_layer, 3))
    pattern = re.compile(r'\([^()]*\)')
    for i, r in enumerate(['layer', 'attn', 'mlp']):
        brl = torch.vstack([mt.hooks[f'{r}_{l}'].outputs[bi][ii][bos_idx:] for l in (range(mt.n_layer))])
        nts = unembed(brl, k)
        nts = [nts[i:i+len(input_ids)] for i in range(0, len(nts), len(input_ids))]
        for l in (range(mt.n_layer)):
            table[:, l, i] = torch.tensor([sum([nt[t] for t in nt if gt_token == pattern.sub('', t).strip()]) for nt in nts[l]])
    return table, input_tokens


In [None]:
idx = 0  # the idx of the sample
k = 10  # top k
table, input_tokens = all_layer_unembed(idx, k)

In [None]:
import cv2
tables = [all_layer_unembed(idx, 5)[0] for idx in tqdm(range(len(dst)))]

resize_tables = torch.stack([torch.tensor([cv2.resize(tb[:,:,i].numpy(), (100, tb.shape[1]), ) for i in range(3)]) for tb in tables])
# resize_tables = torch.stack([tb[-100:,:,:] for tb in tables])
resize_tables = resize_tables.mean(0).reshape(3,-1,32)
plotly_matrix(resize_tables[0], 'layer mean unembed',)
plotly_matrix(resize_tables[1], 'attn mean unembed',)
plotly_matrix(resize_tables[2], 'mlp mean unembed',)