diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index f950d9c71f3ea..abf45deffeb96 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -17,6 +17,7 @@ #include #ifdef USE_ROCM +#include #include // until hipblas has an API to accept flags, we must use rocblas here #include @@ -185,6 +186,64 @@ uint32_t _getAlignment(uintptr_t address) { } #endif +#ifdef USE_ROCM +static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) { + // 0 is default value, meaning full CUs i.e. no mask + if (value == 0) { + return at::cuda::getCurrentCUDAStream(); + } + static int32_t last_value = 0; + static hipStream_t stream; + if (last_value == 0) { + // first request, do nothing for this case + } + else if (last_value == value) { + // stream was created previously and value hasn't changed + return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device()); + } + else { + // need a new stream and a previous stream exists, delete it + AT_CUDA_CHECK(hipStreamDestroy(stream)); + } + + // if we got here, we need to create a new stream + int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + // how many uint32_t do we need to cover all CUs, fill bitmask with 1 + uint32_t mask_size = static_cast((CUs + 32 - 1) / 32); + std::vector mask(mask_size, uint32_t{0x00000000}); + // starting from lowest order bits, in 32-bit chunks + // set bits to 0 based on how many CUs to carve out + int32_t full_shifts = value / 32; + int32_t remainder = value % 32; + for (int32_t i = 0; i < full_shifts; i++) { + mask[i] = uint32_t{0xffffffff}; + } + mask[full_shifts] = uint32_t{0xffffffff} << (32 - remainder); + + // finally, create masked stream + AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0])); + + last_value = value; + return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device()); +} + +static void _syncCurrentWithCarveoutStream(hipStream_t stream, bool presync) { + hipEvent_t event; + AT_CUDA_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming)); + + auto current_stream = at::cuda::getCurrentCUDAStream(); + + if (presync) { + AT_CUDA_CHECK(hipEventRecord(event, current_stream)); + AT_CUDA_CHECK(hipStreamWaitEvent(stream, event, 0)); + } + else { + AT_CUDA_CHECK(hipEventRecord(event, stream)); + AT_CUDA_CHECK(hipStreamWaitEvent(current_stream, event, 0)); + } +} +#endif + struct CublasLtWorkspace { CublasLtWorkspace() { size = at::cuda::getCUDABlasLtWorkspaceSize(); @@ -360,6 +419,7 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa); computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb); + auto stream = at::cuda::getCurrentCUDAStream(); #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -367,6 +427,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +#else + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + stream = _getCarveoutStream( + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); + _syncCurrentWithCarveoutStream(stream, true); + } #endif CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T); CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T); @@ -430,7 +496,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) { &heuristicResult.algo, ltworkspace.ptr, ltworkspace.size, - at::cuda::getCurrentCUDAStream()); + stream); +#ifdef USE_ROCM + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + _syncCurrentWithCarveoutStream(stream, false); + } +#endif TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -1295,6 +1366,7 @@ void gemm_and_bias( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); + auto stream = at::cuda::getCurrentCUDAStream(); #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -1302,6 +1374,12 @@ void gemm_and_bias( at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +#else + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + stream = _getCarveoutStream( + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); + _syncCurrentWithCarveoutStream(stream, true); + } #endif cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS; if (activation == GEMMAndBiasActivationEpilogue::RELU) { @@ -1370,7 +1448,12 @@ void gemm_and_bias( &heuristicResult.algo, ltworkspace.ptr, ltworkspace.size, - at::cuda::getCurrentCUDAStream()); + stream); +#ifdef USE_ROCM + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + _syncCurrentWithCarveoutStream(stream, false); + } +#endif TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -1525,6 +1608,7 @@ void scaled_gemm( if (result_scale_ptr != nullptr) { computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr); } + auto stream = at::cuda::getCurrentCUDAStream(); #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -1532,6 +1616,12 @@ void scaled_gemm( at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +#else + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + stream = _getCarveoutStream( + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); + _syncCurrentWithCarveoutStream(stream, true); + } #endif #ifndef USE_ROCM const int8_t fastAccuMode = use_fast_accum ? 1 : 0; @@ -1570,7 +1660,6 @@ void scaled_gemm( #endif // if CUDA_VERSION >= 12090 } - auto stream = c10::cuda::getCurrentCUDAStream(); CuBlasLtMatmulPreference preference; auto ltworkspace = CublasLtWorkspace(); preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size); @@ -1657,6 +1746,11 @@ void scaled_gemm( ltworkspace.ptr, ltworkspace.size, stream); +#ifdef USE_ROCM + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + _syncCurrentWithCarveoutStream(stream, false); + } +#endif TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -1710,6 +1804,7 @@ void int8_gemm( computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa); cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N; computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb); + auto stream = at::cuda::getCurrentCUDAStream(); #ifndef USE_ROCM if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { computeDesc.setAttribute( @@ -1717,6 +1812,12 @@ void int8_gemm( at::cuda::getCurrentDeviceProperties()->multiProcessorCount - at::globalContext()._SMCarveout_EXPERIMENTAL().value()); } +#else + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + stream = _getCarveoutStream( + at::globalContext()._SMCarveout_EXPERIMENTAL().value()); + _syncCurrentWithCarveoutStream(stream, true); + } #endif CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1); @@ -1778,7 +1879,7 @@ void int8_gemm( #else 0, #endif - at::cuda::getCurrentCUDAStream()); + stream); TORCH_CHECK( cublasStatus == CUBLAS_STATUS_SUCCESS, "CUDA error: ", @@ -1807,6 +1908,11 @@ void int8_gemm( computeType, " scaleType ", scaleType); +#ifdef USE_ROCM + if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) { + _syncCurrentWithCarveoutStream(stream, false); + } +#endif } template <> diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 9aab5e350a93a..18e9edca4835f 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -849,7 +849,6 @@ def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None: self.assertEqual(out_dtype, out_fp8.dtype) self.assertEqual(out_fp32, out_fp8.to(torch.float)) - @unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout") @unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling") @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) @unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet") @@ -878,15 +877,38 @@ def test_honor_sm_carveout(self) -> None: torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16) prof.export_chrome_trace(f.name) - no_carveout, carveout_0, carveout_66, no_carveout_again = [ - math.prod(evt.get("args", {}).get("grid", [])) - for evt in json.load(open(f.name))["traceEvents"] - if evt.get("cat", "") == "kernel" - ] - - self.assertEqual(no_carveout, no_carveout_again) - self.assertNotEqual(no_carveout, carveout_66) - self.assertNotEqual(carveout_66, carveout_0) + if torch.version.hip: + events = [evt for evt in json.load(open(f.name))["traceEvents"] if evt.get("cat", "") == "kernel"] + # events were returned out of order; need to be sorted on "ts" timestamp + events = sorted(events, key=lambda x: x['ts']) + # ROCm carveout is invisible except for kernels running slower on fewer CUs + no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events] + if True or not (no_carveout < carveout_66 and carveout_0 < carveout_66 and no_carveout_again < carveout_66): + # something went wrong, print more info to help debug flaky test + print("ROCm debug info for test_honor_sm_carveout") + print("no_carveout", no_carveout) + print("carveout_0", carveout_0) + print("carveout_66", carveout_66) + print("no_carveout_again", no_carveout_again) + self.assertTrue(no_carveout < carveout_66) + self.assertTrue(carveout_0 < carveout_66) + self.assertTrue(no_carveout_again < carveout_66) + # ROCm carveout will create new streams when enabled, and go back to the original stream when disabled + no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events] + self.assertTrue(no_carveout == no_carveout_again) + self.assertTrue(no_carveout == carveout_0) + self.assertTrue(no_carveout != carveout_66) + self.assertTrue(carveout_0 != carveout_66) + else: + no_carveout, carveout_0, carveout_66, no_carveout_again = [ + math.prod(evt.get("args", {}).get("grid", [])) + for evt in json.load(open(f.name))["traceEvents"] + if evt.get("cat", "") == "kernel" + ] + + self.assertEqual(no_carveout, no_carveout_again) + self.assertNotEqual(no_carveout, carveout_66) + self.assertNotEqual(carveout_66, carveout_0) @unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg) @parametrize("test_case_name", [