Skip to content

TensorRT engine randomly returns nan's when batch size greater than 1. #2785

@azsh1725

Description

@azsh1725

Description

I continue to investigate tensorrt, I found this example and tried to make an inference based on it for an already generated engine, but for some reason everything works fine for batch size 1, but when the size increases (for example to 32), 2 problems arise:

  1. Tensorrt randomly starts to give nan's as a result, and sometimes work as it should.
  2. The error in the generated embeddings using the model from the huggingface and the tensorrt of the engine grows very much. For a batch of size 1, the difference is about 2e-4, and for a batch of size 32 it is already 0.5.

Environment

TensorRT Version: 8.0.1
NVIDIA GPU: RTX 3090
NVIDIA Driver Version: 510.108.03
CUDA Version: 11.3
CUDNN Version: 8.2.2
Operating System: Linux
Python Version (if applicable): 3.8
PyTorch Version (if applicable): 1.11.0+cu113

Relevant Files

Onnx model and tensorrt engine

Steps To Reproduce

import inspect
import logging
import os
import timeit
import math
import csv
from collections import namedtuple
from functools import reduce
from datetime import datetime

import numpy as np

G_LOGGER = logging.getLogger("OSS")
G_LOGGER.DEBUG = logging.DEBUG
G_LOGGER.INFO = logging.INFO
G_LOGGER.WARNING = logging.WARNING
G_LOGGER.ERROR = logging.ERROR

formatter = logging.Formatter("[%(asctime)s][%(name)s][%(levelname)s] %(message)s")
stream = logging.StreamHandler()
stream.setFormatter(formatter)
G_LOGGER.addHandler(stream)

from typing import Callable, Dict, Union, List

import torch
import tensorrt as trt

from transformers import AutoModel, AutoTokenizer

"""TimingProfile(iterations: int, number: int, warmup: int, duration: int, percentile: int or [int])"""
TimingProfile = namedtuple("TimingProfile", ["iterations", "number", "warmup", "duration", "percentile"])


def use_cuda(func: Callable):
    """
    Tries to send all parameters of a given function to cuda device if user supports it.
    Object must have a "to(device: str)" and maps to target device "cuda"
    Basically, uses torch implementation.

    Wrapped functions musts have keyword argument "use_cuda: bool" which enables
    or disables toggling of cuda.
    """

    def _send_args_to_device(caller_kwargs, device):
        new_kwargs = {}
        for k, v in caller_kwargs.items():
            if getattr(v, "to", False):
                new_kwargs[k] = v.to(device)
            else:
                new_kwargs[k] = v
        return new_kwargs

    def wrapper(*args, **kwargs):
        caller_kwargs = inspect.getcallargs(func, *args, **kwargs)
        assert (
                "use_cuda" in caller_kwargs
        ), "Function must have 'use_cuda' as a parameter."

        if caller_kwargs["use_cuda"]:
            new_kwargs = {}
            used_cuda = False
            if torch.cuda.is_available() and caller_kwargs["use_cuda"]:
                new_kwargs = _send_args_to_device(caller_kwargs, "cuda")
                used_cuda = True
            else:
                new_kwargs = _send_args_to_device(caller_kwargs, "cpu")

            try:
                return func(**new_kwargs)
            except RuntimeError as e:
                # If a device has cuda installed but no compatible kernels, cuda.is_available() will still return True.
                # This exception is necessary to catch remaining incompat errors.
                if used_cuda:
                    G_LOGGER.warning("Unable to execute program using cuda compatible device: {}".format(e))
                    G_LOGGER.warning("Retrying using CPU only.")
                    new_kwargs = _send_args_to_device(caller_kwargs, "cpu")
                    new_kwargs["use_cuda"] = False
                    cpu_result = func(**new_kwargs)
                    G_LOGGER.warning("Successfully obtained result using CPU.")
                    return cpu_result
                else:
                    raise e
        else:
            return func(**caller_kwargs)

    return wrapper


