Skip to content

Latest commit

 

History

History

BERT

BERT Inference Using TensorRT

This subfolder of the BERT TensorFlow repository, tested and maintained by NVIDIA, provides scripts to perform high-performance inference using NVIDIA TensorRT.

Table Of Contents

Model overview

BERT, or Bidirectional Encoder Representations from Transformers, is a new method of pre-training language representations which obtains state-of-the-art results on a wide array of Natural Language Processing (NLP) tasks. This model is based on the BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding paper. NVIDIA's BERT is an optimized version of Google's official implementation, leveraging mixed precision arithmetic and Tensor Cores for faster inference times while maintaining target accuracy.

Other publicly available implementations of BERT include:

  1. NVIDIA PyTorch
  2. Hugging Face
  3. codertimo
  4. gluon-nlp
  5. Google's official implementation

Model architecture

BERT's model architecture is a multi-layer bidirectional Transformer encoder. Based on the model size, we have the following two default configurations of BERT:

Model Hidden layers Hidden unit size Attention heads Feed-forward filter size Max sequence length Parameters
BERT-Base 12 encoder 768 12 4 x 768 512 110M
BERT-Large 24 encoder 1024 16 4 x 1024 512 330M

Typically, the language model is followed by a few task-specific layers. The model used here includes layers for question answering.

TensorRT Inference Pipeline

BERT inference consists of three main stages: tokenization, the BERT model, and finally a projection of the tokenized prediction onto the original text. Since the tokenizer and projection of the final predictions are not nearly as compute-heavy as the model itself, we run them on the host. The BERT model is GPU-accelerated via TensorRT.

The tokenizer splits the input text into tokens that can be consumed by the model. For details on this process, see this tutorial.

To run the BERT model in TensorRT, we construct the model using TensorRT APIs and import the weights from a pre-trained TensorFlow checkpoint from NGC. Finally, a TensorRT engine is generated and serialized to the disk. The various inference scripts then load this engine for inference.

Lastly, the tokens predicted by the model are projected back to the original text to get a final result.

Version Info

The following software version configuration has been tested:

Software Version
Python >=3.8
TensorRT 10.0.1.6
CUDA 12.4

Setup

The following section lists the requirements that you need to meet in order to run the BERT model.

Requirements

This demo BERT application can be run within the TensorRT OSS build container. If running in a different environment, following packages are required.

Quick Start Guide

  1. Build and launch the container as described in TensorRT OSS README.

    Note: After this point, all commands should be run from within the container.

  2. Verify TensorRT installation by printing the version: For example:

    python3 -c "import tensorrt as trt; print(trt.__version__)"
  3. Download the SQuAD dataset and BERT checkpoints:

    cd $TRT_OSSPATH/demo/BERT

    Download SQuAD v1.1 training and dev dataset.

    bash ./scripts/download_squad.sh

    Download Tensorflow checkpoints for BERT large model with sequence length 128, fine-tuned for SQuAD v2.0.

    bash scripts/download_model.sh

Note: Since the datasets and checkpoints are stored in the directory mounted from the host, they do not need to be downloaded each time the container is launched.

Warning: In the event of encountering an error message stating, "Missing API key and missing Email Authentication. This command requires an API key or authentication via browser login", the recommended steps for resolution are as follows:

  • Generate an API key by logging in https://ngc.nvidia.com/setup/api-key and copy the generated API key.
  • Execute the command ngc config set in the docker and paste the copied API key into the prompt as directed.

