Skip to content

Commit

Permalink
Merge branch 'main' into jlasek/ptq_tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jan Lasek <janek.lasek@gmail.com>
  • Loading branch information
janekl committed Apr 25, 2024
2 parents 9c3f08c + 74a2dd3 commit de3850a
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 91 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
nvidia-smi
cicd-cluster-clean:
runs-on: self-hosted-azure-cpu
runs-on: self-hosted-azure-builder
steps:
- name: Clean server from old files
run: |
Expand All @@ -53,7 +53,7 @@ jobs:

cicd-test-container-setup:
needs: [cicd-cluster-clean]
runs-on: self-hosted-azure-cpu
runs-on: self-hosted-azure-builder
# uses: actions/cache@v2
#container:
# image: nvcr.io/nvidia/pytorch:24.01-py3
Expand Down Expand Up @@ -179,7 +179,7 @@ jobs:
runs-on: self-hosted-azure
container:
image: nemoci.azurecr.io/nemo_container_${{ github.run_id }}
options:
options:
# --user 0:128
--device=/dev/nvidia0
--gpus all
Expand Down
89 changes: 2 additions & 87 deletions nemo/export/trt_llm/decoder/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
from typing import Optional

from tensorrt_llm.functional import non_gated_version
from tensorrt_llm.layers import Attention, AttentionMaskType, GatedMLP, PositionEmbeddingType, RmsNorm
from tensorrt_llm.models.gemma.model import GemmaDecoderLayer, QuantConfig
from tensorrt_llm.models.modeling_utils import PretrainedConfig
from tensorrt_llm.module import Module
from tensorrt_llm.quantization import QuantMode
from typing_extensions import override

from nemo.export.trt_llm.decoder.decoder import DecoderLayerBuilder, DecoderLayerConfigBuilder
Expand All @@ -32,88 +30,6 @@
)


class GemmaDecoderLayer(Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.config = config

self.input_layernorm = RmsNorm(
normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype
)

self.attention = Attention(
hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
attention_head_size=config.head_size,
max_position_embeddings=config.max_position_embeddings,
dtype=config.dtype,
attention_mask_type=AttentionMaskType.causal,
bias=config.attn_bias,
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
rotary_embedding_base=config.rotary_base,
rotary_embedding_scaling=config.rotary_scaling,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode,
)

mlp_hidden_size = config.hidden_size * 4 if config.intermediate_size is None else config.intermediate_size

self.mlp = GatedMLP(
hidden_size=config.hidden_size,
ffn_hidden_size=mlp_hidden_size,
hidden_act=config.hidden_act,
dtype=config.dtype,
bias=config.mlp_bias,
tp_group=config.mapping.tp_group,
tp_size=config.mapping.tp_size,
quant_mode=config.quant_mode,
)
self.post_layernorm = RmsNorm(normalized_shape=config.hidden_size, eps=config.norm_epsilon, dtype=config.dtype)

def forward(
self,
hidden_states,
attention_mask=None,
medusa_packed_mask=None, # For Medusa support
medusa_position_offsets=None,
use_cache=False,
kv_cache_params=None,
attention_params=None,
lora_layer_params=None,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)

attention_output = self.attention(
hidden_states,
attention_mask=attention_mask,
medusa_packed_mask=medusa_packed_mask, # For Medusa support
medusa_position_offsets=medusa_position_offsets,
use_cache=use_cache,
kv_cache_params=kv_cache_params,
attention_params=attention_params,
lora_layer_params=lora_layer_params,
)

if use_cache:
attention_output, presents = attention_output

hidden_states = residual + attention_output

residual = hidden_states
hidden_states = self.post_layernorm(hidden_states)

hidden_states = self.mlp(hidden_states, lora_layer_params=lora_layer_params)

hidden_states = residual + hidden_states
if use_cache:
return (hidden_states, presents)
return hidden_states


class GemmaDecoderLayerConfigBuilder(DecoderLayerConfigBuilder):
"""The LLAMA implementation of the DecoderLayerConfigBuilder."""

Expand Down Expand Up @@ -200,8 +116,7 @@ def build_decoder(self, layer):
world_size=self.tensor_parallel,
tp_size=self.tensor_parallel,
pp_size=1,
quant_mode=QuantMode(0),
quant_kwargs=None,
quantization=QuantConfig(),
max_lora_rank=layer.max_lora_rank,
)

Expand Down
2 changes: 1 addition & 1 deletion tests/export/run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ python tests/export/test_nemo_export.py --model_name FALCON-7B-base --existing_t
python tests/export/test_nemo_export.py --model_name FALCON-40B-base --existing_test_models --min_gpus 2 --max_gpus 8
python tests/export/test_nemo_export.py --model_name FALCON-180B-base --existing_test_models --min_gpus 8 --max_gpus 8
python tests/export/test_nemo_export.py --model_name STARCODER1-15B-base --existing_test_models --min_gpus 1 --max_gpus 1
python tests/export/test_nemo_export.py --model_name GEMMA-base --existing_test_models --min_gpus 1 --max_gpus 1 --run_accuracy --test_deployment True
python tests/export/test_nemo_export.py --model_name GEMMA-base --existing_test_models --min_gpus 1 --max_gpus 1

0 comments on commit de3850a

Please sign in to comment.