-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
query warp size for host code, do not use C10_WARP_SIZE #857
base: rocm4.5_internal_testing
Are you sure you want to change the base?
query warp size for host code, do not use C10_WARP_SIZE #857
Conversation
ROCm supports gfx targets with 32 and 64 warp size. Device compilation correctly handles the C10_WARP_SIZE (aka warpSize) constant. Host compilation cannot rely on a single hard-coded value, but instead needs to query device properties at runtime.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeffdaily Looks good, just one snippet looks suspect and a clarification is needed.
static_assert(num_threads % C10_WARP_SIZE == 0 && | ||
num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, | ||
TORCH_INTERNAL_ASSERT(num_threads % warp_size == 0); | ||
static_assert(num_threads <= cuda_utils::kCUDABlockReduceMaxThreads, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jeffdaily Sorry, I am not clear on the difference between TORCH_INTERNAL_ASSERT
and static_assert
. Would you mind explaining it here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static_assert
is a C++ feature that can assert at compile-time, but it requires all variables to be known at compile-time. Since warp_size
here is now a runtime query, it cannot be used in the static_assert
. I was treating the TORCH_INTERNAL_ASSERT
as the nearest at-runtime equivalent.
} | ||
case 2: | ||
handle_fused_mode<128, scalar_t>( | ||
grid, self, ti_values, ti_indices, slice_size, slices); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it correct to reduce the if-else to just the if logic? Shouldn't we just replace C10_WARP_SIZE
with at::cuda::warp_size()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought about this one for a while. I could have easily used warp_size
here, the runtime query. The handle_fused_mode
function takes a size
template arg and the static_assert
ensures the size
is at least 2 * C10_WARP_SIZE
.
In the switch(ceilPowerOf2)
statement, the previously handled case was 256, and the next largest case is 128. For the if statement if (celPowerOf2 > 2*C10_WARP_SIZE)
-- we know the warp size is going to be 32 or 64. So it becomes if (128 > 64)
for warp size 32, and if (128 > 128)
for warp size 64. Not terribly useful IMHO. When C10_WARP_SIZE
is 64, it will always call handle_fused_mode<128>
for cases 128 through 2. My change here merely makes that clear. The earlier code wasn't much of an optimization to begin with.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps the if statement was wrong in the first place. Perhaps it should have been if (ceilPowerOf2 >= 2 * C10_WARP_SIZE)
--- >=
instead of just >
.
@@ -574,8 +574,9 @@ void GroupNormKernelImplInternal( | |||
T* rstd_data = rstd.data_ptr<T>(); | |||
|
|||
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream(); | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: extra newline
This reverts commit 24e27af.
…to workaround certificate expiry issue
ROCm supports gfx targets with 32 and 64 warp size. Device compilation correctly handles the C10_WARP_SIZE (aka warpSize) constant. Host compilation cannot rely on a single hard-coded value, but instead needs to query device properties at runtime.
8ba50d8
to
3ec52d0
Compare
…m/ROCmSoftwarePlatform/pytorch into rocm4.5_internal_testing_warpsize
aab8f1d
to
b81dc22
Compare
…sting_warpsize_mmelesse_pr_2 fix launch_unrolled_kernel_for_multi_outputs
ROCm supports gfx targets with 32 and 64 warp size. Device compilation
correctly handles the C10_WARP_SIZE (aka warpSize) constant. Host
compilation cannot rely on a single hard-coded value, but instead needs
to query device properties at runtime.