Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions thunder/executors/nvfuserex_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -3228,6 +3228,7 @@ def _grouped_mm_transform(


register_supported(prims._grouped_mm, _grouped_mm_transform, _grouped_mm_check)
register_supported(DTensorPrimIDs._GROUPED_MM, _grouped_mm_transform, _grouped_mm_check)


def _cumsum_check(a: TensorProxy, dim: int, /, dtype: dtypes.dtype | None = None) -> bool:
Expand Down
6 changes: 6 additions & 0 deletions thunder/tests/distributed/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,13 @@ def _run(cls, rank, test_name, file_name, pipe, *, fake_pg=False):

local_rank = self.rank % torch.cuda.device_count()
torch.cuda.set_device(local_rank)

# nvFuser Multi-GPU expects these environment variables to be set
os.environ["LOCAL_RANK"] = str(local_rank)
# We only have single node tests, so `LOCAL_WORLD_SIZE` is the same as `WORLD_SIZE`
os.environ["LOCAL_WORLD_SIZE"] = str(self.world_size)
os.environ["RANK"] = str(self.rank)
os.environ["WORLD_SIZE"] = str(self.world_size)

torch.distributed.barrier()
try:
Expand Down
53 changes: 52 additions & 1 deletion thunder/tests/distributed/test_dtensor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import unittest
from itertools import product
from collections.abc import Sequence
from looseversion import LooseVersion

import pytest
import torch
Expand All @@ -12,7 +13,7 @@

from thunder.tests.distributed.helper import DistributedParallelTestCase
from torch.distributed._tensor import DeviceMesh, distribute_tensor
from torch.distributed.tensor.placement_types import Shard
from torch.distributed.tensor.placement_types import Shard, Replicate
from torch.testing._internal.distributed._tensor.common_dtensor import DTensorConverter

from torch.testing._internal import common_utils
Expand Down Expand Up @@ -249,6 +250,56 @@ def fn(x):

torch.testing.assert_close(actual, expected)

@common_utils.parametrize("executor", tuple(executors_map.keys()))
@common_utils.parametrize(
"input_shardings",
[
(
[
Shard(
-1,
)
],
[
Shard(1),
],
[Replicate()],
),
],
Comment on lines +255 to +268
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: what's the type of input_shardings? Tuple of two list of Shard/Replicates?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is

[
  ([Shard(-1)], [Shard(1)], [Replicate()]),
]

NOTE: Each elements of the tuple is Sequence[Placement] as expected by distribute_tensor

Doc: https://docs.pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.distribute_tensor

)
def test_dtensor_grouped_mm(self, executor, input_shardings):
if LooseVersion(torch.__version__) < "2.8":
raise unittest.SkipTest("test_dtensor_grouped_mm: torch._grouped_mm is not available in torch < 2.8")

num_devices = self.world_size
mesh = DeviceMesh("cuda", list(range(num_devices)))

if (torch.cuda.get_device_capability() < (9, 0)) and executor == "torch":
raise unittest.SkipTest(
"test_dtensor_grouped_mm: torch._grouped_mm doesn't support device capability < 9.0"
)

M = 16
N = 64
K = 32
G = 2

inp_shard, w_shard, offsets_shard = input_shardings
in_dtensor = distribute_tensor(torch.randn(M, K, requires_grad=False, dtype=torch.bfloat16), mesh, inp_shard)
w_dtensor = distribute_tensor(torch.randn(G, K, N, requires_grad=False, dtype=torch.bfloat16), mesh, w_shard)
offsets_dtensor = distribute_tensor(torch.tensor([0, 16], dtype=torch.int32), mesh, offsets_shard)

tfn = thunder.jit(torch._grouped_mm, executors=executors_map[executor].executors_list())

tfn(in_dtensor, w_dtensor, offsets_dtensor)

trcs = thunder.last_traces(tfn)
init_trc = trcs[0]

from thunder.torch.experimental.dtensor_torch_and_prims import dtensor_grouped_mm

assert any(bsym.sym == dtensor_grouped_mm for bsym in init_trc.bound_symbols)

@common_utils.parametrize(
"op, executor",
product(dtensor_supported_opinfos, tuple(executors_map.keys())),
Expand Down
32 changes: 32 additions & 0 deletions thunder/torch/experimental/dtensor_torch_and_prims.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from functools import partial
from collections.abc import Callable
from enum import auto, Enum
from looseversion import LooseVersion

from thunder.torch import torchsymbol, TensorLike, register_function
import thunder.torch as ltorch
Expand Down Expand Up @@ -36,6 +37,7 @@ class DTensorPrimIDs(Enum):
RESHAPE = auto()
CONVERT_ELEMENT_TYPE = auto()
BROADCAST_IN_DIM = auto()
_GROUPED_MM = auto()
EXP = auto()
LINEAR = auto()
NEG = auto()
Expand Down Expand Up @@ -363,10 +365,40 @@ def dtensor_reciprocal(a: TensorLike) -> TensorLike:
)


if LooseVersion(torch.__version__) >= "2.8":

def dtensor_grouped_mm_meta(a, b, offsets):
output = run_with_fake_tensor(torch._grouped_mm, a, b, offsets)
local_tensor_proxy = TensorProxy(
like=a.local_tensor, dtype=dtypes.to_dtype(output._local_tensor.dtype), shape=output._local_tensor.shape
)
spec = output._spec
spec_proxy = AnyProxy(spec, history=a.history)
return create_dtensor_proxy_from_proxies(local_tensor_proxy, spec_proxy, False)

dtensor_grouped_mm_prim = make_prim(
DTensorPrimIDs._GROUPED_MM, "dtensor_grouped_mm_prim", meta=dtensor_grouped_mm_meta
)

dtensor_grouped_mm_prim_impl = pytorchex.register_operator(
"dtensor_grouped_mm_prim", like=dtensor_grouped_mm_prim, fn=torch._grouped_mm
)

pytorchex.register_implementation(dtensor_grouped_mm_prim, dtensor_grouped_mm_prim_impl)

@dtensor_torchsymbol(torch._grouped_mm, id="dtensor.torch._grouped_mm")
def dtensor_grouped_mm(a: TensorLike, b: TensorLike, offsets: TensorLike, *, bias=None, dtype=None) -> TensorLike:
assert bias is None, "bias is not supported"
assert dtype is None, "dtype is not supported"
return dtensor_grouped_mm_prim(a, b, offsets)


def register_dtensor_torch_and_prims():
register_function_for_dtensor(torch.mul, ltorch.mul, dtensor_mul, is_method=True)
register_function_for_dtensor(torch.reshape, ltorch.reshape, dtensor_reshape, is_method=True)
register_function_for_dtensor(torch.nn.functional.linear, ltorch.linear, dtensor_linear, is_method=False)
register_function_for_dtensor(torch.exp, ltorch.exp, dtensor_exp, is_method=True)
register_function_for_dtensor(torch.neg, ltorch.neg, dtensor_neg, is_method=True)
register_function_for_dtensor(torch.reciprocal, ltorch.reciprocal, dtensor_reciprocal, is_method=True)
if LooseVersion(torch.__version__) >= "2.8":
register_function_for_dtensor(torch._grouped_mm, ltorch._grouped_mm, dtensor_grouped_mm, is_method=False)
Loading