Completing these steps should resolve the error you encountered and allow the command to proceed successfully.

  1. Build a TensorRT engine. To build an engine, run the builder.py script. For example:

    mkdir -p engines && python3 builder.py -m models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_128_v19.03.1/model.ckpt -o engines/bert_large_128.engine -b 1 -s 128 --fp16 -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_128_v19.03.1

    This will build an engine with a maximum batch size of 1 (-b 1), and sequence length of 128 (-s 128) using mixed precision (--fp16) using the BERT Large SQuAD v2 FP16 Sequence Length 128 checkpoint (-c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_128_v19.03.1).

  2. Run inference. Two options are provided for running the model.

    a. inference.py script This script accepts a passage and question and then runs the engine to generate an answer. For example:

    python3 inference.py -e engines/bert_large_128.engine -p "TensorRT is a high performance deep learning inference platform that delivers low latency and high throughput for apps such as recommenders, speech and image/video on NVIDIA GPUs. It includes parsers to import models, and plugins to support novel ops and layers before applying optimizations for inference. Today NVIDIA is open-sourcing parsers and plugins in TensorRT so that the deep learning community can customize and extend these components to take advantage of powerful TensorRT optimizations for your apps." -q "What is TensorRT?" -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_128_v19.03.1/vocab.txt

    b. inference.ipynb Jupyter Notebook The Jupyter Notebook includes a passage and various example questions and allows you to interactively make modifications and see the outcome. To launch the Jupyter Notebook from inside the container, run:

    jupyter notebook --ip 0.0.0.0 inference.ipynb

    Then, use your browser to open the link displayed. The link should look similar to: http://127.0.0.1:8888/?token=<TOKEN>

  3. Run inference with CUDA Graph support.

    A separate python inference_c.py script is provided to run inference with CUDA Graph support. This is necessary since CUDA Graph is only supported through CUDA C/C++ APIs, not pyCUDA. The inference_c.py script uses pybind11 to interface with C/C++ for CUDA graph capturing and launching. The cmdline interface is the same as inference.py except for an extra --enable-graph option.

    mkdir -p build; pushd build
    cmake .. -DPYTHON_EXECUTABLE=$(which python)
    make -j
    popd
    python3 inference_c.py -e engines/bert_large_128.engine --enable-graph -p "TensorRT is a high performance deep learning inference platform that delivers low latency and high throughput for apps such as recommenders, speech and image/video on NVIDIA GPUs. It includes parsers to import models, and plugins to support novel ops and layers before applying optimizations for inference. Today NVIDIA is open-sourcing parsers and plugins in TensorRT so that the deep learning community can customize and extend these components to take advantage of powerful TensorRT optimizations for your apps." -q "What is TensorRT?" -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_128_v19.03.1/vocab.txt

    A separate C/C++ inference benchmark executable perf (compiled from perf.cpp) is provided to run inference benchmarks with CUDA Graph. The cmdline interface is the same as perf.py except for an extra --enable_graph option.

    build/perf -e engines/bert_large_128.engine -b 1 -s 128 -w 100 -i 1000 --enable_graph

(Optional) Trying a different configuration

If you would like to run another configuration, you can manually download checkpoints using the included script. For example, run:

bash scripts/download_model.sh base

to download a BERT Base model instead of the default BERT Large model.

To view all available model options, run:

bash scripts/download_model.sh -h

Advanced

The following sections provide greater details on inference with TensorRT.

Scripts and sample code

In the root directory, the most important files are:

  • builder.py - Builds an engine for the specified BERT model
  • Dockerfile - Container which includes dependencies and model checkpoints to run BERT
  • inference.ipynb - Runs inference interactively
  • inference.py - Runs inference with a given passage and question
  • perf.py - Runs inference benchmarks

The scripts/ folder encapsulates all the one-click scripts required for running various supported functionalities, such as:

  • build.sh - Builds a Docker container that is ready to run BERT
  • launch.sh - Launches the container created by the build.sh script.
  • download_model.sh - Downloads pre-trained model checkpoints from NGC
  • inference_benchmark.sh - Runs an inference benchmark and prints results

Other folders included in the root directory are:

  • helpers - Contains helpers for tokenization of inputs

The infer_c/ folder contains all the necessary C/C++ files required for CUDA Graph support.

  • bert_infer.h - Defines necessary data structures for running BERT inference
  • infer_c.cpp - Defines C/C++ interface using pybind11 that can be plugged into inference_c.py
  • perf.cpp - Runs inference benchmarks. It is equivalent to perf.py, with an extra option --enable_graph to enable CUDA Graph support.

