From 6a7ffa7940baa17b3c4ce77c5f11acd15b79fe93 Mon Sep 17 00:00:00 2001 From: Simon Garcia De Gonzalo Date: Wed, 13 Oct 2021 10:57:01 +0200 Subject: [PATCH 01/25] HOST NVSHMEM EXAMPLE --- 8-H_NCCL_NVSHMEM/NVSHMEM/Instructions.md | 40 ++ 8-H_NCCL_NVSHMEM/NVSHMEM/Makefile | 43 ++ 8-H_NCCL_NVSHMEM/NVSHMEM/copy.mk | 40 ++ 8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu | 594 +++++++++++++++++++++++ 4 files changed, 717 insertions(+) create mode 100644 8-H_NCCL_NVSHMEM/NVSHMEM/Instructions.md create mode 100644 8-H_NCCL_NVSHMEM/NVSHMEM/Makefile create mode 100644 8-H_NCCL_NVSHMEM/NVSHMEM/copy.mk create mode 100644 8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu diff --git a/8-H_NCCL_NVSHMEM/NVSHMEM/Instructions.md b/8-H_NCCL_NVSHMEM/NVSHMEM/Instructions.md new file mode 100644 index 0000000..e087e70 --- /dev/null +++ b/8-H_NCCL_NVSHMEM/NVSHMEM/Instructions.md @@ -0,0 +1,40 @@ +# SC21 Tutorial: Efficient Distributed GPU Programming for Exascale + +- Time: Sunday, 14 November 2021 8AM - 5PM CST +- Location: *online* +- Program Link: https://sc21.supercomputing.org/presentation/?id=tut138&sess=sess188 + + +## Hands-On 8\_NVSHMEM: Host-initiated Communication with NVSHMEM + +### Task 0: Using NVSHMEM device API + +#### Description + +The purpose of this task is to use the NVSHMEM host API instead of MPI to implement a multi-GPU jacobi solver. The starting point of this task is the MPI variant of the jacobi solver. You need to work on `TODOs` in `jacobi.cu`: + +- Initialize NVSHMEM: + - Include NVSHMEM headers. + - Initialize and shutdown NVSHMEM using `MPI_COMM_WORLD`. + - Allocate work arrays `a` and `a_new` from the NVSHMEM symmetric heap. Take care of passing in a consistent size! + - Calculate halo/boundary row index of top and bottom neighbors. + - Add necessary inter PE synchronization. + - Pass in halo/boundary row index of top and bottom neighbors. + - Use `nvshmem_float_p` to directly push values needed by top and bottom neighbors from the host. + - Remove no longer needed MPI communication. + +Compile with + +``` {.bash} +make +``` + +Submit your compiled application to the batch system with + +``` {.bash} +make run +``` + +Study the performance by glimpsing at the profile generated with +`make profile`. For `make run` and `make profile` the environment variable `NP` can be set to change the number of processes. + diff --git a/8-H_NCCL_NVSHMEM/NVSHMEM/Makefile b/8-H_NCCL_NVSHMEM/NVSHMEM/Makefile new file mode 100644 index 0000000..8e7f262 --- /dev/null +++ b/8-H_NCCL_NVSHMEM/NVSHMEM/Makefile @@ -0,0 +1,43 @@ +# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +NP ?= 4 +NVCC=nvcc +JSC_SUBMIT_CMD ?= srun --gres=gpu:4 --ntasks-per-node 4 +CUDA_HOME ?= /usr/local/cuda +ifndef NVSHMEM_HOME +$(error NVSHMEM_HOME is not set) +endif +ifndef MPI_HOME +$(error MPI_HOME is not set) +endif +GENCODE_SM30 := -gencode arch=compute_30,code=sm_30 +GENCODE_SM35 := -gencode arch=compute_35,code=sm_35 +GENCODE_SM37 := -gencode arch=compute_37,code=sm_37 +GENCODE_SM50 := -gencode arch=compute_50,code=sm_50 +GENCODE_SM52 := -gencode arch=compute_52,code=sm_52 +GENCODE_SM60 := -gencode arch=compute_60,code=sm_60 +GENCODE_SM70 := -gencode arch=compute_70,code=sm_70 +GENCODE_SM80 := -gencode arch=compute_80,code=sm_80 -gencode arch=compute_80,code=compute_80 +GENCODE_FLAGS := $(GENCODE_SM70) $(GENCODE_SM80) +ifdef DISABLE_CUB + NVCC_FLAGS = -Xptxas --optimize-float-atomics +else + NVCC_FLAGS = -DHAVE_CUB +endif +NVCC_FLAGS += -dc -Xcompiler -fopenmp -lineinfo -DUSE_NVTX -lnvToolsExt $(GENCODE_FLAGS) -std=c++14 -I$(NVSHMEM_HOME)/include -I$(MPI_HOME)/include +NVCC_LDFLAGS = -ccbin=mpic++ -L$(NVSHMEM_HOME)/lib -lnvshmem -L$(MPI_HOME)/lib -lmpi -L$(CUDA_HOME)/lib64 -lcuda -lcudart -lnvToolsExt +jacobi: Makefile jacobi.cu + $(NVCC) $(NVCC_FLAGS) jacobi.cu -c -o jacobi.o + $(NVCC) $(GENCODE_FLAGS) jacobi.o -o jacobi $(NVCC_LDFLAGS) + +.PHONY.: clean +clean: + rm -f jacobi jacobi.o *.qdrep jacobi.*.compute-sanitizer.log + +sanitize: jacobi + $(JSC_SUBMIT_CMD) -n $(NP) compute-sanitizer --log-file jacobi.%q{SLURM_PROCID}.compute-sanitizer.log ./jacobi -niter 10 + +run: jacobi + $(JSC_SUBMIT_CMD) -n $(NP) ./jacobi + +profile: jacobi + $(JSC_SUBMIT_CMD) -n $(NP) nsys profile --trace=mpi,cuda,nvtx -o jacobi.%q{SLURM_PROCID} ./jacobi -niter 10 diff --git a/8-H_NCCL_NVSHMEM/NVSHMEM/copy.mk b/8-H_NCCL_NVSHMEM/NVSHMEM/copy.mk new file mode 100644 index 0000000..5a35bdd --- /dev/null +++ b/8-H_NCCL_NVSHMEM/NVSHMEM/copy.mk @@ -0,0 +1,40 @@ +#!/usr/bin/make -f +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +TASKDIR = ../../tasks/8-H_NCCL_NVSHMEM/NVSHMEM/ +SOLUTIONDIR = ../../solutions/8-H_NCCL_NVSHMEM/NVSHMEM + +PROCESSFILES = jacobi.cu +COPYFILES = Makefile Instructions.ipynb Instructions.md + + +TASKPROCCESFILES = $(addprefix $(TASKDIR)/,$(PROCESSFILES)) +TASKCOPYFILES = $(addprefix $(TASKDIR)/,$(COPYFILES)) +SOLUTIONPROCCESFILES = $(addprefix $(SOLUTIONDIR)/,$(PROCESSFILES)) +SOLUTIONCOPYFILES = $(addprefix $(SOLUTIONDIR)/,$(COPYFILES)) + +.PHONY: all task +all: task +task: ${TASKPROCCESFILES} ${TASKCOPYFILES} ${SOLUTIONPROCCESFILES} ${SOLUTIONCOPYFILES} + + +${TASKPROCCESFILES}: $(PROCESSFILES) + mkdir -p $(TASKDIR)/ + cppp -USOLUTION $(notdir $@) $@ + +${SOLUTIONPROCCESFILES}: $(PROCESSFILES) + mkdir -p $(SOLUTIONDIR)/ + cppp -DSOLUTION $(notdir $@) $@ + + +${TASKCOPYFILES}: $(COPYFILES) + mkdir -p $(TASKDIR)/ + cp $(notdir $@) $@ + +${SOLUTIONCOPYFILES}: $(COPYFILES) + mkdir -p $(SOLUTIONDIR)/ + cp $(notdir $@) $@ + +%.ipynb: %.md + pandoc $< -o $@ + # add metadata so this is seen as python + jq -s '.[0] * .[1]' $@ ../template.json | sponge $@ diff --git a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu new file mode 100644 index 0000000..894adf0 --- /dev/null +++ b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu @@ -0,0 +1,594 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2017,2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include +#include +#include +#include +#include + +#include + +#define MPI_CALL(call) \ + { \ + int mpi_status = call; \ + if (0 != mpi_status) { \ + char mpi_error_string[MPI_MAX_ERROR_STRING]; \ + int mpi_error_string_length = 0; \ + MPI_Error_string(mpi_status, mpi_error_string, &mpi_error_string_length); \ + if (NULL != mpi_error_string) \ + fprintf(stderr, \ + "ERROR: MPI call \"%s\" in line %d of file %s failed " \ + "with %s " \ + "(%d).\n", \ + #call, __LINE__, __FILE__, mpi_error_string, mpi_status); \ + else \ + fprintf(stderr, \ + "ERROR: MPI call \"%s\" in line %d of file %s failed " \ + "with %d.\n", \ + #call, __LINE__, __FILE__, mpi_status); \ + } \ + } + +#include + +#ifdef SOLUTION +#include +#include +#else +//TODO: Include NVSHMEM headers +#endif + +#ifdef HAVE_CUB +#include +#endif // HAVE_CUB + +#ifdef USE_NVTX +#include + +const uint32_t colors[] = {0x0000ff00, 0x000000ff, 0x00ffff00, 0x00ff00ff, + 0x0000ffff, 0x00ff0000, 0x00ffffff}; +const int num_colors = sizeof(colors) / sizeof(uint32_t); + +#define PUSH_RANGE(name, cid) \ + { \ + int color_id = cid; \ + color_id = color_id % num_colors; \ + nvtxEventAttributes_t eventAttrib = {0}; \ + eventAttrib.version = NVTX_VERSION; \ + eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; \ + eventAttrib.colorType = NVTX_COLOR_ARGB; \ + eventAttrib.color = colors[color_id]; \ + eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; \ + eventAttrib.message.ascii = name; \ + nvtxRangePushEx(&eventAttrib); \ + } +#define POP_RANGE nvtxRangePop(); +#else +#define PUSH_RANGE(name, cid) +#define POP_RANGE +#endif + +#define CUDA_RT_CALL(call) \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaSuccess != cudaStatus) \ + fprintf(stderr, \ + "ERROR: CUDA RT call \"%s\" in line %d of file %s failed " \ + "with " \ + "%s (%d).\n", \ + #call, __LINE__, __FILE__, cudaGetErrorString(cudaStatus), cudaStatus); \ + } + +#ifdef USE_DOUBLE +typedef double real; +#define MPI_REAL_TYPE MPI_DOUBLE +#else +typedef float real; +#define MPI_REAL_TYPE MPI_FLOAT +#endif + +constexpr real tol = 1.0e-8; + +const real PI = 2.0 * std::asin(1.0); + +__global__ void initialize_boundaries(real* __restrict__ const a_new, real* __restrict__ const a, + const real pi, const int offset, const int nx, + const int my_ny, const int ny) { + for (int iy = blockIdx.x * blockDim.x + threadIdx.x; iy < my_ny; iy += blockDim.x * gridDim.x) { + const real y0 = sin(2.0 * pi * (offset + iy) / (ny - 1)); + a[iy * nx + 0] = y0; + a[iy * nx + (nx - 1)] = y0; + a_new[iy * nx + 0] = y0; + a_new[iy * nx + (nx - 1)] = y0; + } +} + +template +__global__ void jacobi_kernel(real* __restrict__ const a_new, const real* __restrict__ const a, + real* __restrict__ const l2_norm, const int iy_start, + const int iy_end, const int nx, const bool calculate_norm) { +#ifdef HAVE_CUB + typedef cub::BlockReduce + BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; +#endif // HAVE_CUB + int iy = blockIdx.y * blockDim.y + threadIdx.y + iy_start; + int ix = blockIdx.x * blockDim.x + threadIdx.x + 1; + real local_l2_norm = 0.0; + + if (iy < iy_end && ix < (nx - 1)) { + const real new_val = 0.25 * (a[iy * nx + ix + 1] + a[iy * nx + ix - 1] + + a[(iy + 1) * nx + ix] + a[(iy - 1) * nx + ix]); + a_new[iy * nx + ix] = new_val; + if (calculate_norm) { + real residue = new_val - a[iy * nx + ix]; + local_l2_norm += residue * residue; + } + } + if (calculate_norm) { +#ifdef HAVE_CUB + real block_l2_norm = BlockReduce(temp_storage).Sum(local_l2_norm); + if (0 == threadIdx.y && 0 == threadIdx.x) atomicAdd(l2_norm, block_l2_norm); +#else + atomicAdd(l2_norm, local_l2_norm); +#endif // HAVE_CUB + } +} + +double single_gpu(const int nx, const int ny, const int iter_max, real* const a_ref_h, + const int nccheck, const bool print); + +template +T get_argval(char** begin, char** end, const std::string& arg, const T default_val) { + T argval = default_val; + char** itr = std::find(begin, end, arg); + if (itr != end && ++itr != end) { + std::istringstream inbuf(*itr); + inbuf >> argval; + } + return argval; +} + +bool get_arg(char** begin, char** end, const std::string& arg) { + char** itr = std::find(begin, end, arg); + if (itr != end) { + return true; + } + return false; +} + +int main(int argc, char* argv[]) { + MPI_CALL(MPI_Init(&argc, &argv)); + int rank; + MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); + int size; + MPI_CALL(MPI_Comm_size(MPI_COMM_WORLD, &size)); + int num_devices = 0; + CUDA_RT_CALL(cudaGetDeviceCount(&num_devices)); + + const int iter_max = get_argval(argv, argv + argc, "-niter", 1000); + const int nccheck = get_argval(argv, argv + argc, "-nccheck", 1); + const int nx = get_argval(argv, argv + argc, "-nx", 16384); + const int ny = get_argval(argv, argv + argc, "-ny", 16384); + const bool csv = get_arg(argv, argv + argc, "-csv"); + + int local_rank = -1; + { + MPI_Comm local_comm; + MPI_CALL(MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, rank, MPI_INFO_NULL, + &local_comm)); + + MPI_CALL(MPI_Comm_rank(local_comm, &local_rank)); + + MPI_CALL(MPI_Comm_free(&local_comm)); + } + + CUDA_RT_CALL(cudaSetDevice(local_rank%num_devices)); + CUDA_RT_CALL(cudaFree(0)); + +#ifdef SOLUTION + MPI_Comm mpi_comm = MPI_COMM_WORLD; + nvshmemx_init_attr_t attr; + attr.mpi_comm = &mpi_comm; + nvshmemx_init_attr(NVSHMEMX_INIT_WITH_MPI_COMM, &attr); + + assert( size == nvshmem_n_pes() ); + assert( rank == nvshmem_my_pe() ); +#else + //TODO: Initialize NVSHMEM using nvshmemx_init_attr +#endif + + real* a_ref_h; + CUDA_RT_CALL(cudaMallocHost(&a_ref_h, nx * ny * sizeof(real))); + real* a_h; + CUDA_RT_CALL(cudaMallocHost(&a_h, nx * ny * sizeof(real))); + double runtime_serial = single_gpu(nx, ny, iter_max, a_ref_h, nccheck, !csv && (0 == rank)); + + // ny - 2 rows are distributed amongst `size` ranks in such a way + // that each rank gets either (ny - 2) / size or (ny - 2) / size + 1 rows. + // This optimizes load balancing when (ny - 2) % size != 0 + int chunk_size; + int chunk_size_low = (ny - 2) / size; + int chunk_size_high = chunk_size_low + 1; + // To calculate the number of ranks that need to compute an extra row, + // the following formula is derived from this equation: + // num_ranks_low * chunk_size_low + (size - num_ranks_low) * (chunk_size_low + 1) = ny - 2 + int num_ranks_low = size * chunk_size_low + size - + (ny - 2); // Number of ranks with chunk_size = chunk_size_low + if (rank < num_ranks_low) + chunk_size = chunk_size_low; + else + chunk_size = chunk_size_high; + +#ifdef SOLUTION + real* a = (real*) nvshmem_malloc(nx * (chunk_size_high + 2) * sizeof(real)); + real* a_new = (real*) nvshmem_malloc(nx * (chunk_size_high + 2) * sizeof(real)); +#else + //TODO: Allocate a and a_new from the NVSHMEM symmetric heap + // Note: size needs to be the same on all PEs but chunk_size might not be! + real* a; + CUDA_RT_CALL(cudaMalloc(&a, nx * (chunk_size + 2) * sizeof(real))); + real* a_new; + CUDA_RT_CALL(cudaMalloc(&a_new, nx * (chunk_size + 2) * sizeof(real))); +#endif + + CUDA_RT_CALL(cudaMemset(a, 0, nx * (chunk_size + 2) * sizeof(real))); + CUDA_RT_CALL(cudaMemset(a_new, 0, nx * (chunk_size + 2) * sizeof(real))); + + // Calculate local domain boundaries + int iy_start_global; // My start index in the global array + if (rank < num_ranks_low) { + iy_start_global = rank * chunk_size_low + 1; + } else { + iy_start_global = + num_ranks_low * chunk_size_low + (rank - num_ranks_low) * chunk_size_high + 1; + } + int iy_end_global = iy_start_global + chunk_size - 1; // My last index in the global array + + int iy_start = 1; + int iy_end = iy_start + chunk_size; + + const int top = rank > 0 ? rank - 1 : (size - 1); + const int bottom = (rank + 1) % size; + +#ifdef SOLUTION + const int iy_top_lower_boundary_idx = (top < num_ranks_low) ? (chunk_size_low + 1) : (chunk_size_high + 1); + const int iy_bottom_upper_boundary_idx = 0; +#else + //TODO: calculate halo/boundary row index of top and bottom neighbors +#endif + + // Set diriclet boundary conditions on left and right boarder + initialize_boundaries<<<(chunk_size + 2) / 128 + 1, 128>>>(a, a_new, PI, iy_start_global - 1, nx, (chunk_size + 2), ny); + CUDA_RT_CALL(cudaGetLastError()); + CUDA_RT_CALL(cudaDeviceSynchronize()); + + cudaStream_t compute_stream; + CUDA_RT_CALL(cudaStreamCreate(&compute_stream)); + cudaEvent_t compute_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&compute_done, cudaEventDisableTiming)); + + real* l2_norm_d; + CUDA_RT_CALL(cudaMalloc(&l2_norm_d, sizeof(real))); + real* l2_norm_h; + CUDA_RT_CALL(cudaMallocHost(&l2_norm_h, sizeof(real))); + + PUSH_RANGE("MPI_Warmup", 5) + for (int i = 0; i < 10; ++i) { + const int top = rank > 0 ? rank - 1 : (size - 1); + const int bottom = (rank + 1) % size; + MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, + a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, + MPI_STATUS_IGNORE)); + MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, + MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + std::swap(a_new, a); + } + POP_RANGE + + CUDA_RT_CALL(cudaDeviceSynchronize()); + + if (!csv && 0 == rank) { + printf( + "Jacobi relaxation: %d iterations on %d x %d mesh with norm check " + "every %d iterations\n", + iter_max, ny, nx, nccheck); + } + + constexpr int dim_block_x = 32; + constexpr int dim_block_y = 32; + dim3 dim_grid((nx + dim_block_x - 1) / dim_block_x, + ((iy_end - iy_start) + dim_block_y - 1) / dim_block_y, 1); + + int iter = 0; + real l2_norm = 1.0; + bool calculate_norm; // boolean to store whether l2 norm will be calculated in + // an iteration or not + + MPI_CALL(MPI_Barrier(MPI_COMM_WORLD)); + double start = MPI_Wtime(); + PUSH_RANGE("Jacobi solve", 0) + while (l2_norm > tol && iter < iter_max) { + CUDA_RT_CALL(cudaMemsetAsync(l2_norm_d, 0, sizeof(real), compute_stream)); + + calculate_norm = (iter % nccheck) == 0 || (!csv && (iter % 100) == 0); + + jacobi_kernel<<>>( + a_new, a, l2_norm_d, iy_start, iy_end, nx, calculate_norm); + CUDA_RT_CALL(cudaGetLastError()); + CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); + +#ifdef SOLUTION + nvshmemx_barrier_all_on_stream(compute_stream); +#else + //TODO: add necessary inter PE synchronization +#endif + + if (calculate_norm) { + CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, + compute_stream)); + } + +#ifdef SOLUTION + nvshmem_float_put(a_new + (iy_end * nx), a_new + iy_start * nx, nx, top); + nvshmem_float_put(a_new, a_new + (iy_end - 1) * nx, nx, bottom); +#else + //TODO: Replace MPI communication with Host initiated NVSHMEM calls + // Apply periodic boundary conditions + CUDA_RT_CALL(cudaEventSynchronize(compute_done)); + PUSH_RANGE("MPI", 5) + MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, + a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, + MPI_STATUS_IGNORE)); + MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, + MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + POP_RANGE +#endif + + if (calculate_norm) { + CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); + MPI_CALL(MPI_Allreduce(l2_norm_h, &l2_norm, 1, MPI_REAL_TYPE, MPI_SUM, MPI_COMM_WORLD)); + l2_norm = std::sqrt(l2_norm); + + if (!csv && 0 == rank && (iter % 100) == 0) { + printf("%5d, %0.6f\n", iter, l2_norm); + } + } + + std::swap(a_new, a); + iter++; + } + double stop = MPI_Wtime(); + POP_RANGE + + CUDA_RT_CALL(cudaMemcpy(a_h + iy_start_global * nx, a + nx, + std::min((ny - iy_start_global) * nx, chunk_size * nx) * sizeof(real), + cudaMemcpyDeviceToHost)); + + int result_correct = 1; + for (int iy = iy_start_global; result_correct && (iy < iy_end_global); ++iy) { + for (int ix = 1; result_correct && (ix < (nx - 1)); ++ix) { + if (std::fabs(a_ref_h[iy * nx + ix] - a_h[iy * nx + ix]) > tol) { + fprintf(stderr, + "ERROR on rank %d: a[%d * %d + %d] = %f does not match %f " + "(reference)\n", + rank, iy, nx, ix, a_h[iy * nx + ix], a_ref_h[iy * nx + ix]); + result_correct = 0; + } + } + } + + int global_result_correct = 1; + MPI_CALL(MPI_Allreduce(&result_correct, &global_result_correct, 1, MPI_INT, MPI_MIN, + MPI_COMM_WORLD)); + result_correct = global_result_correct; + + if (rank == 0 && result_correct) { + if (csv) { + printf("mpi, %d, %d, %d, %d, %d, 1, %f, %f\n", nx, ny, iter_max, nccheck, size, + (stop - start), runtime_serial); + } else { + printf("Num GPUs: %d.\n", size); + printf( + "%dx%d: 1 GPU: %8.4f s, %d GPUs: %8.4f s, speedup: %8.2f, " + "efficiency: %8.2f \n", + ny, nx, runtime_serial, size, (stop - start), runtime_serial / (stop - start), + runtime_serial / (size * (stop - start)) * 100); + } + } + CUDA_RT_CALL(cudaEventDestroy(compute_done)); + CUDA_RT_CALL(cudaStreamDestroy(compute_stream)); + + CUDA_RT_CALL(cudaFreeHost(l2_norm_h)); + CUDA_RT_CALL(cudaFree(l2_norm_d)); + +#ifdef SOLUTION + nvshmem_free(a_new); + nvshmem_free(a); +#else + //TODO: Deallocated a_new and a from the NVSHMEM symmetric heap + CUDA_RT_CALL(cudaFree(a_new)); + CUDA_RT_CALL(cudaFree(a)); +#endif + + CUDA_RT_CALL(cudaFreeHost(a_h)); + CUDA_RT_CALL(cudaFreeHost(a_ref_h)); + +#ifdef SOLUTION + nvshmem_finalize(); +#else + //TODO: Finalize NVSHMEM +#endif + MPI_CALL(MPI_Finalize()); + return (result_correct == 1) ? 0 : 1; +} + +template +__global__ void jacobi_kernel_single_gpu(real* __restrict__ const a_new, const real* __restrict__ const a, + real* __restrict__ const l2_norm, const int iy_start, + const int iy_end, const int nx, const bool calculate_norm) { +#ifdef HAVE_CUB + typedef cub::BlockReduce + BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; +#endif // HAVE_CUB + int iy = blockIdx.y * blockDim.y + threadIdx.y + iy_start; + int ix = blockIdx.x * blockDim.x + threadIdx.x + 1; + real local_l2_norm = 0.0; + + if (iy < iy_end && ix < (nx - 1)) { + const real new_val = 0.25 * (a[iy * nx + ix + 1] + a[iy * nx + ix - 1] + + a[(iy + 1) * nx + ix] + a[(iy - 1) * nx + ix]); + a_new[iy * nx + ix] = new_val; + if (calculate_norm) { + real residue = new_val - a[iy * nx + ix]; + local_l2_norm += residue * residue; + } + } + if (calculate_norm) { +#ifdef HAVE_CUB + real block_l2_norm = BlockReduce(temp_storage).Sum(local_l2_norm); + if (0 == threadIdx.y && 0 == threadIdx.x) atomicAdd(l2_norm, block_l2_norm); +#else + atomicAdd(l2_norm, local_l2_norm); +#endif // HAVE_CUB + } +} + +double single_gpu(const int nx, const int ny, const int iter_max, real* const a_ref_h, + const int nccheck, const bool print) { + real* a; + real* a_new; + + cudaStream_t compute_stream; + cudaStream_t push_top_stream; + cudaStream_t push_bottom_stream; + cudaEvent_t compute_done; + cudaEvent_t push_top_done; + cudaEvent_t push_bottom_done; + + real* l2_norm_d; + real* l2_norm_h; + + int iy_start = 1; + int iy_end = (ny - 1); + + CUDA_RT_CALL(cudaMalloc(&a, nx * ny * sizeof(real))); + CUDA_RT_CALL(cudaMalloc(&a_new, nx * ny * sizeof(real))); + + CUDA_RT_CALL(cudaMemset(a, 0, nx * ny * sizeof(real))); + CUDA_RT_CALL(cudaMemset(a_new, 0, nx * ny * sizeof(real))); + + // Set diriclet boundary conditions on left and right boarder + initialize_boundaries<<>>(a, a_new, PI, 0, nx, ny, ny); + CUDA_RT_CALL(cudaGetLastError()); + CUDA_RT_CALL(cudaDeviceSynchronize()); + + CUDA_RT_CALL(cudaStreamCreate(&compute_stream)); + CUDA_RT_CALL(cudaStreamCreate(&push_top_stream)); + CUDA_RT_CALL(cudaStreamCreate(&push_bottom_stream)); + CUDA_RT_CALL(cudaEventCreateWithFlags(&compute_done, cudaEventDisableTiming)); + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_top_done, cudaEventDisableTiming)); + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_bottom_done, cudaEventDisableTiming)); + + CUDA_RT_CALL(cudaMalloc(&l2_norm_d, sizeof(real))); + CUDA_RT_CALL(cudaMallocHost(&l2_norm_h, sizeof(real))); + + CUDA_RT_CALL(cudaDeviceSynchronize()); + + if (print) + printf( + "Single GPU jacobi relaxation: %d iterations on %d x %d mesh with " + "norm " + "check every %d iterations\n", + iter_max, ny, nx, nccheck); + + constexpr int dim_block_x = 32; + constexpr int dim_block_y = 32; + dim3 dim_grid((nx + dim_block_x - 1) / dim_block_x, + ((iy_end - iy_start) + dim_block_y - 1) / dim_block_y, 1); + + int iter = 0; + real l2_norm = 1.0; + bool calculate_norm; + + double start = MPI_Wtime(); + PUSH_RANGE("Jacobi solve", 0) + while (l2_norm > tol && iter < iter_max) { + CUDA_RT_CALL(cudaMemsetAsync(l2_norm_d, 0, sizeof(real), compute_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); + + calculate_norm = (iter % nccheck) == 0 || (iter % 100) == 0; + jacobi_kernel_single_gpu<<>>( + a_new, a, l2_norm_d, iy_start, iy_end, nx, calculate_norm); + CUDA_RT_CALL(cudaGetLastError()); + CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); + + if (calculate_norm) { + CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, + compute_stream)); + } + + // Apply periodic boundary conditions + + CUDA_RT_CALL(cudaStreamWaitEvent(push_top_stream, compute_done, 0)); + CUDA_RT_CALL(cudaMemcpyAsync(a_new, a_new + (iy_end - 1) * nx, nx * sizeof(real), + cudaMemcpyDeviceToDevice, push_top_stream)); + CUDA_RT_CALL(cudaEventRecord(push_top_done, push_top_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(push_bottom_stream, compute_done, 0)); + CUDA_RT_CALL(cudaMemcpyAsync(a_new + iy_end * nx, a_new + iy_start * nx, nx * sizeof(real), + cudaMemcpyDeviceToDevice, compute_stream)); + CUDA_RT_CALL(cudaEventRecord(push_bottom_done, push_bottom_stream)); + + if (calculate_norm) { + CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); + l2_norm = *l2_norm_h; + l2_norm = std::sqrt(l2_norm); + if (print && (iter % 100) == 0) printf("%5d, %0.6f\n", iter, l2_norm); + } + + std::swap(a_new, a); + iter++; + } + POP_RANGE + double stop = MPI_Wtime(); + + CUDA_RT_CALL(cudaMemcpy(a_ref_h, a, nx * ny * sizeof(real), cudaMemcpyDeviceToHost)); + + CUDA_RT_CALL(cudaEventDestroy(push_bottom_done)); + CUDA_RT_CALL(cudaEventDestroy(push_top_done)); + CUDA_RT_CALL(cudaEventDestroy(compute_done)); + CUDA_RT_CALL(cudaStreamDestroy(push_bottom_stream)); + CUDA_RT_CALL(cudaStreamDestroy(push_top_stream)); + CUDA_RT_CALL(cudaStreamDestroy(compute_stream)); + + CUDA_RT_CALL(cudaFreeHost(l2_norm_h)); + CUDA_RT_CALL(cudaFree(l2_norm_d)); + + CUDA_RT_CALL(cudaFree(a_new)); + CUDA_RT_CALL(cudaFree(a)); + return (stop - start); +} From f7b0db061d3f6dcde36da15deba684ed8360ae26 Mon Sep 17 00:00:00 2001 From: Simon Garcia De Gonzalo Date: Wed, 13 Oct 2021 16:50:04 +0200 Subject: [PATCH 02/25] boundary conditions is not working correctly yet --- 8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu index 894adf0..f18d0fe 100644 --- a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu +++ b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu @@ -350,8 +350,12 @@ int main(int argc, char* argv[]) { } #ifdef SOLUTION - nvshmem_float_put(a_new + (iy_end * nx), a_new + iy_start * nx, nx, top); - nvshmem_float_put(a_new, a_new + (iy_end - 1) * nx, nx, bottom); + //Simon: Currently produces rounding errors + //nvshmem_float_put(a_new + (iy_end * nx), a_new + iy_start * nx, nx, top); + //nvshmem_float_put(a_new, a_new + (iy_end - 1) * nx, nx, bottom); + nvshmem_float_put(a_new + iy_top_lower_boundary_idx * nx, a_new + iy_start * nx, nx, top); + nvshmem_float_put(a_new + iy_bottom_upper_boundary_idx * nx, a_new + (iy_end - 1) * nx, nx, bottom); + #else //TODO: Replace MPI communication with Host initiated NVSHMEM calls // Apply periodic boundary conditions From cbc3d0b8845450b41a28e6a98ba44a9b990c001e Mon Sep 17 00:00:00 2001 From: Simon Garcia De Gonzalo Date: Thu, 14 Oct 2021 15:34:54 +0200 Subject: [PATCH 03/25] Fixed rounding error due to not correctly waiting for the compute kernel to finish --- 8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu index f18d0fe..8d96fbb 100644 --- a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu +++ b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu @@ -338,23 +338,15 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaGetLastError()); CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); -#ifdef SOLUTION - nvshmemx_barrier_all_on_stream(compute_stream); -#else - //TODO: add necessary inter PE synchronization -#endif - if (calculate_norm) { CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, compute_stream)); } #ifdef SOLUTION - //Simon: Currently produces rounding errors - //nvshmem_float_put(a_new + (iy_end * nx), a_new + iy_start * nx, nx, top); - //nvshmem_float_put(a_new, a_new + (iy_end - 1) * nx, nx, bottom); - nvshmem_float_put(a_new + iy_top_lower_boundary_idx * nx, a_new + iy_start * nx, nx, top); - nvshmem_float_put(a_new + iy_bottom_upper_boundary_idx * nx, a_new + (iy_end - 1) * nx, nx, bottom); + + nvshmemx_float_put_on_stream(a_new + iy_top_lower_boundary_idx * nx, a_new + iy_start * nx, nx, top, compute_stream); + nvshmemx_float_put_on_stream(a_new + iy_bottom_upper_boundary_idx * nx, a_new + (iy_end - 1) * nx, nx, bottom, compute_stream); #else //TODO: Replace MPI communication with Host initiated NVSHMEM calls @@ -369,6 +361,14 @@ int main(int argc, char* argv[]) { POP_RANGE #endif + +#ifdef SOLUTION + nvshmemx_barrier_all_on_stream(compute_stream); +#else + //TODO: add necessary inter PE synchronization +#endif + + if (calculate_norm) { CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); MPI_CALL(MPI_Allreduce(l2_norm_h, &l2_norm, 1, MPI_REAL_TYPE, MPI_SUM, MPI_COMM_WORLD)); From 2b0f18018dc789e616f43ae832a556b12cdbf307 Mon Sep 17 00:00:00 2001 From: Simon Garcia De Gonzalo Date: Tue, 19 Oct 2021 11:18:55 +0200 Subject: [PATCH 04/25] First draft of the MPI Overalp version to be tested --- .../Makefile | 41 ++ .../jacobi.cpp | 533 ++++++++++++++++++ .../jacobi_kernels.cu | 113 ++++ 3 files changed, 687 insertions(+) create mode 100644 6-H_Overlap_Communication_and_Computation_MPI/Makefile create mode 100644 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp create mode 100644 6-H_Overlap_Communication_and_Computation_MPI/jacobi_kernels.cu diff --git a/6-H_Overlap_Communication_and_Computation_MPI/Makefile b/6-H_Overlap_Communication_and_Computation_MPI/Makefile new file mode 100644 index 0000000..4e9002d --- /dev/null +++ b/6-H_Overlap_Communication_and_Computation_MPI/Makefile @@ -0,0 +1,41 @@ +# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +NP ?= 1 +NVCC=nvcc +MPICXX=mpicxx +MPIRUN ?= mpirun +CUDA_HOME ?= /usr/local/cuda +GENCODE_SM30 := -gencode arch=compute_30,code=sm_30 +GENCODE_SM35 := -gencode arch=compute_35,code=sm_35 +GENCODE_SM37 := -gencode arch=compute_37,code=sm_37 +GENCODE_SM50 := -gencode arch=compute_50,code=sm_50 +GENCODE_SM52 := -gencode arch=compute_52,code=sm_52 +GENCODE_SM60 := -gencode arch=compute_60,code=sm_60 +GENCODE_SM70 := -gencode arch=compute_70,code=sm_70 +GENCODE_SM80 := -gencode arch=compute_80,code=sm_80 -gencode arch=compute_80,code=compute_80 +GENCODE_FLAGS := $(GENCODE_SM70) $(GENCODE_SM80) +ifdef DISABLE_CUB + NVCC_FLAGS = -Xptxas --optimize-float-atomics +else + NVCC_FLAGS = -DHAVE_CUB +endif +NVCC_FLAGS += -lineinfo $(GENCODE_FLAGS) -std=c++14 +MPICXX_FLAGS = -DUSE_NVTX -I$(CUDA_HOME)/include -std=c++14 +LD_FLAGS = -L$(CUDA_HOME)/lib64 -lcudart -lnvToolsExt +jacobi: Makefile jacobi.cpp jacobi_kernels.o + $(MPICXX) $(MPICXX_FLAGS) jacobi.cpp jacobi_kernels.o $(LD_FLAGS) -o jacobi + +jacobi_kernels.o: Makefile jacobi_kernels.cu + $(NVCC) $(NVCC_FLAGS) jacobi_kernels.cu -c + +.PHONY.: clean +clean: + rm -f jacobi jacobi_kernels.o *.qdrep jacobi.*.compute-sanitizer.log + +sanitize: jacobi + $(MPIRUN) -np $(NP) compute-sanitizer --log-file jacobi.%q{OMPI_COMM_WORLD_RANK}.compute-sanitizer.log ./jacobi -niter 10 + +run: jacobi + $(MPIRUN) -np $(NP) ./jacobi + +profile: jacobi + $(MPIRUN) -np $(NP) nsys profile --trace=mpi,cuda,nvtx -o jacobi.%q{OMPI_COMM_WORLD_RANK} ./jacobi -niter 10 diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp new file mode 100644 index 0000000..34dc2f4 --- /dev/null +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp @@ -0,0 +1,533 @@ +/* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +#include +#include +#include +#include +#include + +#include + +#define MPI_CALL(call) \ + { \ + int mpi_status = call; \ + if (0 != mpi_status) { \ + char mpi_error_string[MPI_MAX_ERROR_STRING]; \ + int mpi_error_string_length = 0; \ + MPI_Error_string(mpi_status, mpi_error_string, &mpi_error_string_length); \ + if (NULL != mpi_error_string) \ + fprintf(stderr, \ + "ERROR: MPI call \"%s\" in line %d of file %s failed " \ + "with %s " \ + "(%d).\n", \ + #call, __LINE__, __FILE__, mpi_error_string, mpi_status); \ + else \ + fprintf(stderr, \ + "ERROR: MPI call \"%s\" in line %d of file %s failed " \ + "with %d.\n", \ + #call, __LINE__, __FILE__, mpi_status); \ + } \ + } + +#include + +#ifdef USE_NVTX +#include + +const uint32_t colors[] = {0x0000ff00, 0x000000ff, 0x00ffff00, 0x00ff00ff, + 0x0000ffff, 0x00ff0000, 0x00ffffff}; +const int num_colors = sizeof(colors) / sizeof(uint32_t); + +#define PUSH_RANGE(name, cid) \ + { \ + int color_id = cid; \ + color_id = color_id % num_colors; \ + nvtxEventAttributes_t eventAttrib = {0}; \ + eventAttrib.version = NVTX_VERSION; \ + eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; \ + eventAttrib.colorType = NVTX_COLOR_ARGB; \ + eventAttrib.color = colors[color_id]; \ + eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; \ + eventAttrib.message.ascii = name; \ + nvtxRangePushEx(&eventAttrib); \ + } +#define POP_RANGE nvtxRangePop(); +#else +#define PUSH_RANGE(name, cid) +#define POP_RANGE +#endif + +#define CUDA_RT_CALL(call) \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaSuccess != cudaStatus) \ + fprintf(stderr, \ + "ERROR: CUDA RT call \"%s\" in line %d of file %s failed " \ + "with " \ + "%s (%d).\n", \ + #call, __LINE__, __FILE__, cudaGetErrorString(cudaStatus), cudaStatus); \ + } + +#ifdef USE_DOUBLE +typedef double real; +#define MPI_REAL_TYPE MPI_DOUBLE +#else +typedef float real; +#define MPI_REAL_TYPE MPI_FLOAT +#endif + +constexpr real tol = 1.0e-8; + +const real PI = 2.0 * std::asin(1.0); + +void launch_initialize_boundaries(real* __restrict__ const a_new, real* __restrict__ const a, + const real pi, const int offset, const int nx, const int my_ny, + const int ny); + +void launch_jacobi_kernel(real* __restrict__ const a_new, const real* __restrict__ const a, + real* __restrict__ const l2_norm, const int iy_start, const int iy_end, + const int nx, const bool calculate_norm, cudaStream_t stream); + +double single_gpu(const int nx, const int ny, const int iter_max, real* const a_ref_h, + const int nccheck, const bool print); + +template +T get_argval(char** begin, char** end, const std::string& arg, const T default_val) { + T argval = default_val; + char** itr = std::find(begin, end, arg); + if (itr != end && ++itr != end) { + std::istringstream inbuf(*itr); + inbuf >> argval; + } + return argval; +} + +bool get_arg(char** begin, char** end, const std::string& arg) { + char** itr = std::find(begin, end, arg); + if (itr != end) { + return true; + } + return false; +} + +int main(int argc, char* argv[]) { + MPI_CALL(MPI_Init(&argc, &argv)); + int rank; + MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); + int size; + MPI_CALL(MPI_Comm_size(MPI_COMM_WORLD, &size)); + + const int iter_max = get_argval(argv, argv + argc, "-niter", 1000); + const int nccheck = get_argval(argv, argv + argc, "-nccheck", 1); + const int nx = get_argval(argv, argv + argc, "-nx", 16384); + const int ny = get_argval(argv, argv + argc, "-ny", 16384); + const bool csv = get_arg(argv, argv + argc, "-csv"); + + int local_rank = -1; + { + MPI_Comm local_comm; + MPI_CALL(MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, rank, MPI_INFO_NULL, + &local_comm)); + + MPI_CALL(MPI_Comm_rank(local_comm, &local_rank)); + + MPI_CALL(MPI_Comm_free(&local_comm)); + } + + CUDA_RT_CALL(cudaSetDevice(local_rank)); + CUDA_RT_CALL(cudaFree(0)); + + real* a_ref_h; + CUDA_RT_CALL(cudaMallocHost(&a_ref_h, nx * ny * sizeof(real))); + real* a_h; + CUDA_RT_CALL(cudaMallocHost(&a_h, nx * ny * sizeof(real))); + double runtime_serial = single_gpu(nx, ny, iter_max, a_ref_h, nccheck, !csv && (0 == rank)); + + // ny - 2 rows are distributed amongst `size` ranks in such a way + // that each rank gets either (ny - 2) / size or (ny - 2) / size + 1 rows. + // This optimizes load balancing when (ny - 2) % size != 0 + int chunk_size; + int chunk_size_low = (ny - 2) / size; + int chunk_size_high = chunk_size_low + 1; + // To calculate the number of ranks that need to compute an extra row, + // the following formula is derived from this equation: + // num_ranks_low * chunk_size_low + (size - num_ranks_low) * (chunk_size_low + 1) = ny - 2 + int num_ranks_low = size * chunk_size_low + size - + (ny - 2); // Number of ranks with chunk_size = chunk_size_low + if (rank < num_ranks_low) + chunk_size = chunk_size_low; + else + chunk_size = chunk_size_high; + + real* a; + CUDA_RT_CALL(cudaMalloc(&a, nx * (chunk_size + 2) * sizeof(real))); + real* a_new; + CUDA_RT_CALL(cudaMalloc(&a_new, nx * (chunk_size + 2) * sizeof(real))); + + CUDA_RT_CALL(cudaMemset(a, 0, nx * (chunk_size + 2) * sizeof(real))); + CUDA_RT_CALL(cudaMemset(a_new, 0, nx * (chunk_size + 2) * sizeof(real))); + + // Calculate local domain boundaries + int iy_start_global; // My start index in the global array + if (rank < num_ranks_low) { + iy_start_global = rank * chunk_size_low + 1; + } else { + iy_start_global = + num_ranks_low * chunk_size_low + (rank - num_ranks_low) * chunk_size_high + 1; + } + int iy_end_global = iy_start_global + chunk_size - 1; // My last index in the global array + + int iy_start = 1; + int iy_end = iy_start + chunk_size; + + // Set diriclet boundary conditions on left and right boarder + launch_initialize_boundaries(a, a_new, PI, iy_start_global - 1, nx, (chunk_size + 2), ny); + CUDA_RT_CALL(cudaDeviceSynchronize()); +#ifdef SOLUTION + int leastPriority = 0; + int greatestPriority = leastPriority; + CUDA_RT_CALL(cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority)); + + cudaStream_t push_top_stream; + cudaStream_t push_bottom_stream; + + cudaEvent_t push_top_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_top_done, cudaEventDisableTiming)); + cudaEvent_t push_bottom_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_bottom_done, cudaEventDisableTiming)); +#else + //TODO: + //*Set least and greates Priority Range + //*Create top and bottom cuda streams variables and corresponding cuda events +#endif + cudaStream_t compute_stream; + cudaEvent_t compute_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&compute_done, cudaEventDisableTiming)); + cudaEvent_t reset_l2norm_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&reset_l2norm_done, cudaEventDisableTiming)); + +#ifdef SOLUTION + CUDA_RT_CALL(cudaStreamCreateWithPriority(&compute_stream, cudaStreamDefault, leastPriority)); + CUDA_RT_CALL( + cudaStreamCreateWithPriority(&push_top_stream, cudaStreamDefault, greatestPriority)); + CUDA_RT_CALL( + cudaStreamCreateWithPriority(&push_bottom_stream, cudaStreamDefault, greatestPriority)); +#else + //TODO: + //Create cuda streams with Greates Priority for top and bottom streams + //Modify the cudaStreamCreate call for the compute stream to have the Least Priority + CUDA_RT_CALL(cudaStreamCreate(&compute_stream)); +#endif + + real* l2_norm_d; + CUDA_RT_CALL(cudaMalloc(&l2_norm_d, sizeof(real))); + real* l2_norm_h; + CUDA_RT_CALL(cudaMallocHost(&l2_norm_h, sizeof(real))); + + PUSH_RANGE("MPI_Warmup", 5) + for (int i = 0; i < 10; ++i) { + const int top = rank > 0 ? rank - 1 : (size - 1); + const int bottom = (rank + 1) % size; + MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, + a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, + MPI_STATUS_IGNORE)); + MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, + MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + std::swap(a_new, a); + } + POP_RANGE + + CUDA_RT_CALL(cudaDeviceSynchronize()); + + if (!csv && 0 == rank) { + printf( + "Jacobi relaxation: %d iterations on %d x %d mesh with norm check " + "every %d iterations\n", + iter_max, ny, nx, nccheck); + } + + int iter = 0; + bool calculate_norm; + real l2_norm = 1.0; + + MPI_CALL(MPI_Barrier(MPI_COMM_WORLD)); + double start = MPI_Wtime(); + PUSH_RANGE("Jacobi solve", 0) + while (l2_norm > tol && iter < iter_max) { + CUDA_RT_CALL(cudaMemsetAsync(l2_norm_d, 0, sizeof(real), compute_stream)); + CUDA_RT_CALL(cudaEventRecord(reset_l2norm_done, compute_stream)); + + calculate_norm = (iter % nccheck) == 0 || (!csv && (iter % 100) == 0); + +#ifdef SOLUTION + launch_jacobi_kernel(a_new, a, l2_norm_d, (iy_start + 1), (iy_end - 1), nx, + calculate_norm, compute_stream); + CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(push_top_stream, reset_l2norm_done, 0)); + launch_jacobi_kernel(a_new, a, l2_norm_d, iy_start, (iy_start + 1), nx, calculate_norm, + push_top_stream); + CUDA_RT_CALL(cudaEventRecord(push_top_done, push_top_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(push_bottom_stream, reset_l2norm_done, 0)); + launch_jacobi_kernel(a_new, a, l2_norm_d, (iy_end - 1), iy_end, nx, calculate_norm, + push_bottom_stream); + CUDA_RT_CALL(cudaEventRecord(push_bottom_done, push_bottom_stream)); +#else + //TODO: + //*Launch two additional jacobi kernels for the top and bottom regions using + // the top and bottom streams after modifying and launching the original jacobi kernel on + // ONLY the center region. + //*Remember to wait on the for l2_norm_done cuda event before launching each top and bottom jacobi kernels + // using the cudaStreamWaitEvent() call. + //*Remember to record when the top and bottom regions are done using the cudaEventRecord() call + launch_jacobi_kernel(a_new, a, l2_norm_d, iy_start, iy_end, nx, + calculate_norm, compute_stream); + CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); +#endif + + if (calculate_norm) { +#ifdef SOLUTION + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); +#else + //TODO: + //Wait on both the top and bottom cuda events +#endif + CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, + compute_stream)); + } + + const int top = rank > 0 ? rank - 1 : (size - 1); + const int bottom = (rank + 1) % size; + + // Apply periodic boundary conditions +#ifdef SOLUTION + CUDA_RT_CALL(cudaStreamSynchronize(push_top_stream)); +#else + //TODO: Modify the synchronization on the compute stream to be on the top stream + CUDA_RT_CALL(cudaEventSynchronize(compute_done)); +#endif + PUSH_RANGE("MPI", 5) + MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, + a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, + MPI_STATUS_IGNORE)); +#ifdef SOLUTION + CUDA_RT_CALL(cudaStreamSynchronize(push_bottom_stream)); +#else + //TODO: Add additional synchronization on the bottom stream +#endif + MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, + MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); + POP_RANGE + + if (calculate_norm) { + CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); + MPI_CALL(MPI_Allreduce(l2_norm_h, &l2_norm, 1, MPI_REAL_TYPE, MPI_SUM, MPI_COMM_WORLD)); + l2_norm = std::sqrt(l2_norm); + + if (!csv && 0 == rank && (iter % 100) == 0) { + printf("%5d, %0.6f\n", iter, l2_norm); + } + } + + std::swap(a_new, a); + iter++; + } + double stop = MPI_Wtime(); + POP_RANGE + + CUDA_RT_CALL(cudaMemcpy(a_h + iy_start_global * nx, a + nx, + std::min((ny - iy_start_global) * nx, chunk_size * nx) * sizeof(real), + cudaMemcpyDeviceToHost)); + + int result_correct = 1; + for (int iy = iy_start_global; result_correct && (iy < iy_end_global); ++iy) { + for (int ix = 1; result_correct && (ix < (nx - 1)); ++ix) { + if (std::fabs(a_ref_h[iy * nx + ix] - a_h[iy * nx + ix]) > tol) { + fprintf(stderr, + "ERROR on rank %d: a[%d * %d + %d] = %f does not match %f " + "(reference)\n", + rank, iy, nx, ix, a_h[iy * nx + ix], a_ref_h[iy * nx + ix]); + result_correct = 0; + } + } + } + + int global_result_correct = 1; + MPI_CALL(MPI_Allreduce(&result_correct, &global_result_correct, 1, MPI_INT, MPI_MIN, + MPI_COMM_WORLD)); + result_correct = global_result_correct; + + if (rank == 0 && result_correct) { + if (csv) { + printf("mpi_overlap, %d, %d, %d, %d, %d, 1, %f, %f\n", nx, ny, iter_max, nccheck, size, + (stop - start), runtime_serial); + } else { + printf("Num GPUs: %d.\n", size); + printf( + "%dx%d: 1 GPU: %8.4f s, %d GPUs: %8.4f s, speedup: %8.2f, " + "efficiency: %8.2f \n", + ny, nx, runtime_serial, size, (stop - start), runtime_serial / (stop - start), + runtime_serial / (size * (stop - start)) * 100); + } + } +#ifdef SOLUTION + CUDA_RT_CALL(cudaEventDestroy(push_bottom_done)); + CUDA_RT_CALL(cudaEventDestroy(push_top_done)); + CUDA_RT_CALL(cudaStreamDestroy(push_bottom_stream)); + CUDA_RT_CALL(cudaStreamDestroy(push_top_stream)); +#else + //TODO: Destroy the additional top and bottom stream as well as their correspoinding events +#endif + CUDA_RT_CALL(cudaEventDestroy(reset_l2norm_done)); + CUDA_RT_CALL(cudaEventDestroy(compute_done)); + CUDA_RT_CALL(cudaStreamDestroy(compute_stream)); + + CUDA_RT_CALL(cudaFreeHost(l2_norm_h)); + CUDA_RT_CALL(cudaFree(l2_norm_d)); + + CUDA_RT_CALL(cudaFree(a_new)); + CUDA_RT_CALL(cudaFree(a)); + + CUDA_RT_CALL(cudaFreeHost(a_h)); + CUDA_RT_CALL(cudaFreeHost(a_ref_h)); + + MPI_CALL(MPI_Finalize()); + return (result_correct == 1) ? 0 : 1; +} + +double single_gpu(const int nx, const int ny, const int iter_max, real* const a_ref_h, + const int nccheck, const bool print) { + real* a; + real* a_new; + + cudaStream_t compute_stream; + cudaStream_t push_top_stream; + cudaStream_t push_bottom_stream; + cudaEvent_t compute_done; + cudaEvent_t push_top_done; + cudaEvent_t push_bottom_done; + + real* l2_norm_d; + real* l2_norm_h; + + int iy_start = 1; + int iy_end = (ny - 1); + + CUDA_RT_CALL(cudaMalloc(&a, nx * ny * sizeof(real))); + CUDA_RT_CALL(cudaMalloc(&a_new, nx * ny * sizeof(real))); + + CUDA_RT_CALL(cudaMemset(a, 0, nx * ny * sizeof(real))); + CUDA_RT_CALL(cudaMemset(a_new, 0, nx * ny * sizeof(real))); + + // Set diriclet boundary conditions on left and right boarder + launch_initialize_boundaries(a, a_new, PI, 0, nx, ny, ny); + CUDA_RT_CALL(cudaDeviceSynchronize()); + + CUDA_RT_CALL(cudaStreamCreate(&compute_stream)); + CUDA_RT_CALL(cudaStreamCreate(&push_top_stream)); + CUDA_RT_CALL(cudaStreamCreate(&push_bottom_stream)); + CUDA_RT_CALL(cudaEventCreateWithFlags(&compute_done, cudaEventDisableTiming)); + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_top_done, cudaEventDisableTiming)); + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_bottom_done, cudaEventDisableTiming)); + + CUDA_RT_CALL(cudaMalloc(&l2_norm_d, sizeof(real))); + CUDA_RT_CALL(cudaMallocHost(&l2_norm_h, sizeof(real))); + + CUDA_RT_CALL(cudaDeviceSynchronize()); + + if (print) + printf( + "Single GPU jacobi relaxation: %d iterations on %d x %d mesh with " + "norm " + "check every %d iterations\n", + iter_max, ny, nx, nccheck); + + int iter = 0; + bool calculate_norm; + real l2_norm = 1.0; + + double start = MPI_Wtime(); + PUSH_RANGE("Jacobi solve", 0) + while (l2_norm > tol && iter < iter_max) { + CUDA_RT_CALL(cudaMemsetAsync(l2_norm_d, 0, sizeof(real), compute_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); + + calculate_norm = (iter % nccheck) == 0 || (iter % 100) == 0; + launch_jacobi_kernel(a_new, a, l2_norm_d, iy_start, iy_end, nx, calculate_norm, + compute_stream); + CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); + + if (calculate_norm) { + CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, + compute_stream)); + } + + // Apply periodic boundary conditions + + CUDA_RT_CALL(cudaStreamWaitEvent(push_top_stream, compute_done, 0)); + CUDA_RT_CALL(cudaMemcpyAsync(a_new, a_new + (iy_end - 1) * nx, nx * sizeof(real), + cudaMemcpyDeviceToDevice, push_top_stream)); + CUDA_RT_CALL(cudaEventRecord(push_top_done, push_top_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(push_bottom_stream, compute_done, 0)); + CUDA_RT_CALL(cudaMemcpyAsync(a_new + iy_end * nx, a_new + iy_start * nx, nx * sizeof(real), + cudaMemcpyDeviceToDevice, compute_stream)); + CUDA_RT_CALL(cudaEventRecord(push_bottom_done, push_bottom_stream)); + + if (calculate_norm) { + CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); + l2_norm = *l2_norm_h; + l2_norm = std::sqrt(l2_norm); + if (print && (iter % 100) == 0) printf("%5d, %0.6f\n", iter, l2_norm); + } + + std::swap(a_new, a); + iter++; + } + POP_RANGE + double stop = MPI_Wtime(); + + CUDA_RT_CALL(cudaMemcpy(a_ref_h, a, nx * ny * sizeof(real), cudaMemcpyDeviceToHost)); + + CUDA_RT_CALL(cudaEventDestroy(push_bottom_done)); + CUDA_RT_CALL(cudaEventDestroy(push_top_done)); + CUDA_RT_CALL(cudaEventDestroy(compute_done)); + CUDA_RT_CALL(cudaStreamDestroy(push_bottom_stream)); + CUDA_RT_CALL(cudaStreamDestroy(push_top_stream)); + CUDA_RT_CALL(cudaStreamDestroy(compute_stream)); + + CUDA_RT_CALL(cudaFreeHost(l2_norm_h)); + CUDA_RT_CALL(cudaFree(l2_norm_d)); + + CUDA_RT_CALL(cudaFree(a_new)); + CUDA_RT_CALL(cudaFree(a)); + return (stop - start); +} diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi_kernels.cu b/6-H_Overlap_Communication_and_Computation_MPI/jacobi_kernels.cu new file mode 100644 index 0000000..0a6cb31 --- /dev/null +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi_kernels.cu @@ -0,0 +1,113 @@ +/* Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +#include + +#ifdef HAVE_CUB +#include +#endif // HAVE_CUB + +#define CUDA_RT_CALL(call) \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaSuccess != cudaStatus) \ + fprintf(stderr, \ + "ERROR: CUDA RT call \"%s\" in line %d of file %s failed " \ + "with " \ + "%s (%d).\n", \ + #call, __LINE__, __FILE__, cudaGetErrorString(cudaStatus), cudaStatus); \ + } + +#ifdef USE_DOUBLE +typedef double real; +#define MPI_REAL_TYPE MPI_DOUBLE +#else +typedef float real; +#define MPI_REAL_TYPE MPI_FLOAT +#endif + +__global__ void initialize_boundaries(real* __restrict__ const a_new, real* __restrict__ const a, + const real pi, const int offset, const int nx, + const int my_ny, const int ny) { + for (int iy = blockIdx.x * blockDim.x + threadIdx.x; iy < my_ny; iy += blockDim.x * gridDim.x) { + const real y0 = sin(2.0 * pi * (offset + iy) / (ny - 1)); + a[iy * nx + 0] = y0; + a[iy * nx + (nx - 1)] = y0; + a_new[iy * nx + 0] = y0; + a_new[iy * nx + (nx - 1)] = y0; + } +} + +void launch_initialize_boundaries(real* __restrict__ const a_new, real* __restrict__ const a, + const real pi, const int offset, const int nx, const int my_ny, + const int ny) { + initialize_boundaries<<>>(a_new, a, pi, offset, nx, my_ny, ny); + CUDA_RT_CALL(cudaGetLastError()); +} + +template +__global__ void jacobi_kernel(real* __restrict__ const a_new, const real* __restrict__ const a, + real* __restrict__ const l2_norm, const int iy_start, + const int iy_end, const int nx, const bool calculate_norm) { +#ifdef HAVE_CUB + typedef cub::BlockReduce + BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; +#endif // HAVE_CUB + int iy = blockIdx.y * blockDim.y + threadIdx.y + iy_start; + int ix = blockIdx.x * blockDim.x + threadIdx.x + 1; + real local_l2_norm = 0.0; + + if (iy < iy_end && ix < (nx - 1)) { + const real new_val = 0.25 * (a[iy * nx + ix + 1] + a[iy * nx + ix - 1] + + a[(iy + 1) * nx + ix] + a[(iy - 1) * nx + ix]); + a_new[iy * nx + ix] = new_val; + if (calculate_norm) { + real residue = new_val - a[iy * nx + ix]; + local_l2_norm += residue * residue; + } + } + if (calculate_norm) { +#ifdef HAVE_CUB + real block_l2_norm = BlockReduce(temp_storage).Sum(local_l2_norm); + if (0 == threadIdx.y && 0 == threadIdx.x) atomicAdd(l2_norm, block_l2_norm); +#else + atomicAdd(l2_norm, local_l2_norm); +#endif // HAVE_CUB + } +} + +void launch_jacobi_kernel(real* __restrict__ const a_new, const real* __restrict__ const a, + real* __restrict__ const l2_norm, const int iy_start, const int iy_end, + const int nx, const bool calculate_norm, cudaStream_t stream) { + constexpr int dim_block_x = 32; + constexpr int dim_block_y = 32; + dim3 dim_grid((nx + dim_block_x - 1) / dim_block_x, + ((iy_end - iy_start) + dim_block_y - 1) / dim_block_y, 1); + jacobi_kernel<<>>( + a_new, a, l2_norm, iy_start, iy_end, nx, calculate_norm); + CUDA_RT_CALL(cudaGetLastError()); +} From 678907307c1da2cefc3135902a6a4744470ee47f Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Tue, 19 Oct 2021 12:10:16 +0200 Subject: [PATCH 05/25] Tested and working properly, also added Instructions.md and copy.mk --- .../Instructions.md | 46 +++++++++++++++++++ .../copy.mk | 40 ++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 6-H_Overlap_Communication_and_Computation_MPI/Instructions.md create mode 100644 6-H_Overlap_Communication_and_Computation_MPI/copy.mk diff --git a/6-H_Overlap_Communication_and_Computation_MPI/Instructions.md b/6-H_Overlap_Communication_and_Computation_MPI/Instructions.md new file mode 100644 index 0000000..cdb8045 --- /dev/null +++ b/6-H_Overlap_Communication_and_Computation_MPI/Instructions.md @@ -0,0 +1,46 @@ +# SC21 Tutorial: Efficient Distributed GPU Programming for Exascale + +- Time: Sunday, 14 November 2021 8AM - 5PM CST +- Location: *online* +- Program Link: https://sc21.supercomputing.org/presentation/?id=tut138&sess=sess188 + + +## Hands-On 6: Overlap Communication and Computation with MPI + +### Task 0: Profile the non-Overlap MPI-CUDA version of the code using Nsight Systems to discover areas of possible compute/communication overlap + +#### Description +The purpose of this task is to use the Nsight System profiler to profile the starting point version non-Overlap MPI jacobi solver. The objective is to become familiar in navigating the GUI identify possible areas to overlap computation and communication. + +- STEPS TO BE ADDED HERE + +### Task 1: Overlap Communication and Computation using high priority streams and hide launch time for halo processing kernels + +#### Description + +The purpose of this task is to overlap computation and communication based on the profiling done during the previus task. The starting point of this task is the non-Overlap MPI variant of the jacobi solver. You need to work on `TODOs` in `jacobi.cu`: + +- Initialize a priority range to be used by the CUDA streams +- Create new top and bottom CUDA streams and corresponding CUDA events +- Initialize all streams using priorities +- Modify the original jacobi kernel launch to not compute the top and bottom regions +- Launch additional jacobi kernels for the top and bottom regions using the high-priority streams +- Wait on both top and bottom streams when calculating the norm +- Synchronize top and bottom streams before applying the periodic boundary conditions using MPI +- Destroy the additional cuda streams and events before ending the application + +Compile with + +``` {.bash} +make +``` + +Submit your compiled application to the batch system with + +``` {.bash} +make run +``` + +Study the performance by glimpsing at the profile generated with +`make profile`. For `make run` and `make profile` the environment variable `NP` can be set to change the number of processes. + diff --git a/6-H_Overlap_Communication_and_Computation_MPI/copy.mk b/6-H_Overlap_Communication_and_Computation_MPI/copy.mk new file mode 100644 index 0000000..fb23bd4 --- /dev/null +++ b/6-H_Overlap_Communication_and_Computation_MPI/copy.mk @@ -0,0 +1,40 @@ +#!/usr/bin/make -f +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +TASKDIR = ../../tasks/6-H_Overlap_Communication_and_Computation_MPI/ +SOLUTIONDIR = ../../solutions/6-H_Overlap_Communication_and_Computation_MPI + +PROCESSFILES = jacobi.cu +COPYFILES = Makefile Instructions.ipynb Instructions.md + + +TASKPROCCESFILES = $(addprefix $(TASKDIR)/,$(PROCESSFILES)) +TASKCOPYFILES = $(addprefix $(TASKDIR)/,$(COPYFILES)) +SOLUTIONPROCCESFILES = $(addprefix $(SOLUTIONDIR)/,$(PROCESSFILES)) +SOLUTIONCOPYFILES = $(addprefix $(SOLUTIONDIR)/,$(COPYFILES)) + +.PHONY: all task +all: task +task: ${TASKPROCCESFILES} ${TASKCOPYFILES} ${SOLUTIONPROCCESFILES} ${SOLUTIONCOPYFILES} + + +${TASKPROCCESFILES}: $(PROCESSFILES) + mkdir -p $(TASKDIR)/ + cppp -USOLUTION $(notdir $@) $@ + +${SOLUTIONPROCCESFILES}: $(PROCESSFILES) + mkdir -p $(SOLUTIONDIR)/ + cppp -DSOLUTION $(notdir $@) $@ + + +${TASKCOPYFILES}: $(COPYFILES) + mkdir -p $(TASKDIR)/ + cp $(notdir $@) $@ + +${SOLUTIONCOPYFILES}: $(COPYFILES) + mkdir -p $(SOLUTIONDIR)/ + cp $(notdir $@) $@ + +%.ipynb: %.md + pandoc $< -o $@ + # add metadata so this is seen as python + jq -s '.[0] * .[1]' $@ ../template.json | sponge $@ From 08b72c80334f19bedb10ac80c012233d75257915 Mon Sep 17 00:00:00 2001 From: Simon Garcia De Gonzalo Date: Wed, 27 Oct 2021 11:52:37 +0200 Subject: [PATCH 06/25] first draft of NCCL version, needs to be tested for correctness --- 8-H_NCCL_NVSHMEM/NCCL/Makefile | 42 ++ 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 586 ++++++++++++++++++++++++ 8-H_NCCL_NVSHMEM/NCCL/jacobi_kernels.cu | 113 +++++ 8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu | 16 +- 4 files changed, 756 insertions(+), 1 deletion(-) create mode 100644 8-H_NCCL_NVSHMEM/NCCL/Makefile create mode 100644 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp create mode 100644 8-H_NCCL_NVSHMEM/NCCL/jacobi_kernels.cu diff --git a/8-H_NCCL_NVSHMEM/NCCL/Makefile b/8-H_NCCL_NVSHMEM/NCCL/Makefile new file mode 100644 index 0000000..ab1bb61 --- /dev/null +++ b/8-H_NCCL_NVSHMEM/NCCL/Makefile @@ -0,0 +1,42 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +NP ?= 1 +NVCC=nvcc +MPICXX=mpicxx +MPIRUN ?= mpirun +CUDA_HOME ?= /usr/local/cuda +NCCL_HOME ?= /usr +GENCODE_SM30 := -gencode arch=compute_30,code=sm_30 +GENCODE_SM35 := -gencode arch=compute_35,code=sm_35 +GENCODE_SM37 := -gencode arch=compute_37,code=sm_37 +GENCODE_SM50 := -gencode arch=compute_50,code=sm_50 +GENCODE_SM52 := -gencode arch=compute_52,code=sm_52 +GENCODE_SM60 := -gencode arch=compute_60,code=sm_60 +GENCODE_SM70 := -gencode arch=compute_70,code=sm_70 +GENCODE_SM80 := -gencode arch=compute_80,code=sm_80 -gencode arch=compute_80,code=compute_80 +GENCODE_FLAGS := $(GENCODE_SM70) $(GENCODE_SM80) +ifdef DISABLE_CUB + NVCC_FLAGS = -Xptxas --optimize-float-atomics +else + NVCC_FLAGS = -DHAVE_CUB +endif +NVCC_FLAGS += -lineinfo $(GENCODE_FLAGS) -std=c++14 +MPICXX_FLAGS = -DUSE_NVTX -I$(CUDA_HOME)/include -I$(NCCL_HOME)/include -std=c++14 +LD_FLAGS = -L$(CUDA_HOME)/lib64 -lcudart -lnvToolsExt -lnccl +jacobi: Makefile jacobi.cpp jacobi_kernels.o + $(MPICXX) $(MPICXX_FLAGS) jacobi.cpp jacobi_kernels.o $(LD_FLAGS) -o jacobi + +jacobi_kernels.o: Makefile jacobi_kernels.cu + $(NVCC) $(NVCC_FLAGS) jacobi_kernels.cu -c + +.PHONY.: clean +clean: + rm -f jacobi jacobi_kernels.o *.qdrep jacobi.*.compute-sanitizer.log + +sanitize: jacobi + $(MPIRUN) -np $(NP) compute-sanitizer --log-file jacobi.%q{OMPI_COMM_WORLD_RANK}.compute-sanitizer.log ./jacobi -niter 10 + +run: jacobi + $(MPIRUN) -np $(NP) ./jacobi + +profile: jacobi + $(MPIRUN) -np $(NP) nsys profile --trace=mpi,cuda,nvtx -o jacobi.%q{OMPI_COMM_WORLD_RANK} ./jacobi -niter 10 diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp new file mode 100644 index 0000000..92bcb0b --- /dev/null +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -0,0 +1,586 @@ +/* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +#include +#include +#include +#include +#include + +#include + +#define MPI_CALL(call) \ + { \ + int mpi_status = call; \ + if (0 != mpi_status) { \ + char mpi_error_string[MPI_MAX_ERROR_STRING]; \ + int mpi_error_string_length = 0; \ + MPI_Error_string(mpi_status, mpi_error_string, &mpi_error_string_length); \ + if (NULL != mpi_error_string) \ + fprintf(stderr, \ + "ERROR: MPI call \"%s\" in line %d of file %s failed " \ + "with %s " \ + "(%d).\n", \ + #call, __LINE__, __FILE__, mpi_error_string, mpi_status); \ + else \ + fprintf(stderr, \ + "ERROR: MPI call \"%s\" in line %d of file %s failed " \ + "with %d.\n", \ + #call, __LINE__, __FILE__, mpi_status); \ + } \ + } + +#include + +#ifdef USE_NVTX +#include + +const uint32_t colors[] = {0x0000ff00, 0x000000ff, 0x00ffff00, 0x00ff00ff, + 0x0000ffff, 0x00ff0000, 0x00ffffff}; +const int num_colors = sizeof(colors) / sizeof(uint32_t); + +#define PUSH_RANGE(name, cid) \ + { \ + int color_id = cid; \ + color_id = color_id % num_colors; \ + nvtxEventAttributes_t eventAttrib = {0}; \ + eventAttrib.version = NVTX_VERSION; \ + eventAttrib.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; \ + eventAttrib.colorType = NVTX_COLOR_ARGB; \ + eventAttrib.color = colors[color_id]; \ + eventAttrib.messageType = NVTX_MESSAGE_TYPE_ASCII; \ + eventAttrib.message.ascii = name; \ + nvtxRangePushEx(&eventAttrib); \ + } +#define POP_RANGE nvtxRangePop(); +#else +#define PUSH_RANGE(name, cid) +#define POP_RANGE +#endif + +#define CUDA_RT_CALL(call) \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaSuccess != cudaStatus) \ + fprintf(stderr, \ + "ERROR: CUDA RT call \"%s\" in line %d of file %s failed " \ + "with " \ + "%s (%d).\n", \ + #call, __LINE__, __FILE__, cudaGetErrorString(cudaStatus), cudaStatus); \ + } +#ifdef SOLUTION +#include +#else + //TODO: include NCCL headers +#endif + +#ifdef SOLUTION +#define NCCL_CALL(call) \ + { \ + ncclResult_t ncclStatus = call; \ + if (ncclSuccess != ncclStatus) \ + fprintf(stderr, \ + "ERROR: NCCL call \"%s\" in line %d of file %s failed " \ + "with " \ + "%s (%d).\n", \ + #call, __LINE__, __FILE__, ncclGetErrorString(ncclStatus), ncclStatus); \ + } +#else +//TODO: Un-comment the following given NCCL_CALL definition to use in your code: +/* +#define NCCL_CALL(call) \ + { \ + ncclResult_t ncclStatus = call; \ + if (ncclSuccess != ncclStatus) \ + fprintf(stderr, \ + "ERROR: NCCL call \"%s\" in line %d of file %s failed " \ + "with " \ + "%s (%d).\n", \ + #call, __LINE__, __FILE__, ncclGetErrorString(ncclStatus), ncclStatus); \ + } +#endif +*/ +#endif + +#ifdef USE_DOUBLE +typedef double real; +#define MPI_REAL_TYPE MPI_DOUBLE +#ifdef SOLUTION +#define NCCL_REAL_TYPE ncclDouble +#else +//TODO:define NCCL_REAL_TYPE using its corresponding nccl double type +//HINT:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html +#endif +#else +typedef float real; +#define MPI_REAL_TYPE MPI_FLOAT +#ifdef SOLUTION +#define NCCL_REAL_TYPE ncclFloat +#else +//TODO:define NCCL_REAL_TYPE using its corresponding nccl float type +//HINT:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html +#endif +#endif + +constexpr real tol = 1.0e-8; + +const real PI = 2.0 * std::asin(1.0); + +void launch_initialize_boundaries(real* __restrict__ const a_new, real* __restrict__ const a, + const real pi, const int offset, const int nx, const int my_ny, + const int ny); + +void launch_jacobi_kernel(real* __restrict__ const a_new, const real* __restrict__ const a, + real* __restrict__ const l2_norm, const int iy_start, const int iy_end, + const int nx, const bool calculate_norm, cudaStream_t stream); + +double single_gpu(const int nx, const int ny, const int iter_max, real* const a_ref_h, + const int nccheck, const bool print); + +template +T get_argval(char** begin, char** end, const std::string& arg, const T default_val) { + T argval = default_val; + char** itr = std::find(begin, end, arg); + if (itr != end && ++itr != end) { + std::istringstream inbuf(*itr); + inbuf >> argval; + } + return argval; +} + +bool get_arg(char** begin, char** end, const std::string& arg) { + char** itr = std::find(begin, end, arg); + if (itr != end) { + return true; + } + return false; +} + +int main(int argc, char* argv[]) { + MPI_CALL(MPI_Init(&argc, &argv)); + int rank; + MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); + int size; + MPI_CALL(MPI_Comm_size(MPI_COMM_WORLD, &size)); +#ifdef SOLUTION + ncclUniqueId nccl_uid; + if (rank == 0) NCCL_CALL(ncclGetUniqueId(&nccl_uid)); + MPI_CALL(MPI_Bcast(&nccl_uid, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD)); +#else + //TODO: Create a ncclUniqueId, have rank 0 initize it by using the appropriate runtime call, + // and remember to broadcast it to all ranks using MPI + //HINT: Remember to wrap your nccl calls using the above NCCL_CALL definition :) +#endif + + const int iter_max = get_argval(argv, argv + argc, "-niter", 1000); + const int nccheck = get_argval(argv, argv + argc, "-nccheck", 1); + const int nx = get_argval(argv, argv + argc, "-nx", 16384); + const int ny = get_argval(argv, argv + argc, "-ny", 16384); + const bool csv = get_arg(argv, argv + argc, "-csv"); + + int local_rank = -1; + { + MPI_Comm local_comm; + MPI_CALL(MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, rank, MPI_INFO_NULL, + &local_comm)); + + MPI_CALL(MPI_Comm_rank(local_comm, &local_rank)); + + MPI_CALL(MPI_Comm_free(&local_comm)); + } + + CUDA_RT_CALL(cudaSetDevice(local_rank)); + CUDA_RT_CALL(cudaFree(0)); +#ifdef SOLUTION + ncclComm_t nccl_comm; + NCCL_CALL(ncclCommInitRank(&nccl_comm, size, nccl_uid, rank)); + int nccl_version = 0; + NCCL_CALL(ncclGetVersion(&nccl_version)); + if ( nccl_version < 2800 ) { + fprintf(stderr,"ERROR NCCL 2.8 or newer is required.\n"); + NCCL_CALL(ncclCommDestroy(nccl_comm)); + MPI_CALL(MPI_Finalize()); + return 1; + } +#else + //TODO: Create a communicator (ncclComm_t), initialize it (ncclCommInitRank) +#endif + + real* a_ref_h; + CUDA_RT_CALL(cudaMallocHost(&a_ref_h, nx * ny * sizeof(real))); + real* a_h; + CUDA_RT_CALL(cudaMallocHost(&a_h, nx * ny * sizeof(real))); + double runtime_serial = single_gpu(nx, ny, iter_max, a_ref_h, nccheck, !csv && (0 == rank)); + + // ny - 2 rows are distributed amongst `size` ranks in such a way + // that each rank gets either (ny - 2) / size or (ny - 2) / size + 1 rows. + // This optimizes load balancing when (ny - 2) % size != 0 + int chunk_size; + int chunk_size_low = (ny - 2) / size; + int chunk_size_high = chunk_size_low + 1; + // To calculate the number of ranks that need to compute an extra row, + // the following formula is derived from this equation: + // num_ranks_low * chunk_size_low + (size - num_ranks_low) * (chunk_size_low + 1) = ny - 2 + int num_ranks_low = size * chunk_size_low + size - + (ny - 2); // Number of ranks with chunk_size = chunk_size_low + if (rank < num_ranks_low) + chunk_size = chunk_size_low; + else + chunk_size = chunk_size_high; + + real* a; + CUDA_RT_CALL(cudaMalloc(&a, nx * (chunk_size + 2) * sizeof(real))); + real* a_new; + CUDA_RT_CALL(cudaMalloc(&a_new, nx * (chunk_size + 2) * sizeof(real))); + + CUDA_RT_CALL(cudaMemset(a, 0, nx * (chunk_size + 2) * sizeof(real))); + CUDA_RT_CALL(cudaMemset(a_new, 0, nx * (chunk_size + 2) * sizeof(real))); + + // Calculate local domain boundaries + int iy_start_global; // My start index in the global array + if (rank < num_ranks_low) { + iy_start_global = rank * chunk_size_low + 1; + } else { + iy_start_global = + num_ranks_low * chunk_size_low + (rank - num_ranks_low) * chunk_size_high + 1; + } + int iy_end_global = iy_start_global + chunk_size - 1; // My last index in the global array + + int iy_start = 1; + int iy_end = iy_start + chunk_size; + + // Set diriclet boundary conditions on left and right boarder + launch_initialize_boundaries(a, a_new, PI, iy_start_global - 1, nx, (chunk_size + 2), ny); + CUDA_RT_CALL(cudaDeviceSynchronize()); + + int leastPriority = 0; + int greatestPriority = leastPriority; + CUDA_RT_CALL(cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority)); + cudaStream_t compute_stream; + CUDA_RT_CALL(cudaStreamCreateWithPriority(&compute_stream, cudaStreamDefault, leastPriority)); + cudaStream_t push_stream; + CUDA_RT_CALL( + cudaStreamCreateWithPriority(&push_stream, cudaStreamDefault, greatestPriority)); + + cudaEvent_t push_prep_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_prep_done, cudaEventDisableTiming)); + cudaEvent_t push_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_done, cudaEventDisableTiming)); + cudaEvent_t reset_l2norm_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&reset_l2norm_done, cudaEventDisableTiming)); + + real* l2_norm_d; + CUDA_RT_CALL(cudaMalloc(&l2_norm_d, sizeof(real))); + real* l2_norm_h; + CUDA_RT_CALL(cudaMallocHost(&l2_norm_h, sizeof(real))); + +#ifdef SOLUTION + PUSH_RANGE("NCCL_Warmup", 5) +#else + //TODO: Rename range + PUSH_RANGE("MPI_Warmup", 5) +#endif + for (int i = 0; i < 10; ++i) { + const int top = rank > 0 ? rank - 1 : (size - 1); + const int bottom = (rank + 1) % size; +#ifdef SOLUTION + NCCL_CALL(ncclGroupStart()); + NCCL_CALL(ncclRecv(a_new, nx, NCCL_REAL_TYPE, top, nccl_comm, compute_stream)); + NCCL_CALL(ncclSend(a_new + (iy_end - 1) * nx, nx, NCCL_REAL_TYPE, bottom, nccl_comm, compute_stream)); + NCCL_CALL(ncclRecv(a_new + (iy_end * nx), nx, NCCL_REAL_TYPE, bottom, nccl_comm, compute_stream)); + NCCL_CALL(ncclSend(a_new + iy_start * nx, nx, NCCL_REAL_TYPE, top, nccl_comm, compute_stream)); + NCCL_CALL(ncclGroupEnd()); + CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); +#else + //TODO: Replace the MPI_Sendrecv calls with ncclRecv and ncclSend calls using the nccl communicator + // on the compute_stream. + // Remeber that a group of ncclRecv and ncclSend should be within a ncclGroupStart() and ncclGroupEnd() + // Also, Rember to stream synchronize on the compute_stream at the end + MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, + a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, + MPI_STATUS_IGNORE)); + MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, + MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); +#endif + std::swap(a_new, a); + } + POP_RANGE + + CUDA_RT_CALL(cudaDeviceSynchronize()); + + if (!csv && 0 == rank) { + printf( + "Jacobi relaxation: %d iterations on %d x %d mesh with norm check " + "every %d iterations\n", + iter_max, ny, nx, nccheck); + } + + int iter = 0; + bool calculate_norm; + real l2_norm = 1.0; + + MPI_CALL(MPI_Barrier(MPI_COMM_WORLD)); + double start = MPI_Wtime(); + PUSH_RANGE("Jacobi solve", 0) + while (l2_norm > tol && iter < iter_max) { + CUDA_RT_CALL(cudaMemsetAsync(l2_norm_d, 0, sizeof(real), compute_stream)); + CUDA_RT_CALL(cudaEventRecord(reset_l2norm_done, compute_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(push_stream, reset_l2norm_done, 0)); + calculate_norm = (iter % nccheck) == 0 || (!csv && (iter % 100) == 0); + + launch_jacobi_kernel(a_new, a, l2_norm_d, (iy_start + 1), (iy_end - 1), nx, calculate_norm, + compute_stream); + + launch_jacobi_kernel(a_new, a, l2_norm_d, iy_start, (iy_start + 1), nx, calculate_norm, + push_stream); + + launch_jacobi_kernel(a_new, a, l2_norm_d, (iy_end - 1), iy_end, nx, calculate_norm, + push_stream); + CUDA_RT_CALL(cudaEventRecord(push_prep_done, push_stream)); + + if (calculate_norm) { + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_prep_done, 0)); + CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, + compute_stream)); + } + + const int top = rank > 0 ? rank - 1 : (size - 1); + const int bottom = (rank + 1) % size; + + // Apply periodic boundary conditions +#ifdef SOLUTION + PUSH_RANGE("NCCL_LAUNCH", 5) + NCCL_CALL(ncclGroupStart()); + NCCL_CALL(ncclRecv(a_new, nx, NCCL_REAL_TYPE, top, nccl_comm, push_stream)); + NCCL_CALL(ncclSend(a_new + (iy_end - 1) * nx, nx, NCCL_REAL_TYPE, bottom, nccl_comm, push_stream)); + NCCL_CALL(ncclRecv(a_new + (iy_end * nx), nx, NCCL_REAL_TYPE, bottom, nccl_comm, push_stream)); + NCCL_CALL(ncclSend(a_new + iy_start * nx, nx, NCCL_REAL_TYPE, top, nccl_comm, push_stream)); + NCCL_CALL(ncclGroupEnd()); +#else + //TODO: Modify the lable for the RANGE, and replace MPI_Sendrecv with ncclSend and ncclRecv calls + // using the nccl communicator and push_stream. + // Remember to use ncclGroupStart() and ncclGroupEnd() + PUSH_RANGE("MPI", 5) + MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, + a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, + MPI_STATUS_IGNORE)); + CUDA_RT_CALL(cudaStreamSynchronize(push_bottom_stream)); + MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, + MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); +#endif + CUDA_RT_CALL(cudaEventRecord(push_done, push_stream)); + POP_RANGE + + if (calculate_norm) { + CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); + MPI_CALL(MPI_Allreduce(l2_norm_h, &l2_norm, 1, MPI_REAL_TYPE, MPI_SUM, MPI_COMM_WORLD)); + l2_norm = std::sqrt(l2_norm); + + if (!csv && 0 == rank && (iter % 100) == 0) { + printf("%5d, %0.6f\n", iter, l2_norm); + } + } + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_done, 0)); + + std::swap(a_new, a); + iter++; + } + CUDA_RT_CALL(cudaDeviceSynchronize()); + double stop = MPI_Wtime(); + POP_RANGE + + CUDA_RT_CALL(cudaMemcpy(a_h + iy_start_global * nx, a + nx, + std::min((ny - iy_start_global) * nx, chunk_size * nx) * sizeof(real), + cudaMemcpyDeviceToHost)); + + int result_correct = 1; + for (int iy = iy_start_global; result_correct && (iy < iy_end_global); ++iy) { + for (int ix = 1; result_correct && (ix < (nx - 1)); ++ix) { + if (std::fabs(a_ref_h[iy * nx + ix] - a_h[iy * nx + ix]) > tol) { + fprintf(stderr, + "ERROR on rank %d: a[%d * %d + %d] = %f does not match %f " + "(reference)\n", + rank, iy, nx, ix, a_h[iy * nx + ix], a_ref_h[iy * nx + ix]); + result_correct = 0; + } + } + } + + int global_result_correct = 1; + MPI_CALL(MPI_Allreduce(&result_correct, &global_result_correct, 1, MPI_INT, MPI_MIN, + MPI_COMM_WORLD)); + result_correct = global_result_correct; + + if (rank == 0 && result_correct) { + if (csv) { +#ifdef SOLUTION + printf("nccl_overlap, %d, %d, %d, %d, %d, 1, %f, %f\n", nx, ny, iter_max, nccheck, size, +#else + //TODO: Dont forget to change your output lable from mpi_overlap to nccl_overlap + printf("mpi_overlap, %d, %d, %d, %d, %d, 1, %f, %f\n", nx, ny, iter_max, nccheck, size, +#endif + (stop - start), runtime_serial); + } else { + printf("Num GPUs: %d.\n", size); + printf( + "%dx%d: 1 GPU: %8.4f s, %d GPUs: %8.4f s, speedup: %8.2f, " + "efficiency: %8.2f \n", + ny, nx, runtime_serial, size, (stop - start), runtime_serial / (stop - start), + runtime_serial / (size * (stop - start)) * 100); + } + } + CUDA_RT_CALL(cudaEventDestroy(reset_l2norm_done)); + CUDA_RT_CALL(cudaEventDestroy(push_done)); + CUDA_RT_CALL(cudaEventDestroy(push_prep_done)); + CUDA_RT_CALL(cudaStreamDestroy(push_stream)); + CUDA_RT_CALL(cudaStreamDestroy(compute_stream)); + + CUDA_RT_CALL(cudaFreeHost(l2_norm_h)); + CUDA_RT_CALL(cudaFree(l2_norm_d)); + + CUDA_RT_CALL(cudaFree(a_new)); + CUDA_RT_CALL(cudaFree(a)); + + CUDA_RT_CALL(cudaFreeHost(a_h)); + CUDA_RT_CALL(cudaFreeHost(a_ref_h)); + + NCCL_CALL(ncclCommDestroy(nccl_comm)); + + MPI_CALL(MPI_Finalize()); + return (result_correct == 1) ? 0 : 1; +} + +double single_gpu(const int nx, const int ny, const int iter_max, real* const a_ref_h, + const int nccheck, const bool print) { + real* a; + real* a_new; + + cudaStream_t compute_stream; + cudaStream_t push_top_stream; + cudaStream_t push_bottom_stream; + cudaEvent_t compute_done; + cudaEvent_t push_top_done; + cudaEvent_t push_bottom_done; + + real* l2_norm_d; + real* l2_norm_h; + + int iy_start = 1; + int iy_end = (ny - 1); + + CUDA_RT_CALL(cudaMalloc(&a, nx * ny * sizeof(real))); + CUDA_RT_CALL(cudaMalloc(&a_new, nx * ny * sizeof(real))); + + CUDA_RT_CALL(cudaMemset(a, 0, nx * ny * sizeof(real))); + CUDA_RT_CALL(cudaMemset(a_new, 0, nx * ny * sizeof(real))); + + // Set diriclet boundary conditions on left and right boarder + launch_initialize_boundaries(a, a_new, PI, 0, nx, ny, ny); + CUDA_RT_CALL(cudaDeviceSynchronize()); + + CUDA_RT_CALL(cudaStreamCreate(&compute_stream)); + CUDA_RT_CALL(cudaStreamCreate(&push_top_stream)); + CUDA_RT_CALL(cudaStreamCreate(&push_bottom_stream)); + CUDA_RT_CALL(cudaEventCreateWithFlags(&compute_done, cudaEventDisableTiming)); + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_top_done, cudaEventDisableTiming)); + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_bottom_done, cudaEventDisableTiming)); + + CUDA_RT_CALL(cudaMalloc(&l2_norm_d, sizeof(real))); + CUDA_RT_CALL(cudaMallocHost(&l2_norm_h, sizeof(real))); + + CUDA_RT_CALL(cudaDeviceSynchronize()); + + if (print) + printf( + "Single GPU jacobi relaxation: %d iterations on %d x %d mesh with " + "norm " + "check every %d iterations\n", + iter_max, ny, nx, nccheck); + + int iter = 0; + bool calculate_norm; + real l2_norm = 1.0; + + double start = MPI_Wtime(); + PUSH_RANGE("Jacobi solve", 0) + while (l2_norm > tol && iter < iter_max) { + CUDA_RT_CALL(cudaMemsetAsync(l2_norm_d, 0, sizeof(real), compute_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); + + calculate_norm = (iter % nccheck) == 0 || (iter % 100) == 0; + launch_jacobi_kernel(a_new, a, l2_norm_d, iy_start, iy_end, nx, calculate_norm, + compute_stream); + CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); + + if (calculate_norm) { + CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, + compute_stream)); + } + + // Apply periodic boundary conditions + + CUDA_RT_CALL(cudaStreamWaitEvent(push_top_stream, compute_done, 0)); + CUDA_RT_CALL(cudaMemcpyAsync(a_new, a_new + (iy_end - 1) * nx, nx * sizeof(real), + cudaMemcpyDeviceToDevice, push_top_stream)); + CUDA_RT_CALL(cudaEventRecord(push_top_done, push_top_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(push_bottom_stream, compute_done, 0)); + CUDA_RT_CALL(cudaMemcpyAsync(a_new + iy_end * nx, a_new + iy_start * nx, nx * sizeof(real), + cudaMemcpyDeviceToDevice, compute_stream)); + CUDA_RT_CALL(cudaEventRecord(push_bottom_done, push_bottom_stream)); + + if (calculate_norm) { + CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); + l2_norm = *l2_norm_h; + l2_norm = std::sqrt(l2_norm); + if (print && (iter % 100) == 0) printf("%5d, %0.6f\n", iter, l2_norm); + } + + std::swap(a_new, a); + iter++; + } + POP_RANGE + double stop = MPI_Wtime(); + + CUDA_RT_CALL(cudaMemcpy(a_ref_h, a, nx * ny * sizeof(real), cudaMemcpyDeviceToHost)); + + CUDA_RT_CALL(cudaEventDestroy(push_bottom_done)); + CUDA_RT_CALL(cudaEventDestroy(push_top_done)); + CUDA_RT_CALL(cudaEventDestroy(compute_done)); + CUDA_RT_CALL(cudaStreamDestroy(push_bottom_stream)); + CUDA_RT_CALL(cudaStreamDestroy(push_top_stream)); + CUDA_RT_CALL(cudaStreamDestroy(compute_stream)); + + CUDA_RT_CALL(cudaFreeHost(l2_norm_h)); + CUDA_RT_CALL(cudaFree(l2_norm_d)); + + CUDA_RT_CALL(cudaFree(a_new)); + CUDA_RT_CALL(cudaFree(a)); + return (stop - start); +} diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi_kernels.cu b/8-H_NCCL_NVSHMEM/NCCL/jacobi_kernels.cu new file mode 100644 index 0000000..b98a9d4 --- /dev/null +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi_kernels.cu @@ -0,0 +1,113 @@ +/* Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions + * are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of NVIDIA CORPORATION nor the names of its + * contributors may be used to endorse or promote products derived + * from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY + * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR + * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR + * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, + * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, + * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR + * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY + * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + */ +#include + +#ifdef HAVE_CUB +#include +#endif // HAVE_CUB + +#define CUDA_RT_CALL(call) \ + { \ + cudaError_t cudaStatus = call; \ + if (cudaSuccess != cudaStatus) \ + fprintf(stderr, \ + "ERROR: CUDA RT call \"%s\" in line %d of file %s failed " \ + "with " \ + "%s (%d).\n", \ + #call, __LINE__, __FILE__, cudaGetErrorString(cudaStatus), cudaStatus); \ + } + +#ifdef USE_DOUBLE +typedef double real; +#define MPI_REAL_TYPE MPI_DOUBLE +#else +typedef float real; +#define MPI_REAL_TYPE MPI_FLOAT +#endif + +__global__ void initialize_boundaries(real* __restrict__ const a_new, real* __restrict__ const a, + const real pi, const int offset, const int nx, + const int my_ny, const int ny) { + for (int iy = blockIdx.x * blockDim.x + threadIdx.x; iy < my_ny; iy += blockDim.x * gridDim.x) { + const real y0 = sin(2.0 * pi * (offset + iy) / (ny - 1)); + a[iy * nx + 0] = y0; + a[iy * nx + (nx - 1)] = y0; + a_new[iy * nx + 0] = y0; + a_new[iy * nx + (nx - 1)] = y0; + } +} + +void launch_initialize_boundaries(real* __restrict__ const a_new, real* __restrict__ const a, + const real pi, const int offset, const int nx, const int my_ny, + const int ny) { + initialize_boundaries<<>>(a_new, a, pi, offset, nx, my_ny, ny); + CUDA_RT_CALL(cudaGetLastError()); +} + +template +__global__ void jacobi_kernel(real* __restrict__ const a_new, const real* __restrict__ const a, + real* __restrict__ const l2_norm, const int iy_start, + const int iy_end, const int nx, const bool calculate_norm) { +#ifdef HAVE_CUB + typedef cub::BlockReduce + BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; +#endif // HAVE_CUB + int iy = blockIdx.y * blockDim.y + threadIdx.y + iy_start; + int ix = blockIdx.x * blockDim.x + threadIdx.x + 1; + real local_l2_norm = 0.0; + + if (iy < iy_end && ix < (nx - 1)) { + const real new_val = 0.25 * (a[iy * nx + ix + 1] + a[iy * nx + ix - 1] + + a[(iy + 1) * nx + ix] + a[(iy - 1) * nx + ix]); + a_new[iy * nx + ix] = new_val; + if (calculate_norm) { + real residue = new_val - a[iy * nx + ix]; + local_l2_norm += residue * residue; + } + } + if (calculate_norm) { +#ifdef HAVE_CUB + real block_l2_norm = BlockReduce(temp_storage).Sum(local_l2_norm); + if (0 == threadIdx.y && 0 == threadIdx.x) atomicAdd(l2_norm, block_l2_norm); +#else + atomicAdd(l2_norm, local_l2_norm); +#endif // HAVE_CUB + } +} + +void launch_jacobi_kernel(real* __restrict__ const a_new, const real* __restrict__ const a, + real* __restrict__ const l2_norm, const int iy_start, const int iy_end, + const int nx, const bool calculate_norm, cudaStream_t stream) { + constexpr int dim_block_x = 32; + constexpr int dim_block_y = 32; + dim3 dim_grid((nx + dim_block_x - 1) / dim_block_x, + ((iy_end - iy_start) + dim_block_y - 1) / dim_block_y, 1); + jacobi_kernel<<>>( + a_new, a, l2_norm, iy_start, iy_end, nx, calculate_norm); + CUDA_RT_CALL(cudaGetLastError()); +} diff --git a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu index 8d96fbb..4915897 100644 --- a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu +++ b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu @@ -283,11 +283,25 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaGetLastError()); CUDA_RT_CALL(cudaDeviceSynchronize()); +#ifdef SOLUTION + int leastPriority = 0; + int greatestPriority = leastPriority; + CUDA_RT_CALL(cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority)); + cudaStream_t compute_stream; + cudaStream_t push_top_stream; + cudaStream_t push_bottom_stream; + + CUDA_RT_CALL(cudaStreamCreateWithPriority(&compute_stream, cudaStreamDefault, leastPriority)); + CUDA_RT_CALL(cudaStreamCreateWithPriority(&push_top_stream, cudaStreamDefault, greatestPriority)); + CUDA_RT_CALL(cudaStreamCreateWithPriority(&push_bottom_stream, cudaStreamDefault, greatestPriority)); + +#else + //TODO: cudaStream_t compute_stream; CUDA_RT_CALL(cudaStreamCreate(&compute_stream)); cudaEvent_t compute_done; CUDA_RT_CALL(cudaEventCreateWithFlags(&compute_done, cudaEventDisableTiming)); - +#endif real* l2_norm_d; CUDA_RT_CALL(cudaMalloc(&l2_norm_d, sizeof(real))); real* l2_norm_h; From f599f89492e4d2fc57340d9ac3bc2a0067870383 Mon Sep 17 00:00:00 2001 From: Simon Garcia De Gonzalo Date: Wed, 27 Oct 2021 14:13:44 +0200 Subject: [PATCH 07/25] NCCL and host-side NVSHMEM with overlap first version, both need to be tested for correctness --- 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 24 ++--------- 8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu | 69 +++++++++++++++++++----------- 2 files changed, 49 insertions(+), 44 deletions(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index 92bcb0b..d0d656c 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -297,34 +297,15 @@ int main(int argc, char* argv[]) { real* l2_norm_h; CUDA_RT_CALL(cudaMallocHost(&l2_norm_h, sizeof(real))); -#ifdef SOLUTION - PUSH_RANGE("NCCL_Warmup", 5) -#else - //TODO: Rename range PUSH_RANGE("MPI_Warmup", 5) -#endif for (int i = 0; i < 10; ++i) { const int top = rank > 0 ? rank - 1 : (size - 1); const int bottom = (rank + 1) % size; -#ifdef SOLUTION - NCCL_CALL(ncclGroupStart()); - NCCL_CALL(ncclRecv(a_new, nx, NCCL_REAL_TYPE, top, nccl_comm, compute_stream)); - NCCL_CALL(ncclSend(a_new + (iy_end - 1) * nx, nx, NCCL_REAL_TYPE, bottom, nccl_comm, compute_stream)); - NCCL_CALL(ncclRecv(a_new + (iy_end * nx), nx, NCCL_REAL_TYPE, bottom, nccl_comm, compute_stream)); - NCCL_CALL(ncclSend(a_new + iy_start * nx, nx, NCCL_REAL_TYPE, top, nccl_comm, compute_stream)); - NCCL_CALL(ncclGroupEnd()); - CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); -#else - //TODO: Replace the MPI_Sendrecv calls with ncclRecv and ncclSend calls using the nccl communicator - // on the compute_stream. - // Remeber that a group of ncclRecv and ncclSend should be within a ncclGroupStart() and ncclGroupEnd() - // Also, Rember to stream synchronize on the compute_stream at the end MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); -#endif std::swap(a_new, a); } POP_RANGE @@ -467,8 +448,11 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaFreeHost(a_h)); CUDA_RT_CALL(cudaFreeHost(a_ref_h)); - +#ifdef SOLUTION NCCL_CALL(ncclCommDestroy(nccl_comm)); +#else + //TODO: Destroy the nccl communicator +#endif MPI_CALL(MPI_Finalize()); return (result_correct == 1) ? 0 : 1; diff --git a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu index 4915897..8fe71d7 100644 --- a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu +++ b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu @@ -283,25 +283,27 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaGetLastError()); CUDA_RT_CALL(cudaDeviceSynchronize()); -#ifdef SOLUTION int leastPriority = 0; int greatestPriority = leastPriority; CUDA_RT_CALL(cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority)); cudaStream_t compute_stream; - cudaStream_t push_top_stream; - cudaStream_t push_bottom_stream; - CUDA_RT_CALL(cudaStreamCreateWithPriority(&compute_stream, cudaStreamDefault, leastPriority)); - CUDA_RT_CALL(cudaStreamCreateWithPriority(&push_top_stream, cudaStreamDefault, greatestPriority)); - CUDA_RT_CALL(cudaStreamCreateWithPriority(&push_bottom_stream, cudaStreamDefault, greatestPriority)); + cudaStream_t push_stream; + CUDA_RT_CALL( + cudaStreamCreateWithPriority(&push_stream, cudaStreamDefault, greatestPriority)); + + cudaEvent_t push_prep_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_prep_done, cudaEventDisableTiming)); + cudaEvent_t push_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&push_done, cudaEventDisableTiming)); + cudaEvent_t reset_l2norm_done; + CUDA_RT_CALL(cudaEventCreateWithFlags(&reset_l2norm_done, cudaEventDisableTiming)); + +// cudaStream_t compute_stream; +// CUDA_RT_CALL(cudaStreamCreate(&compute_stream)); +// cudaEvent_t compute_done; +// CUDA_RT_CALL(cudaEventCreateWithFlags(&compute_done, cudaEventDisableTiming)); -#else - //TODO: - cudaStream_t compute_stream; - CUDA_RT_CALL(cudaStreamCreate(&compute_stream)); - cudaEvent_t compute_done; - CUDA_RT_CALL(cudaEventCreateWithFlags(&compute_done, cudaEventDisableTiming)); -#endif real* l2_norm_d; CUDA_RT_CALL(cudaMalloc(&l2_norm_d, sizeof(real))); real* l2_norm_h; @@ -344,42 +346,54 @@ int main(int argc, char* argv[]) { PUSH_RANGE("Jacobi solve", 0) while (l2_norm > tol && iter < iter_max) { CUDA_RT_CALL(cudaMemsetAsync(l2_norm_d, 0, sizeof(real), compute_stream)); + CUDA_RT_CALL(cudaEventRecord(reset_l2norm_done, compute_stream)); + CUDA_RT_CALL(cudaStreamWaitEvent(push_stream, reset_l2norm_done, 0)); calculate_norm = (iter % nccheck) == 0 || (!csv && (iter % 100) == 0); jacobi_kernel<<>>( - a_new, a, l2_norm_d, iy_start, iy_end, nx, calculate_norm); - CUDA_RT_CALL(cudaGetLastError()); - CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); + a_new, a, l2_norm_d, (iy_start + 1), (iy_end - 1), nx, calculate_norm); + + jacobi_kernel<<>>( + a_new, a, l2_norm_d, iy_start, (iy_start + 1), nx, calculate_norm); + + jacobi_kernel<<>>( + a_new, a, l2_norm_d, (iy_end - 1), iy_end, nx, calculate_norm); + + //CUDA_RT_CALL(cudaGetLastError()); + //CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); + CUDA_RT_CALL(cudaEventRecord(push_prep_done, push_stream)); if (calculate_norm) { + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_prep_done, 0)); CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, compute_stream)); } #ifdef SOLUTION - - nvshmemx_float_put_on_stream(a_new + iy_top_lower_boundary_idx * nx, a_new + iy_start * nx, nx, top, compute_stream); - nvshmemx_float_put_on_stream(a_new + iy_bottom_upper_boundary_idx * nx, a_new + (iy_end - 1) * nx, nx, bottom, compute_stream); - + PUSH_RANGE("NVSHMEM", 5) + nvshmemx_float_put_on_stream(a_new + iy_top_lower_boundary_idx * nx, a_new + iy_start * nx, nx, top, push_stream); + nvshmemx_float_put_on_stream(a_new + iy_bottom_upper_boundary_idx * nx, a_new + (iy_end - 1) * nx, nx, bottom, push_stream); #else //TODO: Replace MPI communication with Host initiated NVSHMEM calls // Apply periodic boundary conditions - CUDA_RT_CALL(cudaEventSynchronize(compute_done)); PUSH_RANGE("MPI", 5) MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); - POP_RANGE #endif + CUDA_RT_CALL(cudaEventRecord(push_done, push_stream)); + POP_RANGE #ifdef SOLUTION nvshmemx_barrier_all_on_stream(compute_stream); + nvshmemx_barrier_all_on_stream(push_stream); #else - //TODO: add necessary inter PE synchronization + //TODO: add necessary inter PE synchronization using the nvshmemx_barrier_all_on_stream(...) + // for both streams #endif @@ -423,7 +437,11 @@ int main(int argc, char* argv[]) { if (rank == 0 && result_correct) { if (csv) { +#ifdef SOLUTION + printf("nvshmem, %d, %d, %d, %d, %d, 1, %f, %f\n", nx, ny, iter_max, nccheck, size, +#else printf("mpi, %d, %d, %d, %d, %d, 1, %f, %f\n", nx, ny, iter_max, nccheck, size, +#endif (stop - start), runtime_serial); } else { printf("Num GPUs: %d.\n", size); @@ -434,7 +452,10 @@ int main(int argc, char* argv[]) { runtime_serial / (size * (stop - start)) * 100); } } - CUDA_RT_CALL(cudaEventDestroy(compute_done)); + CUDA_RT_CALL(cudaEventDestroy(reset_l2norm_done)); + CUDA_RT_CALL(cudaEventDestroy(push_done)); + CUDA_RT_CALL(cudaEventDestroy(push_prep_done)); + CUDA_RT_CALL(cudaStreamDestroy(push_stream)); CUDA_RT_CALL(cudaStreamDestroy(compute_stream)); CUDA_RT_CALL(cudaFreeHost(l2_norm_h)); From bceb8fa1ceb6ad8559f97679ffd60cf3b5bb0168 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Thu, 28 Oct 2021 13:11:08 +0200 Subject: [PATCH 08/25] 6-H Pull request edits have been implemented, 8-H NCCL is also ready to be reviewed --- .../Makefile | 10 +-- .../jacobi.cpp | 88 +++++++++++-------- 8-H_NCCL_NVSHMEM/NCCL/Instructions.md | 40 +++++++++ 8-H_NCCL_NVSHMEM/NCCL/Makefile | 10 +-- 8-H_NCCL_NVSHMEM/NCCL/copy.mk | 40 +++++++++ 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 21 ++++- 8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu | 2 +- 7 files changed, 161 insertions(+), 50 deletions(-) create mode 100644 8-H_NCCL_NVSHMEM/NCCL/Instructions.md create mode 100644 8-H_NCCL_NVSHMEM/NCCL/copy.mk diff --git a/6-H_Overlap_Communication_and_Computation_MPI/Makefile b/6-H_Overlap_Communication_and_Computation_MPI/Makefile index 4e9002d..2ca46c7 100644 --- a/6-H_Overlap_Communication_and_Computation_MPI/Makefile +++ b/6-H_Overlap_Communication_and_Computation_MPI/Makefile @@ -2,7 +2,7 @@ NP ?= 1 NVCC=nvcc MPICXX=mpicxx -MPIRUN ?= mpirun +JSC_SUBMIT_CMD ?= srun --gres=gpu:4 --ntasks-per-node 4 CUDA_HOME ?= /usr/local/cuda GENCODE_SM30 := -gencode arch=compute_30,code=sm_30 GENCODE_SM35 := -gencode arch=compute_35,code=sm_35 @@ -29,13 +29,13 @@ jacobi_kernels.o: Makefile jacobi_kernels.cu .PHONY.: clean clean: - rm -f jacobi jacobi_kernels.o *.qdrep jacobi.*.compute-sanitizer.log + rm -f jacobi jacobi_kernels.o *.nsys-rep jacobi.*.compute-sanitizer.log sanitize: jacobi - $(MPIRUN) -np $(NP) compute-sanitizer --log-file jacobi.%q{OMPI_COMM_WORLD_RANK}.compute-sanitizer.log ./jacobi -niter 10 + $(JSC_SUBMIT_CMD) -n $(NP) compute-sanitizer --log-file jacobi.%q{SLURM_PROCID}.compute-sanitizer.log ./jacobi -niter 10 run: jacobi - $(MPIRUN) -np $(NP) ./jacobi + $(JSC_SUBMIT_CMD) -n $(NP) ./jacobi profile: jacobi - $(MPIRUN) -np $(NP) nsys profile --trace=mpi,cuda,nvtx -o jacobi.%q{OMPI_COMM_WORLD_RANK} ./jacobi -niter 10 + $(JSC_SUBMIT_CMD) -n $(NP) nsys profile --trace=mpi,cuda,nvtx -o jacobi.%q{SLURM_PROCID} ./jacobi -niter 10 diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp index 34dc2f4..ec65fd0 100644 --- a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp @@ -208,6 +208,9 @@ int main(int argc, char* argv[]) { launch_initialize_boundaries(a, a_new, PI, iy_start_global - 1, nx, (chunk_size + 2), ny); CUDA_RT_CALL(cudaDeviceSynchronize()); #ifdef SOLUTION + //TODO: + //*Set least and greates Priority Range + //*Create top and bottom cuda streams variables and corresponding cuda events int leastPriority = 0; int greatestPriority = leastPriority; CUDA_RT_CALL(cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority)); @@ -231,6 +234,9 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaEventCreateWithFlags(&reset_l2norm_done, cudaEventDisableTiming)); #ifdef SOLUTION + //TODO: + //Create cuda streams with Greates Priority for top and bottom streams + //Modify the cudaStreamCreate call for the compute stream to have the Least Priority CUDA_RT_CALL(cudaStreamCreateWithPriority(&compute_stream, cudaStreamDefault, leastPriority)); CUDA_RT_CALL( cudaStreamCreateWithPriority(&push_top_stream, cudaStreamDefault, greatestPriority)); @@ -284,6 +290,13 @@ int main(int argc, char* argv[]) { calculate_norm = (iter % nccheck) == 0 || (!csv && (iter % 100) == 0); #ifdef SOLUTION + //TODO: + //*Launch two additional jacobi kernels for the top and bottom regions using + // the top and bottom streams after modifying and launching the original jacobi kernel on + // ONLY the center region. + //*Remember to wait on the for l2_norm_done cuda event before launching each top and bottom jacobi kernels + // using the cudaStreamWaitEvent() call. + //*Remember to record when the top and bottom regions are done using the cudaEventRecord() call launch_jacobi_kernel(a_new, a, l2_norm_d, (iy_start + 1), (iy_end - 1), nx, calculate_norm, compute_stream); CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); @@ -312,13 +325,15 @@ int main(int argc, char* argv[]) { if (calculate_norm) { #ifdef SOLUTION - CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); - CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); + //TODO: + //Wait on both the top and bottom cuda events + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); #else - //TODO: - //Wait on both the top and bottom cuda events + //TODO: + //Wait on both the top and bottom cuda events #endif - CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, + CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, compute_stream)); } @@ -327,6 +342,7 @@ int main(int argc, char* argv[]) { // Apply periodic boundary conditions #ifdef SOLUTION + //TODO: Modify the synchronization on the compute stream to be on the top stream CUDA_RT_CALL(cudaStreamSynchronize(push_top_stream)); #else //TODO: Modify the synchronization on the compute stream to be on the top stream @@ -337,6 +353,7 @@ int main(int argc, char* argv[]) { a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); #ifdef SOLUTION + //TODO: Add additional synchronization on the bottom stream CUDA_RT_CALL(cudaStreamSynchronize(push_bottom_stream)); #else //TODO: Add additional synchronization on the bottom stream @@ -397,6 +414,7 @@ int main(int argc, char* argv[]) { } } #ifdef SOLUTION + //TODO: Destroy the additional top and bottom stream as well as their correspoinding events CUDA_RT_CALL(cudaEventDestroy(push_bottom_done)); CUDA_RT_CALL(cudaEventDestroy(push_top_done)); CUDA_RT_CALL(cudaStreamDestroy(push_bottom_stream)); @@ -475,43 +493,43 @@ double single_gpu(const int nx, const int ny, const int iter_max, real* const a_ double start = MPI_Wtime(); PUSH_RANGE("Jacobi solve", 0) while (l2_norm > tol && iter < iter_max) { - CUDA_RT_CALL(cudaMemsetAsync(l2_norm_d, 0, sizeof(real), compute_stream)); - - CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); - CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); - - calculate_norm = (iter % nccheck) == 0 || (iter % 100) == 0; - launch_jacobi_kernel(a_new, a, l2_norm_d, iy_start, iy_end, nx, calculate_norm, + + CUDA_RT_CALL(cudaMemsetAsync(l2_norm_d, 0, sizeof(real), compute_stream)); + + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); + + calculate_norm = (iter % nccheck) == 0 || (iter % 100) == 0; + launch_jacobi_kernel(a_new, a, l2_norm_d, iy_start, iy_end, nx, calculate_norm, compute_stream); - CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); + CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); - if (calculate_norm) { - CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, - compute_stream)); - } + if (calculate_norm) { + CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, + compute_stream)); + } - // Apply periodic boundary conditions + // Apply periodic boundary conditions - CUDA_RT_CALL(cudaStreamWaitEvent(push_top_stream, compute_done, 0)); - CUDA_RT_CALL(cudaMemcpyAsync(a_new, a_new + (iy_end - 1) * nx, nx * sizeof(real), + CUDA_RT_CALL(cudaStreamWaitEvent(push_top_stream, compute_done, 0)); + CUDA_RT_CALL(cudaMemcpyAsync(a_new, a_new + (iy_end - 1) * nx, nx * sizeof(real), cudaMemcpyDeviceToDevice, push_top_stream)); - CUDA_RT_CALL(cudaEventRecord(push_top_done, push_top_stream)); + CUDA_RT_CALL(cudaEventRecord(push_top_done, push_top_stream)); - CUDA_RT_CALL(cudaStreamWaitEvent(push_bottom_stream, compute_done, 0)); - CUDA_RT_CALL(cudaMemcpyAsync(a_new + iy_end * nx, a_new + iy_start * nx, nx * sizeof(real), + CUDA_RT_CALL(cudaStreamWaitEvent(push_bottom_stream, compute_done, 0)); + CUDA_RT_CALL(cudaMemcpyAsync(a_new + iy_end * nx, a_new + iy_start * nx, nx * sizeof(real), cudaMemcpyDeviceToDevice, compute_stream)); - CUDA_RT_CALL(cudaEventRecord(push_bottom_done, push_bottom_stream)); - - if (calculate_norm) { - CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); - l2_norm = *l2_norm_h; - l2_norm = std::sqrt(l2_norm); - if (print && (iter % 100) == 0) printf("%5d, %0.6f\n", iter, l2_norm); - } - - std::swap(a_new, a); - iter++; - } + CUDA_RT_CALL(cudaEventRecord(push_bottom_done, push_bottom_stream)); + if (calculate_norm) { + CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); + l2_norm = *l2_norm_h; + l2_norm = std::sqrt(l2_norm); + if (print && (iter % 100) == 0) printf("%5d, %0.6f\n", iter, l2_norm); + } + + std::swap(a_new, a); + iter++; + } POP_RANGE double stop = MPI_Wtime(); diff --git a/8-H_NCCL_NVSHMEM/NCCL/Instructions.md b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md new file mode 100644 index 0000000..20c8cec --- /dev/null +++ b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md @@ -0,0 +1,40 @@ +# SC21 Tutorial: Efficient Distributed GPU Programming for Exascale + +- Time: Sunday, 14 November 2021 8AM - 5PM CST +- Location: *online* +- Program Link: https://sc21.supercomputing.org/presentation/?id=tut138&sess=sess188 + + +## Hands-On 8\_NCCL: Using NCCL for inter-GPU communication + +### Task 0: Using NCCL device API + +#### Description + +The purpose of this task is to use the NCCL API instead of MPI to implement a multi-GPU jacobi solver. The starting point of this task is the MPI variant of the jacobi solver. You need to work on `TODOs` in `jacobi.cu`: + +- Initialize NVSHMEM: + - Include NCCL headers. + - Un-comment NCCL\_CALL definition provided to handle NCCL errors + - Define NCCL\_REAL\_TYPE for both double and float types + - Create a NCCL unique ID, and initialize it + - Create a NCCL communicator and initilize it + - Replace MPI for the periodic boundary conditions with NCCL + - Fix output message to indicate nccl rather than mpi + - Destroy NCCL comunicator + +Compile with + +``` {.bash} +make +``` + +Submit your compiled application to the batch system with + +``` {.bash} +make run +``` + +Study the performance by glimpsing at the profile generated with +`make profile`. For `make run` and `make profile` the environment variable `NP` can be set to change the number of processes. + diff --git a/8-H_NCCL_NVSHMEM/NCCL/Makefile b/8-H_NCCL_NVSHMEM/NCCL/Makefile index ab1bb61..1ebd47f 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/Makefile +++ b/8-H_NCCL_NVSHMEM/NCCL/Makefile @@ -1,8 +1,8 @@ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. NP ?= 1 NVCC=nvcc +JSC_SUBMIT_CMD ?= srun --gres=gpu:4 --ntasks-per-node 4 MPICXX=mpicxx -MPIRUN ?= mpirun CUDA_HOME ?= /usr/local/cuda NCCL_HOME ?= /usr GENCODE_SM30 := -gencode arch=compute_30,code=sm_30 @@ -30,13 +30,13 @@ jacobi_kernels.o: Makefile jacobi_kernels.cu .PHONY.: clean clean: - rm -f jacobi jacobi_kernels.o *.qdrep jacobi.*.compute-sanitizer.log + rm -f jacobi jacobi_kernels.o *.nsys-rep jacobi.*.compute-sanitizer.log sanitize: jacobi - $(MPIRUN) -np $(NP) compute-sanitizer --log-file jacobi.%q{OMPI_COMM_WORLD_RANK}.compute-sanitizer.log ./jacobi -niter 10 + $(JSC_SUBMIT_CMD) -n $(NP) compute-sanitizer --log-file jacobi.%q{SLURM_PROCID}.compute-sanitizer.log ./jacobi -niter 10 run: jacobi - $(MPIRUN) -np $(NP) ./jacobi + $(JSC_SUBMIT_CMD) -n $(NP) ./jacobi profile: jacobi - $(MPIRUN) -np $(NP) nsys profile --trace=mpi,cuda,nvtx -o jacobi.%q{OMPI_COMM_WORLD_RANK} ./jacobi -niter 10 + $(JSC_SUBMIT_CMD) -n $(NP) nsys profile --trace=mpi,cuda,nvtx -o jacobi.%q{SLURM_PROCID} ./jacobi -niter 10 diff --git a/8-H_NCCL_NVSHMEM/NCCL/copy.mk b/8-H_NCCL_NVSHMEM/NCCL/copy.mk new file mode 100644 index 0000000..e7ee445 --- /dev/null +++ b/8-H_NCCL_NVSHMEM/NCCL/copy.mk @@ -0,0 +1,40 @@ +#!/usr/bin/make -f +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +TASKDIR = ../../tasks/8-H_NCCL_NVSHMEM/NCCL/ +SOLUTIONDIR = ../../solutions/8-H_NCCL_NVSHMEM/NCCL + +PROCESSFILES = jacobi.cu +COPYFILES = Makefile Instructions.ipynb Instructions.md + + +TASKPROCCESFILES = $(addprefix $(TASKDIR)/,$(PROCESSFILES)) +TASKCOPYFILES = $(addprefix $(TASKDIR)/,$(COPYFILES)) +SOLUTIONPROCCESFILES = $(addprefix $(SOLUTIONDIR)/,$(PROCESSFILES)) +SOLUTIONCOPYFILES = $(addprefix $(SOLUTIONDIR)/,$(COPYFILES)) + +.PHONY: all task +all: task +task: ${TASKPROCCESFILES} ${TASKCOPYFILES} ${SOLUTIONPROCCESFILES} ${SOLUTIONCOPYFILES} + + +${TASKPROCCESFILES}: $(PROCESSFILES) + mkdir -p $(TASKDIR)/ + cppp -USOLUTION $(notdir $@) $@ + +${SOLUTIONPROCCESFILES}: $(PROCESSFILES) + mkdir -p $(SOLUTIONDIR)/ + cppp -DSOLUTION $(notdir $@) $@ + + +${TASKCOPYFILES}: $(COPYFILES) + mkdir -p $(TASKDIR)/ + cp $(notdir $@) $@ + +${SOLUTIONCOPYFILES}: $(COPYFILES) + mkdir -p $(SOLUTIONDIR)/ + cp $(notdir $@) $@ + +%.ipynb: %.md + pandoc $< -o $@ + # add metadata so this is seen as python + jq -s '.[0] * .[1]' $@ ../template.json | sponge $@ diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index d0d656c..43dce0a 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -31,7 +31,6 @@ #include #include - #define MPI_CALL(call) \ { \ int mpi_status = call; \ @@ -92,12 +91,14 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t); #call, __LINE__, __FILE__, cudaGetErrorString(cudaStatus), cudaStatus); \ } #ifdef SOLUTION +//TODO: include NCCL headers #include #else - //TODO: include NCCL headers +//TODO: include NCCL headers #endif #ifdef SOLUTION +//TODO: Un-comment the following given NCCL_CALL definition to use in your code: #define NCCL_CALL(call) \ { \ ncclResult_t ncclStatus = call; \ @@ -129,6 +130,8 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t); typedef double real; #define MPI_REAL_TYPE MPI_DOUBLE #ifdef SOLUTION +//TODO:define NCCL_REAL_TYPE using its corresponding nccl double type +//HINT:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html #define NCCL_REAL_TYPE ncclDouble #else //TODO:define NCCL_REAL_TYPE using its corresponding nccl double type @@ -138,6 +141,8 @@ typedef double real; typedef float real; #define MPI_REAL_TYPE MPI_FLOAT #ifdef SOLUTION +//TODO:define NCCL_REAL_TYPE using its corresponding nccl float type +//HINT:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html #define NCCL_REAL_TYPE ncclFloat #else //TODO:define NCCL_REAL_TYPE using its corresponding nccl float type @@ -186,13 +191,16 @@ int main(int argc, char* argv[]) { int size; MPI_CALL(MPI_Comm_size(MPI_COMM_WORLD, &size)); #ifdef SOLUTION + //TODO: Create a ncclUniqueId, have rank 0 initize it by using the appropriate runtime call, + // and remember to broadcast it to all ranks using MPI + //HINT: Remember to wrap your nccl calls using the above NCCL_CALL definition ncclUniqueId nccl_uid; if (rank == 0) NCCL_CALL(ncclGetUniqueId(&nccl_uid)); MPI_CALL(MPI_Bcast(&nccl_uid, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD)); #else //TODO: Create a ncclUniqueId, have rank 0 initize it by using the appropriate runtime call, // and remember to broadcast it to all ranks using MPI - //HINT: Remember to wrap your nccl calls using the above NCCL_CALL definition :) + //HINT: Remember to wrap your nccl calls using the above NCCL_CALL definition #endif const int iter_max = get_argval(argv, argv + argc, "-niter", 1000); @@ -215,6 +223,7 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaSetDevice(local_rank)); CUDA_RT_CALL(cudaFree(0)); #ifdef SOLUTION + //TODO: Create a communicator (ncclComm_t), initialize it (ncclCommInitRank) ncclComm_t nccl_comm; NCCL_CALL(ncclCommInitRank(&nccl_comm, size, nccl_uid, rank)); int nccl_version = 0; @@ -354,6 +363,9 @@ int main(int argc, char* argv[]) { // Apply periodic boundary conditions #ifdef SOLUTION + //TODO: Modify the lable for the RANGE, and replace MPI_Sendrecv with ncclSend and ncclRecv calls + // using the nccl communicator and push_stream. + // Remember to use ncclGroupStart() and ncclGroupEnd() PUSH_RANGE("NCCL_LAUNCH", 5) NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclRecv(a_new, nx, NCCL_REAL_TYPE, top, nccl_comm, push_stream)); @@ -369,7 +381,6 @@ int main(int argc, char* argv[]) { MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); - CUDA_RT_CALL(cudaStreamSynchronize(push_bottom_stream)); MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); #endif @@ -419,6 +430,7 @@ int main(int argc, char* argv[]) { if (rank == 0 && result_correct) { if (csv) { #ifdef SOLUTION + //TODO: Dont forget to change your output lable from mpi_overlap to nccl_overlap printf("nccl_overlap, %d, %d, %d, %d, %d, 1, %f, %f\n", nx, ny, iter_max, nccheck, size, #else //TODO: Dont forget to change your output lable from mpi_overlap to nccl_overlap @@ -449,6 +461,7 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaFreeHost(a_h)); CUDA_RT_CALL(cudaFreeHost(a_ref_h)); #ifdef SOLUTION + //TODO: Destroy the nccl communicator NCCL_CALL(ncclCommDestroy(nccl_comm)); #else //TODO: Destroy the nccl communicator diff --git a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu index 8fe71d7..0a8fdda 100644 --- a/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu +++ b/8-H_NCCL_NVSHMEM/NVSHMEM/jacobi.cu @@ -27,7 +27,7 @@ #include #include - +#define SOLUTION #define MPI_CALL(call) \ { \ int mpi_status = call; \ From 2ad3188bb59a703c9aa1865cf3397b34acfaa87f Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Thu, 28 Oct 2021 16:42:25 +0200 Subject: [PATCH 09/25] Update 8-H_NCCL_NVSHMEM/NCCL/Instructions.md Co-authored-by: Jiri Kraus --- 8-H_NCCL_NVSHMEM/NCCL/Instructions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/Instructions.md b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md index 20c8cec..df6daad 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/Instructions.md +++ b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md @@ -7,7 +7,7 @@ ## Hands-On 8\_NCCL: Using NCCL for inter-GPU communication -### Task 0: Using NCCL device API +### Task 0: Using NCCL #### Description From f3aa824873b6cc7a316ce87584b9b7cf2c8915e3 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Thu, 28 Oct 2021 16:56:35 +0200 Subject: [PATCH 10/25] Update 8-H_NCCL_NVSHMEM/NCCL/copy.mk Co-authored-by: Jiri Kraus --- 8-H_NCCL_NVSHMEM/NCCL/copy.mk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/copy.mk b/8-H_NCCL_NVSHMEM/NCCL/copy.mk index e7ee445..fd72f5b 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/copy.mk +++ b/8-H_NCCL_NVSHMEM/NCCL/copy.mk @@ -3,7 +3,7 @@ TASKDIR = ../../tasks/8-H_NCCL_NVSHMEM/NCCL/ SOLUTIONDIR = ../../solutions/8-H_NCCL_NVSHMEM/NCCL -PROCESSFILES = jacobi.cu +PROCESSFILES = jacobi.cpp COPYFILES = Makefile Instructions.ipynb Instructions.md From 65f4a4ae241d84a685a2ac2e7c789fc92587cb72 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Thu, 28 Oct 2021 16:59:04 +0200 Subject: [PATCH 11/25] Update 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp Co-authored-by: Jiri Kraus --- 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index 43dce0a..41b508c 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -193,7 +193,7 @@ int main(int argc, char* argv[]) { #ifdef SOLUTION //TODO: Create a ncclUniqueId, have rank 0 initize it by using the appropriate runtime call, // and remember to broadcast it to all ranks using MPI - //HINT: Remember to wrap your nccl calls using the above NCCL_CALL definition + //HINT: Best practice: wrap your nccl calls using the above NCCL_CALL macro to catch runtime errors early. ncclUniqueId nccl_uid; if (rank == 0) NCCL_CALL(ncclGetUniqueId(&nccl_uid)); MPI_CALL(MPI_Bcast(&nccl_uid, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD)); From 87fa989b7979fef23d9ce943929d5a3ac95111ed Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Thu, 28 Oct 2021 16:59:15 +0200 Subject: [PATCH 12/25] Update 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp Co-authored-by: Jiri Kraus --- 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp index ec65fd0..bddc915 100644 --- a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp @@ -158,7 +158,9 @@ int main(int argc, char* argv[]) { MPI_CALL(MPI_Comm_free(&local_comm)); } - CUDA_RT_CALL(cudaSetDevice(local_rank)); + int num_devices = 0; + CUDA_RT_CALL(cudaGetDeviceCount(&num_devices)); + CUDA_RT_CALL(cudaSetDevice(local_rank%num_devices)); CUDA_RT_CALL(cudaFree(0)); real* a_ref_h; From b29b6c4a03e71468cb65360a391d46fee1ad2cc2 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Thu, 28 Oct 2021 16:59:39 +0200 Subject: [PATCH 13/25] Update 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp Co-authored-by: Jiri Kraus --- 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index 41b508c..1043625 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -220,7 +220,9 @@ int main(int argc, char* argv[]) { MPI_CALL(MPI_Comm_free(&local_comm)); } - CUDA_RT_CALL(cudaSetDevice(local_rank)); + int num_devices = 0; + CUDA_RT_CALL(cudaGetDeviceCount(&num_devices)); + CUDA_RT_CALL(cudaSetDevice(local_rank%num_devices)); CUDA_RT_CALL(cudaFree(0)); #ifdef SOLUTION //TODO: Create a communicator (ncclComm_t), initialize it (ncclCommInitRank) From 8371ce42567d78e12562ebc3480045bcf651ec6c Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Thu, 28 Oct 2021 17:01:56 +0200 Subject: [PATCH 14/25] Addressing pull request comments on TODOs spacing and instructions for both 6-H and 8-H NCCL --- .../jacobi.cpp | 41 ++++--------- 8-H_NCCL_NVSHMEM/NCCL/Instructions.md | 2 - 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 61 ++++--------------- 3 files changed, 25 insertions(+), 79 deletions(-) diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp index ec65fd0..25cd5eb 100644 --- a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp @@ -207,10 +207,11 @@ int main(int argc, char* argv[]) { // Set diriclet boundary conditions on left and right boarder launch_initialize_boundaries(a, a_new, PI, iy_start_global - 1, nx, (chunk_size + 2), ny); CUDA_RT_CALL(cudaDeviceSynchronize()); -#ifdef SOLUTION + //TODO: //*Set least and greates Priority Range //*Create top and bottom cuda streams variables and corresponding cuda events +#ifdef SOLUTION int leastPriority = 0; int greatestPriority = leastPriority; CUDA_RT_CALL(cudaDeviceGetStreamPriorityRange(&leastPriority, &greatestPriority)); @@ -223,9 +224,6 @@ int main(int argc, char* argv[]) { cudaEvent_t push_bottom_done; CUDA_RT_CALL(cudaEventCreateWithFlags(&push_bottom_done, cudaEventDisableTiming)); #else - //TODO: - //*Set least and greates Priority Range - //*Create top and bottom cuda streams variables and corresponding cuda events #endif cudaStream_t compute_stream; cudaEvent_t compute_done; @@ -233,19 +231,16 @@ int main(int argc, char* argv[]) { cudaEvent_t reset_l2norm_done; CUDA_RT_CALL(cudaEventCreateWithFlags(&reset_l2norm_done, cudaEventDisableTiming)); -#ifdef SOLUTION //TODO: //Create cuda streams with Greates Priority for top and bottom streams //Modify the cudaStreamCreate call for the compute stream to have the Least Priority +#ifdef SOLUTION CUDA_RT_CALL(cudaStreamCreateWithPriority(&compute_stream, cudaStreamDefault, leastPriority)); CUDA_RT_CALL( cudaStreamCreateWithPriority(&push_top_stream, cudaStreamDefault, greatestPriority)); CUDA_RT_CALL( cudaStreamCreateWithPriority(&push_bottom_stream, cudaStreamDefault, greatestPriority)); #else - //TODO: - //Create cuda streams with Greates Priority for top and bottom streams - //Modify the cudaStreamCreate call for the compute stream to have the Least Priority CUDA_RT_CALL(cudaStreamCreate(&compute_stream)); #endif @@ -289,7 +284,6 @@ int main(int argc, char* argv[]) { calculate_norm = (iter % nccheck) == 0 || (!csv && (iter % 100) == 0); -#ifdef SOLUTION //TODO: //*Launch two additional jacobi kernels for the top and bottom regions using // the top and bottom streams after modifying and launching the original jacobi kernel on @@ -297,6 +291,7 @@ int main(int argc, char* argv[]) { //*Remember to wait on the for l2_norm_done cuda event before launching each top and bottom jacobi kernels // using the cudaStreamWaitEvent() call. //*Remember to record when the top and bottom regions are done using the cudaEventRecord() call +#ifdef SOLUTION launch_jacobi_kernel(a_new, a, l2_norm_d, (iy_start + 1), (iy_end - 1), nx, calculate_norm, compute_stream); CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); @@ -311,27 +306,18 @@ int main(int argc, char* argv[]) { push_bottom_stream); CUDA_RT_CALL(cudaEventRecord(push_bottom_done, push_bottom_stream)); #else - //TODO: - //*Launch two additional jacobi kernels for the top and bottom regions using - // the top and bottom streams after modifying and launching the original jacobi kernel on - // ONLY the center region. - //*Remember to wait on the for l2_norm_done cuda event before launching each top and bottom jacobi kernels - // using the cudaStreamWaitEvent() call. - //*Remember to record when the top and bottom regions are done using the cudaEventRecord() call launch_jacobi_kernel(a_new, a, l2_norm_d, iy_start, iy_end, nx, calculate_norm, compute_stream); CUDA_RT_CALL(cudaEventRecord(compute_done, compute_stream)); #endif if (calculate_norm) { -#ifdef SOLUTION //TODO: - //Wait on both the top and bottom cuda events + //Wait on both the top and bottom cuda events +#ifdef SOLUTION CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); #else - //TODO: - //Wait on both the top and bottom cuda events #endif CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, compute_stream)); @@ -341,22 +327,21 @@ int main(int argc, char* argv[]) { const int bottom = (rank + 1) % size; // Apply periodic boundary conditions -#ifdef SOLUTION //TODO: Modify the synchronization on the compute stream to be on the top stream +#ifdef SOLUTION CUDA_RT_CALL(cudaStreamSynchronize(push_top_stream)); #else - //TODO: Modify the synchronization on the compute stream to be on the top stream CUDA_RT_CALL(cudaEventSynchronize(compute_done)); #endif PUSH_RANGE("MPI", 5) MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, - a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, - MPI_STATUS_IGNORE)); -#ifdef SOLUTION + a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, + MPI_STATUS_IGNORE)); + //TODO: Add additional synchronization on the bottom stream +#ifdef SOLUTION CUDA_RT_CALL(cudaStreamSynchronize(push_bottom_stream)); #else - //TODO: Add additional synchronization on the bottom stream #endif MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); @@ -413,14 +398,14 @@ int main(int argc, char* argv[]) { runtime_serial / (size * (stop - start)) * 100); } } -#ifdef SOLUTION + //TODO: Destroy the additional top and bottom stream as well as their correspoinding events +#ifdef SOLUTION CUDA_RT_CALL(cudaEventDestroy(push_bottom_done)); CUDA_RT_CALL(cudaEventDestroy(push_top_done)); CUDA_RT_CALL(cudaStreamDestroy(push_bottom_stream)); CUDA_RT_CALL(cudaStreamDestroy(push_top_stream)); #else - //TODO: Destroy the additional top and bottom stream as well as their correspoinding events #endif CUDA_RT_CALL(cudaEventDestroy(reset_l2norm_done)); CUDA_RT_CALL(cudaEventDestroy(compute_done)); diff --git a/8-H_NCCL_NVSHMEM/NCCL/Instructions.md b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md index 20c8cec..45bc255 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/Instructions.md +++ b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md @@ -15,8 +15,6 @@ The purpose of this task is to use the NCCL API instead of MPI to implement a mu - Initialize NVSHMEM: - Include NCCL headers. - - Un-comment NCCL\_CALL definition provided to handle NCCL errors - - Define NCCL\_REAL\_TYPE for both double and float types - Create a NCCL unique ID, and initialize it - Create a NCCL communicator and initilize it - Replace MPI for the periodic boundary conditions with NCCL diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index 43dce0a..2c24e58 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -90,15 +90,13 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t); "%s (%d).\n", \ #call, __LINE__, __FILE__, cudaGetErrorString(cudaStatus), cudaStatus); \ } -#ifdef SOLUTION + //TODO: include NCCL headers +#ifdef SOLUTION #include #else -//TODO: include NCCL headers #endif -#ifdef SOLUTION -//TODO: Un-comment the following given NCCL_CALL definition to use in your code: #define NCCL_CALL(call) \ { \ ncclResult_t ncclStatus = call; \ @@ -110,44 +108,15 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t); #call, __LINE__, __FILE__, ncclGetErrorString(ncclStatus), ncclStatus); \ } #else -//TODO: Un-comment the following given NCCL_CALL definition to use in your code: -/* -#define NCCL_CALL(call) \ - { \ - ncclResult_t ncclStatus = call; \ - if (ncclSuccess != ncclStatus) \ - fprintf(stderr, \ - "ERROR: NCCL call \"%s\" in line %d of file %s failed " \ - "with " \ - "%s (%d).\n", \ - #call, __LINE__, __FILE__, ncclGetErrorString(ncclStatus), ncclStatus); \ - } -#endif -*/ -#endif #ifdef USE_DOUBLE typedef double real; #define MPI_REAL_TYPE MPI_DOUBLE -#ifdef SOLUTION -//TODO:define NCCL_REAL_TYPE using its corresponding nccl double type -//HINT:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html #define NCCL_REAL_TYPE ncclDouble #else -//TODO:define NCCL_REAL_TYPE using its corresponding nccl double type -//HINT:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html -#endif -#else typedef float real; #define MPI_REAL_TYPE MPI_FLOAT -#ifdef SOLUTION -//TODO:define NCCL_REAL_TYPE using its corresponding nccl float type -//HINT:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html #define NCCL_REAL_TYPE ncclFloat -#else -//TODO:define NCCL_REAL_TYPE using its corresponding nccl float type -//HINT:https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html -#endif #endif constexpr real tol = 1.0e-8; @@ -190,17 +159,15 @@ int main(int argc, char* argv[]) { MPI_CALL(MPI_Comm_rank(MPI_COMM_WORLD, &rank)); int size; MPI_CALL(MPI_Comm_size(MPI_COMM_WORLD, &size)); -#ifdef SOLUTION + //TODO: Create a ncclUniqueId, have rank 0 initize it by using the appropriate runtime call, // and remember to broadcast it to all ranks using MPI //HINT: Remember to wrap your nccl calls using the above NCCL_CALL definition +#ifdef SOLUTION ncclUniqueId nccl_uid; if (rank == 0) NCCL_CALL(ncclGetUniqueId(&nccl_uid)); MPI_CALL(MPI_Bcast(&nccl_uid, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD)); #else - //TODO: Create a ncclUniqueId, have rank 0 initize it by using the appropriate runtime call, - // and remember to broadcast it to all ranks using MPI - //HINT: Remember to wrap your nccl calls using the above NCCL_CALL definition #endif const int iter_max = get_argval(argv, argv + argc, "-niter", 1000); @@ -222,8 +189,9 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaSetDevice(local_rank)); CUDA_RT_CALL(cudaFree(0)); -#ifdef SOLUTION + //TODO: Create a communicator (ncclComm_t), initialize it (ncclCommInitRank) +#ifdef SOLUTION ncclComm_t nccl_comm; NCCL_CALL(ncclCommInitRank(&nccl_comm, size, nccl_uid, rank)); int nccl_version = 0; @@ -235,7 +203,6 @@ int main(int argc, char* argv[]) { return 1; } #else - //TODO: Create a communicator (ncclComm_t), initialize it (ncclCommInitRank) #endif real* a_ref_h; @@ -362,10 +329,10 @@ int main(int argc, char* argv[]) { const int bottom = (rank + 1) % size; // Apply periodic boundary conditions -#ifdef SOLUTION - //TODO: Modify the lable for the RANGE, and replace MPI_Sendrecv with ncclSend and ncclRecv calls + //TODO: Modify the lable for the RANGE, and replace MPI_Sendrecv with ncclSend and ncclRecv calls // using the nccl communicator and push_stream. // Remember to use ncclGroupStart() and ncclGroupEnd() +#ifdef SOLUTION PUSH_RANGE("NCCL_LAUNCH", 5) NCCL_CALL(ncclGroupStart()); NCCL_CALL(ncclRecv(a_new, nx, NCCL_REAL_TYPE, top, nccl_comm, push_stream)); @@ -374,9 +341,6 @@ int main(int argc, char* argv[]) { NCCL_CALL(ncclSend(a_new + iy_start * nx, nx, NCCL_REAL_TYPE, top, nccl_comm, push_stream)); NCCL_CALL(ncclGroupEnd()); #else - //TODO: Modify the lable for the RANGE, and replace MPI_Sendrecv with ncclSend and ncclRecv calls - // using the nccl communicator and push_stream. - // Remember to use ncclGroupStart() and ncclGroupEnd() PUSH_RANGE("MPI", 5) MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, @@ -429,14 +393,13 @@ int main(int argc, char* argv[]) { if (rank == 0 && result_correct) { if (csv) { + //TODO: Dont forget to change your output lable from mpi_overlap to nccl_overlap #ifdef SOLUTION - //TODO: Dont forget to change your output lable from mpi_overlap to nccl_overlap printf("nccl_overlap, %d, %d, %d, %d, %d, 1, %f, %f\n", nx, ny, iter_max, nccheck, size, #else - //TODO: Dont forget to change your output lable from mpi_overlap to nccl_overlap printf("mpi_overlap, %d, %d, %d, %d, %d, 1, %f, %f\n", nx, ny, iter_max, nccheck, size, #endif - (stop - start), runtime_serial); + (stop - start), runtime_serial); } else { printf("Num GPUs: %d.\n", size); printf( @@ -460,11 +423,11 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaFreeHost(a_h)); CUDA_RT_CALL(cudaFreeHost(a_ref_h)); -#ifdef SOLUTION + //TODO: Destroy the nccl communicator +#ifdef SOLUTION NCCL_CALL(ncclCommDestroy(nccl_comm)); #else - //TODO: Destroy the nccl communicator #endif MPI_CALL(MPI_Finalize()); From 734da801861c8d94a26fec7463121eb61de45f2b Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 10:41:06 +0200 Subject: [PATCH 15/25] NCCL warmup has been added as a TODO, and added to the instructions --- 8-H_NCCL_NVSHMEM/NCCL/Instructions.md | 5 +++-- 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 22 ++++++++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/Instructions.md b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md index d1bfd69..5064d71 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/Instructions.md +++ b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md @@ -11,12 +11,13 @@ #### Description -The purpose of this task is to use the NCCL API instead of MPI to implement a multi-GPU jacobi solver. The starting point of this task is the MPI variant of the jacobi solver. You need to work on `TODOs` in `jacobi.cu`: +The purpose of this task is to use the NCCL instead of MPI to implement a multi-GPU jacobi solver. The starting point of this task is the MPI variant of the jacobi solver. You need to work on `TODOs` in `jacobi.cu`: - Initialize NVSHMEM: - Include NCCL headers. - Create a NCCL unique ID, and initialize it - - Create a NCCL communicator and initilize it + - Create a NCCL communicator and initilize it + - Replace the MPI\_Sendrecv calls with ncclRecv and ncclSend calls for the warmup stage - Replace MPI for the periodic boundary conditions with NCCL - Fix output message to indicate nccl rather than mpi - Destroy NCCL comunicator diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index 392c86e..5743427 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -107,7 +107,6 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t); "%s (%d).\n", \ #call, __LINE__, __FILE__, ncclGetErrorString(ncclStatus), ncclStatus); \ } -#else #ifdef USE_DOUBLE typedef double real; @@ -275,16 +274,35 @@ int main(int argc, char* argv[]) { real* l2_norm_h; CUDA_RT_CALL(cudaMallocHost(&l2_norm_h, sizeof(real))); +//TODO: Rename range +#ifdef SOLUTION + PUSH_RANGE("NCCL_Warmup", 5) +#else PUSH_RANGE("MPI_Warmup", 5) +#endif for (int i = 0; i < 10; ++i) { const int top = rank > 0 ? rank - 1 : (size - 1); const int bottom = (rank + 1) % size; + //TODO: Replace the MPI_Sendrecv calls with ncclRecv and ncclSend calls using the nccl communicator + // on the compute_stream. + // Remeber that a group of ncclRecv and ncclSend should be within a ncclGroupStart() and ncclGroupEnd() + // Also, Rember to stream synchronize on the compute_stream at the end +#ifdef SOLUTION + NCCL_CALL(ncclGroupStart()); + NCCL_CALL(ncclRecv(a_new, nx, NCCL_REAL_TYPE, top, nccl_comm, compute_stream)); + NCCL_CALL(ncclSend(a_new + (iy_end - 1) * nx, nx, NCCL_REAL_TYPE, bottom, nccl_comm, compute_stream)); + NCCL_CALL(ncclRecv(a_new + (iy_end * nx), nx, NCCL_REAL_TYPE, bottom, nccl_comm, compute_stream)); + NCCL_CALL(ncclSend(a_new + iy_start * nx, nx, NCCL_REAL_TYPE, top, nccl_comm, compute_stream)); + NCCL_CALL(ncclGroupEnd()); + CUDA_RT_CALL(cudaStreamSynchronize(compute_stream)); +#else MPI_CALL(MPI_Sendrecv(a_new + iy_start * nx, nx, MPI_REAL_TYPE, top, 0, a_new + (iy_end * nx), nx, MPI_REAL_TYPE, bottom, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); - std::swap(a_new, a); +#endif + std::swap(a_new, a); } POP_RANGE From 78a3f92f00cba1497235288abbbb244937d20e53 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:07:06 +0200 Subject: [PATCH 16/25] Update 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp Co-authored-by: Jiri Kraus --- 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp index 705418e..a386657 100644 --- a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp @@ -343,7 +343,6 @@ int main(int argc, char* argv[]) { //TODO: Add additional synchronization on the bottom stream #ifdef SOLUTION CUDA_RT_CALL(cudaStreamSynchronize(push_bottom_stream)); -#else #endif MPI_CALL(MPI_Sendrecv(a_new + (iy_end - 1) * nx, nx, MPI_REAL_TYPE, bottom, 0, a_new, nx, MPI_REAL_TYPE, top, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE)); From 2764eed2c1fd45544fb4bbde49b05c74a1479f47 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:10:44 +0200 Subject: [PATCH 17/25] Update 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp Co-authored-by: Jiri Kraus --- 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp index a386657..e751459 100644 --- a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp @@ -225,7 +225,6 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaEventCreateWithFlags(&push_top_done, cudaEventDisableTiming)); cudaEvent_t push_bottom_done; CUDA_RT_CALL(cudaEventCreateWithFlags(&push_bottom_done, cudaEventDisableTiming)); -#else #endif cudaStream_t compute_stream; cudaEvent_t compute_done; From c96a255c9a8da3229854243cac16702177e5356c Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:10:53 +0200 Subject: [PATCH 18/25] Update 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp Co-authored-by: Jiri Kraus --- 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp index e751459..ae69499 100644 --- a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp @@ -318,7 +318,6 @@ int main(int argc, char* argv[]) { #ifdef SOLUTION CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); -#else #endif CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, compute_stream)); From ce10df86193260d5e1d9b262205de3128fcf3071 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:11:01 +0200 Subject: [PATCH 19/25] Update 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp Co-authored-by: Jiri Kraus --- 6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp index ae69499..d24c8c5 100644 --- a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp @@ -404,7 +404,6 @@ int main(int argc, char* argv[]) { CUDA_RT_CALL(cudaEventDestroy(push_top_done)); CUDA_RT_CALL(cudaStreamDestroy(push_bottom_stream)); CUDA_RT_CALL(cudaStreamDestroy(push_top_stream)); -#else #endif CUDA_RT_CALL(cudaEventDestroy(reset_l2norm_done)); CUDA_RT_CALL(cudaEventDestroy(compute_done)); From 2a4de59b37e0158b624e27652207757b6bb05a18 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:11:10 +0200 Subject: [PATCH 20/25] Update 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp Co-authored-by: Jiri Kraus --- 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index 5743427..e253843 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -94,7 +94,6 @@ const int num_colors = sizeof(colors) / sizeof(uint32_t); //TODO: include NCCL headers #ifdef SOLUTION #include -#else #endif #define NCCL_CALL(call) \ From d691b6c2bce0ea9df67c931dbb7e08db8643fa53 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:11:27 +0200 Subject: [PATCH 21/25] Update 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp Co-authored-by: Jiri Kraus --- 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index e253843..d5dacca 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -165,7 +165,6 @@ int main(int argc, char* argv[]) { ncclUniqueId nccl_uid; if (rank == 0) NCCL_CALL(ncclGetUniqueId(&nccl_uid)); MPI_CALL(MPI_Bcast(&nccl_uid, sizeof(ncclUniqueId), MPI_BYTE, 0, MPI_COMM_WORLD)); -#else #endif const int iter_max = get_argval(argv, argv + argc, "-niter", 1000); From 21238b8abfa95daafdb5e6a0cfce04a36f568f37 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:11:35 +0200 Subject: [PATCH 22/25] Update 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp Co-authored-by: Jiri Kraus --- 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index d5dacca..1d4828c 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -201,7 +201,6 @@ int main(int argc, char* argv[]) { MPI_CALL(MPI_Finalize()); return 1; } -#else #endif real* a_ref_h; From af0dc3d71457a7a4d74821244f36f7d0dd058fe3 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:11:41 +0200 Subject: [PATCH 23/25] Update 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp Co-authored-by: Jiri Kraus --- 8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp index 1d4828c..cae5722 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp +++ b/8-H_NCCL_NVSHMEM/NCCL/jacobi.cpp @@ -444,7 +444,6 @@ int main(int argc, char* argv[]) { //TODO: Destroy the nccl communicator #ifdef SOLUTION NCCL_CALL(ncclCommDestroy(nccl_comm)); -#else #endif MPI_CALL(MPI_Finalize()); From 5361260d0868817de5c483fdf8e5421c77724dd4 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:21:24 +0200 Subject: [PATCH 24/25] Fixed indentation on issue on 6-H and added jacobi.cpp to the copy.mk COPYFILES --- .../jacobi.cpp | 12 ++++++------ 8-H_NCCL_NVSHMEM/NCCL/copy.mk | 3 ++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp index d24c8c5..6ca3010 100644 --- a/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp +++ b/6-H_Overlap_Communication_and_Computation_MPI/jacobi.cpp @@ -313,14 +313,14 @@ int main(int argc, char* argv[]) { #endif if (calculate_norm) { - //TODO: - //Wait on both the top and bottom cuda events + //TODO: + //Wait on both the top and bottom cuda events #ifdef SOLUTION - CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); - CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_top_done, 0)); + CUDA_RT_CALL(cudaStreamWaitEvent(compute_stream, push_bottom_done, 0)); #endif - CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, - compute_stream)); + CUDA_RT_CALL(cudaMemcpyAsync(l2_norm_h, l2_norm_d, sizeof(real), cudaMemcpyDeviceToHost, + compute_stream)); } const int top = rank > 0 ? rank - 1 : (size - 1); diff --git a/8-H_NCCL_NVSHMEM/NCCL/copy.mk b/8-H_NCCL_NVSHMEM/NCCL/copy.mk index fd72f5b..073428b 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/copy.mk +++ b/8-H_NCCL_NVSHMEM/NCCL/copy.mk @@ -3,8 +3,9 @@ TASKDIR = ../../tasks/8-H_NCCL_NVSHMEM/NCCL/ SOLUTIONDIR = ../../solutions/8-H_NCCL_NVSHMEM/NCCL + PROCESSFILES = jacobi.cpp -COPYFILES = Makefile Instructions.ipynb Instructions.md +COPYFILES = Makefile jacobi.cpp Instructions.ipynb Instructions.md TASKPROCCESFILES = $(addprefix $(TASKDIR)/,$(PROCESSFILES)) From 62bef40f17d9591153e593afec34b790000a7425 Mon Sep 17 00:00:00 2001 From: Simon Garcia de Gonzalo Date: Fri, 29 Oct 2021 11:23:11 +0200 Subject: [PATCH 25/25] Update 8-H_NCCL_NVSHMEM/NCCL/Instructions.md Co-authored-by: Jiri Kraus --- 8-H_NCCL_NVSHMEM/NCCL/Instructions.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/8-H_NCCL_NVSHMEM/NCCL/Instructions.md b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md index 5064d71..5caa3b9 100644 --- a/8-H_NCCL_NVSHMEM/NCCL/Instructions.md +++ b/8-H_NCCL_NVSHMEM/NCCL/Instructions.md @@ -11,7 +11,7 @@ #### Description -The purpose of this task is to use the NCCL instead of MPI to implement a multi-GPU jacobi solver. The starting point of this task is the MPI variant of the jacobi solver. You need to work on `TODOs` in `jacobi.cu`: +The purpose of this task is to use NCCL instead of MPI to implement a multi-GPU jacobi solver. The starting point of this task is the MPI variant of the jacobi solver. You need to work on `TODOs` in `jacobi.cpp`: - Initialize NVSHMEM: - Include NCCL headers.