Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,12 @@ def _nvshmem_scatter(input_tensor, indices, rank_mappings, num_output_rows):

num_elem = num_output_rows * num_features

# TODO: Look into using calloc here to avoid zeroing out the tensor
scattered_tensor = nvshmem.NVSHMEMP2P.allocate_symmetric_memory(
num_elem, device.index
).reshape((bs, num_output_rows, num_features))
scattered_tensor.zero_()

cur_rank = nvshmem.NVSHMEMP2P.get_rank()
indices = indices % num_output_rows
local_send_tensor = input_tensor[rank_mappings == cur_rank].unsqueeze(0)
Expand Down
9 changes: 6 additions & 3 deletions experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,14 @@ def _check_mpi():
print("Checking if MPI_HOME is set")

error_signal = False
if "MPI_HOME" not in os.environ:
print("ERROR: MPI_HOME is not set\n")

usual_MPI_envs_names = ["MPI_HOME", "MPI_ROOT", "MPICH_HOME"]
if not any(env_name in os.environ for env_name in usual_MPI_envs_names):
print("ERROR: One of MPI_HOME, MPI_ROOT, or MPICH_HOME is not set\n")
error_signal = True
else:
print(f"MPI_HOME: {os.environ['MPI_HOME']}\n")
mpi_env_name = [x for x in usual_MPI_envs_names if x in os.environ][0]
print(f"{mpi_env_name}: {os.environ[mpi_env_name]}\n")

return error_signal

Expand Down
8 changes: 6 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,20 @@
raise EnvironmentError("NVSHMEM_HOME must be set to build DGraph")

# TODO: Try to add the ability to input this path as an argument
if "MPI_HOME" not in os.environ:
usual_MPI_envs_names = ["MPI_HOME", "MPI_ROOT", "MPICH_HOME"]

if not any(env_name in os.environ for env_name in usual_MPI_envs_names):
raise EnvironmentError("MPI_HOME must be set to build DGraph")

mpi_env_name = [x for x in usual_MPI_envs_names if x in os.environ][0]

nvshmem_home = os.environ["NVSHMEM_HOME"]
# print(f"Found NVSHMEM_HOME: {nvshmem_home}")

nvshmem_include = os.path.join(nvshmem_home, "include")
nvshmem_lib = os.path.join(nvshmem_home, "lib")

mpi_home = os.environ["MPI_HOME"]
mpi_home = os.environ[mpi_env_name]
# print(f"Found MPI_HOME: {mpi_home}")
mpi_include = os.path.join(mpi_home, "include")
mpi_lib = os.path.join(mpi_home, "lib")
Expand Down
40 changes: 31 additions & 9 deletions tests/test_nvshmem_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,28 @@ def setup_scatter_data(init_nvshmem_backend):
torch.manual_seed(0)

num_features = 8

all_rank_input_data = torch.randn(1, 8, num_features)

all_edge_coo = torch.tensor([[0, 0, 0, 1, 1, 2, 2, 3], [1, 2, 3, 0, 3, 0, 3, 0]])
rank_mappings = torch.tensor([[0, 0, 0, 0, 0, 1, 1, 1], [0, 1, 1, 0, 1, 0, 1, 0]])

all_rank_output = torch.zeros(2, 4, num_features)
num_global_output_rows = 4
all_rank_output = torch.zeros(2, num_global_output_rows, num_features)

for k in range(2):
_indices = all_edge_coo[k].view(1, -1, 1).expand(1, -1, num_features)
output_data = torch.zeros_like(all_rank_output[[k]])
output_data.scatter_add_(1, _indices, all_rank_input_data)
all_rank_output[k] = output_data

return all_rank_input_data, all_edge_coo, rank_mappings, all_rank_output
return (
all_rank_input_data,
all_edge_coo,
rank_mappings,
all_rank_output,
num_global_output_rows,
)


def test_nvshmem_backend_init(init_nvshmem_backend):
Expand Down Expand Up @@ -131,12 +139,13 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data):
comm = init_nvshmem_backend
rank = comm.get_rank()
world_size = comm.get_world_size()
all_rank_input_data, all_edge_coo, rank_mappings, all_rank_output = (
setup_scatter_data
)

all_edge_coo = all_edge_coo.T
rank_mappings = rank_mappings.T
(
all_rank_input_data,
all_edge_coo,
rank_mappings,
all_rank_output,
num_global_output_rows,
) = setup_scatter_data

input_slice_start = (all_rank_input_data.shape[1] // world_size) * rank
input_slice_end = (all_rank_input_data.shape[1] // world_size) * (rank + 1)
Expand All @@ -147,7 +156,11 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data):
local_input_data_gt = all_rank_input_data[:, input_slice_start:input_slice_end, :]
local_edge_coo = all_edge_coo[:, edge_slice_start:edge_slice_end]
local_rank_mappings_gt = rank_mappings[:, edge_slice_start:edge_slice_end]
local_output_data_gt = all_rank_output[:, edge_slice_start:edge_slice_end, :]

output_slice_start = (num_global_output_rows // world_size) * rank
output_slice_end = (num_global_output_rows // world_size) * (rank + 1)
local_output_data_gt = all_rank_output[:, output_slice_start:output_slice_end, :]
num_output_rows = local_output_data_gt.shape[1]

for i in range(2):
local_indices_gt = local_edge_coo[[i], :]
Expand All @@ -161,3 +174,12 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data):

local_input_data = comm.get_local_rank_slice(all_rank_input_data, dim=1)
assert torch.allclose(local_input_data, local_input_data_gt)

scattered_tensor = comm.scatter(
local_input_data.cuda(),
local_indices.cuda(),
local_rank_mapping.cuda(),
num_output_rows,
)

assert torch.allclose(scattered_tensor, local_output_data_gt[[i]].cuda())