Command-line options

To view the available parameters for each script, you can use the help flag (-h).

TensorRT inference process

As mentioned in the Quick Start Guide, two options are provided for running inference:

  1. The inference.py script which accepts a passage and a question and then runs the engine to generate an answer. Alternatively, this script can be used to run inference on the Squad dataset.
  2. The inference.ipynb Jupyter Notebook which includes a passage and various example questions and allows you to interactively make modifications and see the outcome.

Accuracy

Evaluating PTQ (post-training quantization) Int8 Accuracy Using The SQuAD Dataset

  1. Download Tensorflow checkpoints for a BERT Large FP16 SQuAD v2 model with a sequence length of 384:

    bash scripts/download_model.sh large 384 v2
  2. Build an engine:

    Turing and Ampere GPUs

    # QKVToContextPlugin and SkipLayerNormPlugin supported with INT8 I/O. To enable, use -imh and -iln builder flags respectively.
    mkdir -p engines && python3 builder.py -m models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/model.ckpt -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 --squad-json ./squad/train-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt --calib-num 100 -iln -imh

    Xavier GPU

    # Only supports SkipLayerNormPlugin running with INT8 I/O. Use -iln builder flag to enable.
    mkdir -p engines && python3 builder.py -m models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/model.ckpt -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 --squad-json ./squad/train-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt --calib-num 100 -iln 

    Volta GPU

    # No support for QKVToContextPlugin or SkipLayerNormPlugin running with INT8 I/O. Don't specify -imh or -iln in builder flags.
    mkdir -p engines && python3 builder.py -m models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/model.ckpt -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 --squad-json ./squad/train-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt --calib-num 100

    This will build an engine with a maximum batch size of 1 (-b 1), calibration dataset squad (--squad-json ./squad/train-v1.1.json), calibration sentences number 100 (--calib-num 100), and sequence length of 384 (-s 384) using INT8 mixed precision computation where possible (--int8 --fp16 --strict).

  3. Run inference using the squad dataset, and evaluate the F1 score and exact match score:

    python3 inference.py -e engines/bert_large_384_int8mix.engine -s 384 -sq ./squad/dev-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -o ./predictions.json
    python3 squad/evaluate-v1.1.py  squad/dev-v1.1.json  ./predictions.json 90

Evaluating QAT (quantization aware training) Int8 Accuracy Using The SQuAD Dataset

  1. Download checkpoint for BERT Large FP16 SQuAD v1.1 model with sequence length of 384:

    bash scripts/download_model.sh pyt v1_1
  2. Build an engine:

    Turing and Ampere GPUs

    # QKVToContextPlugin and SkipLayerNormPlugin supported with INT8 I/O. To enable, use -imh and -iln builder flags respectively.
    mkdir -p engines && python3 builder.py -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -x models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx -iln -imh

    Xavier GPU

    # Only supports SkipLayerNormPlugin running with INT8 I/O. Use -iln builder flag to enable.
    mkdir -p engines && python3 builder.py -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -x models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx -iln 

    Volta GPU

    # No support for QKVToContextPlugin or SkipLayerNormPlugin running with INT8 I/O. Don't specify -imh or -iln in builder flags.
    mkdir -p engines && python3 builder.py -o engines/bert_large_384_int8mix.engine -b 1 -s 384 --int8 --fp16 --strict -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -x models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx 

    This will build and engine with a maximum batch size of 1 (-b 1) and sequence length of 384 (-s 384) using INT8 mixed precision computation where possible (--int8 --fp16 --strict).

  3. Run inference using the squad dataset, and evaluate the F1 score and exact match score:

    python3 inference.py -e engines/bert_large_384_int8mix.engine -s 384 -sq ./squad/dev-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -o ./predictions.json
    python3 squad/evaluate-v1.1.py  squad/dev-v1.1.json  ./predictions.json 90

