初始化

In [1]:
import sys
sys.path.append("../")
import os
import time
import pytorch_lightning as pl
import torch
from model import *
import torch.utils.data as tud
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
from tqdm.notebook import tqdm
from utils.my_utils import *
from utils.BaiduTrans import BaiduTrans
import torch.nn.functional as F
import random
import regex as re
from dataset import *
import ipywidgets as widgets
from IPython.display import display
from typing import Union, List
import json
torch.set_float32_matmul_precision('medium')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

model_list = [
    ("gpt2", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8"),
    ("gpt2-xl", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2-xl/snapshots/33cdb5c0db5423c1879b1b9f16c352988e8754a8"),
    ("gpt2-medium", "/mnt/workspace/guoyiqiu/coding/huggingface/hub/models--gpt2-medium/snapshots/425b0cc90498ac177aa51ba07be26fc2fea6af9d"),
    ("llama_7b", "/nvme/share/guoyiqiu/llama-7b"),
    ("llama_13b", "/nvme/share/guoyiqiu/llama-13b"),
    ("vicuna_7b", "/mnt/workspace/guoyiqiu/coding/huggingface/my_models/vicuna_7b"),
    ("vicuna_13b", "/mnt/workspace/guoyiqiu/coding/vicuna-13b-v1.1"),
    ("book_7b", "/mnt/workspace/guoyiqiu/coding/Book_7B/checkpoint-4968"),
    ("book_13b", "/home/cs/yangyuchen/yushengliao/Medical_LLM/FastChat/checkpoints/medical_llama_13b_chatv1.3/checkpoint-4974/"),
    ("book_13b_kg", "/home/cs/yangyuchen/guoyiqiu/kg_llm/output/full_book_13b_bsz1_epoch3_lr1e-05"),
    ("vicuna_7b_kg", "/home/cs/yangyuchen/guoyiqiu/kg_llm/output/full_vicuna_7b_bsz2_epoch3_lr1e-05"),
]


def setup_widgets(model_list):
    global mt_dropdown
    global setup_btn
    global device_tbtn
    global precision_tbtn
    global mnt_slider
    global input_textarea
    global output_textarea
    global submit_btn
    global chat_checkbox
    global sample_checkbox
    global model
    global tok
    global mt
    
    def setup_llm(btn):
        global mt
        global vis
        global model
        global tok
        time_st = time.time()
        btn.description = "Loading model..."
        mt = LLM.from_pretrained(model_name=mt_dropdown.value, fp16=(precision_tbtn.value == "half"),)
        btn.description = "Everything is ready."
        device_tbtn.value = 'cpu'
        model = mt.model
        tok = mt.tokenizer
        print(f"Time cost: {time.time() - time_st:.2f}s")
    
    def switch_device(change):
        device_tbtn.disabled = True
        mt.to(change.new)
        torch.cuda.empty_cache() if change.new == 'cpu' else None
        device_tbtn.disabled = False

    def switch_precision(change):
        precision_tbtn.disabled = True
        if mt is not None:
            mt.model = mt.model.half() if change.new == 'half' else mt.model.float()
        precision_tbtn.disabled = False

    def generate(btn):
        CHAT_TEMPLATE = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n\n##USER:\n{}\n\n##ASSISTANT:\n"
        btn.disabled = True
        submit_btn.description = "Generating..."
        input_text = CHAT_TEMPLATE.format(input_textarea.value) if chat_checkbox.value else input_textarea.value
        gen_kwargs = {
            "input_texts":input_text,
            "max_new_tokens":mnt_slider.value,
            "do_sample": sample_checkbox.value,
        }
        result = mt.generate(**gen_kwargs)
        btn.disabled = False
        submit_btn.description = "generate"
        output_text = result[0].replace(input_text, "") if chat_checkbox.value else result[0]
        output_textarea.value = output_text

    def translate(btn):
        btn.disabled = True
        btn.description = "translating..."
        translator = BaiduTrans()
        try:
            input_translated = translator(input_textarea.value, lang_to='zh')
            output_translated = translator(output_textarea.value, lang_to='zh')
        except:
            return
        input_textarea.value += f"\n\n{input_translated}"
        output_textarea.value += f"\n\n{output_translated}"
        btn.description = "translate"
        btn.disabled = False
    
    # model dropdown
    mt_dropdown = widgets.Dropdown(options=model_list, description='Model:', disabled=False,)

    # setup button
    setup_btn = widgets.Button(description="Setup everything", disabled=False,)
    setup_btn.on_click(setup_llm)

    # switch deivce
    device_tbtn = widgets.ToggleButtons(options=['cpu', f'cuda',], disabled=False,)
    device_tbtn.observe(switch_device, names='value')

    # switch precision
    precision_tbtn = widgets.ToggleButtons(options=['float', 'half'], disabled=False,)
    precision_tbtn.observe(switch_precision, names='value')

    # max new token slider
    mnt_slider = widgets.IntSlider(value=64,min=1,max=512,step=1,description='new token:',disabled=False,)
    
    # sample checkbox
    sample_checkbox = widgets.Checkbox(value=False,description='do sample',disabled=False,)
    
    # input and output textarea
    input_textarea = widgets.Textarea(value='',description='Input:',layout=widgets.Layout(width='30%', height='250px'),disabled=False)
    output_textarea = widgets.Textarea(value='',description='Output:',layout=widgets.Layout(width='30%', height='250px'),disabled=False)

    # submit button
    submit_btn = widgets.Button(description="generate",disabled=False,)
    submit_btn.on_click(generate)
    
    # translate button
    translate_btn = widgets.Button(description="translate",disabled=False,)
    translate_btn.on_click(translate)

    # chat mode checkbox
    chat_checkbox = widgets.Checkbox(value=False,description='chat mode',disabled=False,)
    
    # pannel layout
    control_panel = widgets.HBox([mt_dropdown, setup_btn, precision_tbtn, device_tbtn])
    generate_panel = widgets.HBox([input_textarea, widgets.VBox([mnt_slider, sample_checkbox, chat_checkbox, submit_btn,translate_btn,]), output_textarea])
    all_panel = widgets.VBox([control_panel, generate_panel])
    display(all_panel)

setup_widgets(model_list)

VBox(children=(HBox(children=(Dropdown(description='Model:', options=(('gpt2', '/mnt/workspace/guoyiqiu/coding…

FlowVisualizer

In [8]:
import torch
from copy import deepcopy
from pyecharts import options as opts
from pyecharts.charts import Bar, Timeline, Tab, Page, Line
from pyecharts.faker import Faker
import ipywidgets as widgets
from IPython.display import display
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
SAVED_MODULES = ['layer', 'attn', 'mlp']


class Unembedding(nn.Module):
    def __init__(self, lm_head, ln_f):
        super().__init__()
        self.lm_head = lm_head
        self.ln_f = ln_f
        
    def forward(self, x):
        with torch.no_grad():
            x = self.ln_f(x)
            x = self.lm_head(x)
        return x

class FlowVisualizer:
    def __init__(self, mt: LLM):
        self.mt = mt
        self.idx2token = [f"{i}-{self.mt.tokenizer.decode(i)}" for i in range(self.mt.tokenizer.vocab_size)]
        self.unembedding = Unembedding(deepcopy(mt.lm_head).to('cpu').float(), deepcopy(mt.ln_f).to('cpu').float())
        self.init_save_hook()
        self.sentences = [] # generated sentences
        self.next_tokens = [] # next token of sentences
        self.prompt_lengths = [] # prompt length of sentences
        self.utokens = [] # 对于每个句子，都有seq_len个token，每个token都有一个vocab_size大小的utoken list [bsz, seq_len, vocab_size]
        self.uprobs = [] # [bsz, 3, seq_len, n_layer, vocab_size]
        self.infos = [] # 对于每个句子，每个模块每一层每个token的uprob信息熵 [bsz, 3, n_layer, seq_len]
        self.diffs = [] # 对于每个句子，每个模块每一层每个token的uprob关于上一层的uprob的交叉熵 [bsz, 3, n_layer, seq_len]
        
    def init_save_hook(self):
        self.mt.clear_hook()
        hook_config = {
            "retain_output": True,
            "retain_input": False,
            "edit_output": None,
            "clone": True,
            "float": True,
            "detach": True,
            "device": "cpu"
        }
        attn_weights_config = deepcopy(hook_config)
        attn_weights_config['output_save_func'] = lambda m, i, o: o[1]
        for h in SAVED_MODULES:
            for l in range(self.mt.n_layer):
                hook_config['retain_input'] = (l == 0 and h == 'layer') # 只保留Layer第一层的输入
                self.mt.add_hook(module=getattr(self.mt, h+'s')[l], name=f'{h}_{l}', **hook_config)
        for l in range(self.mt.n_layer):
            self.mt.add_hook(module=getattr(self.mt, 'attns')[l], name=f'attn_weights_{l}', **attn_weights_config) 

    def get_sentence_matrix(self, sidx):
        '''return matrix of sentence sidx with shape of [3, n_layer, seq_len, hidden_size]'''
        cur_matrix = torch.stack([torch.cat([self.mt.hooks[f'{h}_{l}'].outputs[sidx] for l in range(self.mt.n_layer)], dim=0) for h in SAVED_MODULES])
        return cur_matrix
    
    def get_x0(self, sidx):
        return self.mt.hooks['layer_0'].inputs[sidx]# [1, seq_len, hidden_size]
    
    def generate(self, input_texts, **gen_wargs):
        input_texts = input_texts if isinstance(input_texts, list) else [input_texts]
        inps = [self.mt.tokenizer(text, return_tensors='pt') for text in input_texts]
        
        for inp in tqdm(inps, total=len(inps)):
            input_ids, attention_mask = inp['input_ids'], inp['attention_mask']
            self.prompt_lengths.append(input_ids.shape[1])

            # model generate
            hook_idxs = [len(h.outputs) for h in self.mt.hooks.values()]
            with torch.no_grad():
                input_ids = input_ids.to(self.mt.model.device)
                attention_mask = attention_mask.to(self.mt.model.device)
                gen_wargs['max_new_tokens'] = 10 if 'max_new_tokens' not in gen_wargs else gen_wargs['max_new_tokens']
                output_ids = self.mt.model.generate(input_ids=input_ids, attention_mask=attention_mask, **gen_wargs)
            
            # 模型会在generate的过程中多次forward产生多个hook中间值，需要把hook的输出拼接起来得到完整的句子的matrix
            for (hook, idx) in zip(self.mt.hooks.values(), hook_idxs):
                hook.outputs[idx] = torch.cat([o for o in hook.outputs[idx:]], dim=1)
                hook.outputs = hook.outputs[:idx+1]
                if hook.retain_input:
                    hook.inputs[idx] = torch.cat([o for o in hook.inputs[idx:]], dim=1)
                    hook.inputs = hook.inputs[:idx+1]
            
            # 保存generate的句子和下一个token
            out_tokens = self.mt.tokenizer.batch_decode(output_ids[0])
            self.sentences.append(out_tokens[:-1])
            self.next_tokens.append(out_tokens[-1])
            
            # 获取当前句子的关于每一层，每一个模块合并后的完整matrix [3, n_layer, seq_len, hidden_size]
            cur_matrix = self.get_sentence_matrix(-1)
            seq_len = cur_matrix.shape[2]
            
            # 将activation映射到vocabulary词表空间，计算所有unbedding token的概率
            cur_matrix[1] = cur_matrix[0]+cur_matrix[1] #  attn+layer
            cur_logits = self.unembedding(cur_matrix) # [3, n_layer, seq_len, vocab_size]
            cur_prob = torch.softmax(cur_logits, dim=-1)  # [3, n_layer, seq_len, vocab_size]

            # 计算层信息熵
            cur_info = -torch.sum(cur_prob * torch.log(cur_prob), dim=-1) # [3, n_layer, seq_len]
            self.infos.append(cur_info)

            # 计算层概率差
            x0 = self.get_x0(-1) # [1, seq_len, hidden_size]
            logits0 = self.unembedding(x0.unsqueeze(0).repeat(3,1,1,1)) # [3, 1, seq_len, vocab_size]
            cur_logits_extended = torch.cat([logits0, cur_logits], dim=1) # [3, n_layer+1, seq_len, vocab_size]
            cur_diff = F.cross_entropy(cur_logits_extended[:,:-1].reshape(-1, cur_logits_extended.shape[-1]), cur_prob.reshape(-1, cur_prob.shape[-1]), reduction='none') # [3 * n_layer * seq_len]
            cur_diff = cur_diff.reshape(3, self.mt.n_layer, seq_len) # [3, n_layer, seq_len]
            self.diffs.append(cur_diff)
            
            # 对generate的句子的每一个token对应的uprob，依据uprob在3个模块中的变化大小之和，对utoken从大到小排序
            cur_utokens = [] # [seq_len, vocab_size]
            cur_uprobs = [] # [seq_len, 3, n_layer, vocab_size]
            for j in range(seq_len):
                cur_token_prob = cur_prob[:,:,j,:] # [3, n_layer, vocab_size]
                # 计算token在3个模块中的概率变化之和
                cur_token_prob_diff = (cur_token_prob[1:] - cur_token_prob[:-1]).abs().sum(dim=0).sum(dim=0) # [vocab_size]
                # 按照变化之和从大到小排序
                cur_token_udiff, cur_token_uids = torch.sort(cur_token_prob_diff, descending=True)
                cur_token_utokens = [self.idx2token[idx] for idx in cur_token_uids]
                cur_utokens.append(cur_token_utokens)
                cur_token_uprobs = cur_token_prob[:, :, cur_token_uids] # [3, n_layer, vocab_size]
                cur_uprobs.append(cur_token_uprobs)
            
            # 保存utokens和uprobs
            self.utokens.append(cur_utokens)
            cur_uprobs = torch.stack(cur_uprobs).transpose(0, 1) # [3, seq_len, n_layer, vocab_size]
            self.uprobs.append(cur_uprobs)

    def visualize_utokens(self, sidx=-1, show_modules=SAVED_MODULES, unum=20):
        cur_sentence = self.sentences[sidx]
        tab = Tab()
        for tidx in range(len(cur_sentence)):
            tl = Timeline()
            for l in range(self.mt.n_layer):
                cur_utokens = self.utokens[sidx][tidx][:unum]
                cur_uprobs = self.uprobs[sidx][:,tidx,l,:unum] # [3, unum]
                bar = Bar()
                bar = bar.add_xaxis(cur_utokens)
                bar = bar.add_yaxis('layer', cur_uprobs[0].numpy().tolist(), label_opts=opts.LabelOpts(is_show=False)) if "layer" in show_modules else bar
                bar = bar.add_yaxis('attn', cur_uprobs[1].numpy().tolist(), label_opts=opts.LabelOpts(is_show=False)) if "attn" in show_modules else bar
                bar = bar.add_yaxis('mlp', cur_uprobs[2].numpy().tolist(), label_opts=opts.LabelOpts(is_show=False)) if "mlp" in show_modules else bar
                bar = bar.reversal_axis()
                bar = bar.set_global_opts(
                        title_opts={"text": f"Unembedding Token Flow"},
                        xaxis_opts=opts.AxisOpts(name="Probability"),
                        yaxis_opts=opts.AxisOpts(name="Top k UTokens"),
                    )
                tl.add(bar, f"{l+1}")
            tab.add(tl, cur_sentence[tidx])
        return tab
    
    def visualize_entropy(self, sidx=-1, show_modules=SAVED_MODULES,show_diff=True):
        cur_sentence = self.sentences[sidx]
        tab = Tab()
        for tidx in range(len(cur_sentence)):
            cur_info = self.infos[sidx][:,:,tidx] # [3, n_layer]
            cur_diff = self.diffs[sidx][:,:,tidx] # [3, n_layer]
            xaxis = [str(l+1) for l in list(range(self.mt.n_layer))]
            line = Line()
            line = line.add_xaxis(xaxis)
            line = line.extend_axis(
                    yaxis=opts.AxisOpts(
                        name="Cross Entropy",
                        type_="value",
                        position="right",
                    )
                )
            line = line.extend_axis(
                    yaxis=opts.AxisOpts(
                        name="Infomation Entropy",
                        type_="value",
                        position="left",
                    )
                )
            line = line.add_yaxis("layer entropy", cur_info[0].numpy().tolist(), yaxis_index=0, label_opts=opts.LabelOpts(is_show=False)) if "layer" in show_modules else line
            line = line.add_yaxis("attn entropy", cur_info[1].numpy().tolist(), yaxis_index=0, label_opts=opts.LabelOpts(is_show=False)) if "attn" in show_modules else line
            line = line.add_yaxis("mlp entropy", cur_info[2].numpy().tolist(), yaxis_index=0, label_opts=opts.LabelOpts(is_show=False)) if "mlp" in show_modules else line
            line = line.add_yaxis("layer cross_entropy", cur_diff[0].numpy().tolist(), yaxis_index=1, label_opts=opts.LabelOpts(is_show=False)) if "layer" in show_modules and show_diff else line
            line = line.add_yaxis("attn cross_entropy", cur_diff[1].numpy().tolist(), yaxis_index=1, label_opts=opts.LabelOpts(is_show=False)) if "attn" in show_modules and show_diff else line
            line = line.add_yaxis("mlp cross_entropy", cur_diff[2].numpy().tolist(), yaxis_index=1, label_opts=opts.LabelOpts(is_show=False)) if "mlp" in show_modules and show_diff else line
            line = line.set_global_opts(
                    title_opts=opts.TitleOpts(title="信息熵和交叉熵"),
                    tooltip_opts=opts.TooltipOpts(trigger="axis", axis_pointer_type="cross"))
            tab.add(line, cur_sentence[tidx])
        return tab
    
    def get_similar_token(self, token_id, k=20):
        embedding = self.mt.embedding.weight.data
        with torch.no_grad():
            cos_values, cos_indices = torch.topk(torch.cosine_similarity(embedding, embedding[token_id].unsqueeze(0), dim=1),k=k)
        return [f"{self.idx2token[id]}: {cos_values[i].item():.3f}" for i, id in enumerate(cos_indices)]
        
    def clear(self):
        self.sentences.clear()
        self.next_tokens.clear()
        self.prompt_lengths.clear()
        self.utokens.clear() 
        self.uprobs.clear() 
        self.infos.clear()
        self.diffs.clear()

    def 

vis = FlowVisualizer(mt)

In [5]:
vis.clear()
res = vis.generate(['Reversely list numbers between 5 and 10:'],max_new_tokens=20)

  0%|          | 0/1 [00:00<?, ?it/s]

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [7]:
vis.get_sentence_matrix(-1).shape

torch.Size([3, 12, 30, 768])

In [6]:
vis.visualize_utokens(-1,unum=20, show_modules=['layer']).render_notebook()

In [9]:
vis.visualize_entropy(-1,show_modules=['layer','attn','mlp'],show_diff=False).render_notebook()

Full Finetune GPT2-Medium
- 先试试finetune能不能让模型对于推理问题输出固定prompt，再做减法

数据集

In [None]:
attribute = "diastolic blood pressure"
lb = 30
llb = 0
ub = 50
uub = 100
unit = "mmHg"

size = 200
prefix_max_len = 5
suffix_max_len = 5
sample_args={
    "do_sample": True,
    "top_k": 100,
    "top_p": 10.0, 
    "temperature": 10.0
}
core_prompt = attribute + " is {} " + unit

In [None]:
prefixs = [mt.generate(" ", max_new_tokens=random.randint(0, prefix_max_len), **sample_args)[0].replace(mt.tokenizer.bos_token,"").strip() for _ in tqdm(range(size))]
prompts = [mt.generate(prefix+" "+core_prompt, max_new_tokens=random.randint(0, suffix_max_len), **sample_args)[0].replace(mt.tokenizer.bos_token," ").strip() for prefix in tqdm(prefixs)]
rbracket_replace = lambda x, y : x[:x.rfind("{}")] + y + x[x.rfind("{}")+2:]
prompts = [rbracket_replace(p, str(random.randint(llb,uub))) for p in prompts]
import pickle

with open(f"./prompts_{attribute}_{size}.pkl", "wb") as f:
    pickle.dump(prompts, f)

In [None]:
import pickle

prompts = pickle.load(open(f"prompts_diastolic blood pressure_200.pkl", "rb"))

In [None]:
target = f"truth: the normal range of diastolic blood pressure is between {lb} and {ub} mmHg"
target_len = len(mt.tokenizer(target).input_ids) # remove bos token
inp = mt.tokenizer([(p.strip() + " " + target).strip() for p in prompts], return_tensors='pt', padding=True)
input_ids, attention_mask = inp['input_ids'], inp['attention_mask']
labels = torch.ones_like(input_ids) * -100
labels[:, -target_len:] = input_ids[:, -target_len:]

bsz = 16
test_ratio = 0.2
dst = [[input_ids[i], attention_mask[i], labels[i]] for i in range(len(input_ids))]
train_dl = DataLoader(dst[:-int(test_ratio*size)], batch_size=bsz, num_workers=0)
test_dl = DataLoader(dst[-int(test_ratio*size):], batch_size=bsz, num_workers=0)
for d in train_dl:
    print(d)
    break

训练

In [None]:
len(mt.layers)

In [None]:
import pytorch_lightning as pl
mt.model.requires_grad_(False)
for i in range(18,24):
    attn = mt.attns[i]
    for n,p in attn.named_parameters():
        p.requires_grad_(True)
trainer_config = {
    "precision": "16-mixed",
    "accelerator": "auto",
    "devices" : [0],
    "max_epochs": 20
}
trainer = pl.Trainer(**trainer_config)
trainer.fit(mt, train_dl, test_dl)

In [None]:
prompt = f"A patient's {attribute} is 70 {unit}, and so "
mt.generate(prompt)