STF: extend cuda_try output-parameter inference (first/last + ambiguity rejection)#8891
Conversation
Generalize the templated `cuda_try<fun>(args...)` helper to synthesize the output parameter at either the first or the last position, reject ambiguous cases at compile time, and document the three call shapes. Also migrate a handful of safe call sites in __stf and __places to the new shorthand. Co-authored-by: Cursor <cursoragent@cursor.com>
60b56d9 to
2e2e305
Compare
|
/ok to test 2e2e305 |
This comment has been minimized.
This comment has been minimized.
|
I guess this belongs to CCCL utilities at some point, but it's nice ! |
|
/ok to test e21c8d8 |
This comment has been minimized.
This comment has been minimized.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Enterprise Run ID: 📒 Files selected for processing (1)
📝 WalkthroughSummary by CodeRabbit
suggestion: WalkthroughThe PR extends ChangesValue-returning cuda_try with last-parameter support
Suggested reviewers
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 9007af20-0f6c-4f91-b517-43493d544042
📒 Files selected for processing (7)
cudax/include/cuda/experimental/__places/exec/green_context.cuhcudax/include/cuda/experimental/__places/localized_array.cuhcudax/include/cuda/experimental/__places/stream_pool.cuhcudax/include/cuda/experimental/__stf/graph/graph_ctx.cuhcudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuhcudax/test/stf/static_error_checks/CMakeLists.txtcudax/test/stf/static_error_checks/cuda_try_ambiguous.cu
| constexpr bool direct_form = ::std::is_invocable_v<decltype(fun), Ps...>; | ||
|
|
||
| constexpr bool first_output_form = | ||
| ::std::is_pointer_v<reserved::first_param<fun>> | ||
| && !::std::is_const_v<::std::remove_pointer_t<reserved::first_param<fun>>> | ||
| && ::std::is_invocable_v<decltype(fun), ::std::remove_pointer_t<reserved::first_param<fun>>*, Ps...>; | ||
|
|
||
| constexpr bool last_output_form = | ||
| ::std::is_pointer_v<reserved::last_param<fun>> | ||
| && !::std::is_const_v<::std::remove_pointer_t<reserved::last_param<fun>>> | ||
| && ::std::is_invocable_v<decltype(fun), Ps..., ::std::remove_pointer_t<reserved::last_param<fun>>*>; |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Identify zero-argument cuda_try<...>() call sites that are most likely to hit eager param-trait instantiation.
# Expected: inspect matches and confirm whether called functions have zero parameters.
rg -nP 'cuda_try<[^>]+>\(\s*\)' --type=cu --type=cuh --type=h --type=hpp --type=cpp --type=cxxRepository: NVIDIA/cccl
Length of output: 82
🏁 Script executed:
# First, let's understand the file and examine the problematic code section
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '440,460p'Repository: NVIDIA/cccl
Length of output: 1266
🏁 Script executed:
# Search for cuda_try call sites with corrected pattern (use glob instead of file type)
rg 'cuda_try<[^>]+>\s*\(' --glob='*.{cu,cuh,h,hpp,cpp,cxx}' -nRepository: NVIDIA/cccl
Length of output: 9845
🏁 Script executed:
# Examine the reserved::first_param and reserved::last_param trait implementations
rg -n 'struct first_param|struct last_param|template.*first_param|template.*last_param' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B2 -A8Repository: NVIDIA/cccl
Length of output: 695
🏁 Script executed:
# Find the full trait and alias definitions
rg -n 'first_param\s*=' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -A2 -B2Repository: NVIDIA/cccl
Length of output: 166
🏁 Script executed:
# Also search for last_param alias definition
rg -n 'last_param\s*=' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -A2 -B2Repository: NVIDIA/cccl
Length of output: 193
🏁 Script executed:
# Check if there are any detection traits or SFINAE machinery
rg -n 'void_t|is_detected|_v<.*first_param|_v<.*last_param' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B2 -A2Repository: NVIDIA/cccl
Length of output: 1333
🏁 Script executed:
# Look at the test file to see if zero-arg callables are tested
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '500,510p'Repository: NVIDIA/cccl
Length of output: 531
🏁 Script executed:
# Check if there's a primary template for first_param_impl
rg -n 'struct first_param_impl' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -A5Repository: NVIDIA/cccl
Length of output: 255
🏁 Script executed:
# Check if there's a primary template for last_param_impl
rg -n 'struct last_param_impl' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -A8Repository: NVIDIA/cccl
Length of output: 465
🏁 Script executed:
# Check the test functions to see their signatures
rg -n 'test_first_output_param|test_last_output_param' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B5 -A5Repository: NVIDIA/cccl
Length of output: 897
🏁 Script executed:
# Search for zero-parameter functions or void-returning functions that might trigger this
rg 'extern\s+\w+\s+\w+\s*\(\s*\)' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuhRepository: NVIDIA/cccl
Length of output: 37
🏁 Script executed:
# Check the function_last_param_impl specialization to see if it handles empty Ps
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '259,280p'Repository: NVIDIA/cccl
Length of output: 781
🏁 Script executed:
# Try to understand if there's a fallback when a primary template isn't specialized
# Look for any SFINAE-like pattern or detection helpers
rg -n 'dependent_false|detected|void_t' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B3 -A3Repository: NVIDIA/cccl
Length of output: 437
🏁 Script executed:
# Check if there are any truly zero-parameter functions being used with cuda_try
# Look at test helpers defined in the file
rg -n 'inline.*\(\s*\)' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B2 -A6Repository: NVIDIA/cccl
Length of output: 37
🏁 Script executed:
# Try to compile a test case with a zero-param function to verify the issue
# First, let's see the full context of lines 486-496 for the test functions
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '475,510p'Repository: NVIDIA/cccl
Length of output: 1236
🏁 Script executed:
# Let's check the actual signatures of the CUDA functions being called
# by examining if there are any comments or documentation about them
rg -n 'cudaGetDevice|cudaGetDeviceCount|cudaMemGetInfo' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B2 -A2Repository: NVIDIA/cccl
Length of output: 1572
🏁 Script executed:
# Let me create a minimal test to understand the actual behavior
# Check if there are any compile tests or error expectations
fd 'cuda_try.*test' --type f --glob '*.cu' --glob '*.cuh' | head -20Repository: NVIDIA/cccl
Length of output: 199
🏁 Script executed:
# Look for the actual CUDA function declarations or their signatures
rg -n 'extern.*cudaGetDevice\|extern.*cudaGetDeviceCount' cudax/include -R --type=cu --type=h --type=cuh --type=hppRepository: NVIDIA/cccl
Length of output: 460
🏁 Script executed:
# Try a different approach: look at what happens if first_param or last_param can't be resolved
# Check if there's a primary (unspecialized) template and what error it would give
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '236,248p'Repository: NVIDIA/cccl
Length of output: 357
🏁 Script executed:
# Let me check the actual function signatures of the CUDA API being called
# by looking at the documentation or examining if cudaGetDevice actually takes parameters
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '414,430p'Repository: NVIDIA/cccl
Length of output: 1275
🏁 Script executed:
# Let's verify: does cudaGetDevice take a parameter? Check the call site at line 501
# The code says: int dev = cuda_try<cudaGetDevice>();
# This suggests it's being called in "output form" which means the function takes a pointer parameter
# Let's check if there's any documentation about this
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '410,440p'Repository: NVIDIA/cccl
Length of output: 2189
🏁 Script executed:
# Key question: are there ANY functions with zero parameters?
# Let's examine the comment that mentions ambiguity - it shows cudaMemGetInfo which takes 2 parameters
# Let me check if there's any real zero-parameter function in the CUDA API or test code
# by looking at what happens when last_param_impl<> is instantiated
python3 <<'EOF'
# Check: does last_param_impl have a primary template?
# If not, trying to instantiate last_param_impl<> (empty pack) would fail with:
# "error: implicit instantiation of undefined template 'last_param_impl<>'"
# This happens at line 453 when evaluating:
# ::std::is_pointer_v<reserved::last_param<fun>>
# for a zero-parameter function
# The issue is REAL: the trait instantiation happens eagerly in constexpr context
# before the condition can be short-circuited or reported as SFINAE failure
print("Analysis of the issue:")
print("1. For zero-parameter function signatures R (*)():")
print(" - first_param_impl<R (*)()> is undefined → hard error")
print(" - last_param_impl<> is undefined → hard error")
print()
print("2. Current code evaluates constexpr bool conditions eagerly:")
print(" - constexpr bool first_output_form = ::std::is_pointer_v<reserved::first_param<fun>>")
print(" - This instantiates first_param<fun>, which fails hard if first_param_impl is not specialized")
print()
print("3. The fallback static_assert at line 481 would never be reached")
print()
print("4. However, in the current codebase:")
print(" - All CUDA functions used have at least one parameter")
print(" - The test functions also have at least one parameter")
print(" - So this is a latent issue, not currently triggered")
EOFRepository: NVIDIA/cccl
Length of output: 756
important: first_output_form and last_output_form eagerly instantiate reserved::first_param<fun> and reserved::last_param<fun>. For functions without parameters (e.g., zero-argument callables), these traits have no specialization and cause hard template errors before reaching the intended fallback static_assert at line 481. Make these checks SFINAE-safe using detection traits that resolve to false rather than hard_error on ill-formed parameter extraction, so diagnostics remain deterministic and the code can handle future zero-parameter function signatures.
|
/ok to test cebc63c |
🥳 CI Workflow Results🟩 Finished in 1h 04m: Pass: 100%/55 | Total: 1d 02h | Max: 1h 04m | Hits: 10%/195214See results here. |
Summary
Today, calls to single-output CUDA APIs through STF's
cuda_trylook like this:Three lines of bookkeeping per call, with a temporary that has to be named, default-initialized correctly, and then thrown away after one use. The pattern is mechanical, error-prone (e.g. uninitialized variable on the failure path), and obscures what the code is actually doing.
This PR generalizes the templated
cuda_try<fun>(args...)overload so that all three of the above become one-liners:The helper picks the right call shape at compile time based on the function's signature, with no runtime cost and no loss of error reporting (the same
cuda_exceptionis thrown on failure, with the same source location).What changed
cuda_try<fun>(args...)now selects, in order, from three forms:fun(args...)is invocable, call it and check the status. Returnsvoid.fun's first parameter is a non-constpointer (the CUDA convention for output parameters) andfun(&result, args...)is invocable, materializeresult, call, and return it. MatchescudaStreamCreate,cudaGraphAddEmptyNode,cudaDeviceCanAccessPeer, ...fun's last parameter is a non-constpointer andfun(args..., &result)is invocable, materializeresult, call, and return it. Matches the driver-API "getter" convention:cuStreamGetId,cuCtxGetId,cuStreamGetCtx, ...Form (3) is the new piece - previously only forms (1) and (2) were supported, which is why driver-API getters had to fall back to the verbose pattern.
Compile-time ambiguity rejection
Some CUDA functions have non-
constpointer parameters in both the first and the last position (e.g.cudaMemGetInfo(size_t* free, size_t* total)). For those, supplying a singlesize_t*is consistent with either interpretation, with different effects. Rather than silently picking one,cuda_trynow refuses to compile such calls:The user is then expected to fall back to the explicit
cuda_try(cudaMemGetInfo(&free, &total))form, which is unambiguous.The zero-argument case (
cuda_try<fun>()) is exempt from the assertion because the synthesized callfun(&result)is identical for both interpretations.A negative compile-time test (
cudax/test/stf/static_error_checks/cuda_try_ambiguous.cu) exercises the assertion usingcudaMemGetInfo, and is wired into the existing static-error test infrastructure.Documentation
The doc comment on
cuda_try<fun>now spells out:cudaMallocare unsupported and must use the runtime-status overload),Call-site migrations
A handful of safe, single-output call sites in
__stfand__placesare migrated to the new shorthand to demonstrate the ergonomics and exercise both new forms in CI:cudax/include/cuda/experimental/__places/exec/green_context.cuh--cuCtxGetIdcudax/include/cuda/experimental/__places/localized_array.cuh--cudaDeviceCanAccessPeercudax/include/cuda/experimental/__places/stream_pool.cuh--cudaStreamGetDevice,cuStreamGetCtx,cuStreamGetIdcudax/include/cuda/experimental/__stf/graph/graph_ctx.cuh--cudaGraphAddEmptyNodeSites involving overloaded CUDA functions (
cudaMallocfamily) or in/out parameters were intentionally left alone.What is not changed
cuda_try(cudaSomething(...))is unchanged and remains the right tool whenever the new shorthand cannot be used (overloaded functions, multiple outputs, ambiguous signatures).cuda_trytemplate.cuda_exception, same source location.Test plan
cudax-cpp17preset,CMAKE_CUDA_ARCHITECTURES=86.cuda_tryunit tests (cuda_try1,cuda_try2) pass.cuda_try_ambiguous.cufails to compile with the expectedstatic_assertmessage and is wired into the static-error CMake test list.