Experimental

Variable sequence length

In our prior implementation, we used inputs padded to max length along with corresponding input masks to handle variable sequence length inputs in a batch. The padding results in some wasted computations which can be avoided by handling variable sequence length inputs natively. Now we have a new approach called the variable sequence length method. By concatenating each input id into a single long input id, and concatenating each input segment id into a single long segment id, TensorRT can know the exact starts and ends by providing an extra sequence length buffer that contains the start and end positions of each sequence. Now we can eliminate the wasted computation in the input paddings.

Note this is an experimental feature because we only support Xavier+ GPUs, also there is neither FP32 support nor INT8 PTQ calibration.

  1. Download checkpoint for BERT Large FP16 SQuAD v1.1 model with sequence length of 384:

    bash scripts/download_model.sh pyt v1_1
  2. Build an engine:

    FP16 engine

    mkdir -p engines && python3 builder_varseqlen.py -x models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx -o engines/bert_varseq_fp16.engine -b 1 -s 64 --fp16 -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt

    This will build and engine with a maximum batch size of 1 (-b 1) and sequence length of 64 (-s 64) using FP16 precision computation where possible (--fp16).

    INT8 engine

    mkdir -p engines && python3 builder_varseqlen.py -x models/fine-tuned/bert_pyt_onnx_large_qa_squad11_amp_fake_quant_v1/bert_large_v1_1_fake_quant.onnx -o engines/bert_varseq_int8.engine -b 1 -s 256 --int8 --fp16 -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt

    This will build and engine with a maximum batch size of 1 (-b 1) and sequence length of 256 (-s 256) using INT8 precision computation where possible (--int8).

  3. Run inference

    Evaluate the F1 score and exact match score using the squad dataset:

    python3 inference_varseqlen.py -e engines/bert_varseq_int8.engine -s 256 -sq ./squad/dev-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -o ./predictions.json
    python3 squad/evaluate-v1.1.py  squad/dev-v1.1.json  ./predictions.json 90

    Run the quesion and answer mode:

    python3 inference_varseqlen.py -e engines/bert_varseq_int8.engine -p "TensorRT is a high performance deep learning inference platform that delivers low latency and high throughput for apps such as recommenders, speech and image/video on NVIDIA GPUs. It includes parsers to import models, and plugins to support novel ops and layers before applying optimizations for inference. Today NVIDIA is open-sourcing parsers and plugins in TensorRT so that the deep learning community can customize and extend these components to take advantage of powerful TensorRT optimizations for your apps." -q "What is TensorRT?" -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -s 256
  4. Collect performance data

    python3 perf_varseqlen.py -e engines/bert_varseq_int8.engine -b 1 -s 256

    This will collect performance data run use batch size 1 (-b 1) and sequence length of 256 (-s 256).

  5. Collect performance data with CUDA graph enabled

    We can use the same inference_c.py and build/perf to collect performance data with cuda graph enabled. The command line is the same as run without variable sequence length.

Sparsity with Quantization Aware Training

Fine-grained 2:4 structured sparsity support introduced in NVIDIA Ampere GPUs can produce significant performance gains in BERT inference. The network is first trained using dense weights, then fine-grained structured pruning is applied, and finally the remaining non-zero weights are fine-tuned with additional training steps. This method results in virtually no loss in inferencing accuracy.

Using INT8 precision with quantization scales obtained from Post-Training Quantization (PTQ) can produce additional performance gains, but may also result in accuracy loss. Alternatively, for PyTorch-trained models, NVIDIA PyTorch-Quantization toolkit can be leveraged to perform quantized fine tuning (a.k.a. Quantization Aware Training or QAT) and generate the INT8 quantization scales as part of training. This generally results in higher accuracy compared to PTQ.

To demonstrate the potential speedups from these optimizations in demoBERT, we provide the Megatron-LM transformer model finetuned for SQuAD 2.0 task with sparsity and quantization.

