diff --git a/.github/workflows/python-test-main.yml b/.github/workflows/python-test-main.yml index b5b6fde0..7ee06709 100644 --- a/.github/workflows/python-test-main.yml +++ b/.github/workflows/python-test-main.yml @@ -36,7 +36,7 @@ jobs: lizard -l python ./mlora -C 12 - name: Lint with flake8 run: | - flake8 . --count --show-source --statistics --max-line-length=127 --max-complexity 15 --ignore=E722,W504 + flake8 ./mlora --count --show-source --statistics --max-line-length=127 --max-complexity 15 --ignore=E722,W504 - name: Test with pytest run: | pytest diff --git a/benchmarks/bench_mlora.py b/benchmarks/bench_mlora.py index e69de29b..e10f0323 100644 --- a/benchmarks/bench_mlora.py +++ b/benchmarks/bench_mlora.py @@ -0,0 +1,161 @@ +from mlora.utils import setup_seed +from mlora.config import LoraConfig +from mlora.model.modelargs import MultiLoraBatchData, LoraBatchDataConfig +from mlora.profiler.profiler import setup_trace_mode, set_backward_tracepoint, grad_fn_nvtx_wrapper_by_tracepoint, nvtx_range + +import mlora +import torch +import random +import argparse + +from typing import List + +# Command Line Arguments +parser = argparse.ArgumentParser(description='PEFT benchmarks') +parser.add_argument('--base_model', type=str, required=True, + help='Path to or name of base model') +parser.add_argument('--device', type=str, default='cuda:0', + help='Specify which GPU to be used, default is cuda:0') +# load quant +parser.add_argument('--load_8bit', action="store_true", + help='Load model in 8bit mode') +parser.add_argument('--load_4bit', action="store_true", + help='Load model in 4bit mode') +# lora test number +parser.add_argument('--lora_cnt', type=int, default=4, + help='The number of lora') +# test configure +parser.add_argument('--warmup', type=int, default=100, + help="The step of warm up") +parser.add_argument('--repete', type=int, default=100, + help="Total test iteration") +parser.add_argument('--seq_len', type=int, default=128, + help="The length of the sequence") +parser.add_argument('--batch_size', type=int, default=8, + help="The batch size of each lora input") + + +g_default_rank = 16 +g_default_alpha = 16 +g_default_dropout = 0.05 +g_default_target_modules = {"q_proj": True, + "k_proj": True, + "v_proj": True, + "o_proj": True, + "w1_proj": False, + "w2_proj": False, + "w3_proj": False} +g_default_loss_fn = torch.nn.CrossEntropyLoss() + +args = parser.parse_args() +assert not (args.load_4bit and args.load_8bit) + + +def setup_lora_adapter_config() -> List[LoraConfig]: + lora_config: List[LoraConfig] = [] + + for idx in range(0, args.lora_cnt): + lora_config.append(LoraConfig({ + "name": f"lora_{idx}", + "r": g_default_rank, + "alpha": g_default_alpha, + "dropout": g_default_dropout, + "target_modules": g_default_target_modules, + "batch_size": args.batch_size, + "micro_batch_size": args.batch_size, + # unused + "test_batch_size": 0, + "num_epochs": 0, + "data": "", + "test_data": "", + "prompt": "", + "group_by_length": "", + "expand_side": "", + "optim": "sgd", + "momentum": 0.0, + "lr": 0.0, + })) + + return lora_config + + +def setup_input() -> MultiLoraBatchData: + batch_tokens = [] + additional_masks = [] + lora_batch_data_config: List[LoraBatchDataConfig] = [] + + start_idx = 0 + end_idx = 0 + + for lora_idx in range(0, args.lora_cnt): + adapter_name = f"lora_{lora_idx}" + + for _ in range(0, args.batch_size): + tokens = [random.randint(1, 10000) for _ in range(args.seq_len)] + batch_tokens.append(tokens) + additional_masks.append([False] * args.seq_len) + end_idx += 1 + + lora_batch_data_config.append(LoraBatchDataConfig( + adapter_name_=adapter_name, + batch_start_idx_=start_idx, + batch_end_idx_=end_idx, + )) + + start_idx = end_idx + + return MultiLoraBatchData(batch_tokens_=batch_tokens, + additional_mask_=additional_masks, + lora_batch_data_config_=lora_batch_data_config, + inference_model_=False) + + +def calc_loss(train_data: MultiLoraBatchData, model_output: torch.Tensor) -> torch.Tensor: + labels = torch.tensor(train_data.batch_tokens_, dtype=torch.long) + total_loss = None + + for lora_config in train_data.lora_batch_data_config_: + start_idx = lora_config.batch_start_idx_ + end_idx = lora_config.batch_end_idx_ + vocab_size = model_output.shape[-1] + loss_input = model_output[start_idx:end_idx][..., + :-1, :].contiguous().view(-1, vocab_size) + loss_target = labels[start_idx:end_idx][..., + 1:].contiguous().view(-1).to(loss_input.device) + loss = g_default_loss_fn(loss_input, loss_target) + if total_loss is None: + total_loss = loss + else: + total_loss += loss + + return total_loss + + +if __name__ == "__main__": + input_data = setup_input() + + setup_seed(42) + + _, model = mlora.load_base_model(args.base_model, + "llama", + args.device, + args.load_4bit, + args.load_8bit, + None) + + mlora.init_lora_model(model, setup_lora_adapter_config()) + + # to wramup + for test_idx in range(0, args.warmup): + output = model.forward(input_data) + + setup_trace_mode() + + for _ in range(0, args.repete): + output = model.forward(input_data) + with nvtx_range("f_calc_loss"): + total_loss = calc_loss(input_data, output) + set_backward_tracepoint(total_loss.grad_fn, "b_loss") + grad_fn_nvtx_wrapper_by_tracepoint(total_loss.grad_fn) + + total_loss.backward() diff --git a/benchmarks/bench_peft.py b/benchmarks/bench_peft.py index 77416d23..c257a76c 100644 --- a/benchmarks/bench_peft.py +++ b/benchmarks/bench_peft.py @@ -1,5 +1,5 @@ from mlora.utils import setup_seed -from mlora.profiler.profiler import setup_trace_mode, grad_fn_nvtx_wrapper_by_tracepoint +from mlora.profiler.profiler import setup_trace_mode, grad_fn_nvtx_wrapper_by_tracepoint, set_backward_tracepoint import torch import random @@ -30,7 +30,7 @@ help="Total test iteration") parser.add_argument('--seq_len', type=int, default=128, help="The length of the sequence") -parser.add_argument('--batch_size', type=int, default=4, +parser.add_argument('--batch_size', type=int, default=8, help="The batch size of each lora input") parser.add_argument('--peft_mode', type=str, default="seq", help="How to use peft to train multi lora, include: seq, switch") @@ -39,8 +39,7 @@ g_default_alpha = 16 g_default_dropout = 0.05 g_default_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] -g_micro_batch_size = 8 -g_loss_fn = torch.nn.CrossEntropyLoss() +g_default_loss_fn = torch.nn.CrossEntropyLoss() args = parser.parse_args() assert not (args.load_4bit and args.load_8bit) @@ -94,8 +93,6 @@ def setup_llm_model() -> LlamaForCausalLM: torch_dtype = torch.float32 torch_dtype = torch.bfloat16 if qlora_4bit_bf16 else torch_dtype additional_load_args["torch_dtype"] = torch_dtype - additional_load_args["load_in_4bit"] = True if load_bits == 4 else False - additional_load_args["load_in_8bit"] = True if load_bits == 8 else False additional_load_args["quantization_config"] = BitsAndBytesConfig( load_in_4bit=True if load_bits == 4 else False, load_in_8bit=True if load_bits == 8 else False, @@ -145,6 +142,7 @@ def lora_seq(): model.set_adapter(now_lora) for _ in range(0, args.repete): loss = model.forward(input_ids=lables, labels=lables)[0] + set_backward_tracepoint(loss.grad_fn, "b_loss") grad_fn_nvtx_wrapper_by_tracepoint(loss.grad_fn) loss.backward() @@ -154,6 +152,7 @@ def lora_switch(): now_lora = f"lora_{lora_idx}" model.set_adapter(now_lora) loss = model.forward(input_ids=lables, labels=lables)[0] + set_backward_tracepoint(loss.grad_fn, "b_loss") grad_fn_nvtx_wrapper_by_tracepoint(loss.grad_fn) loss.backward() diff --git a/mlora/config.py b/mlora/config.py index 93fa14c8..0c829e1c 100644 --- a/mlora/config.py +++ b/mlora/config.py @@ -4,10 +4,10 @@ class DictConfig: + __params_map: Dict[str, str] = {} def __init__(self, config: Dict[str, str]) -> None: - params_map: Dict[str, str] = {} - self.init(params_map, config) + self.init(self.__params_map, config) def init(self, params_map: Dict[str, str], @@ -20,29 +20,34 @@ class OptimConfig(DictConfig): optim_: str = "" lr_: float = 0.0 + __params_map: Dict[str, str] = { + "lr_": "lr", + "optim_": "optim" + } + def __init__(self, config: Dict[str, str]) -> None: super().__init__(config) - params_map = { - "lr_": "lr", - "optim_": "optim" - } - self.init(params_map, config) + self.init(self.__params_map, config) class SGDOptimConfig(OptimConfig): momentum_: float = 0.0 + __params_map: Dict[str, str] = { + "momentum_": "momentum" + } + def __init__(self, config: Dict[str, str]) -> None: super().__init__(config) - params_map = { - "momentum_": "momentum" - } - self.init(params_map, config) + self.init(self.__params_map, config) class AdamWOptimConfig(OptimConfig): + __params_map: Dict[str, str] = {} + def __init__(self, config: Dict[str, str]) -> None: super().__init__(config) + self.init(self.__params_map, config) class LoraConfig(DictConfig): @@ -70,28 +75,29 @@ class LoraConfig(DictConfig): val_set_size_: int = -1 cutoff_len_: int = -1 + __params_map: Dict[str, str] = { + "adapter_name_": "name", + "r_": "r", + "lora_alpha_": "alpha", + "lora_dropout_": "dropout", + "target_": "target_modules", + + "batch_size_": "batch_size", + "micro_batch_size_": "micro_batch_size", + "test_batch_size_": "test_batch_size", + "num_epochs_": "num_epochs", + + "data_": "data", + "test_data_": "test_data", + "prompt_": "prompt", + + "group_by_length_": "group_by_length", + "expand_side_": "expand_side", + } + def __init__(self, config: Dict[str, str]): super().__init__(config) - params_map = { - "adapter_name_": "name", - "r_": "r", - "lora_alpha_": "alpha", - "lora_dropout_": "dropout", - "target_": "target_modules", - - "batch_size_": "batch_size", - "micro_batch_size_": "micro_batch_size", - "test_batch_size_": "test_batch_size", - "num_epochs_": "num_epochs", - - "data_": "data", - "test_data_": "test_data", - "prompt_": "prompt", - - "group_by_length_": "group_by_length", - "expand_side_": "expand_side", - } - self.init(params_map, config) + self.init(self.__params_map, config) if config["optim"] == "adamw": self.optim_config_ = AdamWOptimConfig(config) @@ -109,17 +115,18 @@ class TrainerConfig(DictConfig): train_lora_simultaneously_num_: int = 2 train_strategy_: str = "optim" + __params_map: Dict[str, str] = { + "cutoff_len_": "cutoff_len", + "save_step_": "save_step", + "early_stop_test_step_": "early_stop_test_step", + "train_lora_candidate_num_": "train_lora_candidate_num", + "train_lora_simultaneously_num_": "train_lora_simultaneously_num", + "train_strategy_": "train_strategy" + } + def __init__(self, config: Dict[str, str]): super().__init__(config) - params_map = { - "cutoff_len_": "cutoff_len", - "save_step_": "save_step", - "early_stop_test_step_": "early_stop_test_step", - "train_lora_candidate_num_": "train_lora_candidate_num", - "train_lora_simultaneously_num_": "train_lora_simultaneously_num", - "train_strategy_": "train_strategy" - } - self.init(params_map, config) + self.init(self.__params_map, config) class MLoRAConfig: diff --git a/mlora/model/LoraLiner.py b/mlora/model/LoraLiner.py index 2e685bf3..5f421be2 100644 --- a/mlora/model/LoraLiner.py +++ b/mlora/model/LoraLiner.py @@ -31,29 +31,12 @@ def set_parameter(self, r: int, alpha: int, dropout: float): self.scaling_ = alpha / r def forward(self, data: torch.Tensor) -> torch.Tensor: - with nvtx_range(f"f_dropout_{self.adapter_name_}"): - data_ = F.dropout(data, self.dropout_) - set_backward_tracepoint( - data_.grad_fn, f"b_dropout_{self.adapter_name_}") - - lora_a_t = self.lora_a_.transpose(0, 1) - set_backward_tracepoint( - lora_a_t.grad_fn, f"b_lora_a_T_{self.adapter_name_}") - lora_b_t = self.lora_b_.transpose(0, 1) - set_backward_tracepoint( - lora_b_t.grad_fn, f"b_lora_b_T_{self.adapter_name_}") - - with nvtx_range(f"f_lora_a_{self.adapter_name_}"): - data_ = data_ @ lora_a_t - set_backward_tracepoint(data_.grad_fn, "b_lora_a") - - with nvtx_range(f"f_lora_b_{self.adapter_name_}"): - data_ = data_ @ lora_b_t - set_backward_tracepoint(data_.grad_fn, "b_lora_b") - - with nvtx_range(f"f_scaling_{self.adapter_name_}"): - data_ = data_ * self.scaling_ - set_backward_tracepoint(data_.grad_fn, "b_scaling") + data_ = F.dropout(data, self.dropout_) + + data_ = data_ @ self.lora_a_.transpose(0, 1) + data_ = data_ @ self.lora_b_.transpose(0, 1) + + data_ = data_ * self.scaling_ return data_ @@ -61,6 +44,7 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: class Linear(torch.nn.Module): def __init__(self, weight: torch.nn.Module): # the weight just wrapper the module from LlamaForCausalLM + # the name for debug super().__init__() if not isinstance(weight, torch.nn.Linear): @@ -128,13 +112,13 @@ def replace_init_lora_tensor(lora: Lora, lora_a: torch.Tensor, lora_b: torch.Ten def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.Tensor: # data shape is: batch_size * max_seq_len * dim # result = data @ self.weight_.transpose(0, 1) + if not self.enable_lora_: + return self.weight_.forward(data) + with nvtx_range("f_linear"): result = self.weight_.forward(data) set_backward_tracepoint(result.grad_fn, "b_linear") - if not self.enable_lora_: - return result - for lora_config in input_args.lora_batch_data_config_: adapter_name = lora_config.adapter_name_ start_idx = lora_config.batch_start_idx_ @@ -143,20 +127,17 @@ def forward(self, data: torch.Tensor, input_args: MultiLoraBatchData) -> torch.T if adapter_name == "" or adapter_name not in self.loras_: continue - with nvtx_range(f"f_lora_split_({adapter_name})"): + with nvtx_range(f"f_lora_{adapter_name}"): lora_data = data[start_idx:end_idx] - set_backward_tracepoint( - lora_data.grad_fn, f"b_lora_split_({adapter_name})") - # backward_tracepoint inside the forward function - lora_delta = self.loras_[adapter_name].forward(lora_data) + # backward_tracepoint inside the forward function + lora_delta = self.loras_[adapter_name].forward(lora_data) - lora_range = torch.arange( - start_idx, end_idx, step=1, device=lora_delta.device) - with nvtx_range(f"f_lora_add_({adapter_name})"): + lora_range = torch.arange( + start_idx, end_idx, step=1, device=lora_delta.device) result.index_add_(dim=0, index=lora_range, source=lora_delta) + set_backward_tracepoint( - result.grad_fn, f"b_lora_add_({adapter_name})") + result.grad_fn, f"b_lora_{adapter_name}") - set_backward_tracepoint(result.grad_fn, "b_lora") return result diff --git a/mlora/model/model_llama.py b/mlora/model/model_llama.py index a9ceaafb..6753a1d4 100644 --- a/mlora/model/model_llama.py +++ b/mlora/model/model_llama.py @@ -135,19 +135,14 @@ def forward(self, input_args: MultiLoraBatchData): batch_size, max_seq_len, _ = data.shape - with nvtx_range(f"f_attention_norm_{self.layer_id_}"): + with nvtx_range("f_attention_norm"): attention_norm_data = self.attention_norm_.forward(data) set_backward_tracepoint( - attention_norm_data.grad_fn, f"b_attention_norm_{self.layer_id_}") + attention_norm_data.grad_fn, "b_attention_norm") - with nvtx_range(f"f_q_{self.layer_id_}"): - xq = self.wq_.forward(attention_norm_data, input_args) - - with nvtx_range(f"f_k_{self.layer_id_}"): - xk = self.wk_.forward(attention_norm_data, input_args) - - with nvtx_range(f"f_v_{self.layer_id_}"): - xv = self.wv_.forward(attention_norm_data, input_args) + xq = self.wq_.forward(attention_norm_data, input_args) + xk = self.wk_.forward(attention_norm_data, input_args) + xv = self.wv_.forward(attention_norm_data, input_args) # conver shape to multi head # the shape is batch_size * number_of_head * seq_len * dim_of_head @@ -163,7 +158,7 @@ def forward(self, cos = self.cos_[:max_seq_len].to(xq.dtype) sin = self.sin_[:max_seq_len].to(xq.dtype) - with nvtx_range(f"f_rotray_emb_{self.layer_id_}"): + with nvtx_range("f_rotray_emb"): xq, xk = apply_rotary_emb(xq, xk, cos, sin) set_backward_tracepoint(xq.grad_fn, "b_q_rope") @@ -180,57 +175,37 @@ def forward(self, # must align with xformers memory efficient attention xq = xq.transpose(1, 2) - set_backward_tracepoint(xq.grad_fn, "b_q_T") - xk = xk.transpose(1, 2) - set_backward_tracepoint(xk.grad_fn, "b_k_T") - xv = xv.transpose(1, 2) - set_backward_tracepoint(xv.grad_fn, "b_v_T") - with nvtx_range(f"f_attention_{self.layer_id_}"): + with nvtx_range("f_attention"): attention_score = xformers.ops.memory_efficient_attention( xq, xk, xv, mask) attention_score = attention_score.view(batch_size, max_seq_len, -1) set_backward_tracepoint(attention_score.grad_fn, "b_attention") # get output attention score - with nvtx_range(f"f_o_{self.layer_id_}"): - wo = self.wo_.forward(attention_score, input_args) + wo = self.wo_.forward(attention_score, input_args) - with nvtx_range(f"f_o_add_{self.layer_id_}"): + with nvtx_range("f_o_add"): data = data + wo set_backward_tracepoint(data.grad_fn, "b_o_add") # feed forward fully connected - with nvtx_range(f"f_ffn_norm_{self.layer_id_}"): + with nvtx_range("f_ffn_norm"): score_norm_data = self.ffn_norm_.forward(data) - set_backward_tracepoint(score_norm_data.grad_fn, - f"b_ffn_norm_{self.layer_id_}") + set_backward_tracepoint(score_norm_data.grad_fn, "b_ffn_norm") - with nvtx_range(f"f_w1_{self.layer_id_}"): + with nvtx_range("f_mlp"): w1 = self.w1_.forward(score_norm_data, input_args) - - with nvtx_range(f"f_w3_{self.layer_id_}"): w3 = self.w3_.forward(score_norm_data, input_args) - - # same as: data = data + w2_forward(F.silu(w1) * w3, input_args) - with nvtx_range(f"f_silu_w2_{self.layer_id_}"): + # same as: data = data + w2_forward(F.silu(w1) * w3, input_args) w1_silu = F.silu(w1) - set_backward_tracepoint( - w1_silu.grad_fn, f"b_silu_w2_{self.layer_id_}") - - with nvtx_range(f"f_w1_m_w3_{self.layer_id_}"): mlp_output = w1_silu * w3 - set_backward_tracepoint( - mlp_output.grad_fn, f"b_w1_m_w3_{self.layer_id_}") - - with nvtx_range(f"f_w2_{self.layer_id_}"): mlp_output = self.w2_.forward(mlp_output, input_args) - set_backward_tracepoint( - mlp_output.grad_fn, f"b_w2_{self.layer_id_}") + set_backward_tracepoint(mlp_output.grad_fn, "b_mlp") - with nvtx_range(f"f_w2_add_{self.layer_id_}"): + with nvtx_range("f_mlp_add"): mlp_output = data + mlp_output set_backward_tracepoint(mlp_output.grad_fn, "b_w2_add") @@ -289,7 +264,6 @@ def embedding_forward(): output = output.requires_grad_(True) return (output, ) + input[1:] - @nvtx_wrapper("f_transformer") def transformer_forward(): if input[-1]: output = CheckpointRecomputeFunction.apply( @@ -297,7 +271,6 @@ def transformer_forward(): set_backward_tracepoint(output.grad_fn, "b_checkpoint") else: output = self.wrapper_module_.forward(*input[:-1]) - set_backward_tracepoint(output.grad_fn, "b_transformer") return (output, ) + input[1:] @nvtx_wrapper("f_rmsnorm") diff --git a/requirements.txt b/requirements.txt index e98ffc30..4b44019e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,5 +8,5 @@ sentencepiece==0.1.99 scipy==1.10.1 xformers==0.0.21 flask -peft -protobuf==3.20.2 \ No newline at end of file +peft==0.10.0 +protobuf==3.20.2 diff --git a/scripts/patch/peft_v_0_10_0/tuners_lora_bnb.patch b/scripts/patch/peft_v_0_10_0/tuners_lora_bnb.patch new file mode 100644 index 00000000..8d061191 --- /dev/null +++ b/scripts/patch/peft_v_0_10_0/tuners_lora_bnb.patch @@ -0,0 +1,31 @@ +28a29,30 +> from mlora.profiler.profiler import nvtx_range,set_backward_tracepoint +> +217c219,222 +< result = self.base_layer(x, *args, **kwargs) +--- +> with nvtx_range("f_linear"): +> result = self.base_layer(x, *args, **kwargs) +> set_backward_tracepoint(result.grad_fn, "b_linear") +> +233,238c238,245 +< if not self.use_dora[active_adapter]: +< output = lora_B(lora_A(dropout(x))) * scaling +< else: +< output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) +< if requires_conversion: +< output = output.to(expected_dtype) +--- +> with nvtx_range(f"f_lora_{active_adapter}"): +> if not self.use_dora[active_adapter]: +> output = lora_B(lora_A(dropout(x))) * scaling +> set_backward_tracepoint(output.grad_fn, f"b_lora_{active_adapter}") +> else: +> output = self._apply_dora(x, lora_A, lora_B, scaling, active_adapter) +> if requires_conversion: +> output = output.to(expected_dtype) +240c247,248 +< result = result + output +--- +> result = result + output +> set_backward_tracepoint(result.grad_fn, f"b_lora_{active_adapter}") diff --git a/scripts/patch/peft_v_0_10_0/tuners_lora_layer.patch b/scripts/patch/peft_v_0_10_0/tuners_lora_layer.patch new file mode 100644 index 00000000..c5bf5405 --- /dev/null +++ b/scripts/patch/peft_v_0_10_0/tuners_lora_layer.patch @@ -0,0 +1,15 @@ +30a31,32 +> from mlora.profiler.profiler import nvtx_range, set_backward_tracepoint +> +497c499,501 +< result = self.base_layer(x, *args, **kwargs) +--- +> with nvtx_range("f_linear"): +> result = self.base_layer(x, *args, **kwargs) +> set_backward_tracepoint(result.grad_fn, "b_linear") +509c513,515 +< result = result + lora_B(lora_A(dropout(x))) * scaling +--- +> with nvtx_range(f"f_lora_{active_adapter}"): +> result = result + lora_B(lora_A(dropout(x))) * scaling +> set_backward_tracepoint(result.grad_fn, f"b_lora_{active_adapter}") diff --git a/scripts/patch/transformers_v_4_38_2/models_llama_modeling_llama.patch b/scripts/patch/transformers_v_4_38_2/models_llama_modeling_llama.patch new file mode 100644 index 00000000..ea6ef3c7 --- /dev/null +++ b/scripts/patch/transformers_v_4_38_2/models_llama_modeling_llama.patch @@ -0,0 +1,140 @@ +50a51,52 +> from mlora.profiler.profiler import nvtx_range, set_backward_tracepoint, grad_fn_nvtx_wrapper_by_tracepoint +> +361,362c363,369 +< cos, sin = self.rotary_emb(value_states, position_ids) +< query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) +--- +> with nvtx_range("f_get_cos_sin"): +> cos, sin = self.rotary_emb(value_states, position_ids) +> with nvtx_range("f_rotray_emb"): +> query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) +> set_backward_tracepoint(query_states.grad_fn, "b_q_rope") +> set_backward_tracepoint(key_states.grad_fn, "b_k_rope") +> +369a377 +> set_backward_tracepoint(key_states.grad_fn, "b_k_rep") +370a379 +> set_backward_tracepoint(value_states.grad_fn, "b_v_rep") +372,383c381,382 +< attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) +< +< if attention_mask is not None: # no matter the length, we just slice it +< causal_mask = attention_mask +< if cache_position is not None: +< causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] +< attn_weights = attn_weights + causal_mask +< +< # upcast attention to fp32 +< attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) +< attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) +< attn_output = torch.matmul(attn_weights, value_states) +--- +> with nvtx_range("f_attention"): +> attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) +385,389c384,399 +< if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): +< raise ValueError( +< f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" +< f" {attn_output.size()}" +< ) +--- +> if attention_mask is not None: # no matter the length, we just slice it +> causal_mask = attention_mask +> if cache_position is not None: +> causal_mask = attention_mask[:, :, cache_position, : key_states.shape[-2]] +> attn_weights = attn_weights + causal_mask +> +> # upcast attention to fp32 +> attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) +> attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) +> attn_output = torch.matmul(attn_weights, value_states) +> +> if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): +> raise ValueError( +> f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" +> f" {attn_output.size()}" +> ) +391c401 +< attn_output = attn_output.transpose(1, 2).contiguous() +--- +> attn_output = attn_output.transpose(1, 2).contiguous() +393c403,404 +< attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) +--- +> attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) +> set_backward_tracepoint(attn_output.grad_fn, "b_attention") +737c748,750 +< hidden_states = self.input_layernorm(hidden_states) +--- +> with nvtx_range("f_attention_norm"): +> hidden_states = self.input_layernorm(hidden_states) +> set_backward_tracepoint(hidden_states.grad_fn, "b_attention_norm") +750c763,765 +< hidden_states = residual + hidden_states +--- +> with nvtx_range("f_o_add"): +> hidden_states = residual + hidden_states +> set_backward_tracepoint(hidden_states.grad_fn, "b_o_add") +754,756c769,781 +< hidden_states = self.post_attention_layernorm(hidden_states) +< hidden_states = self.mlp(hidden_states) +< hidden_states = residual + hidden_states +--- +> with nvtx_range("f_ffn_norm"): +> hidden_states = self.post_attention_layernorm(hidden_states) +> set_backward_tracepoint(hidden_states.grad_fn, "b_ffn_norm") +> +> with nvtx_range("f_mlp"): +> hidden_states = self.mlp(hidden_states) +> set_backward_tracepoint(hidden_states.grad_fn, "b_mlp") +> +> with nvtx_range("f_mlp_add"): +> hidden_states = residual + hidden_states +> set_backward_tracepoint(hidden_states.grad_fn, "b_mlp_add") +> +> grad_fn_nvtx_wrapper_by_tracepoint(hidden_states.grad_fn) +977c1002,1003 +< inputs_embeds = self.embed_tokens(input_ids) +--- +> with nvtx_range("f_embedding"): +> inputs_embeds = self.embed_tokens(input_ids) +1017a1044 +> set_backward_tracepoint(layer_outputs[0].grad_fn, "b_checkpoint") +1037c1064,1066 +< hidden_states = self.norm(hidden_states) +--- +> with nvtx_range("f_rmsnorm"): +> hidden_states = self.norm(hidden_states) +> set_backward_tracepoint(hidden_states.grad_fn, "b_rmsnorm") +1195c1224,1226 +< logits = self.lm_head(hidden_states) +--- +> with nvtx_range("f_output"): +> logits = self.lm_head(hidden_states) +> set_backward_tracepoint(logits.grad_fn, "b_output") +1199,1209c1230,1241 +< if labels is not None: +< # Shift so that tokens < n predict n +< shift_logits = logits[..., :-1, :].contiguous() +< shift_labels = labels[..., 1:].contiguous() +< # Flatten the tokens +< loss_fct = CrossEntropyLoss() +< shift_logits = shift_logits.view(-1, self.config.vocab_size) +< shift_labels = shift_labels.view(-1) +< # Enable model parallelism +< shift_labels = shift_labels.to(shift_logits.device) +< loss = loss_fct(shift_logits, shift_labels) +--- +> with nvtx_range("f_calc_loss"): +> if labels is not None: +> # Shift so that tokens < n predict n +> shift_logits = logits[..., :-1, :].contiguous() +> shift_labels = labels[..., 1:].contiguous() +> # Flatten the tokens +> loss_fct = CrossEntropyLoss() +> shift_logits = shift_logits.view(-1, self.config.vocab_size) +> shift_labels = shift_labels.view(-1) +> # Enable model parallelism +> shift_labels = shift_labels.to(shift_logits.device) +> loss = loss_fct(shift_logits, shift_labels) diff --git a/scripts/performance.md b/scripts/performance.md new file mode 100644 index 00000000..fa3cdae4 --- /dev/null +++ b/scripts/performance.md @@ -0,0 +1,181 @@ +# Performance report + +## The sql to query the nsys + +```sql +CREATE TEMPORARY TABLE IF NOT EXISTS TEMP_KERN_INFOS AS +SELECT R.start AS API_START, R.end AS API_END, + K.start AS KERN_START, K.end AS KERN_END, + R.start AS T_START, + MAX(R.end, K.end) AS T_END, + KNAME.value AS KERN_NAME + FROM + CUPTI_ACTIVITY_KIND_KERNEL AS K + JOIN + CUPTI_ACTIVITY_KIND_RUNTIME AS R + ON K.correlationId == R.correlationId + LEFT JOIN + StringIds AS KNAME + ON KNAME.id == K.demangledName; + +CREATE INDEX IF NOT EXISTS TEMPINDEX ON TEMP_KERN_INFOS(T_START); + +SELECT + E_NAME AS ENVET_NAME, + API_START AS KERN_API_START, + API_END AS KERN_API_END, + KERN_START AS KERN_START, + KERN_END AS KERN_END, + KERN_NAME AS KERN_NAME +FROM + (SELECT start AS E_START, end AS E_END, text AS E_NAME FROM NVTX_EVENTS) +LEFT JOIN + TEMP_KERN_INFOS +ON + E_START <= T_START AND E_END >= T_START; +``` + +## The performace report about mlora's nvtx range + +### mlora metirc group + +1. f_embedding +2. f_attention_norm +3. f_linear = f_linear_wq + f_linear_wk +_ f_linear_wv + f_linear_wo +4. f_lora_{adapter} = f_dropout_{adapter} + f_lora_a_{adapter} + f_lora_b_{adapter} + f_scaling_{adapter} +5. f_lora_add_{adapter} +6. f_rotray_emb +7. f_attention +8. f_o_add +9. f_ffn_norm +10. f_mlp = f_w1 + f_w3 + f_silu_w2 + f_w1_m_w3 + f_w2 + f_w2_add +11. f_mlp_add = f_w2_add +12. f_rmsnorm +13. f_output +14. f_calc_loss +15. b_loss +16. b_output +17. b_rmsnorm +18. b_mlp = b_linear_w2 + b_w1_m_w3 + b_silu_w2 + b_linear_w3 + b_linear_w1 +19. b_ffn_norm +20. b_lora_{adapter} = b_scaling_{adapter} + b_lora_b_{adapter} + b_lora_a_{adapter} + b_dropout_{adapter} + b_lora_split_{adapter} +21. b_lora_add_{adapter} +22. b_linear +23. b_attention +24. b_k_rope +25. b_q_rope +26. b_attention_norm + +### mlora timeline + +1. f_embedding +2. f_attention_norm_{layer_id} +3. f_q_{layer_id} +4. f_linear_{name} +5. f_dropout_{adapter} +6. f_lora_a_{adapter} +7. f_lora_b_{adapter} +8. f_scaling_{adapter} +9. f_lora_add_({adapter}) +10. f_k_{layer_id} +11. f_v_{layer_id} +12. f_rotray_emb_{layer_id} +13. f_attention_{layer_id} +14. f_o_{layer_id} +15. f_o_add_{layer_id} +16. f_ffn_norm_{layer_id} +17. f_w1_{layer_id} +18. f_w3_{layer_id} +19. f_silu_w2_{layer_id} +20. f_w1_m_w3_{layer_id} +21. f_w2_{layer_id} +22. f_w2_add +23. f_rmsnorm +24. f_output +25. f_calc_loss +26. b_loss +27. b_output +28. b_rmsnorm +29. b_checkpoint +30. b_linear_w2_ +31. b_w1_m_w3_{layer_id} +32. b_silu_w2_{layer_id} +33. b_linear_w3_ +34. b_linear_w1_ +35. b_ffn_norm_{layer_id} +36. b_lora_add_{adapter} +37. b_scaling_{adapter} +38. b_lora_b_{adapter} +39. b_lora_a_{adapter} +40. b_dropout_{adapter} +41. b_lora_split_{adapter} +42. b_linear_wo_ +43. b_attention_{layer_id} +44. b_k_rope_{layer_id} +45. b_q_rope_{layer_id} +46. b_linear_wv_ +47. b_linear_wk_ +48. b_linear_wq_ +49. b_attention_norm_{layer_id} + +## The performace report about peft's nvtx range + +### peft metric group + +1. f_embedding +2. f_attention_norm +3. f_linear +4. f_lora_{adapter} +5. f_lora_add_{adapter} +6. f_get_cos_sin +7. f_rotray_emb +8. f_attention +9. f_o_add +10. f_ffn_norm +11. f_mlp +12. f_mlp_add +13. f_rmsnorm +14. f_output +15. f_calc_loss +16. b_loss +17. b_output +18. b_rmsnorm +19. b_mlp +20. b_ffn_norm +21. b_lora_{adapter} +22. b_lora_add_{adpter} +23. b_linear +24. b_attention +25. b_k_rope +26. b_q_rope +27. b_attention_norm + +### peft timeline + +1. f_embedding +2. f_attention_norm +3. f_q (f_linear + f_lora_{adapter} + f_lora_add_{adapter}) +4. f_k +5. f_v +6. f_get_cos_sin +7. f_rotray_emb +8. f_attention +9. f_o +10. f_o_add +11. f_ffn_norm +12. f_mlp +13. f_mlp_add +14. f_rmsnorm +15. f_output +16. f_calc_loss +17. b_loss +18. b_output +19. b_rmsnorm +20. b_checkpoint <- same with transformer block forward +21. b_mlp +22. b_ffn_norm +23. b_lora_{adapter} +24. b_linear +25. b_k_rope +26. b_q_rope +27. b_attention_norm diff --git a/scripts/performance_report.py b/scripts/performance_report.py new file mode 100644 index 00000000..5a371ef2 --- /dev/null +++ b/scripts/performance_report.py @@ -0,0 +1,238 @@ +import argparse +import sqlite3 +import logging +import re +import csv + +from typing import List, Dict, Tuple + +# Command Line Arguments +parser = argparse.ArgumentParser(description='Performance report.') +parser.add_argument('--db', type=str, required=True, help='NSys sqlite file.') +parser.add_argument('--output', type=str, required=True, + help='Export csv file.') + +args = parser.parse_args() + +logging.basicConfig(format="[%(asctime)s] Performance: %(message)s", + level="INFO", + handlers=[logging.StreamHandler()], + force=True) + +G_CREATE_TEMP_TABLE = """ +CREATE TEMPORARY TABLE IF NOT EXISTS TEMP_KERN_INFOS AS +SELECT R.start AS API_START, R.end AS API_END, + K.start AS KERN_START, K.end AS KERN_END, + R.start AS T_START, + MAX(R.end, K.end) AS T_END, + KNAME.value AS KERN_NAME + FROM + CUPTI_ACTIVITY_KIND_KERNEL AS K + JOIN + CUPTI_ACTIVITY_KIND_RUNTIME AS R + ON K.correlationId == R.correlationId + LEFT JOIN + StringIds AS KNAME + ON KNAME.id == K.demangledName; +""" + +G_CREATE_INDEX = """ +CREATE INDEX IF NOT EXISTS TEMPINDEX ON TEMP_KERN_INFOS(T_START); +""" + +G_SELECT_NVTX_KERN = """ +SELECT + E_NAME AS ENVET_NAME, + API_START AS KERN_API_START, + API_END AS KERN_API_END, + KERN_START AS KERN_START, + KERN_END AS KERN_END, + KERN_NAME AS KERN_NAME +FROM + (SELECT start AS E_START, end AS E_END, text AS E_NAME FROM NVTX_EVENTS) +LEFT JOIN + TEMP_KERN_INFOS +ON + E_START <= T_START AND E_END >= T_START; +""" + +G_SELECT_NVTX = """ +SELECT start AS E_START, end AS E_END, text AS E_NAME FROM NVTX_EVENTS; +""" + + +class KernInfo: + name_: str = "" + api_start_: int = 0 + api_end_: int = 0 + kern_start_: int = 0 + kern_end_: int = 0 + + api_time_: int = 0 + kern_time_: int = 0 + queue_time_: int = 0 + + def __init__(self, name: str, + api_start: int, api_end: int, + kern_start: int, kern_end: int): + self.name_ = name + self.api_start_ = api_start + self.api_end_ = api_end + self.kern_start_ = kern_start + self.kern_end_ = kern_end + + self.queue_time_ = 0 if api_end > kern_start else kern_start - api_end + self.api_time_ = api_end - api_start + self.kern_time_ = kern_end - kern_start + + assert self.queue_time_ >= 0 + + +class KernSummary: + name_: str = "" + api_time_: int = 0 + queue_time_: int = 0 + kern_time_: int = 0 + + kern_cnt_: int = 0 + + def __init__(self, name: str, + api_time: int, queue_time: int, kern_time: int, + kern_cnt: int) -> None: + self.name_ = name + self.api_time_ = api_time + self.queue_time_ = queue_time + self.kern_time_ = kern_time + self.kern_cnt_ = kern_cnt + + +class EventInfo: + name_: str = "" + cnt_: int = 0 + ttime_: int = 0 + # all kernels + kerns_: Dict[str, List[KernInfo]] = {} + + def __init__(self, name: str): + self.name_ = name + self.cnt_ = 0 + self.kerns_ = {} + + def add_kern(self, name: str, + api_start: int, api_end: int, + kern_start: int, kern_end: int): + if api_start is None or api_end is None: + return + assert kern_start is not None + assert kern_end is not None + if name not in self.kerns_: + self.kerns_[name] = [] + self.kerns_[name].append( + KernInfo(name, api_start, api_end, kern_start, kern_end)) + + def sum(self) -> Dict[str, KernSummary]: + summary_ret: Dict[str, KernSummary] = {} + + def sum_kern_list(kerns: List[KernInfo]) -> Tuple[int, int, int]: + t_api_time: int = 0 + t_kern_time: int = 0 + t_queue_time: int = 0 + for kern_item in kerns: + t_api_time += kern_item.api_time_ + t_kern_time += kern_item.kern_time_ + t_queue_time += kern_item.queue_time_ + return t_api_time, t_queue_time, t_kern_time + + for kern_name, kern_info in self.kerns_.items(): + t_api_time, t_queue_time, t_kern_time = sum_kern_list(kern_info) + summary_ret[kern_name] = KernSummary( + kern_name, t_api_time, t_queue_time, t_kern_time, len(kern_info)) + + return summary_ret + + +class EventSum: + name_: str = "" + cnt_: int = 0 + + api_time_: int = 0 + queue_time_: int = 0 + kern_time_: int = 0 + + ttime_: int = 0 + + def __init__(self, event: EventInfo) -> None: + self.name_ = event.name_ + self.cnt_ = event.cnt_ + + self.ttime_ = event.ttime_ + + self.__summary(event) + + def __summary(self, event: EventInfo): + t_api_time: int = 0 + t_queue_time: int = 0 + t_kern_name: int = 0 + + for _, kern_sum in event.sum().items(): + t_api_time += kern_sum.api_time_ + t_queue_time += kern_sum.queue_time_ + t_kern_name += kern_sum.kern_time_ + + self.api_time_ = t_api_time + self.queue_time_ = t_queue_time + self.kern_time_ = t_kern_name + + def avg(self) -> Tuple[int, int, int]: + aapi = round(self.api_time_ / self.cnt_) + aqueue = round(self.queue_time_ / self.cnt_) + akern = round(self.kern_time_ / self.cnt_) + return aapi, aqueue, akern + + def sum(self) -> Tuple[int, int, int, int]: + return self.ttime_, self.api_time_, self.queue_time_, self.kern_time_, self.cnt_ + + def __str__(self) -> str: + api, queue, kern = self.avg() + return f"{api} {queue} {kern}" + + +if __name__ == "__main__": + conn = sqlite3.connect(args.db) + + logging.info("To Init the sqlite file.") + + conn.execute(G_CREATE_TEMP_TABLE) + conn.execute(G_CREATE_INDEX) + + logging.info("Init the sqlite file done.") + + events: Dict[str, EventInfo] = {} + + logging.info("To count the NVTX event.") + for row in conn.execute(G_SELECT_NVTX): + start_time, end_time, event_name = row + + if event_name not in events: + events[event_name] = EventInfo(event_name) + event_item = events[event_name] + event_item.cnt_ += 1 + event_item.ttime_ += (end_time - start_time) + logging.info("Count the NVTX event done.") + + logging.info("To get kernel info.") + for row in conn.execute(G_SELECT_NVTX_KERN): + event_name, api_start, api_end, kern_start, kern_end, kern_name = row + event_item = events[event_name] + event_item.add_kern(kern_name, api_start, + api_end, kern_start, kern_end) + logging.info("Get kernel info done.") + + with open(args.output, "w") as csv_f: + writer = csv.writer(csv_f) + for event_name, event_item in events.items(): + event_sum = EventSum(event_item).sum() + if event_sum[0] == 0: + continue + event_sum = (event_name,) + event_sum + writer.writerow(event_sum)