diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index 73097dc28fdb6..43ce9e6c0b507 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -1447,6 +1447,12 @@ void scaled_gemm( #if defined(USE_ROCM) #if defined(HIPBLASLT_OUTER_VEC) // this case is handled later as hipified CUBLASLT_MATMUL_MATRIX_SCALE_OUTER_VEC_32F + if (use_rowwise) { + // swapped + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, mat2_scale_ptr); + computeDesc.setAttribute(CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, mat1_scale_ptr); + } + else #elif defined(HIPBLASLT_VEC_EXT) if (use_rowwise) { // swapped diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 9da63650f3945..de94ad3ddf884 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -1,5 +1,6 @@ # Owner(s): ["module: linear algebra"] +from contextlib import nullcontext import unittest from itertools import product from functools import partial @@ -356,7 +357,8 @@ def test_float8_basics(self, device) -> None: self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported - with self.assertRaises(RuntimeError): + # supported on ROCm but fails on CUDA + with self.assertRaises(RuntimeError) if torch.version.hip is None else nullcontext(): self._test_tautological_mm(device, e5m2_type, e5m2_type) self._test_tautological_mm(device, size=64, out_dtype=torch.float16)