Skip to content

Commit

Permalink
Merge branch 'main' into upt-gemma-trtllm-0.9
Browse files Browse the repository at this point in the history
Signed-off-by: Onur Yilmaz <35306097+oyilmaz-nvidia@users.noreply.github.com>
  • Loading branch information
oyilmaz-nvidia committed Apr 23, 2024
2 parents c4c7961 + 815c5de commit 2e9aa70
Show file tree
Hide file tree
Showing 30 changed files with 1,986 additions and 149 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ RUN pip install flash-attn
# install numba for latest containers
RUN pip install numba>=0.57.1
# install ammo
RUN pip install nvidia-ammo~=0.7.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir
RUN pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir

# copy nemo source into a scratch image
FROM scratch as nemo-src
Expand Down
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ pipeline {

stage('AMMO installation') {
steps {
sh 'pip install nvidia-ammo~=0.7.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir'
sh 'pip install nvidia-ammo~=0.9.0 --extra-index-url https://pypi.nvidia.com --no-cache-dir'
}
}

Expand Down
101 changes: 74 additions & 27 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,15 @@ The NeMo Framework can be installed in a variety of ways, depending on your need
* This is recommended for Automatic Speech Recognition (ASR) and Text-to-Speech (TTS) domains.
* When using a Nvidia PyTorch container as the base, this is the recommended installation method for all domains.

* Docker - Refer to the `Docker containers <#docker-containers>`_ section for installation instructions.
* Docker Containers - Refer to the `Docker containers <#docker-containers>`_ section for installation instructions.

* This is recommended for Large Language Models (LLM), Multimodal and Vision domains.
* NeMo LLM & Multimodal Container - `nvcr.io/nvidia/nemo:24.01.01.framework`
* NeMo LLM & Multimodal Container - `nvcr.io/nvidia/nemo:24.03.framework`
* NeMo Speech Container - `nvcr.io/nvidia/nemo:24.01.speech`

* LLM and Multimodal Dependencies - Refer to the `LLM and Multimodal dependencies <#llm-and-multimodal-dependencies>`_ section for isntallation instructions.
* It's higly recommended to start with a base NVIDIA PyTorch container: `nvcr.io/nvidia/pytorch:24.02-py3`

Conda
~~~~~

Expand Down Expand Up @@ -330,23 +333,59 @@ Note that RNNT requires numba to be installed from conda.
pip uninstall numba
conda install -c conda-forge numba
LLM and Multimodal Dependencies
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The LLM and Multimodal domains require three additional dependencies:
NVIDIA Apex, NVIDIA Transformer Engine, and NVIDIA Megatron Core.

When working with the `main` branch these dependencies may require a recent commit.
The most recent working versions of these dependencies are:

.. code-block:: bash
export apex_commit=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c
export te_commit=bfe21c3d68b0a9951e5716fb520045db53419c5e
export mcore_commit=fbb375d4b5e88ce52f5f7125053068caff47f93f
export nv_pytorch_tag=24.02-py3
When using a released version of NeMo,
please refer to the `Software Component Versions <https://docs.nvidia.com/nemo-framework/user-guide/latest/softwarecomponentversions.html>`_
for the correct versions.

If starting with a base NVIDIA PyTorch container first launch the container:

.. code-block:: bash
docker run \
--gpus all \
-it \
--rm \
--shm-size=16g \
--ulimit memlock=-1 \
--ulimit stack=67108864 \
nvcr.io/nvidia/pytorch:$nv_pytorch_tag
Then install the dependencies:

Apex
~~~~
NeMo LLM Domain training requires NVIDIA Apex to be installed.
Install it manually if not using the NVIDIA PyTorch container.
NeMo LLM Multimodal Domains require that NVIDIA Apex to be installed.
Apex comes installed in the NVIDIA PyTorch container but it's possible that
NeMo LLM and Multimodal may need to be updated to a newer version.

To install Apex, run

.. code-block:: bash
git clone https://github.com/NVIDIA/apex.git
cd apex
git checkout b496d85fb88a801d8e680872a12822de310951fd
pip install -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam" ./
git checkout $apex_commit
pip install . -v --no-build-isolation --disable-pip-version-check --no-cache-dir --config-settings "--build-option=--cpp_ext --cuda_ext --fast_layer_norm --distributed_adam --deprecated_fused_adam --group_norm"
It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Apex or any other dependencies.
While installing Apex, it may raise an error if the CUDA version on your system does not match the CUDA version torch was compiled with.
While installing Apex outside of the NVIDIA PyTorch container,
it may raise an error if the CUDA version on your system does not match the CUDA version torch was compiled with.
This raise can be avoided by commenting it here: https://github.com/NVIDIA/apex/blob/master/setup.py#L32

cuda-nvprof is needed to install Apex. The version should match the CUDA version that you are using:
Expand All @@ -366,35 +405,43 @@ With the latest versions of Apex, the `pyproject.toml` file in Apex may need to

Transformer Engine
~~~~~~~~~~~~~~~~~~
NeMo LLM Domain has been integrated with `NVIDIA Transformer Engine <https://github.com/NVIDIA/TransformerEngine>`_
Transformer Engine enables FP8 training on NVIDIA Hopper GPUs.
`Install <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html>`_ it manually if not using the NVIDIA PyTorch container.

.. code-block:: bash

pip install --upgrade git+https://github.com/NVIDIA/TransformerEngine.git@stable
The NeMo LLM Multimodal Domains require that NVIDIA Transformer Engine to be installed.
Transformer Engine comes installed in the NVIDIA PyTorch container but it's possible that
NeMo LLM and Multimodal may need Transformer Engine to be updated to a newer version.

It is highly recommended to use the NVIDIA PyTorch or NeMo container if having issues installing Transformer Engine or any other dependencies.
Transformer Engine enables FP8 training on NVIDIA Hopper GPUs and many performance optimizations for transformer-based model training.
Documentation for installing Transformer Engine can be found `here <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html>`_.

Transformer Engine requires PyTorch to be built with CUDA 11.8.
.. code-block:: bash
git clone https://github.com/NVIDIA/TransformerEngine.git && \
cd TransformerEngine && \
git checkout $te_commit && \
git submodule init && git submodule update && \
NVTE_FRAMEWORK=pytorch NVTE_WITH_USERBUFFERS=1 MPI_HOME=/usr/local/mpi pip install .
Flash Attention
~~~~~~~~~~~~~~~
When traning Large Language Models in NeMo, users may opt to use Flash Attention for efficient training. Transformer Engine already supports Flash Attention for GPT models. If you want to use Flash Attention for non-causal models, please install `flash-attn <https://github.com/HazyResearch/flash-attention>`_. If you want to use Flash Attention with attention bias (introduced from position encoding, e.g. Alibi), please also install triton pinned version following the `implementation <https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py#L3>`_.
Transformer Engine requires PyTorch to be built with at least CUDA 11.8.

.. code-block:: bash
Megatron Core
~~~~~~~~~~~~~

pip install flash-attn
pip install triton==2.0.0.dev20221202
The NeMo LLM Multimodal Domains require that NVIDIA Megatron Core to be installed.
Megatron core is a library for scaling large transfromer base models.
NeMo LLM and Multimodal models leverage Megatron Core for model parallelism,
transformer architectures, and optimized pytorch datasets.

NLP inference UI
~~~~~~~~~~~~~~~~~~~~
To launch the inference web UI server, please install the gradio `gradio <https://gradio.app/>`_.
NeMo LLM and Multimodal may need Megatron Core to be updated to a recent version.

.. code-block:: bash
pip install gradio==3.34.0
git clone https://github.com/NVIDIA/Megatron-LM.git && \
cd Megatron-LM && \
git checkout $mcore_commit && \
pip install . && \
cd megatron/core/datasets && \
make
NeMo Text Processing
~~~~~~~~~~~~~~~~~~~~
Expand All @@ -404,7 +451,7 @@ Docker containers
~~~~~~~~~~~~~~~~~
We release NeMo containers alongside NeMo releases. For example, NeMo ``r1.23.0`` comes with container ``nemo:24.01.speech``, you may find more details about released containers in `releases page <https://github.com/NVIDIA/NeMo/releases>`_.

To use built container, please run
To use a pre-built container, please run

.. code-block:: bash
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ For more information, browse the developer docs for your area of interest in the
nlp/models
nlp/machine_translation/machine_translation
nlp/megatron_onnx_export
nlp/quantization
nlp/api


Expand Down
46 changes: 31 additions & 15 deletions docs/source/nlp/quantization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,35 +3,35 @@
Quantization
==========================

Post Training Quantization (PTQ)
Post-Training Quantization (PTQ)
--------------------------------

PTQ enables deploying a model in a low-precision format -- FP8, INT4 or INT8 -- for efficient serving. Different quantization methods are available including FP8 quantization, INT8 SmoothQuant and INT4 AWQ.
PTQ enables deploying a model in a low-precision format -- FP8, INT4, or INT8 -- for efficient serving. Different quantization methods are available including FP8 quantization, INT8 SmoothQuant, and INT4 AWQ.

Model quantization has two primary benefits: reduced model memory requirements and increased inference throughput.

In NeMo, quantization is enabled by the Nvidia AMMO library -- a unified algorithmic model optimization & deployment toolkit.

The quantization process consists of the following steps:

1. Loading a model checkpoint using appropriate parallelism strategy for evaluation
1. Loading a model checkpoint using an appropriate parallelism strategy
2. Calibrating the model to obtain appropriate algorithm-specific scaling factors
3. Producing output directory or .qnemo tarball with model config (json), quantized weights (safetensors) and tokenizer config (yaml).
3. Producing an output directory or .qnemo tarball with model config (json), quantized weights (safetensors) and tokenizer config (yaml).

Loading models requires using AMMO spec defined in `megatron.core.deploy.gpt.model_specs module <https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/deploy/gpt/model_specs.py>`_. Typically the calibration step is lightweight and uses a small dataset to obtain appropriate statistics for scaling tensors. The output directory produced (or a .qnemo tarball) is ready to be used to build a serving engine with the Nvidia TensorRT-LLM library. The engine build step is also soon to be the part of NeMo project and ``nemo.deploy`` and ``nemo.export`` modules, see https://github.com/NVIDIA/NeMo/pull/8690.
Loading models requires using an AMMO spec defined in `megatron.core.inference.gpt.model_specs.py <https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/inference/gpt/model_specs.py>`_ module. Typically the calibration step is lightweight and uses a small dataset to obtain appropriate statistics for scaling tensors. The output directory produced (or a .qnemo tarball) is ready to be used to build a serving engine with the Nvidia TensorRT-LLM library. The engine build step is also available in NeMo project in ``nemo.deploy`` and ``nemo.export`` modules.

Quantization algorithm can also be conveniently set to ``"null"`` to perform only the weights export step using default precision for TensorRT-LLM deployment. This is useful to obtain baseline performance and accuracy results for comparison.


Example
^^^^^^^
The example below shows how to quantize the Llama2 70b model into FP8 precision, using tensor parallelism of 8 on a single DGX H100 node. The quantized model is intended for serving using 2 GPUs specified with ``export.inference_tensor_parallel`` parameter.
The example below shows how to quantize the Llama2 70b model into FP8 precision, using tensor parallelism of 8 on a single DGX H100 node. The quantized model is designed for serving using 2 GPUs specified with the ``export.inference_tensor_parallel`` parameter.

The script should be launched correctly with the number of processes equal to tensor parallelism. This is achieved with the ``mpirun`` command below.
The script must be launched correctly with the number of processes equal to tensor parallelism. This is achieved with the ``torchrun`` command below:

.. code-block:: bash
mpirun -n 8 python examples/nlp/language_modeling/megatron_llama_quantization.py \
torchrun --nproc-per-node 8 examples/nlp/language_modeling/megatron_llama_quantization.py \
model_file=llama2-70b-base-bf16.nemo \
tensor_model_parallel_size=8 \
pipeline_model_parallel_size=1 \
Expand All @@ -57,31 +57,47 @@ The output directory stores the following files:
└── tokenizer_config.yaml
The TensorRT-LLM engine can be build with ``trtllm-build`` command, see `TensorRT-LLM documentation <https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#fp8-post-training-quantization>`_.
The TensorRT-LLM engine can be conveniently built and run using ``TensorRTLLM`` class available in ``nemo.export`` submodule:

.. code-block:: python
from nemo.export import TensorRTLLM
trt_llm_exporter = TensorRTLLM(model_dir="/path/to/trt_llm_engine_folder")
trt_llm_exporter.export(
nemo_checkpoint_path="llama2-70b-base-fp8-qnemo",
model_type="llama",
)
trt_llm_exporter.forward(["Hi, how are you?", "I am good, thanks, how about you?"])
Alternatively, it can also be built directly using ``trtllm-build`` command, see `TensorRT-LLM documentation <https://github.com/NVIDIA/TensorRT-LLM/tree/main/examples/llama#fp8-post-training-quantization>`_:

.. code-block:: bash
trtllm-build \
--checkpoint_dir llama2-70b-base-fp8-qnemo \
--output_dir engine_dir \
--output_dir /path/to/trt_llm_engine_folder \
--max_batch_size 8 \
--max_input_len 2048 \
--max_output_len 512
--max_output_len 512 \
--strongly_typed
Known issues
^^^^^^^^^^^^
* Currently in NeMo quantizing and building TensorRT-LLM engines is limited to single-node use cases.
* Supported and tested model family is Llama2. Quantizing other model types is experimental and may not be fully supported.
* For INT8 SmoothQuant ``quantization.algorithm=int8_sq``, the TensorRT-LLM engine cannot be build with CLI ``trtllm-build`` command -- Python API and ``tensorrt_llm.builder`` should be used instead.
* Currently in NeMo, quantizing and building TensorRT-LLM engines is limited to single-node use cases.
* The supported and tested model family is Llama2. Quantizing other model types is experimental and may not be fully supported.


Please refer to the following papers for more details on quantization techniques.

References
----------

`Integer Quantization for Deep Learning Inference: Principles and Empirical Evaluation, 2020 <https://arxiv.org/abs/2004.09602>`_

`FP8 Formats for Deep Learning, 2022 <https://arxiv.org/abs/2209.05433>`_

`SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models, 2022 <https://arxiv.org/abs/2211.10438>`_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ quantization:
algorithm: fp8 # int8_sq, fp8, int8, int4_awq, null
calib_dataset: cnn_dailymail # wikitext, cnn_dailymail, or a local dataset
num_calib_size: 512 # number of samples used for calibration
awq_block_size: 128 # block size for scaling factors in AWQ algorithm

export:
decoder_type: llama # gptnext, gpt2, llama
inference_tensor_parallel: 1 # Default using 1 TP for inference
inference_pipeline_parallel: 1 # Default using 1 PP for inference
dtype: 16 # Default precision data type
export_tensorrt_llm_config: true # export config to build TRT-LLM engine directly

model_file: llama2-7b-fp16.nemo # Nemo file path
model_save: llama2-7b-fp8.qnemo # Path where the quantized model will be saved
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def forward(self, **kwargs):
the superclass by the square root of the hidden size specified in the configuration.
"""
embeddings = super().forward(**kwargs)
return embeddings * (self.config.hidden_size ** 0.5)
return embeddings * torch.tensor(self.config.hidden_size ** 0.5, dtype=embeddings.dtype)


class MegatronGPTExportableModel(torch.nn.Module, Exportable):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import torch
import torch.nn.functional as F
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.transformer.attention import SelfAttention
Expand Down Expand Up @@ -279,10 +281,16 @@ def forward(self, hidden_states):

if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
if self.config.gated_linear_unit:
intermediate_parallel = bias_geglu_impl(intermediate_parallel, bias_parallel)
else:
assert self.config.add_bias_linear is True
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(intermediate_parallel, bias_parallel)
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel, bias_parallel, self.config.activation_func_fp8_input_store,
)

else:
raise ValueError("Only support fusion of gelu and swiglu")
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf'), started
else:
logits[indices_to_remove] = filter_value

if top_p > 0.0:
if 0.0 < top_p < 1.0:
# Cconvert to 1D
sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
Expand Down
Loading

0 comments on commit 2e9aa70

Please sign in to comment.