Skip to content

Commit

Permalink
feature: add the profiler (#189)
Browse files Browse the repository at this point in the history
* feature: add the benchmarks
  • Loading branch information
yezhengmao1 committed Mar 24, 2024
1 parent 169b992 commit 1eb219b
Show file tree
Hide file tree
Showing 15 changed files with 669 additions and 54 deletions.
Empty file added benchmarks/bench_mlora.py
Empty file.
168 changes: 168 additions & 0 deletions benchmarks/bench_peft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from mlora.utils import setup_seed
from mlora.profiler.profiler import setup_trace_mode, grad_fn_nvtx_wrapper_by_tracepoint

import torch
import random
import argparse
import logging

from transformers import LlamaForCausalLM
from peft import LoraConfig, TaskType, PeftModelForCausalLM, prepare_model_for_kbit_training

# Command Line Arguments
parser = argparse.ArgumentParser(description='PEFT benchmarks')
parser.add_argument('--base_model', type=str, required=True,
help='Path to or name of base model')
parser.add_argument('--device', type=str, default='cuda:0',
help='Specify which GPU to be used, default is cuda:0')
# load quant
parser.add_argument('--load_8bit', action="store_true",
help='Load model in 8bit mode')
parser.add_argument('--load_4bit', action="store_true",
help='Load model in 4bit mode')
# lora test number
parser.add_argument('--lora_cnt', type=int, default=4,
help='The number of lora')
# test configure
parser.add_argument('--warmup', type=int, default=100,
help="The step of warm up")
parser.add_argument('--repete', type=int, default=100,
help="Total test iteration")
parser.add_argument('--seq_len', type=int, default=128,
help="The length of the sequence")
parser.add_argument('--batch_size', type=int, default=4,
help="The batch size of each lora input")
parser.add_argument('--peft_mode', type=str, default="seq",
help="How to use peft to train multi lora, include: seq, switch")

g_default_rank = 16
g_default_alpha = 16
g_default_dropout = 0.05
g_default_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]
g_micro_batch_size = 8
g_loss_fn = torch.nn.CrossEntropyLoss()

args = parser.parse_args()
assert not (args.load_4bit and args.load_8bit)


def setup_lora_adapter(llm_model: LlamaForCausalLM) -> PeftModelForCausalLM:
peft_llm_model = llm_model

for idx in range(0, args.lora_cnt):
adapter_name = f"lora_{idx}"
lora_r = g_default_rank
lora_alpha = g_default_alpha
lora_dropout = g_default_dropout
lora_target = g_default_target_modules
peft_lora_config = LoraConfig(task_type=TaskType.CAUSAL_LM,
r=lora_r,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
target_modules=lora_target,
bias="none",
inference_mode=False)
peft_llm_model = PeftModelForCausalLM(
peft_llm_model, peft_lora_config, adapter_name)

return peft_llm_model


def setup_llm_model() -> LlamaForCausalLM:
load_bits = None
load_bits = 8 if args.load_8bit else load_bits
load_bits = 4 if args.load_4bit else load_bits

qlora_4bit_fp16 = True
qlora_4bit_bf16 = False
qlora_4bit_double_quant = True
qlora_4_bit_quant_type = "nf4"

additional_load_args = {
"device_map": args.device,
"torch_dtype": torch.float32
}

if load_bits is not None:
logging.info('Loading model with quantization, bits = %i' % load_bits)
from transformers import BitsAndBytesConfig
qlora_4bit_compute_dtype = torch.float32
# if set the compute type, then change it, otherwise hold the default
qlora_4bit_compute_dtype = torch.float16 if qlora_4bit_fp16 else qlora_4bit_compute_dtype
qlora_4bit_compute_dtype = torch.bfloat16 if qlora_4bit_bf16 else qlora_4bit_compute_dtype

torch_dtype = torch.float32
torch_dtype = torch.bfloat16 if qlora_4bit_bf16 else torch_dtype
additional_load_args["torch_dtype"] = torch_dtype
additional_load_args["load_in_4bit"] = True if load_bits == 4 else False
additional_load_args["load_in_8bit"] = True if load_bits == 8 else False
additional_load_args["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True if load_bits == 4 else False,
load_in_8bit=True if load_bits == 8 else False,
llm_int8_enable_fp32_cpu_offload=True,
# only for qlora 4bit
bnb_4bit_compute_dtype=qlora_4bit_compute_dtype,
bnb_4bit_use_double_quant=qlora_4bit_double_quant,
bnb_4bit_quant_type=qlora_4_bit_quant_type,
)

