Skip to content
7 changes: 4 additions & 3 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -428,10 +428,11 @@ inline Tensor squeeze(const Tensor& x, ffi::Optional<ffi::Array<Integer>> opt_ax
if (val < 0) {
val += static_cast<int>(x->shape.size());
}
if (IsConstInt(x->shape[val])) {
ICHECK_EQ(GetConstInt(x->shape[val]), 1) << "Dimension " << val << " must have size 1";
// If a dimension is not 1, silently skip it (no-op).
bool is_const = IsConstInt(x->shape[val]);
if ((is_const && GetConstInt(x->shape[val]) == 1) || !is_const) {
axis_val.push_back(val);
}
axis_val.push_back(val);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1938,7 +1938,7 @@ def _squeeze(self, node: fx.Node) -> relax.Var:
valid_dims = []
for d in dim:
axis = d if d >= 0 else len(shape) + d
if axis < len(shape) and shape[axis] == 1:
if axis < len(shape):
valid_dims.append(d)
# If no valid dims, use None to squeeze all size-1 dimensions
dim = valid_dims if valid_dims else None
Expand Down
11 changes: 3 additions & 8 deletions src/relax/op/tensor/manipulate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1234,15 +1234,10 @@ StructInfo InferStructInfoSqueeze(const Call& call, const BlockBuilder& ctx) {
// Todo(relax-team): revisit here for better check on if the axis being squeezed has length 1.
// When `axis` is given, the dim lengths at the axes must be integer 1 when it is not symbolic
const auto* int_len = shape_value.value()[axes[i]].as<IntImmNode>();
if (int_len != nullptr && int_len->value != 1) {
ctx->ReportFatal(Diagnostic::Error(call)
<< "Squeeze expects the input tensor shape values at the given axis "
"positions to be all 1. However, the tensor shape at axis "
<< axes[i] << " is " << shape_value.value()[axes[i]]
<< " which is not 1. If it is symbolic, please use MatchCast to cast it "
"to 1 before doing Squeeze.");
// If a dimension is not 1, silently skip it (no-op), matching PyTorch behavior.
if ((int_len != nullptr && int_len->value == 1) || int_len == nullptr) {
axis_removal_mask[axes[i]] = true;
}
axis_removal_mask[axes[i]] = true;
}
} else {
// When `axis` is not defined, squeeze all unit-length dimensions.
Expand Down
19 changes: 18 additions & 1 deletion tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -5344,15 +5344,32 @@ def main(
input: R.Tensor((3, 1, 4, 1), dtype="float32")
) -> R.Tuple(R.Tensor((3, 4), dtype="float32")):
with R.dataflow():
lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, axis=[1, 3])
lv: R.Tensor((3, 4), dtype="float32") = R.squeeze(input, axis=[0, 1, 2, 3])
gv: R.Tuple(R.Tensor((3, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

class Squeeze3(Module):
def forward(self, input):
return input.squeeze(2)

@I.ir_module
class Expected3:
@R.function
def main(
inp_0: R.Tensor((3, 1, 4, 1), dtype="float32")
) -> R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")):
with R.dataflow():
lv: R.Tensor((3, 1, 4, 1), dtype="float32") = R.squeeze(inp_0, axis=[2])
gv: R.Tuple(R.Tensor((3, 1, 4, 1), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(3, 1, 4, 1, dtype=torch.float32),)

verify_model(Squeeze1(), example_args, {}, Expected1)
verify_model(Squeeze2(), example_args, {}, Expected2)
verify_model(Squeeze3(), example_args, {}, Expected3)


def test_stack():
Expand Down
18 changes: 13 additions & 5 deletions tests/python/relax/test_op_manipulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,11 +994,19 @@ def test_squeeze_infer_struct_info_axis_length_not_one():
x2 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
x3 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))

with pytest.raises(TVMError):
bb.normalize(relax.op.squeeze(x0, [0]))
_check_inference(bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo((3, 4), "float32"))
with pytest.raises(TVMError):
bb.normalize(relax.op.squeeze(x2, [0]))
# Squeeze concrete shape (2,3,4) at axis=0, but axis length 2 != 1, squeeze is no-op.
_check_inference(
bb, relax.op.squeeze(x0, [0]), relax.TensorStructInfo(shape=(2, 3, 4), dtype="float32")
)
# Squeeze symbolic shape (a,3,4) at axis=0, assuming a can achieve successful squeeze.
_check_inference(
bb, relax.op.squeeze(x1, [0]), relax.TensorStructInfo(shape=(3, 4), dtype="float32")
)
# Squeeze shape variable s0 (corresponding to (2,3,4)) at axis=0.
_check_inference(
bb, relax.op.squeeze(x2, [0]), relax.TensorStructInfo(shape=s0, dtype="float32")
)
# Squeeze shape variable s1 (a,3,4) at axis=0, assuming a can achieve successful squeeze.
_check_inference(bb, relax.op.squeeze(x3, [0]), relax.TensorStructInfo(dtype="float32", ndim=2))


Expand Down