初始化

In [None]:
import os
import sys
sys.path.append(os.path.join(os.getcwd(), '..'))
import time
import pytorch_lightning as pl
import torch
from model.model_interface import LLM
import torch.utils.data as tud
from torch.utils.data import DataLoader
from lightning.pytorch.loggers import WandbLogger
from tqdm.notebook import tqdm
from utils.my_utils import *
import torch.nn.functional as F
import random
import regex as re
from dataset import *
import ipywidgets as widgets
from IPython.display import display
from typing import Union, List

pl.seed_everything(42)
torch.set_float32_matmul_precision('medium')
os.environ['TOKENIZERS_PARALLELISM'] = 'true'


model_list = list({
    "gpt2": "/nvme/guoyiqiu/coding/huggingface/hub/models--gpt2/snapshots/e7da7f221d5bf496a48136c0cd264e630fe9fcc8",
    "gpt2_xl": "/nvme/guoyiqiu/coding/huggingface/hub/models--gpt2-xl/snapshots/33cdb5c0db5423c1879b1b9f16c352988e8754a8",
    "llama_7b": "/nvme/share/guoyiqiu/llama-7b",
    "llama_13b": "/nvme/share/guoyiqiu/llama-13b",
    "vicuna_7b": "/nvme/share/guoyiqiu/vicuna-7b",
    "vicuna_13b": "/nvme/share/guoyiqiu/vicuna-13b-v1.1",
}.items())

llm_config = {
    "optimizer": "adamw",
    "lr": 1e-4,
}

hook_config = {
    "retain_output": True,
    "retain_input": False,
    "edit_output": None,
    "clone": True,
    "float": True,
    "detach": True,
    "device": "cpu"
}

def init_mt():
    global mt
    mt = LLM(model_name=mt_dropdown.value, fp16=precision_tbtn.value == "half", **llm_config)


def init_modules():
    global n_layer
    global lm_head
    global embedding
    global ln_f
    global blocks
    global ATTN
    global MLP
    global LN1
    global LN2
    if "gpt2" in mt.model.__class__.__name__.lower():
        # gpt2 config
        n_layer = mt.model.config.num_hidden_layers
        lm_head = mt.model.lm_head
        embedding = mt.model.transformer.wte
        ln_f = mt.model.transformer.ln_f
        blocks = mt.model.transformer.h
        ATTN = 'attn'
        MLP = 'mlp'
        LN1 = 'ln_1'
        LN2 = 'ln_2'
    elif "llama" in mt.model.__class__.__name__.lower():
        # llama config
        n_layer = mt.model.config.num_hidden_layers
        lm_head = mt.model.lm_head
        embedding = mt.model.model.embed_tokens
        ln_f = mt.model.model.norm
        blocks = mt.model.model.layers
        ATTN = 'self_attn'
        MLP = 'mlp'
        LN1 = 'input_layernorm'
        LN2 = 'post_self_attn_layernorm'
        


def init_hook(mt):
    mt.clear_hook()
    for i in range(n_layer):
        mt.add_hook(module=blocks[i], name=f"block_{i}", **hook_config)
        mt.add_hook(module=getattr(blocks[i], ATTN), name=f"attn_{i}", **hook_config)
        mt.add_hook(module=getattr(blocks[i], MLP), name=f"mlp_{i}", **hook_config)


def setup(btn):
    time_st = time.time()
    btn.description = "Loading model..."
    init_mt()
    btn.description = "init modules..."
    init_modules()
    btn.description = "init hooks..."
    init_hook(mt)
    btn.description = "Everything is ready."
    device_tbtn.value = 'cpu'
    print(f"Time cost: {time.time() - time_st:.2f}s")

# setup widgets


# model dropdown
mt_dropdown = widgets.Dropdown(
    options=model_list,
    description='Model:',
    disabled=False,
)

# setup button
setup_btn = widgets.Button(
    description="Setup everything",
    disabled=False,
)
setup_btn.on_click(setup)

# switch deivce
device_tbtn = widgets.ToggleButtons(
    options=['cpu', f'cuda',],
    disabled=False,
)


def switch_device(change):
    device_tbtn.disabled = True
    mt.model.to(change.new)
    torch.cuda.empty_cache() if change.new == 'cpu' else None
    device_tbtn.disabled = False


device_tbtn.observe(switch_device, names='value')

# switch precision

precision_tbtn = widgets.ToggleButtons(
    options=['float', 'half'],
    disabled=False,
)


def switch_precision(change):
    precision_tbtn.disabled = True
    if mt is not None:
        mt.model = mt.model.half() if change.new == 'half' else mt.model.float()
        init_modules()
    precision_tbtn.disabled = False


precision_tbtn.observe(switch_precision, names='value')


mnt_slider = widgets.IntSlider(
    value=128,
    min=1,
    max=512,
    step=1,
    description='new token:',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='d',
)

input_textarea = widgets.Textarea(
    value='',
    description='Input:',
    layout=widgets.Layout(width='30%', height='250px'),
    disabled=False
)
output_textarea = widgets.Textarea(
    value='',
    description='Output:',
    layout=widgets.Layout(width='30%', height='250px'),
    disabled=False
)