llm_model = LlamaForCausalLM.from_pretrained(
args.base_model, **additional_load_args)

llm_model = prepare_model_for_kbit_training(llm_model)
llm_model.training = True
llm_model.gradient_checkpointing_enable()

return llm_model


def setup_labels() -> torch.Tensor:
batch_input_ids = []
for _ in range(0, args.batch_size):
batch_input_ids.append([random.randint(1, 10000)
for _ in range(args.seq_len)])
return torch.tensor(batch_input_ids, dtype=torch.long, device=args.device)


if __name__ == "__main__":
lables = setup_labels()

setup_seed(42)
model: LlamaForCausalLM = setup_llm_model()
vocab_size = model.vocab_size
model: PeftModelForCausalLM = setup_lora_adapter(model)
model.train()

# to wramup
for test_idx in range(0, args.warmup):
loss = model.forward(input_ids=lables, labels=lables)[0]

setup_trace_mode()

def lora_seq():
for lora_idx in range(0, args.lora_cnt):
now_lora = f"lora_{lora_idx}"
model.set_adapter(now_lora)
for _ in range(0, args.repete):
loss = model.forward(input_ids=lables, labels=lables)[0]
grad_fn_nvtx_wrapper_by_tracepoint(loss.grad_fn)
loss.backward()

def lora_switch():
for _ in range(0, args.repete):
for lora_idx in range(0, args.lora_cnt):
now_lora = f"lora_{lora_idx}"
model.set_adapter(now_lora)
loss = model.forward(input_ids=lables, labels=lables)[0]
grad_fn_nvtx_wrapper_by_tracepoint(loss.grad_fn)
loss.backward()

mode_function = {
"seq": lora_seq,
"switch": lora_switch,
}

peft_mode = args.peft_mode

assert peft_mode in mode_function, NotImplementedError
mode_function[peft_mode]()
Empty file added benchmarks/bench_utils.py
Empty file.
7 changes: 7 additions & 0 deletions mlora.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@
help="The device's rank number")
parser.add_argument('--balance', type=int, nargs="+",
help="The model's balance")
# the argument about the trace mode
parser.add_argument('--trace', action="store_true",
help="enbale the trace mode.")


args = parser.parse_args()
Expand All @@ -87,6 +90,10 @@ def get_dispatcher_cls() -> type[mlora.Dispatcher]:
mlora.setup_logging(args.log_level, args.log_file)
mlora.setup_cuda_check()

# enable the trace mode
if args.trace:
mlora.setup_trace_mode()

# load part of model to device
partial_model_to_device = None
if args.pipeline:
Expand Down
3 changes: 3 additions & 0 deletions mlora/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mlora.dispatcher.dispatcher import Dispatcher
from mlora.dispatcher.pipeline_dispatcher import PipelineDispatcher
from mlora.pipeline.pipe import Pipe
from mlora.profiler.profiler import setup_trace_mode

__all__ = [
"Tokenizer",
Expand All @@ -29,6 +30,8 @@
"load_base_model",
"init_lora_model",
"MLoRAConfig",
# profiler function
"setup_trace_mode",
# evaluateor
"EvaluatorFactory",
"Evaluator",
Expand Down
8 changes: 5 additions & 3 deletions mlora/checkpoint/recompute.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
get_device_states,
set_device_states,
detach_variable)
from mlora.profiler.profiler import tensors_nvtx_wrapper_by_tracepoint

import torch

Expand Down Expand Up @@ -61,9 +62,10 @@ def backward(ctx, *args):
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs), \
torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):
outputs = ctx.run_function(*detached_inputs)

if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
# only in enable grad context can wrapper the tracepoint
tensors_nvtx_wrapper_by_tracepoint(outputs)

outputs_with_grad = []
args_with_grad = []
Expand Down
10 changes: 0 additions & 10 deletions mlora/common.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,5 @@
import torch

from typing import Callable


def nvtx_wrapper(func: Callable,
msg: str):
def wrap(*args, **kwargs):
with torch.cuda.nvtx.range(msg=msg):
return func(*args, **kwargs)
return wrap


def is_offload_device(device: torch.device):
return device == torch.device("meta")
60 changes: 50 additions & 10 deletions mlora/model/LoraLiner.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from mlora.model.modelargs import MultiLoraBatchData
from mlora.config import LoraConfig
from mlora.profiler.profiler import set_backward_tracepoint, nvtx_range

