Skip to content

Commit

Permalink
Updates for TRT-LLM 0.9 (#8873)
Browse files Browse the repository at this point in the history
* upgrade to trtllm0.9

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update gpt to config based export

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* fix for lora checkpoint

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix for in flight batching case

* Update falcon for trt-llm 0.9

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

* Removed unused import and comment

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>

---------

Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Co-authored-by: abharwani <abharwani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Apr 16, 2024
1 parent e9d8266 commit 32e6302
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 41 deletions.
6 changes: 2 additions & 4 deletions nemo/export/trt_llm/decoder/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@

from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.models.falcon.model import FalconDecoderLayer
from tensorrt_llm.models.modeling_utils import PretrainedConfig
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig
from typing_extensions import override

from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder
Expand Down Expand Up @@ -119,8 +118,7 @@ def build_decoder(self, layer):
world_size=self.tensor_parallel,
tp_size=self.tensor_parallel,
pp_size=1,
quant_mode=QuantMode(0),
quant_kwargs=None,
quantization=QuantConfig(),
max_lora_rank=layer.max_lora_rank,
use_parallel_embedding=False,
)
Expand Down
46 changes: 27 additions & 19 deletions nemo/export/trt_llm/decoder/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from tensorrt_llm.layers import AttentionMaskType, PositionEmbeddingType
from tensorrt_llm.models.gpt.model import GPTDecoderLayer
from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig
from typing_extensions import override

from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder
Expand Down Expand Up @@ -85,37 +86,44 @@ class GPTDecoderLayerBuilder(DecoderLayerBuilder):
@override
def build_decoder(self, layer):
rotary_pct = layer.rotary_pct
position_embedding_type = (
PositionEmbeddingType.rope_gpt_neox
if layer.position_embedding_type == "rope"
else PositionEmbeddingType.learned_absolute
)

assert not (position_embedding_type == PositionEmbeddingType.rope_gpt_neox and rotary_pct == 0.0)
position_embedding_type = "rope_gpt_neox" if layer.position_embedding_type == "rope" else "learned_absolute"

assert not (position_embedding_type == "rope_gpt_neox" and rotary_pct == 0.0)

bias_qkv = layer.attention.qkv.bias is not None

rotary_scaling = None
if layer.rotary_scaling is not None:
rotary_scaling = {"type": "linear", "factor": float(layer.rotary_scaling)}

return GPTDecoderLayer(
config = PretrainedConfig(
architecture=None,
dtype=self.dtype,
logits_dtype=self.dtype,
vocab_size=layer.vocab_size,
max_position_embeddings=self.max_position_embeddings,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_layers,
num_attention_heads=self.num_attention_heads,
max_position_embeddings=self.max_position_embeddings,
num_layers=self.num_layers,
dtype=self.dtype,
apply_query_key_layer_scaling=False,
attention_mask_type=AttentionMaskType.causal,
num_key_value_heads=self.num_kv_heads,
hidden_act=self.hidden_act,
intermediate_size=layer.ffn_hidden_size_local * self.tensor_parallel,
norm_epsilon=layer.norm_epsilon,
position_embedding_type=position_embedding_type,
rotary_embedding_percentage=rotary_pct,
rotary_base=layer.rotary_base,
rotary_scaling=rotary_scaling,
inter_size=layer.ffn_hidden_size_local * self.tensor_parallel,
bias=bias_qkv,
num_kv_heads=self.num_kv_heads,
tp_group=self.tp_group,
world_size=self.tensor_parallel,
tp_size=self.tensor_parallel,
pp_size=1,
max_lora_rank=layer.max_lora_rank,
quantization=QuantConfig(),
)

config.set_if_not_exist('hidden_act', self.hidden_act)
config.set_if_not_exist('apply_query_key_layer_scaling', False)
config.set_if_not_exist('bias', bias_qkv)
config.set_if_not_exist('rotary_base', layer.rotary_base)
config.set_if_not_exist('rotary_scaling', rotary_scaling)
config.set_if_not_exist('rotary_pct', rotary_pct)
config.set_if_not_exist('moe_num_experts', 0)

return GPTDecoderLayer(config=config, layer_idx=self.layer_id,)
6 changes: 2 additions & 4 deletions nemo/export/trt_llm/decoder/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@
from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import MoeConfig
from tensorrt_llm.models.llama.model import LLaMADecoderLayer
from tensorrt_llm.models.modeling_utils import PretrainedConfig
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.models.modeling_utils import PretrainedConfig, QuantConfig
from typing_extensions import override

from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder
Expand Down Expand Up @@ -118,9 +117,8 @@ def build_decoder(self, layer):
world_size=self.tensor_parallel,
tp_size=self.tensor_parallel,
pp_size=1,
quant_mode=QuantMode(0),
quant_kwargs=None,
max_lora_rank=layer.max_lora_rank,
quantization=QuantConfig(),
)

