In [None]:
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

<img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">

# Accelerating HuggingFace BART Inference with TensorRT

BART is an encoder-decoder model that converts all NLP problems into a text-to-text format. More specifically, it does so by encoding different tasks as text directives in the input stream. This enables a single model to be trained supervised on a wide variety of NLP tasks such as translation, classification, Q&A and summarization.

This notebook shows easy steps to convert a [HuggingFace PyTorch BART model](https://huggingface.co/docs/transformers/model_doc/bart) to a TensorRT engine for high-performance inference, with performance comparison between PyTorch and TensorRT inference.

1. [Download HuggingFace BART model](#1)
1. [PyTorch HuggingFace Inference](#2)
1. [TensorRT Engine Building](#3)
1. [TensorRT Inference](#4)


## Prerequisites

Follow the instructions at https://github.com/NVIDIA/TensorRT to build the TensorRT-OSS docker container required to run this notebook.

Next, we install some extra dependencies.

In [None]:
#%%capture
!pip3 install -r ../requirements.txt
!pip3 install ipywidgets

**Note:** After this step, you should restart the Jupyter kernel for the change to take effect.

In [None]:
import os
import sys
ROOT_DIR = os.path.abspath("../")
sys.path.append(ROOT_DIR)

# disable warning in notebook
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# notebook widgets
import ipywidgets as widgets
widget_style = {'description_width': 'initial'}
widget_layout = widgets.Layout(width='auto')

import torch
import tensorrt as trt
from tensorrt import PreviewFeature
from polygraphy.backend.trt import Profile

import numpy as np
import time

# huggingface
from transformers import (
    AutoModelForPreTraining,
    AutoTokenizer,
    AutoConfig,
)

# BART
from BART.BARTModelConfig import BARTModelTRTConfig, BARTMetadata
from BART.measurements import encoder_inference, decoder_inference, full_inference_greedy, full_inference_beam
from BART.export import BARTEncoderTorchFile, BARTDecoderTorchFile
from BART.export import BARTDecoderONNXFile, BARTEncoderONNXFile
from BART.trt import BARTTRTEncoder, BARTTRTDecoder

# NNDF
from NNDF.networks import NetworkMetadata, Precision
from NNDF.networks import TimingProfile
from NNDF.general_utils import measure_python_inference_code
from NNDF.torch_utils import expand_inputs_for_beam_search

<a id="1"></a>

## 1. Download HuggingFace BART model

First, we download the original HuggingFace PyTorch BART model from HuggingFace model hubs, together with its associated tokernizer.

The BART variants that are suported by TensorRT are: facebook/bart-base (139M), facebook/bart-large (406M), facebook/bart-large-cnn (406M), facebook/mbart-large-50 (680M)

### Model and Inference Configuration

In [None]:
# UI
model_widget = widgets.Select(
    options=['facebook/bart-base', 'facebook/bart-large', 'facebook/bart-large-cnn', 'facebook/mbart-large-50'],
    value='facebook/bart-base',
    description='Model variant:',
    disabled=False,
    style=widget_style,
    layout=widget_layout
)
display(model_widget)

In [None]:
BART_VARIANT = model_widget.value

preview_dynamic_feature_widget = widgets.Checkbox(
    value=True,
    description='Preview 8.5 EA dynamic shapes feature',
    disabled=False,
    indent=False,
    style=widget_style,
    layout=widget_layout
)

FP16_widget = widgets.Checkbox(
    value=False,
    description='FP16',
    disabled=False,
    indent=False,
    style=widget_style,
    layout=widget_layout
)

HF_KV_widget = widgets.Checkbox(
    value=True,
    description='HuggingFace KV cache',
    disabled=False,
    indent=False,
    style=widget_style,
    layout=widget_layout
)

TRT_KV_widget = widgets.Checkbox(
    value=False,
    description='TensorRT KV cache (disabled due to performance improvements in progress, not beating non-KV version yet)', #  
    disabled=True,
    indent=False,
    style=widget_style,
    layout=widget_layout
)

KV_widgets = widgets.HBox([HF_KV_widget,TRT_KV_widget])

batch_size_widget = widgets.BoundedIntText(
    value=1,
    min=1,
    max=100000,
    step=1,
    description='Batch size',
    disabled=False,
    style=widget_style,
    layout=widget_layout
)

max_input_len_widget = widgets.BoundedIntText(
    value=BARTModelTRTConfig.MAX_SEQUENCE_LENGTH[BART_VARIANT],
    min=1,
    max=100000,
    step=1,
    description='Max input length',
    disabled=False,
    style=widget_style,
    layout=widget_layout
)

min_output_len_widget = widgets.BoundedIntText(
    value=BARTModelTRTConfig.MIN_OUTPUT_LENGTH[BART_VARIANT],
    min=0,
    max=100000,
    step=1,
    description='Min output length',
    disabled=False,
    style=widget_style,
    layout=widget_layout
)

max_output_len_widget = widgets.BoundedIntText(
    value=BARTModelTRTConfig.MAX_OUTPUT_LENGTH[BART_VARIANT],
    min=1,
    max=100000,
    step=1,
    description='Max output length',
    disabled=False,
    style=widget_style,
    layout=widget_layout
)

encoder_hidden_size_widget = widgets.BoundedIntText(
    value=BARTModelTRTConfig.ENCODER_HIDDEN_SIZE[BART_VARIANT],
    min=1,
    max=100000,
    step=1,
    description='Encoder hidden size',
    disabled=False,
    style=widget_style,
    layout=widget_layout
)

num_beam_widget = widgets.BoundedIntText(
    value=1,
    min=1,
    max=100000,
    step=1,
    description='Number of beams',
    disabled=False,
    style=widget_style,
    layout=widget_layout
)

widgets_all = widgets.VBox([
    FP16_widget, 
    preview_dynamic_feature_widget,
    KV_widgets,
    batch_size_widget, 
    max_input_len_widget,
    min_output_len_widget,
    max_output_len_widget, 
    encoder_hidden_size_widget,
    num_beam_widget
])

display(widgets_all)

In [None]:
# Inference config
FP16 = FP16_widget.value # flag to use FP16 precision in PyTorch & TRT
preview_dynamic_shapes = preview_dynamic_feature_widget.value # flag to preview 8.5 EA feature
HF_KV = HF_KV_widget.value # flag to use KV cache in HF
TRT_KV = TRT_KV_widget.value # flag to use KV cache in TRT

# Model config
batch_size = batch_size_widget.value
max_input_len = max_input_len_widget.value
min_output_len = min_output_len_widget.value
max_output_len = max_output_len_widget.value
encoder_hidden_size = encoder_hidden_size_widget.value
num_beams = num_beam_widget.value

# Benchmark config
# `TimingProfile` is a named tuple that specifies the number of experiments and number of times to call the function per iteration and number of warm-up calls, oercentiles, etc.
timing_profile = TimingProfile(iterations=10, number=1, warmup=1, duration=0, percentile=[50,99])

def percentile_print(timing):
    return ', '.join(['p{} {:.2f}ms'.format(timing_profile.percentile[i], p*1000) for i,p in enumerate(timing)])

In [None]:
# mbart variant can't be recognized by HF AutoClass yet
if "mbart" not in BART_VARIANT:    
    bart_model = AutoModelForPreTraining.from_pretrained(BART_VARIANT) # BartForConditionalGeneration
    tokenizer = AutoTokenizer.from_pretrained(BART_VARIANT) # BartTokenizer
else:
    from transformers import MBartForConditionalGeneration, MBart50Tokenizer
    bart_model = MBartForConditionalGeneration.from_pretrained(BART_VARIANT)
    tokenizer = MBart50Tokenizer.from_pretrained(BART_VARIANT, src_lang="en_XX")

config = AutoConfig.from_pretrained(BART_VARIANT)

bart_model = bart_model.to('cuda').eval()

In [None]:
# save model locally
pytorch_model_dir = './models/{}/pytorch'.format(BART_VARIANT)
!mkdir -p $pytorch_model_dir

if os.path.exists(pytorch_model_dir) and len(os.listdir(pytorch_model_dir)) != 0:
    print('PyTorch model already exists. Skipping...')
else:
    bart_model.save_pretrained(pytorch_model_dir)
    print("PyTorch model saved to {}".format(pytorch_model_dir))

### Test Input Data

In [None]:
# input sequence
inputs = "NVIDIA TensorRT-based applications perform up to 36X faster than CPU-only platforms during inference, enabling developers to optimize neural network models trained on all major frameworks, calibrate for lower precision with high accuracy, and deploy to hyperscale data centers, embedded platforms, or automotive product platforms. TensorRT, built on the NVIDIA CUDA parallel programming model, enables developers to optimize inference by leveraging libraries, development tools, and technologies in CUDA-X for AI, autonomous machines, high performance computing, and graphics. With new NVIDIA Ampere Architecture GPUs, TensorRT also uses sparse tensor cores for an additional performance boost."

input_ids = tokenizer(inputs, padding=True, return_tensors="pt").input_ids.to('cuda')

<a id="2"></a>

## 2. PyTorch HuggingFace Inference

Next, we will carry out inference with the HuggingFace PyTorch model as a baseline.

### End-to-End HuggingFace Inference

In [None]:
# encoder-decoder inference 
with torch.no_grad():
    output_ids = bart_model.generate(input_ids, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False)    
    outputs = tokenizer.decode(output_ids[-1,:], skip_special_tokens=True)    
outputs_hf = outputs

# timing
# FP32
bart_model.float()
hf_nonkv_time = measure_python_inference_code(lambda: bart_model.generate(input_ids, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time = measure_python_inference_code(lambda: bart_model.generate(input_ids, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

# FP16
bart_model.half()
hf_nonkv_time_fp16 = measure_python_inference_code(lambda: bart_model.generate(input_ids, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=False), timing_profile)
hf_kv_time_fp16 = measure_python_inference_code(lambda: bart_model.generate(input_ids, max_length=max_output_len, min_length=min_output_len, num_beams=num_beams, use_cache=True), timing_profile)

In [None]:
# print results and timing statistics
print(f'Input length: {input_ids.size(1)}')
print(inputs)
print('\n')      
print(f'Output length: {output_ids[-1,:].size(0)}')
print(outputs_hf)
print('\n')      
print(f'Device: {torch.cuda.get_device_name()}')
print(f"Precision: FP32, Number of Beams: {num_beams}")
print(f"HF time (no KV cache): {percentile_print(hf_nonkv_time)}")
print(f"HF time (w/ KV cache): {percentile_print(hf_kv_time)}")
print(f"Precision: FP16, Number of Beams: {num_beams}")
print(f"HF time (no KV cache): {percentile_print(hf_nonkv_time_fp16)}")
print(f"HF time (w/ KV cache): {percentile_print(hf_kv_time_fp16)}")

### Time Measurement of Encoder, Decoder, and Full E2E
For benchmarking purposes, we will employ helper functions `encoder_inference`, `decoder_inference`, and `full_inference_greedy` which execute the inference repeatedly for the BART encoder and decoder stacks separately as well as end-to-end for the entire output sequence, and measure the execution time. These execution times can be later on compared with TensorRT counterpart to demonstrate the speedup. 

Encoder and decoder of BART are wrapped as standalone PyTorch module for testing.

In [None]:
# FP32
bart_model.float()
bart_torch_encoder = BARTEncoderTorchFile.TorchModule(bart_model.get_encoder())
bart_torch_decoder = BARTDecoderTorchFile.TorchModule(bart_model.get_decoder(), bart_model.lm_head, bart_model.final_logits_bias, bart_model.config)

with torch.no_grad():

    encoder_last_hidden_state, encoder_pytorch_time = encoder_inference(bart_torch_encoder, input_ids, timing_profile)
    _, decoder_pytorch_time = decoder_inference(bart_torch_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile, use_cache=HF_KV)
    if num_beams == 1:
        output_ids, full_pytorch_time = full_inference_greedy(bart_torch_encoder,bart_torch_decoder,input_ids,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV)
    else:
        output_ids, full_pytorch_time = full_inference_beam(bart_torch_encoder,bart_torch_decoder,input_ids,tokenizer,timing_profile,num_beams=num_beams,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV)
    outputs = tokenizer.decode(output_ids[0], skip_special_tokens=True)    

outputs_pytorch = outputs

# FP16
bart_model.half()
bart_torch_encoder_fp16 = BARTEncoderTorchFile.TorchModule(bart_model.get_encoder())
bart_torch_decoder_fp16 = BARTDecoderTorchFile.TorchModule(bart_model.get_decoder(), bart_model.lm_head, bart_model.final_logits_bias, bart_model.config)

with torch.no_grad():

    encoder_last_hidden_state, encoder_pytorch_time_fp16 = encoder_inference(bart_torch_encoder_fp16, input_ids, timing_profile)
    _, decoder_pytorch_time_fp16 = decoder_inference(bart_torch_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_state, num_beams) if num_beams > 1 else encoder_last_hidden_state, timing_profile, use_cache=HF_KV)
    if num_beams == 1:
        output_ids_fp16, full_pytorch_time_fp16 = full_inference_greedy(bart_torch_encoder_fp16,bart_torch_decoder_fp16,input_ids,tokenizer,timing_profile,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV)
    else:
        output_ids_fp16, full_pytorch_time_fp16 = full_inference_beam(bart_torch_encoder_fp16,bart_torch_decoder_fp16,input_ids,tokenizer,timing_profile,num_beams=num_beams,max_length=max_output_len, min_length=min_output_len, use_cache=HF_KV)
    outputs_fp16 = tokenizer.decode(output_ids_fp16[0], skip_special_tokens=True)    

outputs_pytorch_fp16 = outputs_fp16

In [None]:
# print
print(f'PyTorch FP32 Output identical to HF results? {outputs_pytorch == outputs_hf}')
print(f'PyTorch FP16 Output identical to HF results? {outputs_pytorch_fp16 == outputs_hf}')
print('\n')      
print(f'Device: {torch.cuda.get_device_name()}')
print(f"Precision: FP32, Number of Beams: {num_beams}")
print(f"Encoder time: {percentile_print(encoder_pytorch_time)}")
print(f"Decoder time: {percentile_print(decoder_pytorch_time)}")
print(f"Full E2E time: {percentile_print(full_pytorch_time)}")
print(f"Precision: FP16, Number of Beams: {num_beams}")
print(f"Encoder time: {percentile_print(encoder_pytorch_time_fp16)}")
print(f"Decoder time: {percentile_print(decoder_pytorch_time_fp16)}")
print(f"Full E2E time: {percentile_print(full_pytorch_time_fp16)}")

<a id="3"></a>

## 3. TensorRT Engine Building

### Convert PyTorch to ONNX

Prior to converting the model to a TensorRT engine, we will first convert the PyTorch model to an intermediate universal format.

ONNX is an open format for machine learning and deep learning models. It allows you to convert deep learning and machine learning models from different frameworks such as TensorFlow, PyTorch, MATLAB, Caffe, and Keras to a single format.

The steps to convert a PyTorch model to TensorRT are as follows:
- Convert the pretrained PyTorch model into ONNX.
- Import the ONNX model into TensorRT, apply optimizations and generate a TensorRT engine.
- Perform inference on the GPU using the engine. 

For the BART model, we will convert the encoder and decoder to ONNX and build each engine seperately. The logistics of this separate building approach come from the nature of sequence-to-sequence models. BART and T5 are good examples of sequence-to-sequence models which use encoder-decoder architecture. The encoder is only executed once on the input and generates hidden states. Next, the decoder is executed repeatedly in an auto-regressive manner until the entire output finishes generating, i.e. the output sequence length is the number of times the decoder runs. The most efficient way to run encoder-decoder models with TensorRT is to have two separate engines.

In [None]:
onnx_model_path = './models/{}/onnx'.format(BART_VARIANT)
!mkdir -p $onnx_model_path

# FP32
bart_model.float()
metadata = NetworkMetadata(variant=BART_VARIANT, precision=Precision(fp16=False), other=BARTMetadata(kv_cache=TRT_KV))
trt_config = BARTModelTRTConfig()
metadata_string = trt_config.get_metadata_string(metadata)

encoder_onnx_model_fpath = metadata_string + "-encoder.onnx"
decoder_onnx_model_fpath = metadata_string + "-decoder-with-lm-head.onnx"

# for onnx conversion, ensure model is on CPU and FP32 precision in this step
bart_torchfile_encoder = BARTEncoderTorchFile(bart_model.to('cpu'), metadata)
bart_torchfile_decoder = BARTDecoderTorchFile(bart_model.to('cpu'), metadata)

onnx_bart_encoder = bart_torchfile_encoder.as_onnx_model(os.path.join(onnx_model_path, encoder_onnx_model_fpath), force_overwrite=False)
onnx_bart_decoder = bart_torchfile_decoder.as_onnx_model(os.path.join(onnx_model_path, decoder_onnx_model_fpath), force_overwrite=False)

# FP16
metadata_fp16 = NetworkMetadata(variant=BART_VARIANT, precision=Precision(fp16=True), other=BARTMetadata(kv_cache=TRT_KV))
trt_config_fp16 = BARTModelTRTConfig()
metadata_string_fp16 = trt_config.get_metadata_string(metadata_fp16)

encoder_onnx_model_fpath_fp16 = metadata_string_fp16 + "-encoder.onnx"
decoder_onnx_model_fpath_fp16 = metadata_string_fp16 + "-decoder-with-lm-head.onnx"

# for onnx conversion, ensure model is on CPU and FP32 precision in this step
bart_torchfile_encoder = BARTEncoderTorchFile(bart_model.to('cpu'), metadata)
bart_torchfile_decoder = BARTDecoderTorchFile(bart_model.to('cpu'), metadata)

onnx_bart_encoder_fp16 = bart_torchfile_encoder.as_onnx_model(os.path.join(onnx_model_path, encoder_onnx_model_fpath_fp16), force_overwrite=False)
onnx_bart_decoder_fp16 = bart_torchfile_decoder.as_onnx_model(os.path.join(onnx_model_path, decoder_onnx_model_fpath_fp16), force_overwrite=False)

### Convert ONNX to TensorRT

Now we are ready to parse the ONNX encoder and decoder models and convert them to optimized TensorRT engines.

Since the models contains dynamic input shapes, we can specify a valid input range with a TensorRT optimization profile.

In [None]:
tensorrt_model_path = './models/{}/tensorrt'.format(BART_VARIANT)
!mkdir -p $tensorrt_model_path

# Encoder optimization profiles
encoder_profile = Profile()
encoder_profile.add(
    "input_ids",
    min=(batch_size, 1),
    opt=(batch_size, max_input_len // 2),
    max=(batch_size, max_input_len),
)

# Decoder optimization profiles
decoder_profile = Profile()
decoder_profile.add(
    "input_ids",
    min=(batch_size * num_beams, 1),
    opt=(batch_size * num_beams, max_output_len // 2),
    max=(batch_size * num_beams, max_output_len),
)
decoder_profile.add(
    "encoder_hidden_states",
    min=(batch_size * num_beams, 1, encoder_hidden_size),
    opt=(batch_size * num_beams, max_input_len // 2, encoder_hidden_size),
    max=(batch_size * num_beams, max_input_len, encoder_hidden_size),
)

In [None]:
engine_tag = f"bs{batch_size}"

if num_beams > 1:
    engine_tag += "-beam{}".format(num_beams)

preview_features = []
if preview_dynamic_shapes:
    preview_features = [PreviewFeature.FASTER_DYNAMIC_SHAPES_0805]
    engine_tag += "-previewFasterDynamicShapes"

# FP32
bart_trt_encoder_engine = BARTEncoderONNXFile(os.path.join(onnx_model_path, encoder_onnx_model_fpath), metadata).as_trt_engine(
    os.path.join(tensorrt_model_path, encoder_onnx_model_fpath) + f"-{engine_tag}.engine", 
    profiles=[encoder_profile], 
    preview_features=preview_features
)

bart_trt_decoder_engine = BARTDecoderONNXFile(os.path.join(onnx_model_path, decoder_onnx_model_fpath), metadata).as_trt_engine(
    os.path.join(tensorrt_model_path, decoder_onnx_model_fpath) + f"-{engine_tag}.engine", 
    profiles=[decoder_profile], 
    preview_features=preview_features
)

# FP16
bart_trt_encoder_engine_fp16 = BARTEncoderONNXFile(os.path.join(onnx_model_path, encoder_onnx_model_fpath_fp16), metadata_fp16).as_trt_engine(
    os.path.join(tensorrt_model_path, encoder_onnx_model_fpath_fp16) + f"-{engine_tag}.engine", 
    profiles=[encoder_profile], 
    preview_features=preview_features
)

bart_trt_decoder_engine_fp16 = BARTDecoderONNXFile(os.path.join(onnx_model_path, decoder_onnx_model_fpath_fp16), metadata_fp16).as_trt_engine(
    os.path.join(tensorrt_model_path, decoder_onnx_model_fpath_fp16) + f"-{engine_tag}.engine", 
    profiles=[decoder_profile], 
    preview_features=preview_features
)

<a id="4"></a>

## 4. TensorRT Inference

Great, if you have reached this stage, it means we now have successfully built optimized TensorRT engines for the BART model, ready for us to carry out inference. The BART model with TensorRT backend can now be employed in place of the original HuggingFace BART model.

In [None]:
# Initialize TensorRT engines
trt_config = AutoConfig.from_pretrained(BART_VARIANT)
trt_config.use_cache = metadata.other.kv_cache
trt_config.num_layers = BARTModelTRTConfig.NUMBER_OF_LAYERS[BART_VARIANT]

# FP32
bart_trt_encoder = BARTTRTEncoder(bart_trt_encoder_engine, metadata, trt_config, batch_size=batch_size)
bart_trt_decoder = BARTTRTDecoder(bart_trt_decoder_engine, metadata, trt_config, batch_size=batch_size, num_beams=num_beams)

# FP16
bart_trt_encoder_fp16 = BARTTRTEncoder(bart_trt_encoder_engine_fp16, metadata_fp16, trt_config, batch_size=batch_size)
bart_trt_decoder_fp16 = BARTTRTDecoder(bart_trt_decoder_engine_fp16, metadata_fp16, trt_config, batch_size=batch_size, num_beams=num_beams)

### End-to-End TensorRT Inference

In [None]:
from transformers.generation_logits_process import (
    NoRepeatNGramLogitsProcessor,
    MinLengthLogitsProcessor,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    LogitsProcessorList,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)
from transformers.generation_beam_search import (
    BeamSearchScorer,
)

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_output_len)])
no_repeat_ngram_size = BARTModelTRTConfig.NO_REPEAT_NGRAM_SIZE
min_length = BARTModelTRTConfig.MIN_OUTPUT_LENGTH[BART_VARIANT]
logits_processor = LogitsProcessorList([
    NoRepeatNGramLogitsProcessor(no_repeat_ngram_size), 
    MinLengthLogitsProcessor(min_length, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)),
    ForcedBOSTokenLogitsProcessor(tokenizer.convert_tokens_to_ids(tokenizer.bos_token)),
    ForcedEOSTokenLogitsProcessor(max_output_len, tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
]) # by checking HuggingFace's generate() implementation carefully, the default logits processor for BART has no_repeat_ngram_size = 3 and forced_eos_token_id = 2. In this way we can ensure identical results with raw HuggingFace

decoder_initial_input = torch.full(
    (batch_size, 1), tokenizer.convert_tokens_to_ids(tokenizer.eos_token), dtype=torch.int32
).to('cuda')

if num_beams > 1:
    decoder_initial_input = expand_inputs_for_beam_search(decoder_initial_input, expand_size=num_beams)
    
# FP32
def e2e_trt():
    with torch.no_grad():
        # beam scorer must be reset before each beam search run, otherwise beam search will be skipped due to scorer cache
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=num_beams,
            device="cuda",
            do_early_stopping=True,
        )
        
        encoder_last_hidden_states = bart_trt_encoder(input_ids=input_ids)
        
        if num_beams > 1:
            encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_last_hidden_states, expand_size=num_beams)
        
        bart_trt_decoder.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)
        
        if num_beams == 1:
            decoder_output = bart_trt_decoder.greedy_search(
                input_ids=decoder_initial_input,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
        else:
            decoder_output = bart_trt_decoder.beam_search(
                input_ids=decoder_initial_input,
                beam_scorer=beam_scorer,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
    return decoder_output

output_ids = e2e_trt()
outputs_trt = tokenizer.decode(output_ids[0], skip_special_tokens=True)
trt_time = measure_python_inference_code(e2e_trt, timing_profile)

# FP16
def e2e_trt_fp16():
    with torch.no_grad():
        # beam scorer must be reset before each beam search run, otherwise beam search will be skipped due to scorer cache
        beam_scorer = BeamSearchScorer(
            batch_size=batch_size,
            num_beams=num_beams,
            device="cuda",
            do_early_stopping=True,
        )
        
        encoder_last_hidden_states = bart_trt_encoder_fp16(input_ids=input_ids)
        
        if num_beams > 1:
            encoder_last_hidden_states = expand_inputs_for_beam_search(encoder_last_hidden_states, expand_size=num_beams)
        
        bart_trt_decoder_fp16.set_encoder_hidden_states_for_inference_cycle(encoder_last_hidden_states)
        
        if num_beams == 1:
            decoder_output = bart_trt_decoder_fp16.greedy_search(
                input_ids=decoder_initial_input,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
        else:
            decoder_output = bart_trt_decoder_fp16.beam_search(
                input_ids=decoder_initial_input,
                beam_scorer=beam_scorer,
                encoder_hidden_states=encoder_last_hidden_states,
                stopping_criteria=stopping_criteria,
                logits_processor=logits_processor,
                use_cache=metadata.other.kv_cache,
                use_cuda=True
            )
    return decoder_output

output_ids_fp16 = e2e_trt_fp16()
outputs_trt_fp16 = tokenizer.decode(output_ids_fp16[0], skip_special_tokens=True)
trt_time_fp16 = measure_python_inference_code(e2e_trt_fp16, timing_profile)

In [None]:
# print results and timing statistics
print(f'Device: {torch.cuda.get_device_name()}')
print(f"Using engine: {metadata_string + '-' + engine_tag}")   
print(f'Output identical to HF results? {outputs_trt == outputs_hf}')
print(f"Precision: FP32")
print(f'TRT time: {percentile_print(trt_time)}')
print()
print(f"Using engine: {metadata_string_fp16 + '-' + engine_tag}")   
print(f'Output identical to HF results? {outputs_trt_fp16 == outputs_hf}')
print(f"Precision: FP16")
print(f'TRT time: {percentile_print(trt_time_fp16)}')

### Time Measurement of Encoder, Decoder, and Full E2E
We will benchmark the encoder, decoder, and full end-to-end as we did for HuggingFace before.

In [None]:
# FP32
encoder_last_hidden_states, encoder_trt_time = encoder_inference(bart_trt_encoder, input_ids, timing_profile)
_, decoder_trt_time = decoder_inference(bart_trt_decoder, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time = full_inference_greedy(
        bart_trt_encoder,
        bart_trt_decoder,
        input_ids,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=BARTModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time = full_inference_beam(
        bart_trt_encoder,
        bart_trt_decoder,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=BARTModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
    
print(f'Encoder time: {percentile_print(encoder_trt_time)}')
print(f'Decoder time: {percentile_print(decoder_trt_time)}')
print(f'Full E2E time: {percentile_print(full_trt_time)}')

# FP16
encoder_last_hidden_states, encoder_trt_time_fp16 = encoder_inference(bart_trt_encoder_fp16, input_ids, timing_profile)
_, decoder_trt_time_fp16 = decoder_inference(bart_trt_decoder_fp16, expand_inputs_for_beam_search(input_ids, num_beams) if num_beams > 1 else input_ids, expand_inputs_for_beam_search(encoder_last_hidden_states, num_beams) if num_beams > 1 else encoder_last_hidden_states, timing_profile)

if num_beams == 1:
    _, full_trt_time_fp16 = full_inference_greedy(
        bart_trt_encoder_fp16,
        bart_trt_decoder_fp16,
        input_ids,
        tokenizer,
        timing_profile,
        max_length=max_output_len,
        min_length=BARTModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
    )
else:
    _, full_trt_time_fp16 = full_inference_beam(
        bart_trt_encoder_fp16,
        bart_trt_decoder_fp16,
        input_ids,
        tokenizer,
        timing_profile,
        num_beams=num_beams,
        max_length=max_output_len,
        min_length=BARTModelTRTConfig.MIN_OUTPUT_LENGTH[metadata.variant],
        batch_size=batch_size,
        use_cache=metadata.other.kv_cache,
        early_stopping=True,
    )
print(f'Encoder FP16 time: {percentile_print(encoder_trt_time_fp16)}')
print(f'Decoder FP16 time: {percentile_print(decoder_trt_time_fp16)}')
print(f'Full E2E FP16 time: {percentile_print(full_trt_time_fp16)}')

## Comparison

In [None]:
from tabulate import tabulate

data = [
    ['Framework', 'Precision', 'Encoder p50 (ms)', 'Decoder p50 (ms)', 'Full E2E p50 (ms)', 'Accuracy'],
    ['HuggingFace (w/o cache)', 'FP32', '-', '-', f'{hf_nonkv_time[0]*1000:.2f}', '-'],
    ['HuggingFace (w/ cache)', 'FP32', '-', '-', f'{hf_kv_time[0]*1000:.2f}', '-'],
    ['HuggingFace (w/o cache)', 'FP16', '-', '-', f'{hf_nonkv_time_fp16[0]*1000:.2f}', '-'],
    ['HuggingFace (w/ cache)', 'FP16', '-', '-', f'{hf_kv_time_fp16[0]*1000:.2f}', '-'],
    ['PyTorch', 'FP32', f'{encoder_pytorch_time[0]*1000:.2f}', f'{decoder_pytorch_time[0]*1000:.2f}', f'{full_pytorch_time[0]*1000:.2f}', outputs_pytorch == outputs_hf],
    ['PyTorch', 'FP16', f'{encoder_pytorch_time_fp16[0]*1000:.2f}', f'{decoder_pytorch_time_fp16[0]*1000:.2f}', f'{full_pytorch_time_fp16[0]*1000:.2f}', outputs_pytorch_fp16 == outputs_hf],
    ['TensorRT', 'FP32', f'{encoder_trt_time[0]*1000:.2f}', f'{decoder_trt_time[0]*1000:.2f}', f'{full_trt_time[0]*1000:.2f}', outputs_trt == outputs_hf],
    ['TensorRT', 'FP16', f'{encoder_trt_time_fp16[0]*1000:.2f}', f'{decoder_trt_time_fp16[0]*1000:.2f}', f'{full_trt_time_fp16[0]*1000:.2f}', outputs_trt_fp16 == outputs_hf],
]

print(tabulate(data, headers='firstrow', tablefmt='github'))

We can now compare the original HuggingFace model and the TensorRT engine, from both separate encoder/decoder and end-to-end speed difference. For bart-base variant on an NVIDIA Titan V GPU and input/output sequence length around 130, this results in about 2x performance improvement with FP16 inference.

## Variable Input/Output Length

We can run more tests by varying input/output length, while using the same engines.

Note that TensorRT performance depends on optimal selection of the kernels in the engine. The variable length test here uses the same engine built with max input/output length profile, therefore may not represent the best perf. If the use case has known input/output length ranges, it is highly recommended to specify in the TensorRT engine profiles to ensure optimized kernel selection.

### Single example

In [None]:
# ensure HF model are on GPU for testing (cells above moved it CPU)
bart_model = bart_model.to('cuda').eval()

in_len, out_len = 24, 24

data = [
    ['(input_len, output_len)', 'HF FP32 p50 (s)', 'HF FP16 p50 (s)', 'TRT FP32 p50 (s)', 'TRT FP16 p50 (s)'],
]

assert in_len <= max_input_len and out_len <= max_output_len
    
in_ids = torch.randint(0, BARTModelTRTConfig.VOCAB_SIZE[BART_VARIANT], (batch_size, in_len)).to('cuda')

# HF
bart_model.float()
hf_32 = measure_python_inference_code(lambda: bart_model.generate(in_ids, min_length=out_len, max_length=out_len, num_beams=num_beams, use_cache=True), timing_profile)
bart_model.half()
hf_16 = measure_python_inference_code(lambda: bart_model.generate(in_ids, min_length=out_len, max_length=out_len, num_beams=num_beams, use_cache=True), timing_profile)

# TRT
if num_beams == 1:
    _, trt_32 = full_inference_greedy(bart_trt_encoder, bart_trt_decoder, in_ids, tokenizer, timing_profile, max_length=out_len, min_length=out_len, batch_size=batch_size, use_cache=metadata.other.kv_cache, use_cuda=True,)
    _, trt_16 = full_inference_greedy(bart_trt_encoder_fp16, bart_trt_decoder_fp16, in_ids, tokenizer, timing_profile, max_length=out_len, min_length=out_len, batch_size=batch_size, use_cache=metadata.other.kv_cache, use_cuda=True,)
else:
    _, trt_32 = full_inference_beam(bart_trt_encoder, bart_trt_decoder, in_ids, tokenizer, timing_profile, num_beams=num_beams, max_length=out_len, min_length=out_len, batch_size=batch_size, use_cache=metadata.other.kv_cache, early_stopping=True,)
    _, trt_16 = full_inference_beam(bart_trt_encoder_fp16, bart_trt_decoder_fp16, in_ids, tokenizer, timing_profile, num_beams=num_beams, max_length=out_len, min_length=out_len, batch_size=batch_size, use_cache=metadata.other.kv_cache, early_stopping=True,)

data.append([(in_len, out_len), hf_32[0], hf_16[0], trt_32[0], trt_16[0]])

print(tabulate(data, headers='firstrow', tablefmt='github'))

### Several representative examples

In [None]:
# ensure HF model are on GPU for testing (cells above moved it CPU)
bart_model = bart_model.to('cuda').eval()

input_output_len_list = [
    (64, 128), # generation task
    (64, 512),
    (512, 64), # summarization task
    (128, 64),
    (32, 32), # translation task
    (128, 128),
    (512, 512),
]

data = [
    ['(input_len, output_len)', 'HF FP32 p50 (s)', 'HF FP16 p50 (s)', 'TRT FP32 p50 (s)', 'TRT FP16 p50 (s)'],
]

for (in_len, out_len) in input_output_len_list:
    assert in_len <= max_input_len and out_len <= max_output_len
    
    in_ids = torch.randint(0, BARTModelTRTConfig.VOCAB_SIZE[BART_VARIANT], (batch_size, in_len)).to('cuda')
    
    # HF
    bart_model.float()
    hf_32 = measure_python_inference_code(lambda: bart_model.generate(in_ids, min_length=out_len, max_length=out_len, num_beams=num_beams, use_cache=True), timing_profile)
    bart_model.half()
    hf_16 = measure_python_inference_code(lambda: bart_model.generate(in_ids, min_length=out_len, max_length=out_len, num_beams=num_beams, use_cache=True), timing_profile)
    
    # TRT
    if num_beams == 1:
        _, trt_32 = full_inference_greedy(bart_trt_encoder, bart_trt_decoder, in_ids, tokenizer, timing_profile, max_length=out_len, min_length=out_len, batch_size=batch_size, use_cache=metadata.other.kv_cache, use_cuda=True,)
        _, trt_16 = full_inference_greedy(bart_trt_encoder_fp16, bart_trt_decoder_fp16, in_ids, tokenizer, timing_profile, max_length=out_len, min_length=out_len, batch_size=batch_size, use_cache=metadata.other.kv_cache, use_cuda=True,)
    else:
        _, trt_32 = full_inference_beam(bart_trt_encoder, bart_trt_decoder, in_ids, tokenizer, timing_profile, num_beams=num_beams, max_length=out_len, min_length=out_len, batch_size=batch_size, use_cache=metadata.other.kv_cache, early_stopping=True,)
        _, trt_16 = full_inference_beam(bart_trt_encoder_fp16, bart_trt_decoder_fp16, in_ids, tokenizer, timing_profile, num_beams=num_beams, max_length=out_len, min_length=out_len, batch_size=batch_size, use_cache=metadata.other.kv_cache, early_stopping=True,)
    
    data.append([(in_len, out_len), hf_32[0], hf_16[0], trt_32[0], trt_16[0]])

print(tabulate(data, headers='firstrow', tablefmt='github'))

It shows around 2x speedup comparing to HuggingFace's KV-cache optimized timing, for relatively short output sequence length. For long output sequence length, due to memory copies overhead between the decoding steps, TensorRT may not provide significant speedup at the current stage.

## Conclusion

This notebook has walked you through the process of converting a HuggingFace PyTorch BART model to an optimized TensorRT engine for inference in easy steps. The TensorRT inference engine can be conviniently used as a drop-in replacement for the orginial HuggingFace BART model while providing speed up. 

If you are interested in further details of the conversion process, check out [BART/trt.py](../BART/trt.py)