Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MultiModal] Fusion model inference acceleration with TensorRT #2836

Merged
merged 2 commits into from
Mar 3, 2023

Conversation

liangfu
Copy link
Collaborator

@liangfu liangfu commented Feb 4, 2023

Issue #, if available:

Description of changes:

This PR brings support to accelerate fusion models with tensorrt.

For a fusion model trained with petfinder dataset, TensorrtExecutionProvider can boost inference speed up to 2.7x faster, comparing to GPU-based realtime prediction with PyTorch.

"petfinder",
["numerical_mlp", "categorical_mlp", "timm_image", "hf_text", "fusion_mlp"],
"google/electra-small-discriminator",
"mobilenetv3_small_100"

image

By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice.

@liangfu liangfu changed the title [DRAFT][MultiModal] Initial support to TensorRT [Post v0.7][MultiModal] Initial support to TensorRT Feb 6, 2023
@liangfu liangfu force-pushed the trt-1 branch 2 times, most recently from d3c0ab1 to c49b91a Compare February 9, 2023 22:57
@liangfu liangfu added the model list checked You have updated the model list after modifying multimodal unit tests/docs label Feb 9, 2023
@liangfu liangfu added this to the 0.7 Fast-Follow Items milestone Feb 10, 2023
@liangfu liangfu marked this pull request as ready for review February 24, 2023 17:30
@liangfu liangfu changed the title [Post v0.7][MultiModal] Initial support to TensorRT [MultiModal] Fusion model inference acceleration with TensorRT Feb 24, 2023
@github-actions
Copy link

Job PR-2836-99121a5 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2836/99121a5/index.html

@github-actions
Copy link

Job PR-2836-1d05c3e is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2836/1d05c3e/index.html

@tonyhoo
Copy link
Collaborator

tonyhoo commented Feb 25, 2023

Thanks for the change @liangfu. Do you know why preprocessing logics are generally slower in the TRT execution env?

Copy link
Collaborator

@tonyhoo tonyhoo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some accuracy metrics as well just to make sure accuracy doesn't compromise post TRT transformation

if not onnx_path:
onnx_path = os.path.join(self.path, default_onnx_path)

device_type = "cuda" if torch.cuda.is_available() else "cpu"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why previously we use GPU for onnx_export

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually used cpu instead, see deleted comment below.

To quote

# Perform tracing on cpu, since we're facing an error when tracing with cuda device:
#     ERROR: Tensor-valued Constant nodes differed in value across invocations.
#     This often indicates that the tracer has encountered untraceable code.
#     Comparison exception:   The values for attribute 'shape' do not match: torch.Size([]) != torch.Size([384]).
#     from https://github.com/rwightman/pytorch-image-models/blob/3aa31f537d5fbf6be8f1aaf5a36f6bbb4a55a726/timm/models/swin_transformer.py#L112
device = "cpu"
num_gpus = 0

@liangfu
Copy link
Collaborator Author

liangfu commented Feb 28, 2023

Do you know why preprocessing logics are generally slower in the TRT execution env?

Good question, I'm also interested to see why this is happening. Therefore, I did a breakdown of the call stack

image

Seems like data_processor is a bit slower when using TRT.

pure_model = model.module if isinstance(model, nn.DataParallel) else model
if isinstance(pure_model, OnnxModule):
for k in batch:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What possible types can batch[k].dtype be when code runs here? Can we say it is always in [torch.float32, torch.int32, torch.int64]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a special handling for converting the data type in token_ids. I'm not sure why the token_ids are in int32, but the inputs are required to be int64.

@github-actions
Copy link

Job PR-2836-6cb5a9e is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2836/6cb5a9e/index.html

