diff --git a/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py b/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py index 9251448..8598907 100644 --- a/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py +++ b/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py @@ -55,9 +55,12 @@ def _nvshmem_scatter(input_tensor, indices, rank_mappings, num_output_rows): num_elem = num_output_rows * num_features + # TODO: Look into using calloc here to avoid zeroing out the tensor scattered_tensor = nvshmem.NVSHMEMP2P.allocate_symmetric_memory( num_elem, device.index ).reshape((bs, num_output_rows, num_features)) + scattered_tensor.zero_() + cur_rank = nvshmem.NVSHMEMP2P.get_rank() indices = indices % num_output_rows local_send_tensor = input_tensor[rank_mappings == cur_rank].unsqueeze(0) diff --git a/experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py b/experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py index c08fa4e..5de78c0 100644 --- a/experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py +++ b/experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py @@ -33,11 +33,14 @@ def _check_mpi(): print("Checking if MPI_HOME is set") error_signal = False - if "MPI_HOME" not in os.environ: - print("ERROR: MPI_HOME is not set\n") + + usual_MPI_envs_names = ["MPI_HOME", "MPI_ROOT", "MPICH_HOME"] + if not any(env_name in os.environ for env_name in usual_MPI_envs_names): + print("ERROR: One of MPI_HOME, MPI_ROOT, or MPICH_HOME is not set\n") error_signal = True else: - print(f"MPI_HOME: {os.environ['MPI_HOME']}\n") + mpi_env_name = [x for x in usual_MPI_envs_names if x in os.environ][0] + print(f"{mpi_env_name}: {os.environ[mpi_env_name]}\n") return error_signal diff --git a/setup.py b/setup.py index a2b9c88..d2120c9 100644 --- a/setup.py +++ b/setup.py @@ -44,16 +44,20 @@ raise EnvironmentError("NVSHMEM_HOME must be set to build DGraph") # TODO: Try to add the ability to input this path as an argument - if "MPI_HOME" not in os.environ: + usual_MPI_envs_names = ["MPI_HOME", "MPI_ROOT", "MPICH_HOME"] + + if not any(env_name in os.environ for env_name in usual_MPI_envs_names): raise EnvironmentError("MPI_HOME must be set to build DGraph") + mpi_env_name = [x for x in usual_MPI_envs_names if x in os.environ][0] + nvshmem_home = os.environ["NVSHMEM_HOME"] # print(f"Found NVSHMEM_HOME: {nvshmem_home}") nvshmem_include = os.path.join(nvshmem_home, "include") nvshmem_lib = os.path.join(nvshmem_home, "lib") - mpi_home = os.environ["MPI_HOME"] + mpi_home = os.environ[mpi_env_name] # print(f"Found MPI_HOME: {mpi_home}") mpi_include = os.path.join(mpi_home, "include") mpi_lib = os.path.join(mpi_home, "lib") diff --git a/tests/test_nvshmem_backend.py b/tests/test_nvshmem_backend.py index d6ea866..37b747f 100644 --- a/tests/test_nvshmem_backend.py +++ b/tests/test_nvshmem_backend.py @@ -61,12 +61,14 @@ def setup_scatter_data(init_nvshmem_backend): torch.manual_seed(0) num_features = 8 + all_rank_input_data = torch.randn(1, 8, num_features) all_edge_coo = torch.tensor([[0, 0, 0, 1, 1, 2, 2, 3], [1, 2, 3, 0, 3, 0, 3, 0]]) rank_mappings = torch.tensor([[0, 0, 0, 0, 0, 1, 1, 1], [0, 1, 1, 0, 1, 0, 1, 0]]) - all_rank_output = torch.zeros(2, 4, num_features) + num_global_output_rows = 4 + all_rank_output = torch.zeros(2, num_global_output_rows, num_features) for k in range(2): _indices = all_edge_coo[k].view(1, -1, 1).expand(1, -1, num_features) @@ -74,7 +76,13 @@ def setup_scatter_data(init_nvshmem_backend): output_data.scatter_add_(1, _indices, all_rank_input_data) all_rank_output[k] = output_data - return all_rank_input_data, all_edge_coo, rank_mappings, all_rank_output + return ( + all_rank_input_data, + all_edge_coo, + rank_mappings, + all_rank_output, + num_global_output_rows, + ) def test_nvshmem_backend_init(init_nvshmem_backend): @@ -131,12 +139,13 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data): comm = init_nvshmem_backend rank = comm.get_rank() world_size = comm.get_world_size() - all_rank_input_data, all_edge_coo, rank_mappings, all_rank_output = ( - setup_scatter_data - ) - - all_edge_coo = all_edge_coo.T - rank_mappings = rank_mappings.T + ( + all_rank_input_data, + all_edge_coo, + rank_mappings, + all_rank_output, + num_global_output_rows, + ) = setup_scatter_data input_slice_start = (all_rank_input_data.shape[1] // world_size) * rank input_slice_end = (all_rank_input_data.shape[1] // world_size) * (rank + 1) @@ -147,7 +156,11 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data): local_input_data_gt = all_rank_input_data[:, input_slice_start:input_slice_end, :] local_edge_coo = all_edge_coo[:, edge_slice_start:edge_slice_end] local_rank_mappings_gt = rank_mappings[:, edge_slice_start:edge_slice_end] - local_output_data_gt = all_rank_output[:, edge_slice_start:edge_slice_end, :] + + output_slice_start = (num_global_output_rows // world_size) * rank + output_slice_end = (num_global_output_rows // world_size) * (rank + 1) + local_output_data_gt = all_rank_output[:, output_slice_start:output_slice_end, :] + num_output_rows = local_output_data_gt.shape[1] for i in range(2): local_indices_gt = local_edge_coo[[i], :] @@ -161,3 +174,12 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data): local_input_data = comm.get_local_rank_slice(all_rank_input_data, dim=1) assert torch.allclose(local_input_data, local_input_data_gt) + + scattered_tensor = comm.scatter( + local_input_data.cuda(), + local_indices.cuda(), + local_rank_mapping.cuda(), + num_output_rows, + ) + + assert torch.allclose(scattered_tensor, local_output_data_gt[[i]].cuda())