diff --git a/thunder/executors/nvfuserex_impl.py b/thunder/executors/nvfuserex_impl.py index 8cc0addca9..6a22bb7bd6 100644 --- a/thunder/executors/nvfuserex_impl.py +++ b/thunder/executors/nvfuserex_impl.py @@ -3228,6 +3228,7 @@ def _grouped_mm_transform( register_supported(prims._grouped_mm, _grouped_mm_transform, _grouped_mm_check) +register_supported(DTensorPrimIDs._GROUPED_MM, _grouped_mm_transform, _grouped_mm_check) def _cumsum_check(a: TensorProxy, dim: int, /, dtype: dtypes.dtype | None = None) -> bool: diff --git a/thunder/tests/distributed/helper.py b/thunder/tests/distributed/helper.py index cab9b19e1d..24a9460eff 100644 --- a/thunder/tests/distributed/helper.py +++ b/thunder/tests/distributed/helper.py @@ -187,7 +187,13 @@ def _run(cls, rank, test_name, file_name, pipe, *, fake_pg=False): local_rank = self.rank % torch.cuda.device_count() torch.cuda.set_device(local_rank) + + # nvFuser Multi-GPU expects these environment variables to be set os.environ["LOCAL_RANK"] = str(local_rank) + # We only have single node tests, so `LOCAL_WORLD_SIZE` is the same as `WORLD_SIZE` + os.environ["LOCAL_WORLD_SIZE"] = str(self.world_size) + os.environ["RANK"] = str(self.rank) + os.environ["WORLD_SIZE"] = str(self.world_size) torch.distributed.barrier() try: diff --git a/thunder/tests/distributed/test_dtensor.py b/thunder/tests/distributed/test_dtensor.py index b010644a47..aee4b0f187 100644 --- a/thunder/tests/distributed/test_dtensor.py +++ b/thunder/tests/distributed/test_dtensor.py @@ -1,6 +1,7 @@ import unittest from itertools import product from collections.abc import Sequence +from looseversion import LooseVersion import pytest import torch @@ -12,7 +13,7 @@ from thunder.tests.distributed.helper import DistributedParallelTestCase from torch.distributed._tensor import DeviceMesh, distribute_tensor -from torch.distributed.tensor.placement_types import Shard +from torch.distributed.tensor.placement_types import Shard, Replicate from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter from torch.testing._internal import common_utils @@ -249,6 +250,56 @@ def fn(x): torch.testing.assert_close(actual, expected) + @common_utils.parametrize("executor", tuple(executors_map.keys())) + @common_utils.parametrize( + "input_shardings", + [ + ( + [ + Shard( + -1, + ) + ], + [ + Shard(1), + ], + [Replicate()], + ), + ], + ) + def test_dtensor_grouped_mm(self, executor, input_shardings): + if LooseVersion(torch.__version__) < "2.8": + raise unittest.SkipTest("test_dtensor_grouped_mm: torch._grouped_mm is not available in torch < 2.8") + + num_devices = self.world_size + mesh = DeviceMesh("cuda", list(range(num_devices))) + + if (torch.cuda.get_device_capability() < (9, 0)) and executor == "torch": + raise unittest.SkipTest( + "test_dtensor_grouped_mm: torch._grouped_mm doesn't support device capability < 9.0" + ) + + M = 16 + N = 64 + K = 32 + G = 2 + + inp_shard, w_shard, offsets_shard = input_shardings + in_dtensor = distribute_tensor(torch.randn(M, K, requires_grad=False, dtype=torch.bfloat16), mesh, inp_shard) + w_dtensor = distribute_tensor(torch.randn(G, K, N, requires_grad=False, dtype=torch.bfloat16), mesh, w_shard) + offsets_dtensor = distribute_tensor(torch.tensor([0, 16], dtype=torch.int32), mesh, offsets_shard) + + tfn = thunder.jit(torch._grouped_mm, executors=executors_map[executor].executors_list()) + + tfn(in_dtensor, w_dtensor, offsets_dtensor) + + trcs = thunder.last_traces(tfn) + init_trc = trcs[0] + + from thunder.torch.experimental.dtensor_torch_and_prims import dtensor_grouped_mm + + assert any(bsym.sym == dtensor_grouped_mm for bsym in init_trc.bound_symbols) + @common_utils.parametrize( "op, executor", product(dtensor_supported_opinfos, tuple(executors_map.keys())), diff --git a/thunder/torch/experimental/dtensor_torch_and_prims.py b/thunder/torch/experimental/dtensor_torch_and_prims.py index 00202c62bb..4df6de39ff 100644 --- a/thunder/torch/experimental/dtensor_torch_and_prims.py +++ b/thunder/torch/experimental/dtensor_torch_and_prims.py @@ -1,6 +1,7 @@ from functools import partial from collections.abc import Callable from enum import auto, Enum +from looseversion import LooseVersion from thunder.torch import torchsymbol, TensorLike, register_function import thunder.torch as ltorch @@ -36,6 +37,7 @@ class DTensorPrimIDs(Enum): RESHAPE = auto() CONVERT_ELEMENT_TYPE = auto() BROADCAST_IN_DIM = auto() + _GROUPED_MM = auto() EXP = auto() LINEAR = auto() NEG = auto() @@ -363,6 +365,34 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike: ) +if LooseVersion(torch.__version__) >= "2.8": + + def dtensor_grouped_mm_meta(a, b, offsets): + output = run_with_fake_tensor(torch._grouped_mm, a, b, offsets) + local_tensor_proxy = TensorProxy( + like=a.local_tensor, dtype=dtypes.to_dtype(output._local_tensor.dtype), shape=output._local_tensor.shape + ) + spec = output._spec + spec_proxy = AnyProxy(spec, history=a.history) + return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False) + + dtensor_grouped_mm_prim = make_prim( + DTensorPrimIDs._GROUPED_MM, "dtensor_grouped_mm_prim", meta=dtensor_grouped_mm_meta + ) + + dtensor_grouped_mm_prim_impl = pytorchex.register_operator( + "dtensor_grouped_mm_prim", like=dtensor_grouped_mm_prim, fn=torch._grouped_mm + ) + + pytorchex.register_implementation(dtensor_grouped_mm_prim, dtensor_grouped_mm_prim_impl) + + @dtensor_torchsymbol(torch._grouped_mm, id="dtensor.torch._grouped_mm") + def dtensor_grouped_mm(a: TensorLike, b: TensorLike, offsets: TensorLike, *, bias=None, dtype=None) -> TensorLike: + assert bias is None, "bias is not supported" + assert dtype is None, "dtype is not supported" + return dtensor_grouped_mm_prim(a, b, offsets) + + def register_dtensor_torch_and_prims(): register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True) register_function_for_dtensor(torch.reshape, ltorch.reshape, dtensor_reshape, is_method=True) @@ -370,3 +400,5 @@ def register_dtensor_torch_and_prims(): register_function_for_dtensor(torch.exp, ltorch.exp, dtensor_exp, is_method=True) register_function_for_dtensor(torch.neg, ltorch.neg, dtensor_neg, is_method=True) register_function_for_dtensor(torch.reciprocal, ltorch.reciprocal, dtensor_reciprocal, is_method=True) + if LooseVersion(torch.__version__) >= "2.8": + register_function_for_dtensor(torch._grouped_mm, ltorch._grouped_mm, dtensor_grouped_mm, is_method=False)