Skip to content

Commit 0ad8381

Browse files
committed
Address review comments wrt triton_heuristics and install_rocm
1 parent f3e8213 commit 0ad8381

File tree

2 files changed

+22
-23
lines changed

2 files changed

+22
-23
lines changed

.ci/docker/common/install_rocm.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,6 @@ EOF
114114
rm -rf HIP clr
115115
fi
116116

117-
# temporary hipblasLT dependency install
118-
apt install libmsgpackc2
119117
pip_install "git+https://github.com/rocm/composable_kernel@$ROCM_COMPOSABLE_KERNEL_VERSION"
120118

121119
# Cleanup

torch/_inductor/runtime/triton_heuristics.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2959,31 +2959,32 @@ def _persistent_reduction_configs(
29592959
if "y" in size_hints:
29602960
pass
29612961
# TODO(jansel): we should be able to improve these heuristics
2962-
elif reduction_hint == ReductionHint.INNER:
2963-
if rnumel > 1024:
2964-
configs = configs[:1]
2965-
else:
2966-
x_block = 8
2967-
if xnumel // x_block < 128 or (loads_and_stores >= 5 and rnumel >= 256):
2968-
# If loads/stores greater than 5, a lot of register pressure
2969-
# rnumel < 256 means no vectorized loads if we split up r dim
2970-
# so xblock still needs to be larger
2971-
x_block = 1
2972-
2973-
configs = [
2974-
triton_config_reduction(
2975-
size_hints,
2976-
x_block,
2977-
rnumel,
2978-
register_intensive=True,
2979-
reduction_hint=reduction_hint,
2980-
)
2981-
]
2962+
elif not max_autotune_enabled: # Don't filter if tuning enabled
2963+
if reduction_hint == ReductionHint.INNER:
2964+
if rnumel > 1024:
2965+
configs = configs[:1]
2966+
else:
2967+
x_block = 8
2968+
if xnumel // x_block < 128 or (loads_and_stores >= 5 and rnumel >= 256):
2969+
# If loads/stores greater than 5, a lot of register pressure
2970+
# rnumel < 256 means no vectorized loads if we split up r dim
2971+
# so xblock still needs to be larger
2972+
x_block = 1
2973+
2974+
configs = [
2975+
triton_config_reduction(
2976+
size_hints,
2977+
x_block,
2978+
rnumel,
2979+
register_intensive=True,
2980+
reduction_hint=reduction_hint,
2981+
)
2982+
]
29822983

29832984
elif reduction_hint == ReductionHint.OUTER:
29842985
configs = configs[-1:]
29852986
elif reduction_hint == ReductionHint.OUTER_TINY:
2986-
configs = [
2987+
tiny_configs = [
29872988
triton_config_reduction(
29882989
size_hints,
29892990
2 * (256 // rnumel) if rnumel <= 256 else 1,

0 commit comments

Comments
 (0)