## Install all requirements.

In [None]:
!pip install -r ../requirements.txt

## Download onnx and build engine.

In [None]:
!mkdir -p cache/

In [None]:
from pathlib import Path
from typing import Union

from utils.download import download_model_onnx
from utils.trt_builder import DefaultTransformerEngineBuilder


def get_engine_path(
    model_name: str,
    repo_id: str,
    max_batch_size: int = 1,
    max_seq_len: int = 256,
    max_history_len: int = 512,
    force_rebuild: bool = False,
    cache_dir: Union[str, Path] = 'cache',
) -> Path:
    engine_cache_path = Path(cache_dir) / f'{model_name}-b{max_batch_size}s{max_seq_len}h{max_history_len}.engine'
    if not force_rebuild and engine_cache_path.is_file() and engine_cache_path.exists():
        return engine_cache_path

    path_to_onnx = download_model_onnx(model_name=model_name, repo_id=repo_id, cache_dir=cache_dir)
    builder = DefaultTransformerEngineBuilder(
        max_batch_size=max_batch_size,
        max_seq_len=max_seq_len,
        max_history_len=max_history_len,
        use_fp16=False,
        use_int8=True,
    )
    builder.build(path_to_onnx=path_to_onnx, engine_cache_path=engine_cache_path)
    return engine_cache_path


engine_path = get_engine_path(model_name='gpt2-xl-i8', repo_id='ENOT-AutoDL/gpt2-tensorrt')

## Initialize and test seq2seq model.
### All important information with TensorRT initialization you can find in `utils/trt_model.py` and `utils/trt_seq2seq_model.py`.

In [None]:
import tensorrt as trt
import torch

from transformers import AutoTokenizer
from utils.trt_seq2seq_model import TrtSeq2SeqModel


model = TrtSeq2SeqModel(path_to_engine=engine_path)
tokenizer = AutoTokenizer.from_pretrained('gpt2')

input_text = 'Hello world!'

input_ids = tokenizer(input_text, return_tensors='pt')
input_ids = input_ids['input_ids'].to(device='cuda', dtype=torch.int32)
generated_ids = model.generate(input_ids, generate_len=100)
(generated_ids,) = generated_ids
generated_text = tokenizer.decode(generated_ids)

print('=' * 100)
print(input_text + generated_text)
print('=' * 100)

## Accuracy validation.

In [None]:
from utils.test import test_acc


test_acc(
    lambda input_ids: model.generate(input_ids, generate_len=1),
    device='cuda',
    verbose=True,
);

## Latency test.

In [None]:
from itertools import product

from utils.test import test_latency


def generate_ids_function(seq_len: int) -> torch.Tensor:
    return torch.ones(size=(1, seq_len), device='cuda', dtype=torch.int32)


def generate_seq_function(input_ids: torch.Tensor, generate_len: int) -> torch.Tensor:
    return model.generate(input_ids, generate_len=generate_len)


test_latency(
    generate_ids_function=generate_ids_function,
    generate_seq_function=generate_seq_function,
    variants=list(product([64, 128, 256], [64, 128, 256])),
    warmup=20,
    repeats=20,
    verbose=True,
);