def measure_python_inference_code(
        stmt: Union[Callable, str], timing_profile: TimingProfile
):
    """
    Measures the time it takes to run Pythonic inference code.
    Statement given should be the actual model inference like forward() in torch.

    Args:
        stmt (Union[Callable, str]): Callable or string for generating numbers.
        timing_profile (TimingProfile): The timing profile settings with the following fields.
            warmup (int): Number of iterations to run as warm-up before actual measurement cycles.
            number (int): Number of times to call function per iteration.
            iterations (int): Number of measurement cycles.
            duration (float): Minimal duration for measurement cycles.
            percentile (int or list of ints): key percentile number(s) for measurement.
    """

    def simple_percentile(data, p):
        """
        Temporary replacement for numpy.percentile() because TRT CI/CD pipeline requires additional packages to be added at boot up in this general_utils.py file.
        """
        assert p >= 0 and p <= 100, "Percentile must be between 1 and 99"

        rank = len(data) * p / 100
        if rank.is_integer():
            return sorted(data)[int(rank)]
        else:
            return sorted(data)[int(math.ceil(rank)) - 1]

    warmup = timing_profile.warmup
    number = timing_profile.number
    iterations = timing_profile.iterations
    duration = timing_profile.duration
    percentile = timing_profile.percentile

    G_LOGGER.debug(
        "Measuring inference call with warmup: {} and number: {} and iterations {} and duration {} secs".format(
            warmup, number, iterations, duration
        )
    )
    # Warmup
    warmup_mintime = timeit.repeat(stmt, number=number, repeat=warmup)
    G_LOGGER.debug("Warmup times: {}".format(warmup_mintime))

    # Actual measurement cycles
    results = []
    start_time = datetime.now()
    iter_idx = 0
    while iter_idx < iterations or (datetime.now() - start_time).total_seconds() < duration:
        iter_idx += 1
        results.append(timeit.timeit(stmt, number=number))

    print("Results length: ", len(results))

    if isinstance(percentile, int):
        return np.mean(results), simple_percentile(results, percentile) / number
    else:
        return np.mean(results), [simple_percentile(results, p) / number for p in percentile]


