In [None]:
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

##### <img src="http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png" style="width: 90px; float: right;">

# BART Playground

This notebook demonstrates BART model on the task of text summarization and mask filling.

The TensorRT HuggingFace BART model is a plug-in replacement for the original PyTorch modules in HuggingFace BART model.

**Notes**: 
 - For "CPU - PyTorch" and "GPU - PyTorch", a BART-base model from HuggingFace model repository is employed. Inference is carried out in FP32 for CPU-PyTorch, and FP16 for GPU-PyTorch and TensorRT. All models run with batch size 1.
Average run time across 5 runs is reported.
 - Prior to running this notebook, run [bart.ipynb](bart.ipynb) to download the BART model and generate the TensorRT engine.

In [None]:
import ipywidgets as widgets

model_selection = widgets.RadioButtons(
    options=['facebook/bart-base', 
             'facebook/bart-large', 
             'facebook/bart-large-cnn', 
             'facebook/mbart-large-50'],
    description='Model:',
    disabled=False
)

display(model_selection)

In [None]:
import os
import sys
import glob
ROOT_DIR = os.path.abspath("../")
sys.path.append(ROOT_DIR)

import torch 

# huggingface
from transformers import (
    AutoModelForPreTraining,
    AutoTokenizer,
    MBartForConditionalGeneration, 
    MBart50Tokenizer,
    AutoConfig,
)

# download HuggingFace model and tokernizer
BART_VARIANT = model_selection.value

# mbart variant can't be recognized by HF AutoClass yet
if "mbart" not in BART_VARIANT:    
    bart_model = AutoModelForPreTraining.from_pretrained(BART_VARIANT) # BartForConditionalGeneration
    tokenizer = AutoTokenizer.from_pretrained(BART_VARIANT) # BartTokenizer
else:
    bart_model = MBartForConditionalGeneration.from_pretrained(BART_VARIANT)
    tokenizer = MBart50Tokenizer.from_pretrained(BART_VARIANT, src_lang="en_XX")

config = AutoConfig.from_pretrained(BART_VARIANT)

# load TensorRT engine
from BART.trt import BARTTRTEncoder, BARTTRTDecoder, TRTHFRunner
from BART.BARTModelConfig import BARTModelTRTConfig, BARTMetadata
from BART.export import BARTDecoderTRTEngine, BARTEncoderTRTEngine
from NNDF.networks import NetworkMetadata, Precision

from transformers.generation_logits_process import (
    NoRepeatNGramLogitsProcessor,
    MinLengthLogitsProcessor,
    ForcedBOSTokenLogitsProcessor,
    ForcedEOSTokenLogitsProcessor,
    LogitsProcessorList,
)
from transformers.generation_stopping_criteria import (
    MaxLengthCriteria,
    StoppingCriteriaList,
)

trt_config = AutoConfig.from_pretrained(BART_VARIANT)
trt_config.use_cache = False
trt_config.num_layers = BARTModelTRTConfig.NUMBER_OF_LAYERS[BART_VARIANT]

metadata=NetworkMetadata(variant=BART_VARIANT, precision=Precision(fp16=True), other=BARTMetadata(kv_cache=False))
metadata_string = BARTModelTRTConfig().get_metadata_string(metadata)

encoder_stem = metadata_string + "-encoder.onnx"
decoder_stem = metadata_string + "-decoder-with-lm-head.onnx"

encoder_path = glob.glob(f'./models/{BART_VARIANT}/tensorrt/{encoder_stem}*')[0]
decoder_path = glob.glob(f'./models/{BART_VARIANT}/tensorrt/{decoder_stem}*')[0]

if not os.path.exists(encoder_path) or not os.path.exists(decoder_path):
    print(f"Error: TensorRT engine not found at ./models/{BART_VARIANT}/tensorrt/. Please run bart.ipynb to generate the TensorRT engines first!")
else:
    encoder_engine = BARTEncoderTRTEngine(encoder_path, metadata)
    decoder_engine = BARTDecoderTRTEngine(decoder_path, metadata)

bart_trt_encoder = BARTTRTEncoder(encoder_engine, metadata, trt_config)
bart_trt_decoder = BARTTRTDecoder(decoder_engine, metadata, trt_config)

decoder_input_ids = torch.full(
    (1, 1), tokenizer.convert_tokens_to_ids(tokenizer.pad_token), dtype=torch.int32
).to("cuda:0")

In [None]:
import numpy as np
import time

device = widgets.RadioButtons(
    options=['CPU - PyTorch', 
             'GPU - PyTorch', 
             'GPU - TensorRT'],
    description='Device:',
    disabled=False
)

