From b31dfbaf1e30368c9d08f89e53aeac9061dfccb2 Mon Sep 17 00:00:00 2001 From: huanmei9 Date: Tue, 18 Jul 2023 22:07:46 +0800 Subject: [PATCH 1/2] add view_as op --- python/tvm/relay/frontend/pytorch.py | 15 +++++++++++++ tests/python/frontend/pytorch/test_forward.py | 21 +++++++++++++++++++ 2 files changed, 36 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 5e4d75599613..2883d7ca5e9a 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1622,6 +1622,20 @@ def view(self, inputs, input_types): return _op.transform.reshape(data, new_shape) + + def view_as(self, inputs, input_types): + data = inputs[0] + tensors = inputs[1] + + if not isinstance(tensors, (_expr.Call, _expr.Constant)): + msg = f"Data type {type(tensors)} could not be parsed in view_as op" + raise AssertionError(msg) + + shape = self.infer_shape(tensors) + + return _op.transform.reshape(data, shape) + + def reshape(self, inputs, input_types): data = inputs[0] new_shape = inputs[1] @@ -3838,6 +3852,7 @@ def create_convert_map(self): "aten::addmm": self.addmm, "aten::size": self.size, "aten::view": self.view, + "aten::view_as": self.view_as, "aten::reshape": self.reshape, "aten::reshape_as": self.reshape_as, "aten::clone": self.clone, diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 83930d1ea80b..32c27113a63a 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1660,6 +1660,27 @@ def forward(self, *args): verify_model(View3().float().eval(), input_data=input_data) +@tvm.testing.uses_gpu +def test_forward_view_as(): + """test_forward_view_as""" + torch.set_grad_enabled(False) + input_shape = [1, 3, 10] + + class ViewAs1(Module): + def forward(self, *args): + t1 = torch.ones((1 * 3 * 10)) + return args[0].view_as(t1) + + class ViewAs2(Module): + def forward(self, *args): + t1 = torch.rand(1 * 3 * 10).float() + return args[0].view_as(t1) + + input_data = torch.rand(input_shape).float() + verify_model(ViewAs1().float().eval(), input_data=input_data) + verify_model(ViewAs2().float().eval(), input_data=input_data) + + @tvm.testing.uses_gpu def test_forward_select(): """test_forward_select""" From 7819725e0a10039ea2693c0d37303a7c6fc763bd Mon Sep 17 00:00:00 2001 From: meihuan Date: Tue, 18 Jul 2023 20:39:39 +0800 Subject: [PATCH 2/2] fix lint, fix test case --- python/tvm/relay/frontend/pytorch.py | 2 -- tests/python/frontend/pytorch/test_forward.py | 6 +++--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 2883d7ca5e9a..8e36a749046e 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1622,7 +1622,6 @@ def view(self, inputs, input_types): return _op.transform.reshape(data, new_shape) - def view_as(self, inputs, input_types): data = inputs[0] tensors = inputs[1] @@ -1635,7 +1634,6 @@ def view_as(self, inputs, input_types): return _op.transform.reshape(data, shape) - def reshape(self, inputs, input_types): data = inputs[0] new_shape = inputs[1] diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 32c27113a63a..fbe211189d9e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1673,12 +1673,12 @@ def forward(self, *args): class ViewAs2(Module): def forward(self, *args): - t1 = torch.rand(1 * 3 * 10).float() - return args[0].view_as(t1) + return args[0].view_as(args[1]) input_data = torch.rand(input_shape).float() + tensor = torch.rand(1 * 3 * 10).float() verify_model(ViewAs1().float().eval(), input_data=input_data) - verify_model(ViewAs2().float().eval(), input_data=input_data) + verify_model(ViewAs2().float().eval(), input_data=[input_data, tensor]) @tvm.testing.uses_gpu