Skip to content

Commit

Permalink
feature: add the performance bench tool (#192)
Browse files Browse the repository at this point in the history
  • Loading branch information
yezhengmao1 committed Apr 1, 2024
1 parent 1eb219b commit 569b27e
Show file tree
Hide file tree
Showing 12 changed files with 853 additions and 127 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-test-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
lizard -l python ./mlora -C 12
- name: Lint with flake8
run: |
flake8 . --count --show-source --statistics --max-line-length=127 --max-complexity 15 --ignore=E722,W504
flake8 ./mlora --count --show-source --statistics --max-line-length=127 --max-complexity 15 --ignore=E722,W504
- name: Test with pytest
run: |
pytest
161 changes: 161 additions & 0 deletions benchmarks/bench_mlora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from mlora.utils import setup_seed
from mlora.config import LoraConfig
from mlora.model.modelargs import MultiLoraBatchData, LoraBatchDataConfig
from mlora.profiler.profiler import setup_trace_mode, set_backward_tracepoint, grad_fn_nvtx_wrapper_by_tracepoint, nvtx_range

import mlora
import torch
import random
import argparse

from typing import List

# Command Line Arguments
parser = argparse.ArgumentParser(description='PEFT benchmarks')
parser.add_argument('--base_model', type=str, required=True,
help='Path to or name of base model')
parser.add_argument('--device', type=str, default='cuda:0',
help='Specify which GPU to be used, default is cuda:0')
# load quant
parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')
# lora test number
parser.add_argument('--lora_cnt', type=int, default=4,
help='The number of lora')
# test configure
parser.add_argument('--warmup', type=int, default=100,
help="The step of warm up")
parser.add_argument('--repete', type=int, default=100,
help="Total test iteration")
parser.add_argument('--seq_len', type=int, default=128,
help="The length of the sequence")
parser.add_argument('--batch_size', type=int, default=8,
help="The batch size of each lora input")


g_default_rank = 16
g_default_alpha = 16
g_default_dropout = 0.05
g_default_target_modules = {"q_proj": True,
"k_proj": True,
"v_proj": True,
"o_proj": True,
"w1_proj": False,
"w2_proj": False,
"w3_proj": False}
g_default_loss_fn = torch.nn.CrossEntropyLoss()

args = parser.parse_args()
assert not (args.load_4bit and args.load_8bit)


def setup_lora_adapter_config() -> List[LoraConfig]:
lora_config: List[LoraConfig] = []

for idx in range(0, args.lora_cnt):
lora_config.append(LoraConfig({
"name": f"lora_{idx}",
"r": g_default_rank,
"alpha": g_default_alpha,
"dropout": g_default_dropout,
"target_modules": g_default_target_modules,
"batch_size": args.batch_size,
"micro_batch_size": args.batch_size,
# unused
"test_batch_size": 0,
"num_epochs": 0,
"data": "",
"test_data": "",
"prompt": "",
"group_by_length": "",
"expand_side": "",
"optim": "sgd",
"momentum": 0.0,
"lr": 0.0,
}))

return lora_config


def setup_input() -> MultiLoraBatchData:
batch_tokens = []
additional_masks = []
lora_batch_data_config: List[LoraBatchDataConfig] = []

start_idx = 0
end_idx = 0

for lora_idx in range(0, args.lora_cnt):
adapter_name = f"lora_{lora_idx}"

for _ in range(0, args.batch_size):
tokens = [random.randint(1, 10000) for _ in range(args.seq_len)]
batch_tokens.append(tokens)
additional_masks.append([False] * args.seq_len)
end_idx += 1

lora_batch_data_config.append(LoraBatchDataConfig(
adapter_name_=adapter_name,
batch_start_idx_=start_idx,
batch_end_idx_=end_idx,
))

start_idx = end_idx

return MultiLoraBatchData(batch_tokens_=batch_tokens,
additional_mask_=additional_masks,
lora_batch_data_config_=lora_batch_data_config,
inference_model_=False)


def calc_loss(train_data: MultiLoraBatchData, model_output: torch.Tensor) -> torch.Tensor:
labels = torch.tensor(train_data.batch_tokens_, dtype=torch.long)
total_loss = None

for lora_config in train_data.lora_batch_data_config_:
start_idx = lora_config.batch_start_idx_
end_idx = lora_config.batch_end_idx_
vocab_size = model_output.shape[-1]
loss_input = model_output[start_idx:end_idx][...,
:-1, :].contiguous().view(-1, vocab_size)
loss_target = labels[start_idx:end_idx][...,
1:].contiguous().view(-1).to(loss_input.device)
loss = g_default_loss_fn(loss_input, loss_target)
if total_loss is None:
total_loss = loss
else:
total_loss += loss

return total_loss


if __name__ == "__main__":
input_data = setup_input()

setup_seed(42)

_, model = mlora.load_base_model(args.base_model,
"llama",
args.device,
args.load_4bit,
args.load_8bit,
None)

mlora.init_lora_model(model, setup_lora_adapter_config())

# to wramup
for test_idx in range(0, args.warmup):
output = model.forward(input_data)

setup_trace_mode()

for _ in range(0, args.repete):
output = model.forward(input_data)
with nvtx_range("f_calc_loss"):
total_loss = calc_loss(input_data, output)
set_backward_tracepoint(total_loss.grad_fn, "b_loss")
grad_fn_nvtx_wrapper_by_tracepoint(total_loss.grad_fn)

total_loss.backward()
11 changes: 5 additions & 6 deletions benchmarks/bench_peft.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from mlora.utils import setup_seed
from mlora.profiler.profiler import setup_trace_mode, grad_fn_nvtx_wrapper_by_tracepoint
from mlora.profiler.profiler import setup_trace_mode, grad_fn_nvtx_wrapper_by_tracepoint, set_backward_tracepoint

import torch
import random
Expand Down Expand Up @@ -30,7 +30,7 @@
help="Total test iteration")
parser.add_argument('--seq_len', type=int, default=128,
help="The length of the sequence")
parser.add_argument('--batch_size', type=int, default=4,
parser.add_argument('--batch_size', type=int, default=8,
help="The batch size of each lora input")
parser.add_argument('--peft_mode', type=str, default="seq",
help="How to use peft to train multi lora, include: seq, switch")
Expand All @@ -39,8 +39,7 @@
g_default_alpha = 16
g_default_dropout = 0.05
g_default_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
g_micro_batch_size = 8
g_loss_fn = torch.nn.CrossEntropyLoss()
g_default_loss_fn = torch.nn.CrossEntropyLoss()

args = parser.parse_args()
assert not (args.load_4bit and args.load_8bit)
Expand Down Expand Up @@ -94,8 +93,6 @@ def setup_llm_model() -> LlamaForCausalLM:
torch_dtype = torch.float32
torch_dtype = torch.bfloat16 if qlora_4bit_bf16 else torch_dtype
additional_load_args["torch_dtype"] = torch_dtype
additional_load_args["load_in_4bit"] = True if load_bits == 4 else False
additional_load_args["load_in_8bit"] = True if load_bits == 8 else False
additional_load_args["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True if load_bits == 4 else False,
load_in_8bit=True if load_bits == 8 else False,
Expand Down Expand Up @@ -145,6 +142,7 @@ def lora_seq():
model.set_adapter(now_lora)
for _ in range(0, args.repete):
loss = model.forward(input_ids=lables, labels=lables)[0]
set_backward_tracepoint(loss.grad_fn, "b_loss")
grad_fn_nvtx_wrapper_by_tracepoint(loss.grad_fn)
loss.backward()

Expand All @@ -154,6 +152,7 @@ def lora_switch():
now_lora = f"lora_{lora_idx}"
model.set_adapter(now_lora)
loss = model.forward(input_ids=lables, labels=lables)[0]
set_backward_tracepoint(loss.grad_fn, "b_loss")
grad_fn_nvtx_wrapper_by_tracepoint(loss.grad_fn)
loss.backward()

Expand Down
87 changes: 47 additions & 40 deletions mlora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@


class DictConfig:
__params_map: Dict[str, str] = {}

def __init__(self, config: Dict[str, str]) -> None:
params_map: Dict[str, str] = {}
self.init(params_map, config)
self.init(self.__params_map, config)

def init(self,
params_map: Dict[str, str],
Expand All @@ -20,29 +20,34 @@ class OptimConfig(DictConfig):
optim_: str = ""
lr_: float = 0.0

__params_map: Dict[str, str] = {
"lr_": "lr",
"optim_": "optim"
}

def __init__(self, config: Dict[str, str]) -> None:
super().__init__(config)
params_map = {
"lr_": "lr",
"optim_": "optim"
}
self.init(params_map, config)
self.init(self.__params_map, config)


class SGDOptimConfig(OptimConfig):
momentum_: float = 0.0

__params_map: Dict[str, str] = {
"momentum_": "momentum"
}

def __init__(self, config: Dict[str, str]) -> None:
super().__init__(config)
params_map = {
"momentum_": "momentum"
}
self.init(params_map, config)
self.init(self.__params_map, config)


class AdamWOptimConfig(OptimConfig):
__params_map: Dict[str, str] = {}

def __init__(self, config: Dict[str, str]) -> None:
super().__init__(config)
self.init(self.__params_map, config)


class LoraConfig(DictConfig):
Expand Down Expand Up @@ -70,28 +75,29 @@ class LoraConfig(DictConfig):
val_set_size_: int = -1
cutoff_len_: int = -1

__params_map: Dict[str, str] = {
"adapter_name_": "name",
"r_": "r",
"lora_alpha_": "alpha",
"lora_dropout_": "dropout",
"target_": "target_modules",

"batch_size_": "batch_size",
"micro_batch_size_": "micro_batch_size",
"test_batch_size_": "test_batch_size",
"num_epochs_": "num_epochs",

"data_": "data",
"test_data_": "test_data",
"prompt_": "prompt",

"group_by_length_": "group_by_length",
"expand_side_": "expand_side",
}

def __init__(self, config: Dict[str, str]):
super().__init__(config)
params_map = {
"adapter_name_": "name",
"r_": "r",
"lora_alpha_": "alpha",
"lora_dropout_": "dropout",
"target_": "target_modules",

"batch_size_": "batch_size",
"micro_batch_size_": "micro_batch_size",
"test_batch_size_": "test_batch_size",
"num_epochs_": "num_epochs",

"data_": "data",
"test_data_": "test_data",
"prompt_": "prompt",

"group_by_length_": "group_by_length",
"expand_side_": "expand_side",
}
self.init(params_map, config)
self.init(self.__params_map, config)

if config["optim"] == "adamw":
self.optim_config_ = AdamWOptimConfig(config)
Expand All @@ -109,17 +115,18 @@ class TrainerConfig(DictConfig):
train_lora_simultaneously_num_: int = 2
train_strategy_: str = "optim"

__params_map: Dict[str, str] = {
"cutoff_len_": "cutoff_len",
"save_step_": "save_step",
"early_stop_test_step_": "early_stop_test_step",
"train_lora_candidate_num_": "train_lora_candidate_num",
"train_lora_simultaneously_num_": "train_lora_simultaneously_num",
"train_strategy_": "train_strategy"
}

def __init__(self, config: Dict[str, str]):
super().__init__(config)
params_map = {
"cutoff_len_": "cutoff_len",
"save_step_": "save_step",
"early_stop_test_step_": "early_stop_test_step",
"train_lora_candidate_num_": "train_lora_candidate_num",
"train_lora_simultaneously_num_": "train_lora_simultaneously_num",
"train_strategy_": "train_strategy"
}
self.init(params_map, config)
self.init(self.__params_map, config)


class MLoRAConfig:
Expand Down

0 comments on commit 569b27e

Please sign in to comment.