Skip to content

Resharding set with alternate DIDx and DIDy #6001

@wujingyue

Description

@wujingyue

Repro: https://github.com/NVIDIA/Fuser/tree/bug6001

diff --git a/tests/python/multidevice/test_sharding.py b/tests/python/multidevice/test_sharding.py
index 5745fbf9..2b4cfd4b 100644
--- a/tests/python/multidevice/test_sharding.py
+++ b/tests/python/multidevice/test_sharding.py
@@ -154,6 +154,40 @@ class TestShardTensor:
             ].flatten(),
         )

+    @pytest.mark.mpi
+    def test_2d_alltoall(self, multidevice_test):
+        d = multidevice_test.size
+        dx = 2
+        if d % dx != 0:
+            pytest.skip(f"Number of devices ({d=}) must be divisible by {dx=}")
+        dy = d // dx
+        assert dx == dy
+
+        with nvfuser.FusionDefinition() as fd:
+            inp_tv = fd.define_tensor([-1, -1])
+            out_tv = fd.ops.set(inp_tv)
+            fd.add_output(out_tv)
+
+            mesh = nvfuser.multidevice.DeviceMesh(torch.arange(d).reshape(dy, dx))
+            inp_tv.set_device_mesh(mesh)
+            inp_tv.outer_split(1, dx)
+            inp_tv.axis(1).parallelize(nvfuser.ParallelType.mesh_x)
+            inp_tv.outer_split(0, dy)
+            inp_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_y)
+            out_tv.set_device_mesh(mesh)
+            out_tv.outer_split(1, dy)
+            out_tv.axis(1).parallelize(nvfuser.ParallelType.mesh_y)
+            out_tv.outer_split(0, dx)
+            out_tv.axis(0).parallelize(nvfuser.ParallelType.mesh_x)
+
+        rows_per_rank, cols_per_rank = 3, 5
+        inp_ref = torch.testing.make_tensor(
+            rows_per_rank * dy, cols_per_rank * dx, dtype=torch.float, device="cpu"
+        )
+
+        inp = multidevice_test.shard_tensor(inp_ref, inp_tv)
+        fd.execute([inp])
+
     @pytest.mark.mpi
     def test_context_and_tensor_parallel(self, multidevice_test):
         d = multidevice_test.size
E       RuntimeError:  INTERNAL ASSERT FAILED at /opt/pytorch/nvfuser/csrc/host_ir/lower_to_communication.cpp:474, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues.
E       Expected !communication_info.has_value() . Expected at most one sharding change in `e`: T1_g_float[ideviceIdx.x10{2}, iS11{( ceilDiv(i0, 2) )}, ideviceIdx.y8{2}, iS9{( ceilDiv(i1, 2) )}] (DeviceMesh{{0 1}{2 3}})
E          = Set( T0_g_float[ideviceIdx.y6{2}, iS7{( ceilDiv(i0, 2) )}, ideviceIdx.x4{2}, iS5{( ceilDiv(i1, 2) )}] (DeviceMesh{{0 1}{2 3}}), cache_op=Streaming )
E       , but got: CommunicationInfo(AllToAll, p_sharded_id=iS1{i1}, c_sharded_id=iS2{i0}) and CommunicationInfo(AllToAll, p_sharded_id=iS0{i0}, c_sharded_id=iS3{i1})

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions