Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix][Frontend][TOPI] minor bugs #8622

Merged
merged 3 commits into from Aug 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/topi/detail/ravel_unravel.h
Expand Up @@ -44,7 +44,9 @@ using namespace tvm::te;
*/
inline PrimExpr RavelIndex(Array<PrimExpr> indices, Array<PrimExpr> shape) {
ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
ICHECK_GT(indices.size(), 0) << "indices must not be empty";
if (indices.size() == 0U) {
return 0;
}
PrimExpr idx;
for (size_t i = 0; i < indices.size(); ++i) {
if (i == 0) {
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Expand Up @@ -1445,7 +1445,16 @@ def linear(self, inputs, input_types):
# 0 - input
# 1 - weight
bias = inputs[2]
mm_out = self.matmul(inputs[:2], input_types[:2])
a_shape = self.infer_shape_with_prelude(inputs[0])
b_shape = self.infer_shape_with_prelude(inputs[1])
if len(a_shape) == 2 and len(b_shape) == 2:
mm_out = _op.nn.dense(inputs[0], inputs[1])
elif len(b_shape) == 1:
mm_out = self.matmul([inputs[0], inputs[1]], input_types[:2])
else:
mm_out = self.matmul(
[inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2]
)
if isinstance(bias, _expr.Expr):
bias_ndims = len(self.infer_shape_with_prelude(bias))
if bias_ndims == 1:
Expand Down
6 changes: 6 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Expand Up @@ -1569,8 +1569,10 @@ def forward(self, input, weight):
return F.linear(input, weight)

input2d = torch.rand([2, 2]).float()
input3d = torch.rand([4, 3, 2]).float()
weight1d = torch.rand([2]).float()
weight2d = torch.rand([2, 2]).float()
weight3x2 = torch.rand([3, 2]).float()
bias1d = torch.rand([2]).float()
bias2d = torch.rand([2, 2]).float()
# 2D input, 2D weight, 1D bias
Expand All @@ -1579,9 +1581,12 @@ def forward(self, input, weight):
verify_model(Linear(), input_data=[input2d, weight2d, bias2d])
# 2D input, 2D weight, no bias
verify_model(LinearNoBias(), input_data=[input2d, weight2d])
verify_model(LinearNoBias(), input_data=[input2d, weight3x2])
# 2D input, 1D weight, 1D bias is not supported by torch.linear()
# 2D input, 1D weight, no bias
verify_model(LinearNoBias(), input_data=[input2d, weight1d])
# 3D input, 2D weight, no bias
verify_model(LinearNoBias(), input_data=[input3d, weight3x2])
# TODO: Add the following cases when matmul(1D, _) is supported by TVM
# 1D input, 2D weight, 1D bias
# 1D input, 2D weight, no bias
Expand Down Expand Up @@ -3981,6 +3986,7 @@ def forward(self, x):
test_forward_logsoftmax()
test_forward_sigmoid()
test_forward_dense()
test_forward_linear()
test_forward_avgpool1d()
test_forward_avgpool2d()
test_forward_avgpool3d()
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/test_op_level3.py
Expand Up @@ -293,6 +293,7 @@ def verify_reshape(shape, newshape, oshape):
verify_reshape((2, 3, 4), (-3, -2), (6, 4))
verify_reshape((2, 3, 4), (-4, 1, 2, -2), (1, 2, 3, 4))
verify_reshape((2, 3, 4), (2, -4, -1, 3, -2), (2, 1, 3, 4))
verify_reshape((1,), (), ())


def test_reshape_fail():
Expand Down