From 32e630220d2c24550a869f43ce618949ccf1a1a5 Mon Sep 17 00:00:00 2001 From: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com> Date: Tue, 16 Apr 2024 01:17:16 -0400 Subject: [PATCH] Updates for TRT-LLM 0.9 (#8873) * upgrade to trtllm0.9 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update gpt to config based export Signed-off-by: Onur Yilmaz * fix for lora checkpoint * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix for in flight batching case * Update falcon for trt-llm 0.9 Signed-off-by: Onur Yilmaz * Removed unused import and comment Signed-off-by: Onur Yilmaz --------- Signed-off-by: Onur Yilmaz Co-authored-by: abharwani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- nemo/export/trt_llm/decoder/falcon.py | 6 +-- nemo/export/trt_llm/decoder/gpt.py | 46 +++++++++++++---------- nemo/export/trt_llm/decoder/llama.py | 6 +-- nemo/export/trt_llm/tensorrt_llm_build.py | 4 ++ nemo/export/trt_llm/tensorrt_llm_model.py | 18 +++------ nemo/export/trt_llm/tensorrt_llm_run.py | 5 ++- 6 files changed, 44 insertions(+), 41 deletions(-) diff --git a/nemo/export/trt_llm/decoder/falcon.py b/nemo/export/trt_llm/decoder/falcon.py index b0e69d2b99c4..91edc7794607 100644 --- a/nemo/export/trt_llm/decoder/falcon.py +++ b/nemo/export/trt_llm/decoder/falcon.py @@ -17,8 +17,7 @@ from tensorrt_llm.functional import non_gated_version from tensorrt_llm.models.falcon.model import FalconDecoderLayer -from tensorrt_llm.models.modeling_utils import PretrainedConfig -from tensorrt_llm.quantization import QuantMode +from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig from typing_extensions import override from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder @@ -119,8 +118,7 @@ def build_decoder(self, layer): world_size=self.tensor_parallel, tp_size=self.tensor_parallel, pp_size=1, - quant_mode=QuantMode(0), - quant_kwargs=None, + quantization=QuantConfig(), max_lora_rank=layer.max_lora_rank, use_parallel_embedding=False, ) diff --git a/nemo/export/trt_llm/decoder/gpt.py b/nemo/export/trt_llm/decoder/gpt.py index 294ccb737c1f..8af4e4ef01e4 100644 --- a/nemo/export/trt_llm/decoder/gpt.py +++ b/nemo/export/trt_llm/decoder/gpt.py @@ -17,6 +17,7 @@ from tensorrt_llm.layers import AttentionMaskType, PositionEmbeddingType from tensorrt_llm.models.gpt.model import GPTDecoderLayer +from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig from typing_extensions import override from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder @@ -85,13 +86,10 @@ class GPTDecoderLayerBuilder(DecoderLayerBuilder): @override def build_decoder(self, layer): rotary_pct = layer.rotary_pct - position_embedding_type = ( - PositionEmbeddingType.rope_gpt_neox - if layer.position_embedding_type == "rope" - else PositionEmbeddingType.learned_absolute - ) - assert not (position_embedding_type == PositionEmbeddingType.rope_gpt_neox and rotary_pct == 0.0) + position_embedding_type = "rope_gpt_neox" if layer.position_embedding_type == "rope" else "learned_absolute" + + assert not (position_embedding_type == "rope_gpt_neox" and rotary_pct == 0.0) bias_qkv = layer.attention.qkv.bias is not None @@ -99,23 +97,33 @@ def build_decoder(self, layer): if layer.rotary_scaling is not None: rotary_scaling = {"type": "linear", "factor": float(layer.rotary_scaling)} - return GPTDecoderLayer( + config = PretrainedConfig( + architecture=None, + dtype=self.dtype, + logits_dtype=self.dtype, + vocab_size=layer.vocab_size, + max_position_embeddings=self.max_position_embeddings, hidden_size=self.hidden_size, + num_hidden_layers=self.num_layers, num_attention_heads=self.num_attention_heads, - max_position_embeddings=self.max_position_embeddings, - num_layers=self.num_layers, - dtype=self.dtype, - apply_query_key_layer_scaling=False, - attention_mask_type=AttentionMaskType.causal, + num_key_value_heads=self.num_kv_heads, hidden_act=self.hidden_act, + intermediate_size=layer.ffn_hidden_size_local * self.tensor_parallel, + norm_epsilon=layer.norm_epsilon, position_embedding_type=position_embedding_type, - rotary_embedding_percentage=rotary_pct, - rotary_base=layer.rotary_base, - rotary_scaling=rotary_scaling, - inter_size=layer.ffn_hidden_size_local * self.tensor_parallel, - bias=bias_qkv, - num_kv_heads=self.num_kv_heads, - tp_group=self.tp_group, + world_size=self.tensor_parallel, tp_size=self.tensor_parallel, + pp_size=1, max_lora_rank=layer.max_lora_rank, + quantization=QuantConfig(), ) + + config.set_if_not_exist('hidden_act', self.hidden_act) + config.set_if_not_exist('apply_query_key_layer_scaling', False) + config.set_if_not_exist('bias', bias_qkv) + config.set_if_not_exist('rotary_base', layer.rotary_base) + config.set_if_not_exist('rotary_scaling', rotary_scaling) + config.set_if_not_exist('rotary_pct', rotary_pct) + config.set_if_not_exist('moe_num_experts', 0) + + return GPTDecoderLayer(config=config, layer_idx=self.layer_id,) diff --git a/nemo/export/trt_llm/decoder/llama.py b/nemo/export/trt_llm/decoder/llama.py index e554e18608f7..873c0306375b 100644 --- a/nemo/export/trt_llm/decoder/llama.py +++ b/nemo/export/trt_llm/decoder/llama.py @@ -18,8 +18,7 @@ from tensorrt_llm.functional import non_gated_version from tensorrt_llm.layers import MoeConfig from tensorrt_llm.models.llama.model import LLaMADecoderLayer -from tensorrt_llm.models.modeling_utils import PretrainedConfig -from tensorrt_llm.quantization import QuantMode +from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig from typing_extensions import override from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder @@ -118,9 +117,8 @@ def build_decoder(self, layer): world_size=self.tensor_parallel, tp_size=self.tensor_parallel, pp_size=1, - quant_mode=QuantMode(0), - quant_kwargs=None, max_lora_rank=layer.max_lora_rank, + quantization=QuantConfig(), ) config.set_if_not_exist('mlp_bias', False) diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index 0941a6d3dbba..3ad27a2eb9a6 100644 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -27,6 +27,7 @@ from tensorrt_llm._utils import np_dtype_to_trt from tensorrt_llm.builder import Builder from tensorrt_llm.logger import logger +from tensorrt_llm.models.modeling_utils import add_lora from tensorrt_llm.network import net_guard from tensorrt_llm.plugin.plugin import ContextFMHAType from tensorrt_llm.quantization import QuantMode @@ -170,6 +171,9 @@ def _build_impl(tensorrt_llm_model, args): timing_cache_file = args.timing_cache if args.timing_cache else args.output_dir / "model.cache" timing_cache = timing_cache_file + if args.use_lora_plugin is not None: + add_lora(tensorrt_llm_model, args.max_lora_rank) + builder = Builder() apply_query_key_layer_scaling = False diff --git a/nemo/export/trt_llm/tensorrt_llm_model.py b/nemo/export/trt_llm/tensorrt_llm_model.py index b2da7855ccdc..52e9c4960fc9 100644 --- a/nemo/export/trt_llm/tensorrt_llm_model.py +++ b/nemo/export/trt_llm/tensorrt_llm_model.py @@ -144,15 +144,7 @@ def forward( if attention_mask is not None: attention_mask = expand_mask(attention_mask, shape(input_ids, -1)) - for layer_idx, (layer, past, pointer, host_pointer, max_attention_window_size) in enumerate( - zip( - self.layers, - kv_cache_params.past_key_value, - kv_cache_params.kv_cache_block_pointers, - kv_cache_params.host_kv_cache_block_pointers, - kv_cache_params.host_max_attention_window_sizes, - ) - ): + for layer_idx, (layer, past) in enumerate(zip(self.layers, kv_cache_params.past_key_value,)): decoder_params = { "hidden_states": hidden_states, @@ -161,8 +153,8 @@ def forward( "kv_cache_params": KeyValueCacheParams( past_key_value=[past], host_past_key_value_lengths=kv_cache_params.host_past_key_value_lengths, - kv_cache_block_pointers=[pointer], - host_max_attention_window_sizes=max_attention_window_size, + kv_cache_block_pointers=kv_cache_params.kv_cache_block_pointers, + host_max_attention_window_sizes=kv_cache_params.host_max_attention_window_sizes, cache_indirection=kv_cache_params.cache_indirection, host_sink_token_length=kv_cache_params.host_sink_token_length, host_kv_cache_block_pointers=kv_cache_params.host_kv_cache_block_pointers, @@ -329,8 +321,8 @@ def prepare_inputs( past_key_value=model_inputs['past_key_value'], host_past_key_value_lengths=model_inputs['host_past_key_value_lengths'], host_max_attention_window_sizes=model_inputs['host_max_attention_window_sizes'], - kv_cache_block_pointers=model_inputs['kv_cache_block_pointers_list'], - host_kv_cache_block_pointers=model_inputs['host_kv_cache_block_pointers_list'], + kv_cache_block_pointers=model_inputs['kv_cache_block_pointers'], + host_kv_cache_block_pointers=model_inputs['host_kv_cache_block_pointers'], cache_indirection=model_inputs['cache_indirection'], host_sink_token_length=model_inputs['host_sink_token_length'], ), diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index cdc0b78d6c18..1e24f4f207a4 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -24,12 +24,14 @@ import torch from mpi4py.futures import MPIPoolExecutor from tensorrt_llm.logger import logger +from tensorrt_llm.lora_manager import LoraManager from tensorrt_llm.quantization import QuantMode -from tensorrt_llm.runtime import LoraManager, ModelConfig, SamplingConfig +from tensorrt_llm.runtime import ModelConfig, SamplingConfig from transformers import PreTrainedTokenizer from nemo.export.trt_llm.tensor_utils import get_tensor_parallel_group from nemo.export.trt_llm.tensorrt_llm_model import LMHeadModelBuilder + from nemo.export.trt_llm.tensorrt_llm_build import get_engine_name, MODEL_NAME, refit_runtime_engine # isort:skip from nemo.export.trt_llm.nemo_utils import to_word_list_format # isort:skip @@ -90,6 +92,7 @@ def _read_config(config_path: Path): model_config = ModelConfig( model_name=config["builder_config"]["name"], max_batch_size=config["builder_config"]["max_batch_size"], + max_beam_width=config["builder_config"]["max_beam_width"], vocab_size=config["builder_config"]["vocab_size"], num_layers=config["builder_config"]["num_layers"], num_heads=num_heads,