Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support aten::type_as in the pytorch frontend #5787

Merged
merged 2 commits into from Jun 13, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Expand Up @@ -1598,6 +1598,14 @@ def _impl(inputs, input_types):
return _impl


def _type_as():
def _impl(inputs, input_types):
assert len(inputs) == 2
assert len(input_types) == 2
return _op.cast(inputs[0], _convert_data_type(input_types[1]))
return _impl


def _add(prelude):
# add_ is overloaded for tensor add and list concat
def _impl(inputs, input_types):
Expand Down Expand Up @@ -1902,6 +1910,7 @@ def _get_convert_map(prelude):
"aten::stack" : _tensor_array_stack(prelude),
"aten::__getitem__" : _list_getitem(prelude),
"aten::len" : _list_len(prelude),
"aten::type_as" : _type_as(),
}
return convert_map

Expand Down
37 changes: 37 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Expand Up @@ -27,6 +27,7 @@

from tvm import relay
from tvm.contrib import graph_runtime
from tvm.contrib.nvcc import have_fp16
from tvm.relay.testing.config import ctx_list


Expand Down Expand Up @@ -836,6 +837,41 @@ def forward(self, *args):
input_data = torch.rand(input_shape).float()
verify_model(Size1().float().eval(), input_data=input_data)


def test_type_as():
torch.set_grad_enabled(False)
input_shape = [1, 3]

def _create_module(dtype):
class TypeAs(Module):
def forward(self, *args):
expected_type_tensor = torch.zeros(1, 3, dtype=dtype)
return args[0].type_as(expected_type_tensor)

return TypeAs()

input_data = torch.randn(input_shape).float()
verify_model(_create_module(torch.float64), input_data=input_data)
verify_model(_create_module(torch.float32), input_data=input_data)
verify_model(_create_module(torch.int64), input_data=input_data)
verify_model(_create_module(torch.int32), input_data=input_data)
verify_model(_create_module(torch.int16), input_data=input_data)
verify_model(_create_module(torch.int8), input_data=input_data)

if torch.cuda.is_available():
check_fp16 = False
try:
# Only check half precision on supported hardwares.
if have_fp16(tvm.gpu(0).compute_version):
check_fp16 = True
except Exception as e:
# If GPU is not enabled in TVM, skip the fp16 test.
pass

if check_fp16:
verify_model(_create_module(torch.float16), input_data=input_data)


def test_forward_view():
torch.set_grad_enabled(False)
input_shape = [1, 3, 10, 10]
Expand Down Expand Up @@ -2460,6 +2496,7 @@ def test_forward_pretrained_bert_base_uncased():
test_upsample()
test_forward_upsample3d()
test_to()
test_type_as()
test_forward_functional_pad()
test_forward_zero_pad2d()
test_forward_constant_pad1d()
Expand Down