Skip to content

Latest commit

 

History

History
232 lines (189 loc) · 7.61 KB

README.md

File metadata and controls

232 lines (189 loc) · 7.61 KB

DBRX

This document shows how to build and run a DBRX model in TensorRT-LLM. DBRX is a leading large language model trained by Databricks. Read more details about the model: DBRX Technical Blog.

Overview

The TensorRT-LLM DBRX implementation can be found in tensorrt_llm/models/dbrx/model.py.

Support Matrix

  • BF16
  • FP16
  • INT8 Weight-Only
  • INT4 Weight-Only
  • INT8 KV CACHE
  • Tensor Parallel
  • Pipeline Parallel
  • Expert Parallel

Usage

Download weights from HuggingFace Transformers

Install the dependencies and setup git-lfs.

# Install dependencies
# DBRX uses tiktoken as the tokenizer; make sure it is installed
pip install -r requirements.txt

# Setup git-lfs
git lfs install

Download one or more DBRX models that you would like to build to TensorRT-LLM engines. You can download from the HuggingFace hub:

# Download dbrx-base
git clone https://huggingface.co/databricks/dbrx-base

# Download dbrx-instruct
git clone https://huggingface.co/databricks/dbrx-instruct

Build TensorRT engine(s)

The convert_checkpoint.py script converts HF weights to TensorRT-LLM checkpoints. A DBRX model has 132B parameters, so you need at least 4 x 80GB GPUs to load the model in 16-bit precision for weight conversion.

The trtllm-build command builds TensorRT-LLM engines from TensorRT-LLM checkpoints. The number of engine files is same to the number of GPUs used to run inference. Normally, trtllm-build uses one GPU by default, but if you have already more GPUs available at build time, you may enable parallel builds to make the engine building process faster by adding the --workers argument.

Here are some examples:

# 8-way tensor parallelism, dtype bfloat16
python convert_checkpoint.py --model_dir dbrx-base \
        --dtype bfloat16 \
        --tp_size 8 \
        --workers 8 \
        --output_dir dbrx/trt_ckpt/bf16/tp8

trtllm-build --checkpoint_dir dbrx/trt_ckpt/bf16/tp8 \
        --gpt_attention_plugin bfloat16 \
        --gemm_plugin bfloat16 \
        --moe_plugin bfloat16 \
        --workers 8 \
        --output_dir dbrx/trt_engines/bf16/tp8
# 8-way tensor parallelism, dtype float16
python convert_checkpoint.py --model_dir dbrx-base \
        --dtype float16 \
        --tp_size 8 \
        --workers 8 \
        --output_dir dbrx/trt_ckpt/fp16/tp8

trtllm-build --checkpoint_dir dbrx/trt_ckpt/fp16/tp8 \
        --gpt_attention_plugin float16 \
        --gemm_plugin float16 \
        --moe_plugin float16 \
        --workers 8 \
        --output_dir dbrx/trt_engines/fp16/tp8
# 4-way tensor parallelism and 2-way pipeline parallelism, dtype bfloat16
python convert_checkpoint.py --model_dir dbrx-base \
        --dtype bfloat16 \
        --tp_size 4 \
        --pp_size 2 \
        --workers 8 \
        --output_dir dbrx/trt_ckpt/bf16/tp4pp2

trtllm-build --checkpoint_dir dbrx/trt_ckpt/bf16/tp4pp2 \
        --gpt_attention_plugin bfloat16 \
        --gemm_plugin bfloat16 \
        --moe_plugin bfloat16 \
        --workers 8 \
        --output_dir dbrx/trt_engines/bf16/tp4pp2
# Build DBRX with expert parallelism for DbrxExperts layer and tensor parallelism for rest
python convert_checkpoint.py --model_dir dbrx-base \
        --dtype bfloat16 \
        --tp_size 8 \
        --moe_tp_size 1 \
        --moe_ep_size 8 \
        --workers 8 \
        --output_dir dbrx/trt_ckpt/bf16/ep8

trtllm-build --checkpoint_dir dbrx/trt_ckpt/bf16/ep8 \
        --gpt_attention_plugin bfloat16 \
        --gemm_plugin bfloat16 \
        --moe_plugin bfloat16 \
        --workers 8 \
        --output_dir dbrx/trt_engines/bf16/ep8

