diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 2d7096613bdc..ef4830a46adf 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -428,10 +428,11 @@ inline Tensor squeeze(const Tensor& x, ffi::Optional> opt_ax if (val < 0) { val += static_cast(x->shape.size()); } - if (IsConstInt(x->shape[val])) { - ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1"; + // If a dimension is not 1, silently skip it (no-op). + bool is_const = IsConstInt(x->shape[val]); + if ((is_const && GetConstInt(x->shape[val]) == 1) || !is_const) { + axis_val.push_back(val); } - axis_val.push_back(val); } } diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 753b0d791495..6a07fe86f28b 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1938,7 +1938,7 @@ def _squeeze(self, node: fx.Node) -> relax.Var: valid_dims = [] for d in dim: axis = d if d >= 0 else len(shape) + d - if axis < len(shape) and shape[axis] == 1: + if axis < len(shape): valid_dims.append(d) # If no valid dims, use None to squeeze all size-1 dimensions dim = valid_dims if valid_dims else None diff --git a/src/relax/op/tensor/manipulate.cc b/src/relax/op/tensor/manipulate.cc index 79c0687cada5..ff71dd26c201 100644 --- a/src/relax/op/tensor/manipulate.cc +++ b/src/relax/op/tensor/manipulate.cc @@ -1234,15 +1234,10 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) { // Todo(relax-team): revisit here for better check on if the axis being squeezed has length 1. // When `axis` is given, the dim lengths at the axes must be integer 1 when it is not symbolic const auto* int_len = shape_value.value()[axes[i]].as(); - if (int_len != nullptr && int_len->value != 1) { - ctx->ReportFatal(Diagnostic::Error(call) - << "Squeeze expects the input tensor shape values at the given axis " - "positions to be all 1. However, the tensor shape at axis " - << axes[i] << " is " << shape_value.value()[axes[i]] - << " which is not 1. If it is symbolic, please use MatchCast to cast it " - "to 1 before doing Squeeze."); + // If a dimension is not 1, silently skip it (no-op), matching PyTorch behavior. + if ((int_len != nullptr && int_len->value == 1) || int_len == nullptr) { + axis_removal_mask[axes[i]] = true; } - axis_removal_mask[axes[i]] = true; } } else { // When `axis` is not defined, squeeze all unit-length dimensions. diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 1429dec5e731..b34f600be4b5 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -5344,15 +5344,32 @@ def main( input: R.Tensor((3, 1, 4, 1), dtype="float32") ) -> R.Tuple(R.Tensor((3, 4), dtype="float32")): with R.dataflow(): - lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, axis=[1, 3]) + lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, axis=[0, 1, 2, 3]) gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,) R.output(gv) return gv + class Squeeze3(Module): + def forward(self, input): + return input.squeeze(2) + + @I.ir_module + class Expected3: + @R.function + def main( + inp_0: R.Tensor((3, 1, 4, 1), dtype="float32") + ) -> R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((3, 1, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[2]) + gv: R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),) verify_model(Squeeze1(), example_args, {}, Expected1) verify_model(Squeeze2(), example_args, {}, Expected2) + verify_model(Squeeze3(), example_args, {}, Expected3) def test_stack(): diff --git a/tests/python/relax/test_op_manipulate.py b/tests/python/relax/test_op_manipulate.py index 004c4b9618a0..d39584e06ba8 100644 --- a/tests/python/relax/test_op_manipulate.py +++ b/tests/python/relax/test_op_manipulate.py @@ -994,11 +994,19 @@ def test_squeeze_infer_struct_info_axis_length_not_one(): x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32")) - with pytest.raises(TVMError): - bb.normalize(relax.op.squeeze(x0, [0])) - _check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 4), "float32")) - with pytest.raises(TVMError): - bb.normalize(relax.op.squeeze(x2, [0])) + # Squeeze concrete shape (2,3,4) at axis=0, but axis length 2 != 1, squeeze is no-op. + _check_inference( + bb, relax.op.squeeze(x0, [0]), relax.TensorStructInfo(shape=(2, 3, 4), dtype="float32") + ) + # Squeeze symbolic shape (a,3,4) at axis=0, assuming a can achieve successful squeeze. + _check_inference( + bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo(shape=(3, 4), dtype="float32") + ) + # Squeeze shape variable s0 (corresponding to (2,3,4)) at axis=0. + _check_inference( + bb, relax.op.squeeze(x2, [0]), relax.TensorStructInfo(shape=s0, dtype="float32") + ) + # Squeeze shape variable s1 (a,3,4) at axis=0, assuming a can achieve successful squeeze. _check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2))