Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <ATen/Context.h>

#include <c10/core/CPUAllocator.h>
#include <c10/util/Logging.h>

#include <algorithm>
#include <cctype>
Expand Down Expand Up @@ -173,6 +174,9 @@ bool Context::userEnabledOverrideableSDP() const {
static const char cublas_config_var_name[] = "CUBLAS_WORKSPACE_CONFIG";
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
static const char* const cublas_deterministic_configs[] = { ":4096:8", ":16:8" };
#ifdef USE_ROCM
static constexpr const auto hipblaslt_allow_tf32 = "HIPBLASLT_ALLOW_TF32";
#endif

bool Context::checkCuBLASConfigDeterministic() {
bool cublas_config_deterministic = true;
Expand Down Expand Up @@ -228,10 +232,24 @@ void Context::setBenchmarkLimitCuDNN(int b) {
}

bool Context::allowTF32CuBLAS() const {
#ifdef USE_ROCM
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
if (allow_tf32 != true) {
return false;
}
#endif
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
}

void Context::setAllowTF32CuBLAS(bool b) {
#ifdef USE_ROCM
const static auto allow_tf32 = c10::utils::check_env(hipblaslt_allow_tf32);
if (allow_tf32 != true) {
LOG(INFO) << "torch.backends.cuda.matmul.allow_tf32 is not supported on ROCm by default. "
<< "Please set environment variable HIPBLASLT_ALLOW_TF32=1 to enable it.";
return;
}
#endif
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
}

Expand Down
4 changes: 0 additions & 4 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -329,11 +329,9 @@ inline void bgemm_internal_cublaslt(CUDABLAS_BGEMM_ARGTYPES(Dtype)) {
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
#ifndef USE_ROCM
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
#endif
} else if constexpr (std::is_same_v<Dtype, c10::complex<double>>) {
abcType = CUDA_C_64F;
computeType = CUBLAS_COMPUTE_64F;
Expand Down Expand Up @@ -1205,11 +1203,9 @@ void gemm_and_bias(
computeType = CUBLAS_COMPUTE_64F;
scaleType = CUDA_R_64F;
} else if constexpr (std::is_same_v<Dtype, float>) {
#ifndef USE_ROCM
if (at::globalContext().allowTF32CuBLAS()) {
computeType = CUBLAS_COMPUTE_32F_FAST_TF32;
}
#endif
abcType = CUDA_R_32F;
} else if constexpr (std::is_same_v<Dtype, at::Half>) {
abcType = CUDA_R_16F;
Expand Down
55 changes: 37 additions & 18 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7717,24 +7717,43 @@ def write_state(state):
def fn(x):
return x + 1

initial_state = read_state()
y = torch.randn(10)
try:
for round in range(3):
for i in range(len(initial_state)):
new_state = [False] * len(initial_state)
new_state[i] = True
write_state(new_state)
assert read_state() == new_state
last_state.clear()
fn(y)
assert last_state == new_state
if round == 0:
assert cnt == i + 1
else:
assert cnt == len(initial_state)
finally:
write_state(initial_state)
import contextlib

@contextlib.contextmanager
def _hip_allow_tf32():
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"

try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]

tf32_ctx = _hip_allow_tf32 if torch.version.hip else contextlib.nullcontext
with tf32_ctx():
initial_state = read_state()
y = torch.randn(10)
try:
for round in range(3):
for i in range(len(initial_state)):
new_state = [False] * len(initial_state)
new_state[i] = True
write_state(new_state)
assert read_state() == new_state
last_state.clear()
fn(y)
assert last_state == new_state
if round == 0:
assert cnt == i + 1
else:
assert cnt == len(initial_state)
finally:
write_state(initial_state)

def test_grad_state_mutated(self):
prior = torch.is_grad_enabled()
Expand Down
33 changes: 33 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,33 @@ def check_workspace_size(inp):

torch._C._cuda_clearCublasWorkspaces()

@contextlib.contextmanager
def _hip_allow_tf32(self):
# for HIP/AMDGPU, tf32 is behind a flag because the TF32 support is new
# and only for MI300+
hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
os.environ["HIPBLASLT_ALLOW_TF32"] = "1"

try:
yield
finally:
if hip_allow_tf32 is not None:
os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
else:
del os.environ["HIPBLASLT_ALLOW_TF32"]

def test_cublas_allow_tf32_get_set(self):
"""
We only turn on TF32 for MI300 with a special env var. This is because TF32
is only available in MI300+ and is in experimental mode (hipblaslt support
is current WIP)
"""
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext

with tf32_ctx():
self._test_cublas_allow_tf32_get_set_inner()

def _test_cublas_allow_tf32_get_set_inner(self):
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
)
Expand All @@ -448,6 +474,12 @@ def test_cublas_allow_tf32_get_set(self):
torch.backends.cuda.matmul.allow_tf32 = orig

def test_float32_matmul_precision_get_set(self):
tf32_ctx = self._hip_allow_tf32 if torch.version.hip else contextlib.nullcontext

with tf32_ctx():
self._test_float32_matmul_precision_get_set_inner()

def _test_float32_matmul_precision_get_set_inner(self):
orig = torch.get_float32_matmul_precision()
skip_tf32_cublas = "TORCH_ALLOW_TF32_CUBLAS_OVERRIDE" in os.environ and int(
os.environ["TORCH_ALLOW_TF32_CUBLAS_OVERRIDE"]
Expand All @@ -459,6 +491,7 @@ def test_float32_matmul_precision_get_set(self):
self.assertEqual(torch.get_float32_matmul_precision(), "highest")
else:
self.assertTrue(torch.backends.cuda.matmul.allow_tf32)

for p in ("medium", "high"):
torch.set_float32_matmul_precision(p)
self.assertEqual(torch.get_float32_matmul_precision(), p)
Expand Down
4 changes: 4 additions & 0 deletions torch/utils/hipify/cuda_to_hip_mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -7284,6 +7284,10 @@
"CUBLAS_COMPUTE_32F",
("HIPBLAS_COMPUTE_32F", CONV_MATH_FUNC, API_BLAS)
),
(
"CUBLAS_COMPUTE_32F_FAST_TF32",
("HIPBLAS_COMPUTE_32F_FAST_TF32", CONV_MATH_FUNC, API_BLAS)
),
(
"CUBLAS_COMPUTE_64F",
("HIPBLAS_COMPUTE_64F", CONV_MATH_FUNC, API_BLAS)
Expand Down