Weight Only Quantization

convert_checkpoint.py features a --use_weight_only option that can enable weight-only quantization. You can further set the weight-only precision by passing int8 or int4 to the --weight_only_precision flag.

# 4-way tensor parallelism, int8 weight-only
python convert_checkpoint.py --model_dir dbrx-base \
        --dtype float16 \
        --use_weight_only \
        --weight_only_precision int8 \
        --tp_size 4 \
        --workers 4 \
        --output_dir dbrx/trt_ckpt/int8-wo/tp4

trtllm-build --checkpoint_dir dbrx/trt_ckpt/int8-wo/tp4 \
        --gpt_attention_plugin float16 \
        --gemm_plugin float16 \
        --moe_plugin float16 \
        --workers 4 \
        --output_dir dbrx/trt_engines/int8-wo/tp4
# 4-way tensor parallelism, int4 weight-only
python convert_checkpoint.py --model_dir dbrx-base \
        --dtype float16 \
        --use_weight_only \
        --weight_only_precision int4 \
        --tp_size 4 \
        --workers 4 \
        --output_dir dbrx/trt_ckpt/int4-wo/tp4

trtllm-build --checkpoint_dir dbrx/trt_ckpt/int4-wo/tp4 \
        --gpt_attention_plugin float16 \
        --gemm_plugin float16 \
        --moe_plugin float16 \
        --workers 4 \
        --output_dir dbrx/trt_engines/int4-wo/tp4

INT8 KV Cache

INT8 KV cache can be enabled to reduce the memory footprint. It will bring performance gains at large batch sizes and long sequence lengths.

For INT8 KV cache, convert_checkpoint.py features a --int8_kv_cache option. Setting --int8_kv_cache will calibrate the model, and then export the scaling factors needed for INT8 KV cache inference.

# 4-way tensor parallelism, int8 kv-cache
python convert_checkpoint.py --model_dir dbrx-base \
        --dtype float16 \
        --int8_kv_cache \
        --tp_size 4 \
        --workers 4 \
        --output_dir dbrx/trt_ckpt/int8kv/tp4

trtllm-build --checkpoint_dir dbrx/trt_ckpt/int8kv/tp4 \
        --gpt_attention_plugin float16 \
        --gemm_plugin float16 \
        --moe_plugin float16 \
        --workers 4 \
        --output_dir dbrx/trt_engines/int8kv/tp4

Run inference

You can test your engines with the run.py script:

mpirun -n 8 \
    python3 ../run.py --engine_dir dbrx/trt_engines/bf16/tp8 \
        --tokenizer_dir dbrx-base \
        --max_output_len 10 \
        --input_text "What is AGI?"

If the engines are run successfully, you will see output like:

......
Input [Text 0]: "What is AGI?"
Output [Text 0 Beam 0]: " How do I find it?
AGI stands for"

You can also evaluate with the summarize.py script:

mpirun -n 8 \
    python ../summarize.py --engine_dir dbrx/trt_engines/bf16/tp8 \
        --hf_model_dir dbrx-base \
        --test_trt_llm

If the engines are run successfully, you will see output like:

......
[04/02/2024-11:16:37] [TRT-LLM] [I] TensorRT-LLM (total latency: 9.962657451629639 sec)
[04/02/2024-11:16:37] [TRT-LLM] [I] TensorRT-LLM (total output tokens: 1189)
[04/02/2024-11:16:37] [TRT-LLM] [I] TensorRT-LLM (tokens per second: 119.34566713477734)
[04/02/2024-11:16:37] [TRT-LLM] [I] TensorRT-LLM beam 0 result
[04/02/2024-11:16:37] [TRT-LLM] [I]   rouge1 : 26.842471264679535
[04/02/2024-11:16:37] [TRT-LLM] [I]   rouge2 : 9.979512100961314
[04/02/2024-11:16:37] [TRT-LLM] [I]   rougeL : 19.50336050538688
[04/02/2024-11:16:37] [TRT-LLM] [I]   rougeLsum : 22.00400189383231