Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] Fix dtype handling for modules with integer parameters #6311

Merged
merged 3 commits into from
Aug 21, 2020

Conversation

masahi
Copy link
Member

@masahi masahi commented Aug 20, 2020

This fixes the interesting typing problem raised in #6300.

In Torchscript, parameters of operations like conv are always accessed via prim::GetAttr nodes. For example, when we see

 %input : Float(1, 6, 4) = aten::_convolution(%input.1, %6, %7, %9, %11, %13, %14, %16, %17, %18, %19, %20)

The input weight %6 is accessed in this way:

%3 : Module = prim::GetAttr[name="conv"](%self.1)
%6 : Tensor = prim::GetAttr[name="weight"](%3)

The problem is, Torch cannot figure out the correct type of GetAttr nodes. So when we visit aten::_convolution to get its input types, we have to assume that this is an untyped tensor and get annoying warnings Untyped Tensor found, assume it is float.

This hasn't been a big issue so far because parameters are usually float anyways. But #6300 brought a use case where there are integer parameters as well as float ones. Changing default_dtype to int doesn't solve it.

So I added a workaround when we try to get the dtype of GetAttr nodes. For every GetAttr node there is a corresponding parameter from the original PyTorch module with known, correct dtype. And for every PyTorch parameter tensor we have a corresponding Relay Var (via convert_params(...) function). So inside _get_input_types when we find GetAttr node, we return the dtype of corresponding Relay Var, instead of returning default_type.

This is the solution I came up with minimal change. It fixes the problem, but I feel it is a bit hacky. Please let me know if there are better ways to handle this. The test case I added is a minimal reproduction of the issue raised in #6300.

As a bonus, there would be no more annoying warnings Untyped Tensor found... when working with traced models.

please review @siju-samuel @t-vi

Copy link
Member

@siju-samuel siju-samuel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
WARNING:root:Untyped Tensor found, assume it is float32 -> This annoying warning is gone. Thanks.

@siju-samuel siju-samuel merged commit 470dfc3 into apache:master Aug 21, 2020
@siju-samuel
Copy link
Member

Thanks @masahi. This PR is merged.

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
…e#6311)

* return the correct type for GetAttr node

* keep _get_pytorch_value_type intact

* add test and handle quantized param
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
…e#6311)

* return the correct type for GetAttr node

* keep _get_pytorch_value_type intact

* add test and handle quantized param
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Aug 26, 2020
…e#6311)

* return the correct type for GetAttr node

* keep _get_pytorch_value_type intact

* add test and handle quantized param
electriclilies pushed a commit to electriclilies/tvm that referenced this pull request Aug 26, 2020
…e#6311)

* return the correct type for GetAttr node

* keep _get_pytorch_value_type intact

* add test and handle quantized param
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Sep 2, 2020
…e#6311)

* return the correct type for GetAttr node

* keep _get_pytorch_value_type intact

* add test and handle quantized param
trevor-m pushed a commit to neo-ai/tvm that referenced this pull request Sep 3, 2020
…e#6311)

* return the correct type for GetAttr node

* keep _get_pytorch_value_type intact

* add test and handle quantized param
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants