-
Notifications
You must be signed in to change notification settings - Fork 669
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A whole lot of extra operations (including interpolate, 3d convs/deconvs/pooling) #249
base: master
Are you sure you want to change the base?
Conversation
added converter / tests for torch.Tensor.__getitem__
fixed Linear for N, *, H case
…rg_fix fixed permute for list args
…pool2d added support for adaptive_max_pool2d using regular pooling
Thanks for sharing! Currently the unchecked TensorRT 7 requirement is a blocker for merging, since many of the projects using torch2trt still use TensorRT 5. That said, I'd love to support these converters if we can add version checking and ensure all relevant test cases pass for TensorRT 5,6,7. Best, |
@@ -20,7 +20,7 @@ def run(self): | |||
inputs_conversion += (torch.zeros(shape).to(self.device).type(self.dtype), ) | |||
|
|||
# convert module | |||
module_trt = torch2trt(module, inputs_conversion, **self.torch2trt_kwargs) | |||
module_trt = torch2trt(module, inputs_conversion, max_workspace_size=1 << 20, **self.torch2trt_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_workspace_size can be provided as argument to torch2trt_kwargs where relevant when using @add_module_test
I found a problem when converting pytorch's interpolate to trt's resize layer. The result just does not match. I tried your code and got the same result. I printed the input and outputs: |
Fix for interpolate bilinear mode problem will come with TRT 7.1 |
@SrivastavaKshitij tks 👍 |
This is probably not merge-able generally as (some?) of it I think might use TensorRT 7 features, should there be a branch for TensorRT 7 and try to get torch2trt to support as many operations as possible?