The sparse weights are generated by finetuning with INT8 Quantization Aware Training recipe. This feature can be used with the fixed or variable sequence length implementations by passing in -sp flag to demoBERT builder.

Megatron-LM for Question Answering

Example: Megatron-LM Large SQuAD v2.0 with sparse weights for sequence length 384

Build the TensorRT engine:

Options specified:

  • --megatron : assume Megatron style residuals instead of vanilla BERT.
  • --pickle : specify a pickle file containing the PyTorch statedict corresponding to fine-tuned Megatron model.
  • -sp : enable sparsity during engine optimization and treat the weights as sparse.
  • --int8 --il : enable int8 tactics/plugins with interleaving.
bash ./scripts/download_model.sh 384 v1_1 # BERT-large model checkpoint fine-tuned for SQuAD 1.1
bash ./scripts/download_model.sh pyt megatron-large int8-qat sparse # Megatron-LM model weights
export CKPT_PATH=models/fine-tuned/bert_pyt_statedict_megatron_sparse_int8qat_v21.03.0/bert_pyt_statedict_megatron_sparse_int8_qat
mkdir -p engines && python3 builder_varseqlen.py -c models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1 -b 1 -s 384 -o engines/megatron_large_seqlen384_int8qat_sparse.engine --fp16 --int8 --strict -il --megatron --pickle $CKPT_PATH -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -sp

Ask a question:

python3 inference_varseqlen.py -e engines/megatron_large_seqlen384_int8qat_sparse.engine -p "TensorRT is a high performance deep learning inference platform that delivers low latency and high throughput for apps such as recommenders, speech and image/video on NVIDIA GPUs. It includes parsers to import models, and plugins to support novel ops and layers before applying optimizations for inference. Today NVIDIA is open-sourcing parsers and plugins in TensorRT so that the deep learning community can customize and extend these components to take advantage of powerful TensorRT optimizations for your apps." -q "What is TensorRT?" -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -s 256

Evaluate F1 score:

python3 inference_varseqlen.py -e engines/megatron_large_seqlen384_int8qat_sparse.engine -s 384 -sq ./squad/dev-v1.1.json -v models/fine-tuned/bert_tf_ckpt_large_qa_squad2_amp_384_v19.03.1/vocab.txt -o ./predictions.json
python3 squad/evaluate-v1.1.py  squad/dev-v1.1.json  ./predictions.json 90

Expected output:

&&&& PASSED TensorRT BERT Squad Accuracy matches reference.
{"exact_match": 84.03973509933775, "f1": 90.88667129897755}

Performance

Benchmarking

The following section shows how to run the inference benchmarks for BERT.

TensorRT inference benchmark

The inference benchmark is performed on a single GPU by the inference_benchmark.sh script, which takes the following steps for each set of model parameters:

  1. Downloads checkpoints and builds a TensorRT engine if it does not already exist.

  2. Runs 100 warm-up iteration then runs inference for 1000 to 2000 iterations for each batch size specified in the script, selecting the profile best for each size.

Note: The time measurements do not include the time required to copy inputs to the device and copy outputs to the host.

To run the inference benchmark script, run:

bash scripts/inference_benchmark.sh --gpu <arch>

Options for <arch> are: 'Volta', 'Xavier', 'Turing', 'Ampere'

Note: Some of the configurations in the benchmark script require 16GB of GPU memory. On GPUs with smaller amounts of memory, parts of the benchmark may fail to run.

Also note that BERT Large engines, especially using mixed precision with large batch sizes and sequence lengths may take a couple hours to build.

Results

The following sections provide details on how we achieved our performance and inference.

Inference performance: NVIDIA A100 (40GB)

Results were obtained by running scripts/inference_benchmark.sh --gpu Ampere on NVIDIA A100 (40G).