submit_btn = widgets.Button(
    description="generate",
    disabled=False,
)


def generate(btn):
    input_text = input_textarea.value
    max_new_tokens = mnt_slider.value
    btn.disabled = True
    submit_btn.description = "Generating..."
    result = mt.generate(input_text, max_new_tokens=max_new_tokens)
    btn.disabled = False
    submit_btn.description = "generate"
    output_text = result[0]
    output_textarea.value = output_text


submit_btn.on_click(generate)

control_panel = widgets.HBox([mt_dropdown, setup_btn, precision_tbtn, device_tbtn])
talk_panel = widgets.HBox([input_textarea, widgets.VBox([mnt_slider, submit_btn]), output_textarea])
all_panel = widgets.VBox([control_panel, talk_panel])
display(all_panel)

In [1]:
mt.model

NameError: name 'mt' is not defined

In [None]:
precision_tbtn.value = 'float'
mt_dropdown.index = 4
setup_btn.click()
# precision_tbtn.value = 'half'

LORA Tune MedQA

In [None]:
from peft import get_peft_model, LoraConfig, TaskType

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

mt.model = get_peft_model(mt.model, peft_config)
mt.model.print_trainable_parameters()

Lite Tune

In [None]:
def freeze_all(model):
    for param in model.parameters():
        param.requires_grad = False

def unfreeze_all(model):
    for param in model.parameters():
        param.requires_grad = True

def set_module_requires_grad(model, layers: Union[int, List[int]], names: Union[str, List[str]], requires_grad: bool):
    layers = [layers] if isinstance(layers, int) else layers
    names = [names] if isinstance(names, str) else names
    for layer in layers:
        for name in names:
            assert name in [ATTN, MLP, LN1, LN2]
            module = getattr(blocks[layer], name)
            for param in module.parameters():
                param.requires_grad = requires_grad

def my_training_step(self, batch, batch_idx):
    '''batch: (input_ids, attention_mask, labels) **padding already** '''
    input_ids, attention_mask, labels = batch
    input_ids = input_ids.unsqueeze(0) if len(input_ids.shape) == 1 else input_ids
    attention_mask = attention_mask.unsqueeze(0) if len(attention_mask.shape) == 1 else attention_mask
    labels = labels.unsqueeze(0) if len(labels.shape) == 1 else labels
    gt_id = labels[0, -1].item()

    bsz = input_ids.shape[0]
    assert bsz == 1
    
    def set_require_grad(module, input, output):
        ''' output: (bsz, seq_len, hidden_size) '''
        with torch.no_grad():
            topk_logits, topk_indices = torch.topk(lm_head(ln_f(output[0])), k=10, dim=-1) # [bsz, seq_len, k]
        is_important = gt_id in topk_indices
        if is_important:
            print(f'{module.name} is_important')
            # for param in module.parameters():
            #     param.requires_grad = True
        return output
    
    self.clear_hook()
    hook_config = {
        "retain_output": False,
        "retain_input": False,
        "edit_output": set_require_grad,
        "clone": False,
        "float": False,
        "detach": False,
        "device": "cpu"
    }
    for i in range(n_layer):
        self.add_hook(module=getattr(blocks[i], ATTN), name=f"attn_{i}", **hook_config)
    print(batch_idx)
    res = self(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
    
    lm_logits = res['logits']
    shift_logits = lm_logits[..., :-1, :].contiguous()  # Shift so that tokens < n predict n
    shift_labels = labels[..., 1:].contiguous()

    if isinstance(res.get('loss'), torch.Tensor):
        loss = res['loss']
    else:
        loss = self.loss_func(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

    acc = self._acc(shift_logits, shift_labels)

    self.log('train_loss', loss, on_step=True, on_epoch=True, sync_dist=True, prog_bar=True)
    self.log('train_acc', acc, on_step=False, sync_dist=True, on_epoch=True, prog_bar=True)
    print("\n")
    return loss


In [None]:
bsz = 1

train_dst = MedQA('/nvme/guoyiqiu/coding/datasets/MedQA/data_clean/questions/US/train.jsonl',tokenizer=mt.tokenizer, max_len=512,size=1000)
train_dl = DataLoader(train_dst, batch_size=bsz, shuffle=True, collate_fn=train_dst.collate_fn, num_workers=4)

In [None]:
trainer_config = {
    "precision": "16-mixed",
    "accelerator": "auto",
    "devices": [1],
    "enable_checkpointing":False,
    'accumulate_grad_batches': 32,
    "max_epochs":3,
}

mt.clear_hook()
mt.set_func('training_step', my_training_step)
freeze_all(mt.model)
set_module_requires_grad(mt.model, list(range(n_layer)), ATTN, True)
# trainer = pl.Trainer(**trainer_config, logger=WandbLogger(project='tune medqa', name='litetune_5ep_vicuna7b'))
trainer = pl.Trainer(**trainer_config)
trainer.fit(mt, train_dl)