-
Notifications
You must be signed in to change notification settings - Fork 79
Open
Labels
Description
Repro: https://github.com/NVIDIA/Fuser/tree/bug6001
diff --git a/tests/python/multidevice/test_sharding.py b/tests/python/multidevice/test_sharding.py
index 5745fbf9..2b4cfd4b 100644
--- a/tests/python/multidevice/test_sharding.py
+++ b/tests/python/multidevice/test_sharding.py
@@ -154,6 +154,40 @@ class TestShardTensor:
].flatten(),
)
+ @pytest.mark.mpi
+ def test_2d_alltoall(self, multidevice_test):
+ d = multidevice_test.size
+ dx = 2
+ if d % dx != 0:
+ pytest.skip(f"Number of devices ({d=}) must be divisible by {dx=}")
+ dy = d // dx
+ assert dx == dy
+
+ with nvfuser.FusionDefinition() as fd:
+ inp_tv = fd.define_tensor([-1, -1])
+ out_tv = fd.ops.set(inp_tv)
+ fd.add_output(out_tv)
+
+ mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d).reshape(dy, dx))
+ inp_tv.set_device_mesh(mesh)
+ inp_tv.outer_split(1, dx)
+ inp_tv.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
+ inp_tv.outer_split(0, dy)
+ inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_y)
+ out_tv.set_device_mesh(mesh)
+ out_tv.outer_split(1, dy)
+ out_tv.axis(1).parallelize(nvfuser.ParallelType.mesh_y)
+ out_tv.outer_split(0, dx)
+ out_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x)
+
+ rows_per_rank, cols_per_rank = 3, 5
+ inp_ref = torch.testing.make_tensor(
+ rows_per_rank * dy, cols_per_rank * dx, dtype=torch.float, device="cpu"
+ )
+
+ inp = multidevice_test.shard_tensor(inp_ref, inp_tv)
+ fd.execute([inp])
+
@pytest.mark.mpi
def test_context_and_tensor_parallel(self, multidevice_test):
d = multidevice_test.size
E RuntimeError: INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/host_ir/lower_to_communication.cpp:474, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues.
E Expected !communication_info.has_value() . Expected at most one sharding change in `e`: T1_g_float[ideviceIdx.x10{2}, iS11{( ceilDiv(i0, 2) )}, ideviceIdx.y8{2}, iS9{( ceilDiv(i1, 2) )}] (DeviceMesh{{0 1}{2 3}})
E = Set( T0_g_float[ideviceIdx.y6{2}, iS7{( ceilDiv(i0, 2) )}, ideviceIdx.x4{2}, iS5{( ceilDiv(i1, 2) )}] (DeviceMesh{{0 1}{2 3}}), cache_op=Streaming )
E , but got: CommunicationInfo(AllToAll, p_sharded_id=iS1{i1}, c_sharded_id=iS2{i0}) and CommunicationInfo(AllToAll, p_sharded_id=iS0{i0}, c_sharded_id=iS3{i1})
Reactions are currently unavailable