config.set_if_not_exist('mlp_bias', False)
Expand Down
4 changes: 4 additions & 0 deletions nemo/export/trt_llm/tensorrt_llm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tensorrt_llm._utils import np_dtype_to_trt
from tensorrt_llm.builder import Builder
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import add_lora
from tensorrt_llm.network import net_guard
from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.quantization import QuantMode
Expand Down Expand Up @@ -170,6 +171,9 @@ def _build_impl(tensorrt_llm_model, args):
timing_cache_file = args.timing_cache if args.timing_cache else args.output_dir / "model.cache"
timing_cache = timing_cache_file

if args.use_lora_plugin is not None:
add_lora(tensorrt_llm_model, args.max_lora_rank)

builder = Builder()
apply_query_key_layer_scaling = False

Expand Down
18 changes: 5 additions & 13 deletions nemo/export/trt_llm/tensorrt_llm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,15 +144,7 @@ def forward(
if attention_mask is not None:
attention_mask = expand_mask(attention_mask, shape(input_ids, -1))

for layer_idx, (layer, past, pointer, host_pointer, max_attention_window_size) in enumerate(
zip(
self.layers,
kv_cache_params.past_key_value,
kv_cache_params.kv_cache_block_pointers,
kv_cache_params.host_kv_cache_block_pointers,
kv_cache_params.host_max_attention_window_sizes,
)
):
for layer_idx, (layer, past) in enumerate(zip(self.layers, kv_cache_params.past_key_value,)):

decoder_params = {
"hidden_states": hidden_states,
Expand All @@ -161,8 +153,8 @@ def forward(
"kv_cache_params": KeyValueCacheParams(
past_key_value=[past],
host_past_key_value_lengths=kv_cache_params.host_past_key_value_lengths,
kv_cache_block_pointers=[pointer],
host_max_attention_window_sizes=max_attention_window_size,
kv_cache_block_pointers=kv_cache_params.kv_cache_block_pointers,
host_max_attention_window_sizes=kv_cache_params.host_max_attention_window_sizes,
cache_indirection=kv_cache_params.cache_indirection,
host_sink_token_length=kv_cache_params.host_sink_token_length,
host_kv_cache_block_pointers=kv_cache_params.host_kv_cache_block_pointers,
Expand Down Expand Up @@ -329,8 +321,8 @@ def prepare_inputs(
past_key_value=model_inputs['past_key_value'],
host_past_key_value_lengths=model_inputs['host_past_key_value_lengths'],
host_max_attention_window_sizes=model_inputs['host_max_attention_window_sizes'],
kv_cache_block_pointers=model_inputs['kv_cache_block_pointers_list'],
host_kv_cache_block_pointers=model_inputs['host_kv_cache_block_pointers_list'],
kv_cache_block_pointers=model_inputs['kv_cache_block_pointers'],
host_kv_cache_block_pointers=model_inputs['host_kv_cache_block_pointers'],
cache_indirection=model_inputs['cache_indirection'],
host_sink_token_length=model_inputs['host_sink_token_length'],
),
Expand Down
5 changes: 4 additions & 1 deletion nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
import torch
from mpi4py.futures import MPIPoolExecutor
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import LoraManager, ModelConfig, SamplingConfig
from tensorrt_llm.runtime import ModelConfig, SamplingConfig
from transformers import PreTrainedTokenizer

from nemo.export.trt_llm.tensor_utils import get_tensor_parallel_group
from nemo.export.trt_llm.tensorrt_llm_model import LMHeadModelBuilder

from nemo.export.trt_llm.tensorrt_llm_build import get_engine_name, MODEL_NAME, refit_runtime_engine # isort:skip
from nemo.export.trt_llm.nemo_utils import to_word_list_format # isort:skip

Expand Down Expand Up @@ -90,6 +92,7 @@ def _read_config(config_path: Path):
model_config = ModelConfig(
model_name=config["builder_config"]["name"],
max_batch_size=config["builder_config"]["max_batch_size"],
max_beam_width=config["builder_config"]["max_beam_width"],
vocab_size=config["builder_config"]["vocab_size"],
num_layers=config["builder_config"]["num_layers"],
num_heads=num_heads,
Expand Down

0 comments on commit 32e6302

Please sign in to comment.