import math
import torch
Expand All @@ -9,8 +10,10 @@
from typing import Dict, Optional, Tuple


class Lora():
class Lora(torch.nn.Module):
def __init__(self, adapter_name: str):
super().__init__()

self.adapter_name_: str = adapter_name

self.lora_a_: torch.Tensor = None
Expand All @@ -28,16 +31,38 @@ def set_parameter(self, r: int, alpha: int, dropout: float):
self.scaling_ = alpha / r

def forward(self, data: torch.Tensor) -> torch.Tensor:
data_ = F.dropout(data, self.dropout_)
data_ @= self.lora_a_.transpose(0, 1)
data_ @= self.lora_b_.transpose(0, 1)
data_ *= self.scaling_
with nvtx_range(f"f_dropout_{self.adapter_name_}"):
data_ = F.dropout(data, self.dropout_)
set_backward_tracepoint(
data_.grad_fn, f"b_dropout_{self.adapter_name_}")

lora_a_t = self.lora_a_.transpose(0, 1)
set_backward_tracepoint(
lora_a_t.grad_fn, f"b_lora_a_T_{self.adapter_name_}")
lora_b_t = self.lora_b_.transpose(0, 1)
set_backward_tracepoint(
lora_b_t.grad_fn, f"b_lora_b_T_{self.adapter_name_}")

with nvtx_range(f"f_lora_a_{self.adapter_name_}"):
data_ = data_ @ lora_a_t
set_backward_tracepoint(data_.grad_fn, "b_lora_a")

with nvtx_range(f"f_lora_b_{self.adapter_name_}"):
data_ = data_ @ lora_b_t
set_backward_tracepoint(data_.grad_fn, "b_lora_b")

with nvtx_range(f"f_scaling_{self.adapter_name_}"):
data_ = data_ * self.scaling_
set_backward_tracepoint(data_.grad_fn, "b_scaling")

return data_


class Linear():
# the weight just wrapper the module from LlamaForCausalLM
class Linear(torch.nn.Module):
def __init__(self, weight: torch.nn.Module):
# the weight just wrapper the module from LlamaForCausalLM
super().__init__()

if not isinstance(weight, torch.nn.Linear):
assert isinstance(weight, bitsandbytes.nn.Linear8bitLt) or isinstance(
weight, bitsandbytes.nn.Linear4bit), f"error type - {type(weight)}."
Expand Down Expand Up @@ -103,7 +128,9 @@ def replace_init_lora_tensor(lora: Lora, lora_a: torch.Tensor, lora_b: torch.Ten
def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.Tensor:
# data shape is: batch_size * max_seq_len * dim
# result = data @ self.weight_.transpose(0, 1)
result = self.weight_.forward(data)
with nvtx_range("f_linear"):
result = self.weight_.forward(data)
set_backward_tracepoint(result.grad_fn, "b_linear")

if not self.enable_lora_:
return result
Expand All @@ -116,7 +143,20 @@ def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.T
if adapter_name == "" or adapter_name not in self.loras_:
continue

result[start_idx: end_idx] += self.loras_[
adapter_name].forward(data[start_idx:end_idx])
with nvtx_range(f"f_lora_split_({adapter_name})"):
lora_data = data[start_idx:end_idx]
set_backward_tracepoint(
lora_data.grad_fn, f"b_lora_split_({adapter_name})")

# backward_tracepoint inside the forward function
lora_delta = self.loras_[adapter_name].forward(lora_data)

lora_range = torch.arange(
start_idx, end_idx, step=1, device=lora_delta.device)
with nvtx_range(f"f_lora_add_({adapter_name})"):
result.index_add_(dim=0, index=lora_range, source=lora_delta)
set_backward_tracepoint(
result.grad_fn, f"b_lora_add_({adapter_name})")

set_backward_tracepoint(result.grad_fn, "b_lora")
return result
1 change: 1 addition & 0 deletions mlora/model/RMSNorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ def __init__(self, weight: torch.Tensor, eps: float = 1e-6):

def forward(self, data: torch.Tensor) -> torch.Tensor:
input_dtype = data.dtype

v = data.to(torch.float32).pow(2).mean(-1, keepdim=True)
data = data * torch.rsqrt(v + self.norm_eps_)

Expand Down
Loading

0 comments on commit 1eb219b

Please sign in to comment.