@use_cuda
def encoder_inference(t5_encoder, input_ids, attention_mask, token_type_ids, timing_profile, use_cuda=True):
    encoder_stmt = lambda: t5_encoder(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    avg_time, encoder_e2e_time = measure_python_inference_code(encoder_stmt, timing_profile)

    return (encoder_stmt(), avg_time, encoder_e2e_time)


def allocate_binding_buffer(types_dict, shapes_dict):
    '''
    Allocate binding buffers for trt based on provided types and shapes dict
    '''
    return {
        k: torch.zeros(reduce(lambda v, a: v * a, shape), dtype=types_dict[k]).cuda()
        for k, shape in shapes_dict.items()
    }


def get_alpha_dataset(path_to_dataset):
    bad_strs = 0

    with open(path_to_dataset, 'r') as tsvfile:
        tsvreader = csv.reader(tsvfile, delimiter='\t')
        train_texts = []
        train_labels = []

        test_texts = []
        test_labels = []

        for idx, row in enumerate(tsvreader):
            try:
                if idx == 0:
                    continue

                if row[2] == '0':
                    train_texts.append(row[0])
                    train_labels.append(row[1])
                else:
                    test_texts.append(row[0])
                    test_labels.append(row[1])
            except IndexError:
                bad_strs += 1

        print(f"Bad strings: {bad_strs}")

        return train_texts, train_labels, test_texts, test_labels


class TRTEngineFile:
    # get_network_definition can be overloaded to alter the network definition.
    # For example, this function can be used to change the precisions of ops or
    # data type of intermediate tensors.
    def get_network_definition(self, network_definition):
        return network_definition

    def __init__(
            self,
            model: str
    ):
        self.fpath = model
        self.max_trt_workspace = 3072

    def cleanup(self) -> None:
        G_LOGGER.debug("Removing saved engine model from location: {}".format(self.fpath))
        os.remove(self.fpath)


class TRTNativeRunner:
    """TRTNativeRunner avoids the high overheads with Polygraphy runner providing performance comparable to C++ implementation."""

    def __init__(self, trt_engine_file: TRTEngineFile):
        self.trt_engine_file = trt_engine_file
        self.trt_logger = trt.Logger(trt.Logger.INFO)

        # if G_LOGGER.level == G_LOGGER.DEBUG:
        #     self.trt_logger.min_severity = trt.Logger.VERBOSE
        # elif G_LOGGER.level == G_LOGGER.INFO:
        #     self.trt_logger.min_severity = trt.Logger.INFO
        # else:
        #     self.trt_logger.min_severity = trt.Logger.WARNING

        G_LOGGER.info("Reading and loading engine file {} using trt native runner.".format(self.trt_engine_file.fpath))
        with open(self.trt_engine_file.fpath, "rb") as f:
            self.trt_runtime = trt.Runtime(self.trt_logger)
            self.trt_engine = self.trt_runtime.deserialize_cuda_engine(f.read())
            self.trt_context = self.trt_engine.create_execution_context()

        # By default set optimization profile to 0
        self.profile_idx = 0

        # Other metadata required by the profile
        self._num_bindings_per_profile = self.trt_engine.num_bindings // self.trt_engine.num_optimization_profiles
        G_LOGGER.debug("Number of profiles detected in engine: {}".format(self._num_bindings_per_profile))

    def release(self):
        pass

    def get_optimization_profile(self, batch_size, sequence_length):
        """Provided helper function to obtain a profile optimization."""
        # Select an optimization profile
        # inspired by demo/BERT/inference.py script
        selected_profile_idx = None
        for idx in range(self.trt_engine.num_optimization_profiles):
            profile_shape = self.trt_engine.get_profile_shape(profile_index=idx,
                                                              binding=idx * self._num_bindings_per_profile)

            if profile_shape[0][0] <= batch_size and profile_shape[2][0] >= batch_size \
                    and profile_shape[0][1] <= sequence_length and profile_shape[2][1] >= sequence_length:
                G_LOGGER.info("Selected profile: {}".format(profile_shape))
                selected_profile_idx = idx
                break

        if selected_profile_idx == -1:
            raise RuntimeError(
                "Could not find any profile that matches batch_size={}, sequence_length={}".format(batch_size,
                                                                                                   sequence_length))

        return selected_profile_idx

    def __call__(self, *args, **kwargs):
        # self.trt_context.active_optimization_profile = self.profile_idx
        return self.forward(*args, **kwargs)


class TRTHFRunner(TRTNativeRunner):
    def _allocate_memory(self,
                         input_shapes: Dict[str, tuple],
                         input_types: Dict[str, torch.dtype],
                         output_shapes: Dict[str, tuple],
                         output_types: Dict[str, torch.dtype]):
        """Helper function for binding several inputs at once and pre-allocating the results."""
        # Allocate memories as 1D linear buffers for simpler handling of dynamic shapes.
        self.inputs = allocate_binding_buffer(input_types, input_shapes)
        self.outputs = allocate_binding_buffer(output_types, output_shapes)

        bindings = [None] * self.trt_engine.num_bindings

        for input_name, input_array in self.inputs.items():
            # Allocate memory for inputs
            input_idx = self.trt_engine.get_binding_index(input_name)
            self.trt_context.set_binding_shape(input_idx, input_shapes[input_name])
            bindings[input_idx] = input_array.data_ptr()

        assert self.trt_context.all_binding_shapes_specified

        for output_name, output_array in self.outputs.items():
            # Output shape should be allocated from context size
            output_idx = self.trt_engine.get_binding_index(output_name)
            bindings[output_idx] = output_array.data_ptr()

        return bindings

    def __init__(
            self,
            trt_engine_file: TRTEngineFile,
            batch_size: int = 1
    ):
        super().__init__(trt_engine_file)
        # self.config = hf_config
        self.batch_size = batch_size


class TRTEncoder(TRTHFRunner):
    """TRT implemented network interface that can be used to measure inference time."""

    def __init__(
            self,
            trt_engine_file: str,
            batch_size: int = 1,
            max_sequence_length: int = 128,
            hidden_size: int = 768
    ):
        super().__init__(TRTEngineFile(trt_engine_file), batch_size=batch_size)
        # In benchmarking mode, the max_sequence_length should be the designated input_profile_max_len
        self.max_sequence_length = max_sequence_length
        self.encoder_hidden_size = hidden_size
        self.main_input_name = "input_ids"
        # We only have one profile to select so we can just grab the profile at the start of the class
        self.profile_idx = self.get_optimization_profile(batch_size=self.batch_size, sequence_length=1)

        self.input_shapes = {
            "input_ids": (self.batch_size, self.max_sequence_length),
            "attention_mask": (self.batch_size, self.max_sequence_length),
            "token_type_ids": (self.batch_size, self.max_sequence_length)
        }
        self.input_types = {
            "input_ids": torch.int32,
            "attention_mask": torch.int32,
            "token_type_ids": torch.int32
        }
        self.output_shapes = {
            "last_hidden_state": (self.batch_size, self.max_sequence_length, self.encoder_hidden_size),
            "1607": (self.batch_size, self.encoder_hidden_size)
        }
        self.output_types = {
            "last_hidden_state": torch.float32,
            "1607": torch.float32
        }

        self.bindings = self._allocate_memory(self.input_shapes, self.input_types, self.output_shapes,
                                              self.output_types)

    def forward(self, input_ids, attention_mask, token_type_ids, *args, **kwargs):
        bs = self.batch_size
        input_length = input_ids.shape[1]
        encoder_hidden_size = self.encoder_hidden_size

        # Check if the input data is on CPU (which usually means the PyTorch does not support current GPU).
        is_cpu_mode = False
        if input_ids.device == torch.device("cpu") or token_type_ids.device == torch.device(
                "cpu") or attention_mask.device == torch.device("cpu"):
            is_cpu_mode = True

        # We allocate the buffers using max_length, but we only need to first portion of it, so copy the data into the
        # first portion of the input buffer.
        # TODO: Could we just reuse input_ids' data_ptr() as the first binding when input_ids is already contiguous to
        # avoid an additional D2D?
        if is_cpu_mode:
            # print("cpu mode")
            self.inputs["input_ids"] = input_ids.int().flatten().contiguous().cuda()
            self.inputs["attention_mask"] = attention_mask.int().flatten().contiguous().cuda()
            self.inputs["token_type_ids"] = token_type_ids.int().flatten().contiguous().cuda()

            self.bindings[0] = self.inputs["input_ids"].data_ptr()
            self.bindings[1] = self.inputs["attention_mask"].data_ptr()
            self.bindings[2] = self.inputs["token_type_ids"].data_ptr()
        else:
            # print("cuda mode")
            self.inputs["input_ids"][:bs * input_length] = input_ids.flatten()
            self.inputs["attention_mask"][:bs * input_length] = attention_mask.flatten()
            self.inputs["input_ids"][:bs * input_length] = input_ids.flatten()

        # Set the binding shape of input_ids, which should be (bs, input_length).
        self.trt_context.set_binding_shape(0, input_ids.shape)
        self.trt_context.set_binding_shape(1, attention_mask.shape)
        self.trt_context.set_binding_shape(2, input_ids.shape)

        # Launch TRT inference.
        self.trt_context.execute_v2(bindings=self.bindings)

        # We allocate the buffers using max_length, but we only need to first portion of it, so get only the first
        # portion of the output buffer and return that.
        # TODO: Could we construct a Torch tensor using given data_ptr() to avoid this D2D copy?
        hidden_states_output = self.outputs["last_hidden_state"]
        last_hidden_output = self.outputs["1607"]

        if is_cpu_mode:
            hidden_states_output = hidden_states_output.cpu()
            last_hidden_output = last_hidden_output.cpu()

        hidden_states_folded = hidden_states_output[:bs * input_length * encoder_hidden_size].view(bs, input_length,
                                                                                                   encoder_hidden_size)
        last_hidden_folded = last_hidden_output[:bs * encoder_hidden_size].view(bs, encoder_hidden_size)

        return hidden_states_folded, last_hidden_folded


class HFVectorizer:
    def __init__(self, model_path: str, max_seq_len: int, device: str):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModel.from_pretrained(model_path)

        self.device = device
        self.model.to(self.device)
        self.model.eval()

        self.max_seq_len = max_seq_len

    @torch.no_grad()
    def get_context_embeddings(self, text: List[str]) -> torch.Tensor:
        tokens = self.tokenizer(text, padding=True, truncation=True, max_length=self.max_seq_len, return_tensors='pt')

        model_output = self.model(**{k: v.to(self.device) for k, v in tokens.items()})

        return model_output[0][:, 0, :]


class TensorRTVectorizer:
    def __init__(self, model_path: str, engine_filename: str, batch_size: int, max_seq_len: int):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = TRTEncoder(engine_filename, batch_size=batch_size, max_sequence_length=max_seq_len)

        self.max_seq_len = max_seq_len

    @torch.no_grad()
    def get_context_embeddings(self, text: List[str]) -> torch.Tensor:
        tokens = self.tokenizer(text, padding=True, truncation=True, max_length=self.max_seq_len, return_tensors='pt')

        hidden_states_folded, last_hidden_folded = self.model(tokens["input_ids"].to("cuda"),
                                                              tokens["attention_mask"].to("cuda"),
                                                              tokens["token_type_ids"].to("cuda"))
        return hidden_states_folded[:, 0, :]


def run_inference_correctness():
    sample_text = "здравствуйте подскажите где подключается по умолчанию сбп"
    model_name = "DeepPavlov/rubert-base-cased"
    batch_size = 32
    seq_len = 128
    engine_filename = '/home/tensorrt_engines/python_test.engine'
    print(f'Running inference on engine {engine_filename}')

    hf_vectorizer = HFVectorizer(model_name, seq_len, "cuda")
    trt_vectorizer = TensorRTVectorizer(model_name, engine_filename, batch_size, seq_len)

    for i in range(100):
        hf_embedding = hf_vectorizer.get_context_embeddings([sample_text] * batch_size)
        trt_embedding = trt_vectorizer.get_context_embeddings([sample_text] * batch_size)

        if torch.isnan(hf_embedding).any():
            print(f"HF embedding is nan, idx: {i}")
            print(hf_embedding.cpu().numpy())
            break

        if torch.isnan(trt_embedding).any():
            print(f"TRT embedding is nan, idx: {i}")
            print(trt_embedding.cpu().numpy())
            break

    print((hf_embedding.cpu() - trt_embedding.cpu()).abs().mean())

    EPS = 3e-4
    assert (hf_embedding.cpu() - trt_embedding.cpu()).abs().mean() <= EPS


run_inference_correctness()

Output:

TRT embedding is nan, idx: 46
[[ 0.27062076 -0.3629676   0.06873146 ...  0.19171244  0.28416577
  -0.29272905]
 [ 0.1654504  -0.06139711 -0.02816715 ...  0.11053374  0.19584551
   0.4450955 ]
 [ 0.35021943  0.306062   -0.35293704 ... -0.46051547  0.21811077
   0.4337832 ]
 ...
 [ 1.4005156   0.52827865  0.86885476 ...  1.4299251  -0.33051533
   0.06044114]
 [ 0.46486786 -0.71743774  1.2329694  ...  0.7739747   0.44407886
  -1.1249343 ]
 [        nan         nan         nan ...         nan         nan
          nan]]
tensor(nan)

Metadata

Metadata

Assignees

Labels

triagedIssue has been triaged by maintainers

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions