From 86f29f1a3fdc4b92dfc6f02fc67339a3fb75be36 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 23 Dec 2024 09:33:37 -0800 Subject: [PATCH 1/3] Re-enables NVSHMEM Scatter tests on tester - Also adds checks for new environment variables when installing --- experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py | 9 ++++++--- setup.py | 8 ++++++-- tests/test_nvshmem_backend.py | 5 +++++ 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py b/experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py index c08fa4e..5de78c0 100644 --- a/experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py +++ b/experiments/NVSHMEM-Enabled-DGRAPH/PreInstallCheck.py @@ -33,11 +33,14 @@ def _check_mpi(): print("Checking if MPI_HOME is set") error_signal = False - if "MPI_HOME" not in os.environ: - print("ERROR: MPI_HOME is not set\n") + + usual_MPI_envs_names = ["MPI_HOME", "MPI_ROOT", "MPICH_HOME"] + if not any(env_name in os.environ for env_name in usual_MPI_envs_names): + print("ERROR: One of MPI_HOME, MPI_ROOT, or MPICH_HOME is not set\n") error_signal = True else: - print(f"MPI_HOME: {os.environ['MPI_HOME']}\n") + mpi_env_name = [x for x in usual_MPI_envs_names if x in os.environ][0] + print(f"{mpi_env_name}: {os.environ[mpi_env_name]}\n") return error_signal diff --git a/setup.py b/setup.py index a2b9c88..d2120c9 100644 --- a/setup.py +++ b/setup.py @@ -44,16 +44,20 @@ raise EnvironmentError("NVSHMEM_HOME must be set to build DGraph") # TODO: Try to add the ability to input this path as an argument - if "MPI_HOME" not in os.environ: + usual_MPI_envs_names = ["MPI_HOME", "MPI_ROOT", "MPICH_HOME"] + + if not any(env_name in os.environ for env_name in usual_MPI_envs_names): raise EnvironmentError("MPI_HOME must be set to build DGraph") + mpi_env_name = [x for x in usual_MPI_envs_names if x in os.environ][0] + nvshmem_home = os.environ["NVSHMEM_HOME"] # print(f"Found NVSHMEM_HOME: {nvshmem_home}") nvshmem_include = os.path.join(nvshmem_home, "include") nvshmem_lib = os.path.join(nvshmem_home, "lib") - mpi_home = os.environ["MPI_HOME"] + mpi_home = os.environ[mpi_env_name] # print(f"Found MPI_HOME: {mpi_home}") mpi_include = os.path.join(mpi_home, "include") mpi_lib = os.path.join(mpi_home, "lib") diff --git a/tests/test_nvshmem_backend.py b/tests/test_nvshmem_backend.py index d6ea866..93a2961 100644 --- a/tests/test_nvshmem_backend.py +++ b/tests/test_nvshmem_backend.py @@ -161,3 +161,8 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data): local_input_data = comm.get_local_rank_slice(all_rank_input_data, dim=1) assert torch.allclose(local_input_data, local_input_data_gt) + + scattered_tensor = comm.scatter( + local_input_data.cuda(), local_indices.cuda(), local_rank_mapping.cuda() + ) + assert torch.allclose(scattered_tensor, local_output_data_gt[[i]].cuda()) From fc5779c80beda0e9399f8c187cf191660689f2d4 Mon Sep 17 00:00:00 2001 From: M Shehtab Zaman Date: Mon, 23 Dec 2024 21:52:59 -0800 Subject: [PATCH 2/3] Fixing some issues with nvshmem scatter data set up --- tests/test_nvshmem_backend.py | 39 ++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/tests/test_nvshmem_backend.py b/tests/test_nvshmem_backend.py index 93a2961..1a0df93 100644 --- a/tests/test_nvshmem_backend.py +++ b/tests/test_nvshmem_backend.py @@ -60,13 +60,17 @@ def setup_scatter_data(init_nvshmem_backend): torch.manual_seed(0) + torch.manual_seed(0) + num_features = 8 + all_rank_input_data = torch.randn(1, 8, num_features) all_edge_coo = torch.tensor([[0, 0, 0, 1, 1, 2, 2, 3], [1, 2, 3, 0, 3, 0, 3, 0]]) rank_mappings = torch.tensor([[0, 0, 0, 0, 0, 1, 1, 1], [0, 1, 1, 0, 1, 0, 1, 0]]) - all_rank_output = torch.zeros(2, 4, num_features) + num_global_output_rows = 4 + all_rank_output = torch.zeros(2, num_global_output_rows, num_features) for k in range(2): _indices = all_edge_coo[k].view(1, -1, 1).expand(1, -1, num_features) @@ -74,7 +78,13 @@ def setup_scatter_data(init_nvshmem_backend): output_data.scatter_add_(1, _indices, all_rank_input_data) all_rank_output[k] = output_data - return all_rank_input_data, all_edge_coo, rank_mappings, all_rank_output + return ( + all_rank_input_data, + all_edge_coo, + rank_mappings, + all_rank_output, + num_global_output_rows, + ) def test_nvshmem_backend_init(init_nvshmem_backend): @@ -131,12 +141,13 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data): comm = init_nvshmem_backend rank = comm.get_rank() world_size = comm.get_world_size() - all_rank_input_data, all_edge_coo, rank_mappings, all_rank_output = ( - setup_scatter_data - ) - - all_edge_coo = all_edge_coo.T - rank_mappings = rank_mappings.T + ( + all_rank_input_data, + all_edge_coo, + rank_mappings, + all_rank_output, + num_global_output_rows, + ) = setup_scatter_data input_slice_start = (all_rank_input_data.shape[1] // world_size) * rank input_slice_end = (all_rank_input_data.shape[1] // world_size) * (rank + 1) @@ -147,7 +158,12 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data): local_input_data_gt = all_rank_input_data[:, input_slice_start:input_slice_end, :] local_edge_coo = all_edge_coo[:, edge_slice_start:edge_slice_end] local_rank_mappings_gt = rank_mappings[:, edge_slice_start:edge_slice_end] - local_output_data_gt = all_rank_output[:, edge_slice_start:edge_slice_end, :] + + output_slice_start = (num_global_output_rows // world_size) * rank + output_slice_end = (num_global_output_rows // world_size) * (rank + 1) + local_output_data_gt = all_rank_output[:, output_slice_start:output_slice_end, :] + num_output_rows = local_output_data_gt.shape[1] + num_output_rows = local_output_data_gt.shape[1] for i in range(2): local_indices_gt = local_edge_coo[[i], :] @@ -163,6 +179,9 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data): assert torch.allclose(local_input_data, local_input_data_gt) scattered_tensor = comm.scatter( - local_input_data.cuda(), local_indices.cuda(), local_rank_mapping.cuda() + local_input_data.cuda(), + local_indices.cuda(), + local_rank_mapping.cuda(), + num_output_rows, ) assert torch.allclose(scattered_tensor, local_output_data_gt[[i]].cuda()) From 609690b46ca749d88846c698163ce5f6fff4a398 Mon Sep 17 00:00:00 2001 From: Shehtab Zaman Date: Mon, 23 Dec 2024 22:41:22 -0800 Subject: [PATCH 3/3] Fix issue with non-zeroed out tensor init on shared memory --- DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py | 3 +++ tests/test_nvshmem_backend.py | 4 +--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py b/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py index 9251448..8598907 100644 --- a/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py +++ b/DGraph/distributed/nvshmem/NVSHMEMBackendEngine.py @@ -55,9 +55,12 @@ def _nvshmem_scatter(input_tensor, indices, rank_mappings, num_output_rows): num_elem = num_output_rows * num_features + # TODO: Look into using calloc here to avoid zeroing out the tensor scattered_tensor = nvshmem.NVSHMEMP2P.allocate_symmetric_memory( num_elem, device.index ).reshape((bs, num_output_rows, num_features)) + scattered_tensor.zero_() + cur_rank = nvshmem.NVSHMEMP2P.get_rank() indices = indices % num_output_rows local_send_tensor = input_tensor[rank_mappings == cur_rank].unsqueeze(0) diff --git a/tests/test_nvshmem_backend.py b/tests/test_nvshmem_backend.py index 1a0df93..37b747f 100644 --- a/tests/test_nvshmem_backend.py +++ b/tests/test_nvshmem_backend.py @@ -60,8 +60,6 @@ def setup_scatter_data(init_nvshmem_backend): torch.manual_seed(0) - torch.manual_seed(0) - num_features = 8 all_rank_input_data = torch.randn(1, 8, num_features) @@ -163,7 +161,6 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data): output_slice_end = (num_global_output_rows // world_size) * (rank + 1) local_output_data_gt = all_rank_output[:, output_slice_start:output_slice_end, :] num_output_rows = local_output_data_gt.shape[1] - num_output_rows = local_output_data_gt.shape[1] for i in range(2): local_indices_gt = local_edge_coo[[i], :] @@ -184,4 +181,5 @@ def test_nvshmem_backend_scatter(init_nvshmem_backend, setup_scatter_data): local_rank_mapping.cuda(), num_output_rows, ) + assert torch.allclose(scattered_tensor, local_output_data_gt[[i]].cuda())