Comment on lines +24 to +25
elif "valid_length" in k or k.startswith("numerical") or k.startswith("timm_image"):
dynamic_axes[k] = {0: "batch_size"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shape of images in timm_image's input should be (b, n, c, h, w), where both b and n may be dynamic.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would suggest avoid adding image size into dynamic dimensions, because

  1. adding an extra dynamic dimension would increase the complexity of model compilation
  2. the compiled model with too many dynamic dimensions could be suboptimal in terms of performance

The question becomes do we really need to support dynamic shape in image data? If yes, what is the lower bound and upper bound of the image size.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n is the number of images per sample. (c, h, w) are the fixed shape of one image.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In a pretrained setting, n should be a fixed number, since input DataFrame should have same number of image columns, isn't it?

Comment on lines +10 to +16
try:
import tensorrt # Unused but required by TensorrtExecutionProvider
except:
logger.warning(
"Failed to import tensorrt package. "
"onnxruntime would fallback to CUDAExecutionProvider instead of using TensorrtExecutionProvider."
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it easy to install tensorrt for users? If not, consider making this a lazy import to avoid unnecessary warnings for users who don't need it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making it a lazy import can also reduce predictor's init time.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it easy to install tensorrt for users?

Yes, it's just pip install tensorrt.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tried lazy import, but didn't work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's kind of weird why lazy import doesn't work. It's fine to keep it here for now. Maybe later dive into the reason.

logger.info("Loading ONNX file from path {}...".format(onnx_path))
onnx_model = onnx.load(onnx_path)

trt_module = OnnxModule(onnx_model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The returned module may not use tensorrt. Maybe moving the warning of importing tensorrt inside OnnxModule is better?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to move importing tensorrt inside OnnxModule, but that didn't work. (onnxruntime.InferenceSession doesn't compile to use TensorrtExecutionProvider)

@liangfu liangfu force-pushed the trt-1 branch 3 times, most recently from 73dc079 to 1e75f6d Compare March 1, 2023 00:04
@github-actions
Copy link

github-actions bot commented Mar 1, 2023

Job PR-2836-1e75f6d is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2836/1e75f6d/index.html

@github-actions
Copy link

github-actions bot commented Mar 1, 2023

Job PR-2836-9c8a787 is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2836/9c8a787/index.html

@github-actions
Copy link

github-actions bot commented Mar 2, 2023

Job PR-2836-d37b0ad is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2836/d37b0ad/index.html


input_dict = {k: args[i].cpu().numpy() for i, k in enumerate(self.input_names)}
onnx_outputs = self.sess.run(self.output_names, input_dict)
onnx_outputs = onnx_outputs[:3]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need only the first 3 model outputs? Don't we need them all?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comes with the undetermined size of output dict for a fusion_mlp model, which contains a tuple of (features, logits, multimodal_logits). The variable multimodal_logits contains the logit output from all modalities.

Don't we need them all?

Good question. In short, multimodal_logits are only used for computing loss, not used for inference.

Specifically, get_output_dict() would take outputs from forward(), but in onnxruntime outputs are flatten, which means the list of tensors in multimodal_logits would be merged into other outputs (e.g. features, logits). We kind of lost the information about where did the extra tensors come from.

# Prediction with default predictor
y_pred = predictor.predict(test_df)

trt_module = predictor.export_tensorrt(path=model_path, data=tail_df, batch_size=batch_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, export_onnx and export_tensorrt either returns a path or a module. When a user calls export_something(), do users expect return something? Or users just expect the model is saved to disk? If users want to use the saved model, they probably want to load it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are excellent questions.

There are several different usecases:
U1: Some users might expect an ONNX file to be exported, so that they can use the ONNX file where ever they want.
U2: Some users might expect faster inference time with ONNX. They actually don't care much about the details on where is the ONNX file.

In terms of U1, I think the expected output could be either the location of the ONNX file, or the ONNX model itself. This is the way how export_onnx is defined.

In terms of U2, we should be able to generate a drop-in replacement of torch.nn.Module, so that integration with existing inference flow would be easy. This is the way how export_tensorrt is defined.

But I would agree, we should have a better name for these public APIs.

batch[key] = inp.to(device, dtype=dtype)
else:
batch[key] = inp.to(device)
self._model.to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need self._model.to(device) here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so. We need to ensure the model parameters are moved to CPU before tracing.

@github-actions
Copy link

github-actions bot commented Mar 2, 2023

Job PR-2836-dc7fc9b is done.
Docs are uploaded to http://autogluon-staging.s3-website-us-west-2.amazonaws.com/PR-2836/dc7fc9b/index.html

Copy link
Contributor

@zhiqiangdon zhiqiangdon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for supporting Tensorrt!

@liangfu liangfu merged commit 398bbdd into autogluon:master Mar 3, 2023
@liangfu liangfu deleted the trt-1 branch March 3, 2023 19:00
@Innixma Innixma modified the milestones: 0.7 Fast-Follow Items, 0.8 Release May 23, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
model list checked You have updated the model list after modifying multimodal unit tests/docs
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants