Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 110 additions & 4 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <c10/core/ScalarType.h>

#ifdef USE_ROCM
#include <c10/cuda/CUDAStream.h>
#include <hipblaslt/hipblaslt-ext.hpp>
// until hipblas has an API to accept flags, we must use rocblas here
#include <hipblas/hipblas.h>
Expand Down Expand Up @@ -185,6 +186,64 @@ uint32_t _getAlignment(uintptr_t address) {
}
#endif

#ifdef USE_ROCM
static c10::cuda::CUDAStream _getCarveoutStream(int32_t value) {
// 0 is default value, meaning full CUs i.e. no mask
if (value == 0) {
return at::cuda::getCurrentCUDAStream();
}
static int32_t last_value = 0;
static hipStream_t stream;
if (last_value == 0) {
// first request, do nothing for this case
}
else if (last_value == value) {
// stream was created previously and value hasn't changed
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
}
else {
// need a new stream and a previous stream exists, delete it
AT_CUDA_CHECK(hipStreamDestroy(stream));
}

// if we got here, we need to create a new stream
int32_t CUs = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
// how many uint32_t do we need to cover all CUs, fill bitmask with 1
uint32_t mask_size = static_cast<uint32_t>((CUs + 32 - 1) / 32);
std::vector<uint32_t> mask(mask_size, uint32_t{0x00000000});
// starting from lowest order bits, in 32-bit chunks
// set bits to 0 based on how many CUs to carve out
int32_t full_shifts = value / 32;
int32_t remainder = value % 32;
for (int32_t i = 0; i < full_shifts; i++) {
mask[i] = uint32_t{0xffffffff};
}
mask[full_shifts] = uint32_t{0xffffffff} << (32 - remainder);

// finally, create masked stream
AT_CUDA_CHECK(hipExtStreamCreateWithCUMask(&stream, mask_size, &mask[0]));

last_value = value;
return c10::cuda::getStreamFromExternal(stream, c10::cuda::current_device());
}

static void _syncCurrentWithCarveoutStream(hipStream_t stream, bool presync) {
hipEvent_t event;
AT_CUDA_CHECK(hipEventCreateWithFlags(&event, hipEventDisableTiming));

auto current_stream = at::cuda::getCurrentCUDAStream();

if (presync) {
AT_CUDA_CHECK(hipEventRecord(event, current_stream));
AT_CUDA_CHECK(hipStreamWaitEvent(stream, event, 0));
}
else {
AT_CUDA_CHECK(hipEventRecord(event, stream));
AT_CUDA_CHECK(hipStreamWaitEvent(current_stream, event, 0));
}
}
#endif

struct CublasLtWorkspace {
CublasLtWorkspace() {
size = at::cuda::getCUDABlasLtWorkspaceSize();
Expand Down Expand Up @@ -360,13 +419,20 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
CuBlasLtMatmulDescriptor computeDesc(computeType, scaleType);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, opa);
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, opb);
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif
CuBlasLtMatrixLayout Adesc(abcType, m, k, lda, opa == CUBLAS_OP_T);
CuBlasLtMatrixLayout Bdesc(abcType, k, n, ldb, opb == CUBLAS_OP_T);
Expand Down Expand Up @@ -430,7 +496,12 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
&heuristicResult.algo,
ltworkspace.ptr,
ltworkspace.size,
at::cuda::getCurrentCUDAStream());
stream);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
Expand Down Expand Up @@ -1295,13 +1366,20 @@ void gemm_and_bias(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif
cublasLtEpilogue_t epilogue = CUBLASLT_EPILOGUE_BIAS;
if (activation == GEMMAndBiasActivationEpilogue::RELU) {
Expand Down Expand Up @@ -1370,7 +1448,12 @@ void gemm_and_bias(
&heuristicResult.algo,
ltworkspace.ptr,
ltworkspace.size,
at::cuda::getCurrentCUDAStream());
stream);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
Expand Down Expand Up @@ -1525,13 +1608,20 @@ void scaled_gemm(
if (result_scale_ptr != nullptr) {
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, result_scale_ptr);
}
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif
#ifndef USE_ROCM
const int8_t fastAccuMode = use_fast_accum ? 1 : 0;
Expand Down Expand Up @@ -1570,7 +1660,6 @@ void scaled_gemm(
#endif // if CUDA_VERSION >= 12090
}

auto stream = c10::cuda::getCurrentCUDAStream();
CuBlasLtMatmulPreference preference;
auto ltworkspace = CublasLtWorkspace();
preference.setAttribute(CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, ltworkspace.size);
Expand Down Expand Up @@ -1657,6 +1746,11 @@ void scaled_gemm(
ltworkspace.ptr,
ltworkspace.size,
stream);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
Expand Down Expand Up @@ -1710,13 +1804,20 @@ void int8_gemm(
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSA, transa);
cublasOperation_t transb = transpose_mat2 ? CUBLAS_OP_T : CUBLAS_OP_N;
computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_TRANSB, transb);
auto stream = at::cuda::getCurrentCUDAStream();
#ifndef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
computeDesc.setAttribute<int32_t>(
CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET,
at::cuda::getCurrentDeviceProperties()->multiProcessorCount -
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
}
#else
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
stream = _getCarveoutStream(
at::globalContext()._SMCarveout_EXPERIMENTAL().value());
_syncCurrentWithCarveoutStream(stream, true);
}
#endif

CuBlasLtMatrixLayout Adesc(abType, m, k, mat1_ld, transpose_mat1);
Expand Down Expand Up @@ -1778,7 +1879,7 @@ void int8_gemm(
#else
0,
#endif
at::cuda::getCurrentCUDAStream());
stream);
TORCH_CHECK(
cublasStatus == CUBLAS_STATUS_SUCCESS,
"CUDA error: ",
Expand Down Expand Up @@ -1807,6 +1908,11 @@ void int8_gemm(
computeType,
" scaleType ",
scaleType);
#ifdef USE_ROCM
if (at::globalContext()._SMCarveout_EXPERIMENTAL().has_value()) {
_syncCurrentWithCarveoutStream(stream, false);
}
#endif
}

template <>
Expand Down
42 changes: 32 additions & 10 deletions test/test_matmul_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,7 +849,6 @@ def test_zero_dim_tensorwise(self, which_dim_zero, use_torch_compile) -> None:
self.assertEqual(out_dtype, out_fp8.dtype)
self.assertEqual(out_fp32, out_fp8.to(torch.float))

@unittest.skipIf(TEST_WITH_ROCM, "ROCm doesn't support sm carveout")
@unittest.skipIf(IS_WINDOWS, "Windows doesn't support row-wise scaling")
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
@unittest.skipIf(not SM90OrLater, "sm89 kernel isn't opted into carveout yet")
Expand Down Expand Up @@ -878,15 +877,38 @@ def test_honor_sm_carveout(self) -> None:
torch._scaled_mm(x_fp8, y_fp8, scale_a=x_scales, scale_b=y_scales, out_dtype=torch.bfloat16)

prof.export_chrome_trace(f.name)
no_carveout, carveout_0, carveout_66, no_carveout_again = [
math.prod(evt.get("args", {}).get("grid", []))
for evt in json.load(open(f.name))["traceEvents"]
if evt.get("cat", "") == "kernel"
]

self.assertEqual(no_carveout, no_carveout_again)
self.assertNotEqual(no_carveout, carveout_66)
self.assertNotEqual(carveout_66, carveout_0)
if torch.version.hip:
events = [evt for evt in json.load(open(f.name))["traceEvents"] if evt.get("cat", "") == "kernel"]
# events were returned out of order; need to be sorted on "ts" timestamp
events = sorted(events, key=lambda x: x['ts'])
# ROCm carveout is invisible except for kernels running slower on fewer CUs
no_carveout, carveout_0, carveout_66, no_carveout_again = [float(evt.get("dur", "0.0")) for evt in events]
if True or not (no_carveout < carveout_66 and carveout_0 < carveout_66 and no_carveout_again < carveout_66):
# something went wrong, print more info to help debug flaky test
print("ROCm debug info for test_honor_sm_carveout")
print("no_carveout", no_carveout)
print("carveout_0", carveout_0)
print("carveout_66", carveout_66)
print("no_carveout_again", no_carveout_again)
self.assertTrue(no_carveout < carveout_66)
self.assertTrue(carveout_0 < carveout_66)
self.assertTrue(no_carveout_again < carveout_66)
# ROCm carveout will create new streams when enabled, and go back to the original stream when disabled
no_carveout, carveout_0, carveout_66, no_carveout_again = [int(evt.get("tid", "0")) for evt in events]
self.assertTrue(no_carveout == no_carveout_again)
self.assertTrue(no_carveout == carveout_0)
self.assertTrue(no_carveout != carveout_66)
self.assertTrue(carveout_0 != carveout_66)
else:
no_carveout, carveout_0, carveout_66, no_carveout_again = [
math.prod(evt.get("args", {}).get("grid", []))
for evt in json.load(open(f.name))["traceEvents"]
if evt.get("cat", "") == "kernel"
]

self.assertEqual(no_carveout, no_carveout_again)
self.assertNotEqual(no_carveout, carveout_66)
self.assertNotEqual(carveout_66, carveout_0)

@unittest.skipIf(not PLATFORM_SUPPORTS_MX_GEMM, mx_skip_msg)
@parametrize("test_case_name", [
Expand Down