Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions tests/python/multidevice/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

from enum import auto, Enum


class Parallelism(Enum):
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism
TENSOR_PARALLEL = auto()
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#sequence-parallelism
SEQUENCE_PARALLEL = auto()
8 changes: 0 additions & 8 deletions tests/python/multidevice/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,6 @@
# SPDX-License-Identifier: BSD-3-Clause

import torch
from enum import auto, Enum


class Parallelism(Enum):
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#tensor-parallelism
TENSOR_PARALLEL = auto()
# https://docs.nvidia.com/nemo-framework/user-guide/latest/nemotoolkit/features/parallelisms.html#sequence-parallelism
SEQUENCE_PARALLEL = auto()


def get_benchmark_fn(func, /, profile: bool):
Expand Down
12 changes: 7 additions & 5 deletions tests/python/multidevice/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import torch
import torch.distributed as dist
from dataclasses import dataclass
from enum import auto, Enum
from functools import lru_cache
from fusion_definition_wrapper import FusionDefinitionWrapper
from nvfuser_direct import DataType, FusionDefinition
from typing import Iterable

import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Placement
from typing import Iterable

from nvfuser_direct import DataType, FusionDefinition
from .fusion_definition_wrapper import FusionDefinitionWrapper


@dataclass(frozen=True)
Expand Down
18 changes: 11 additions & 7 deletions tests/python/multidevice/test_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
# Run command:
# mpirun -np 1 pytest tests/python/multidevice/test_deepseek_v3.py --only-mpi -s

import pytest
import transformers
import torch
import torch.distributed as dist
from contextlib import contextmanager
from enum import Enum, auto
from functools import wraps
from linear import TensorParallelLinear
from benchmark_utils import get_benchmark_fns
from typing import Optional

import pytest

import torch
import torch.distributed as dist
from torch.distributed.tensor import DTensor
from torch.distributed.tensor.parallel import (
parallelize_module,
Expand All @@ -23,7 +23,11 @@
ColwiseParallel,
)
from torch.distributed.tensor.placement_types import Shard
from typing import Optional

import transformers

from .benchmark_utils import get_benchmark_fns
from .linear import TensorParallelLinear


@contextmanager
Expand Down
8 changes: 5 additions & 3 deletions tests/python/multidevice/test_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
# SPDX-License-Identifier: BSD-3-Clause

import pytest

import torch
import torch.distributed as dist
from fusion_definition_wrapper import FusionDefinitionWrapper
from linear import LinearFunction
from nvfuser_direct import FusionDefinition, DataType

from torch.distributed.tensor import DTensor
from torch.distributed.tensor.placement_types import Shard, Replicate
from .fusion_definition_wrapper import FusionDefinitionWrapper
from .linear import LinearFunction
from nvfuser_direct import FusionDefinition, DataType


@pytest.mark.mpi
Expand Down
2 changes: 1 addition & 1 deletion tests/python/multidevice/test_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torch.distributed.tensor import distribute_tensor, Shard

import nvfuser_direct as nvfuser
from .benchmark_utils import get_benchmark_fns
from nvfuser_direct import DataType, FusionDefinition, CommunicatorBackend, TensorView
from benchmark_utils import get_benchmark_fns


@pytest.mark.mpi
Expand Down
3 changes: 2 additions & 1 deletion tests/python/multidevice/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import torch.nn.functional as F

import nvfuser_direct as nvfuser
from . import Parallelism
from .benchmark_utils import get_benchmark_fns
from nvfuser_direct import DataType, FusionDefinition
from python.direct_utils import (
create_sdpa_rng_tensors,
is_pre_ampere,
)
from benchmark_utils import get_benchmark_fns, Parallelism


@pytest.mark.mpi
Expand Down
6 changes: 5 additions & 1 deletion tests/python/multidevice/test_transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@
# SPDX-License-Identifier: BSD-3-Clause

import pytest

import torch
import torch.distributed as dist

import transformer_engine.pytorch as te
from benchmark_utils import get_benchmark_fns, Parallelism

from . import Parallelism
from .benchmark_utils import get_benchmark_fns
from enum import auto, Enum

compute_cap = torch.cuda.get_device_capability()
Expand Down