Skip to content

Commit

Permalink
update gemme for trt-llm 0.9 (NVIDIA#8974)
Browse files Browse the repository at this point in the history
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
Signed-off-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com>
Co-authored-by: Pablo Garay <palenq@gmail.com>
  • Loading branch information
2 people authored and alxzhang-amazon committed Apr 26, 2024
1 parent 67d5b66 commit 51f5b46
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 88 deletions.
89 changes: 2 additions & 87 deletions nemo/export/trt_llm/decoder/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
from typing import Optional

from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import Attention, AttentionMaskType, GatedMLP, PositionEmbeddingType, RmsNorm
from tensorrt_llm.models.gemma.model import GemmaDecoderLayer, QuantConfig
from tensorrt_llm.models.modeling_utils import PretrainedConfig
from tensorrt_llm.module import Module
from tensorrt_llm.quantization import QuantMode
from typing_extensions import override

from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder
Expand All @@ -32,88 +30,6 @@
)


class GemmaDecoderLayer(Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.config = config

self.input_layernorm = RmsNorm(
normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype
)

self.attention = Attention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
attention_head_size=config.head_size,
max_position_embeddings=config.max_position_embeddings,
dtype=config.dtype,
attention_mask_type=AttentionMaskType.causal,
bias=config.attn_bias,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
rotary_embedding_base=config.rotary_base,
rotary_embedding_scaling=config.rotary_scaling,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode,
)

mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size

self.mlp = GatedMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=mlp_hidden_size,
hidden_act=config.hidden_act,
dtype=config.dtype,
bias=config.mlp_bias,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode,
)
self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype)

def forward(
self,
hidden_states,
attention_mask=None,
medusa_packed_mask=None, # For Medusa support
medusa_position_offsets=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
lora_layer_params=None,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

attention_output = self.attention(
hidden_states,
attention_mask=attention_mask,
medusa_packed_mask=medusa_packed_mask, # For Medusa support
medusa_position_offsets=medusa_position_offsets,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_layer_params=lora_layer_params,
)

if use_cache:
attention_output, presents = attention_output

hidden_states = residual + attention_output

residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)

hidden_states = self.mlp(hidden_states, lora_layer_params=lora_layer_params)

hidden_states = residual + hidden_states
if use_cache:
return (hidden_states, presents)
return hidden_states


class GemmaDecoderLayerConfigBuilder(DecoderLayerConfigBuilder):
"""The LLAMA implementation of the DecoderLayerConfigBuilder."""

Expand Down Expand Up @@ -200,8 +116,7 @@ def build_decoder(self, layer):
world_size=self.tensor_parallel,
tp_size=self.tensor_parallel,
pp_size=1,
quant_mode=QuantMode(0),
quant_kwargs=None,
quantization=QuantConfig(),
max_lora_rank=layer.max_lora_rank,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/export/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ python tests/export/test_nemo_export.py --model_name FALCON-7B-base --existing_t
python tests/export/test_nemo_export.py --model_name FALCON-40B-base --existing_test_models --min_gpus 2 --max_gpus 8
python tests/export/test_nemo_export.py --model_name FALCON-180B-base --existing_test_models --min_gpus 8 --max_gpus 8
python tests/export/test_nemo_export.py --model_name STARCODER1-15B-base --existing_test_models --min_gpus 1 --max_gpus 1
python tests/export/test_nemo_export.py --model_name GEMMA-base --existing_test_models --min_gpus 1 --max_gpus 1 --run_accuracy --test_deployment True
python tests/export/test_nemo_export.py --model_name GEMMA-base --existing_test_models --min_gpus 1 --max_gpus 1

0 comments on commit 51f5b46

Please sign in to comment.