diff --git a/oneflow/python/nn/modules/where.py b/oneflow/python/nn/modules/where.py index 6785bb8b399..51f42e07b21 100644 --- a/oneflow/python/nn/modules/where.py +++ b/oneflow/python/nn/modules/where.py @@ -42,6 +42,11 @@ def forward(self, condition, x, y): condition.device.type == x.device.type and condition.device.type == y.device.type ) + + assert len(condition.shape) == len(x.shape) and len(condition.shape) == len( + y.shape + ), f"The dim of where module's inputs can not match, please check!" + broadcast_cond = condition broadcast_x = x broadcast_y = y