BERT Base
Sequence Length Batch Size INT8 Latency (ms) FP16 Latency (ms)
95th Percentile 99th Percentile Average 95th Percentile 99th Percentile Average
128 1 0.68 0.68 0.55 0.67 0.79 0.63
128 2 0.60 0.76 0.60 0.91 0.91 0.73
128 4 0.73 0.93 0.73 1.19 1.19 0.94
128 8 1.21 1.21 0.96 1.31 1.31 1.31
128 12 1.20 1.52 1.20 1.72 1.72 1.71
128 16 1.34 1.72 1.35 2.07 2.32 2.06
128 24 1.82 1.82 1.82 3.02 3.08 3.02
128 32 2.24 2.24 2.24 3.91 3.91 3.89
128 64 4.15 4.19 4.12 7.62 7.64 7.57
128 128 8.11 8.12 8.03 15.34 15.38 15.21
384 1 1.13 1.13 1.13 1.24 1.60 1.25
384 2 1.31 1.31 1.31 1.54 1.54 1.54
384 4 1.66 1.66 1.66 2.08 2.08 2.08
384 8 2.21 2.21 2.21 3.37 3.37 3.32
384 12 3.32 3.32 3.32 4.78 4.82 4.77
384 16 4.01 4.01 4.00 6.37 6.37 6.36
384 24 5.70 5.70 5.70 9.34 9.39 9.29
384 32 7.63 7.63 7.63 12.99 13.03 12.85
384 64 14.86 14.87 14.72 24.89 25.12 24.70
384 128 28.96 28.96 28.69 48.93 49.02 48.59
BERT Large
Sequence Length Batch Size INT8 Latency (ms) FP16 Latency (ms)
95th Percentile 99th Percentile Average 95th Percentile 99th Percentile Average
128 1 1.39 1.39 1.24 1.54 1.55 1.54
128 2 1.42 1.42 1.41 1.82 1.82 1.82
128 4 1.78 1.95 1.79 2.50 2.50 2.50
128 8 2.64 2.64 2.64 3.97 3.97 3.97
128 12 3.09 3.09 3.09 5.02 5.03 4.99
128 16 4.03 4.03 4.03 6.93 6.93 6.86
128 24 5.28 5.31 5.28 9.64 9.65 9.56
128 32 7.01 7.01 6.95 12.95 13.07 12.86
128 64 12.84 12.86 12.72 24.80 25.05 24.68
128 128 25.26 25.27 25.01 49.09 49.25 48.71
384 1 2.55 2.55 2.55 2.96 2.96 2.95
384 2 3.04 3.04 3.04 3.90 3.90 3.90
384 4 4.01 4.02 4.01 5.74 5.80 5.68
384 8 7.18 7.18 7.17 10.98 11.00 10.91
384 12 9.15 9.15 9.14 15.43 15.44 15.33
384 16 12.28 12.29 12.28 21.13 21.14 20.90
384 24 17.67 17.67 17.56 30.98 31.07 30.71
384 32 23.22 23.23 23.02 41.22 41.28 40.63
384 64 45.16 45.30 44.83 79.64 79.98 79.24
384 128 87.81 87.82 87.73 156.66 157.03 155.65
Megatron Large with Sparsity
Sequence Length Batch Size INT8 QAT Latency (ms)
95th Percentile 99th Percentile Average
128 1 1.12 1.41 1.13
128 2 1.37 1.70 1.38
128 4 1.77 1.78 1.77
128 8 2.54 2.54 2.53
128 12 3.13 3.13 3.12
128 16 3.99 3.99 3.98
128 24 4.90 4.90 4.90
128 32 7.04 7.06 7.00
128 64 11.62 11.63 11.61
128 128 21.24 21.34 21.12
384 1 1.71 2.15 1.71
384 2 2.21 2.21 2.21
384 4 3.63 3.64 3.63
384 8 5.74 5.74 5.73
384 12 8.22 8.23 8.21
384 16 10.33 10.33 10.31
384 24 14.52 14.52 14.51
384 32 18.72 18.73 18.71
384 64 35.79 35.81 35.50
384 128 67.72 67.86 67.55