task = widgets.RadioButtons(
    options=['Summarization', 
             'Mask Filling', 
             ],
    description='Task:',
    disabled=False
)

example_text = {
    task.options[0]:
         "NVIDIA TensorRT-based applications perform up to 36X faster than CPU-only platforms during inference, enabling developers to optimize neural network models trained on all major frameworks, calibrate for lower precision with high accuracy, and deploy to hyperscale data centers, embedded platforms, or automotive product platforms.",
    task.options[1]: 
         "My friends are <mask> but they eat too many carbs."
    }
    
paragraph_text = widgets.Textarea(
    value=example_text[task.options[0]],
    placeholder='Type something',
    description='Context:',
    disabled=False,
    layout=widgets.Layout(width="auto"),
    rows=5,  
)

generated_text = widgets.Textarea(
    value='...',
    placeholder='Context',
    description='BART output:',
    disabled=False,
    layout=widgets.Layout(width="auto"),
    rows=5,
)
button = widgets.Button(description="Generate")

display(paragraph_text)
display(generated_text)
display(device)
display(task)

from IPython.display import display
box_layout = widgets.Layout(display='flex',
                flex_flow='column',
                align_items='center',
                width='100%')
N_RUN = 6
progress_bar = widgets.IntProgress(
    value=0,
    min=0,
    max=N_RUN,
    description='Progress:',
    bar_style='', # 'success', 'info', 'warning', 'danger' or ''
    style={'bar_color': 'green'},
    orientation='horizontal', 
    layout=widgets.Layout(width='100%', height='50px')
)

box = widgets.HBox(children=[button],layout=box_layout)
output = widgets.Output()
display(box)
display(progress_bar)
display(output)

max_output_length = BARTModelTRTConfig.MAX_OUTPUT_LENGTH[BART_VARIANT]

stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_output_length)])
no_repeat_ngram_size = BARTModelTRTConfig.NO_REPEAT_NGRAM_SIZE
min_length = BARTModelTRTConfig.MIN_OUTPUT_LENGTH[BART_VARIANT]
logits_processor = LogitsProcessorList([
    NoRepeatNGramLogitsProcessor(no_repeat_ngram_size), 
    MinLengthLogitsProcessor(min_length, tokenizer.convert_tokens_to_ids(tokenizer.eos_token)),
    ForcedBOSTokenLogitsProcessor(tokenizer.convert_tokens_to_ids(tokenizer.bos_token)),
    ForcedEOSTokenLogitsProcessor(max_output_length, tokenizer.convert_tokens_to_ids(tokenizer.eos_token))
])

def generate(b):
    progress_bar.value = 0
    inference_time_arr = []
    inputs = tokenizer(paragraph_text.value, return_tensors="pt")
    
    with output:
        if device.value == 'GPU - TensorRT':
            for _ in range(N_RUN):
                start_time = time.time()
                encoder_last_hidden_state = bart_trt_encoder(input_ids=inputs.input_ids)
                outputs = bart_trt_decoder.greedy_search(
                            input_ids=decoder_input_ids,
                            encoder_hidden_states=encoder_last_hidden_state,
                            stopping_criteria = stopping_criteria,
                            logits_processor=logits_processor,
                        )
                inference_time_arr.append(time.time()-start_time)
                progress_bar.value += 1
            print("GPU - TensorRT - Average inference time: %.2f (ms)"%(1000*np.mean(inference_time_arr[1:])))                   
                
        elif device.value == 'CPU - PyTorch':
            for _ in range(N_RUN):
                start_time = time.time()
                outputs = bart_model.float().to('cpu').generate(inputs.input_ids.to('cpu'), num_beams=1, max_length=max_output_length)
                inference_time_arr.append(time.time()-start_time)
                progress_bar.value += 1
            print("CPU - PyTorch - Average inference time: %.2f (ms)"%(1000*np.mean(inference_time_arr[1:])))
            
        elif  device.value == 'GPU - PyTorch':  
            for _ in range(N_RUN):
                start_time = time.time()
                outputs = bart_model.half().to('cuda:0').generate(inputs.input_ids.to('cuda:0'), num_beams=1, max_length=max_output_length)
                inference_time_arr.append(time.time()-start_time)
                progress_bar.value += 1
            print("GPU - PyTorch - Average inference time: %.2f (ms)"%(1000*np.mean(inference_time_arr[1:])))    
           
        # de-tokenize model output to raw text
        text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_text.value = text


def switch_task(change):
    with output:
        paragraph_text.value = example_text[task.value]

task.observe(switch_task, 'value')

button.on_click(generate)