Skip to content

STF: extend cuda_try output-parameter inference (first/last + ambiguity rejection)#8891

Merged
caugonnet merged 5 commits into
mainfrom
stf-extend-cuda-try-output-inference
May 21, 2026
Merged

STF: extend cuda_try output-parameter inference (first/last + ambiguity rejection)#8891
caugonnet merged 5 commits into
mainfrom
stf-extend-cuda-try-output-inference

Conversation

@andralex
Copy link
Copy Markdown
Contributor

@andralex andralex commented May 10, 2026

Summary

Today, calls to single-output CUDA APIs through STF's cuda_try look like this:

unsigned long long id;
cuda_try(cuStreamGetId(stream, &id));

cudaGraphNode_t node;
cuda_try(cudaGraphAddEmptyNode(&node, graph, deps.data(), deps.size()));

int can_access;
cuda_try(cudaDeviceCanAccessPeer(&can_access, dev, peer));

Three lines of bookkeeping per call, with a temporary that has to be named, default-initialized correctly, and then thrown away after one use. The pattern is mechanical, error-prone (e.g. uninitialized variable on the failure path), and obscures what the code is actually doing.

This PR generalizes the templated cuda_try<fun>(args...) overload so that all three of the above become one-liners:

auto id          = cuda_try<cuStreamGetId>(stream);
auto node        = cuda_try<cudaGraphAddEmptyNode>(graph, deps.data(), deps.size());
auto can_access  = cuda_try<cudaDeviceCanAccessPeer>(dev, peer);

The helper picks the right call shape at compile time based on the function's signature, with no runtime cost and no loss of error reporting (the same cuda_exception is thrown on failure, with the same source location).

What changed

cuda_try<fun>(args...) now selects, in order, from three forms:

  1. Direct. If fun(args...) is invocable, call it and check the status. Returns void.
  2. First-parameter output. Otherwise, if fun's first parameter is a non-const pointer (the CUDA convention for output parameters) and fun(&result, args...) is invocable, materialize result, call, and return it. Matches cudaStreamCreate, cudaGraphAddEmptyNode, cudaDeviceCanAccessPeer, ...
  3. Last-parameter output. Otherwise, if fun's last parameter is a non-const pointer and fun(args..., &result) is invocable, materialize result, call, and return it. Matches the driver-API "getter" convention: cuStreamGetId, cuCtxGetId, cuStreamGetCtx, ...

Form (3) is the new piece - previously only forms (1) and (2) were supported, which is why driver-API getters had to fall back to the verbose pattern.

Compile-time ambiguity rejection

Some CUDA functions have non-const pointer parameters in both the first and the last position (e.g. cudaMemGetInfo(size_t* free, size_t* total)). For those, supplying a single size_t* is consistent with either interpretation, with different effects. Rather than silently picking one, cuda_try now refuses to compile such calls:

error: static assertion failed: Ambiguous cuda_try: both first- and last-output forms apply;
       call the function explicitly to disambiguate.

The user is then expected to fall back to the explicit cuda_try(cudaMemGetInfo(&free, &total)) form, which is unambiguous.

The zero-argument case (cuda_try<fun>()) is exempt from the assertion because the synthesized call fun(&result) is identical for both interpretations.

A negative compile-time test (cudax/test/stf/static_error_checks/cuda_try_ambiguous.cu) exercises the assertion using cudaMemGetInfo, and is wired into the existing static-error test infrastructure.

Documentation

The doc comment on cuda_try<fun> now spells out:

  • the three call shapes and their selection order,
  • the ambiguity rule and its zero-argument exception,
  • known limitations (overloaded names like cudaMalloc are unsupported and must use the runtime-status overload),
  • worked examples for each form.

Call-site migrations

A handful of safe, single-output call sites in __stf and __places are migrated to the new shorthand to demonstrate the ergonomics and exercise both new forms in CI:

  • cudax/include/cuda/experimental/__places/exec/green_context.cuh -- cuCtxGetId
  • cudax/include/cuda/experimental/__places/localized_array.cuh -- cudaDeviceCanAccessPeer
  • cudax/include/cuda/experimental/__places/stream_pool.cuh -- cudaStreamGetDevice, cuStreamGetCtx, cuStreamGetId
  • cudax/include/cuda/experimental/__stf/graph/graph_ctx.cuh -- cudaGraphAddEmptyNode

Sites involving overloaded CUDA functions (cudaMalloc family) or in/out parameters were intentionally left alone.

What is not changed

  • The runtime-status overload cuda_try(cudaSomething(...)) is unchanged and remains the right tool whenever the new shorthand cannot be used (overloaded functions, multiple outputs, ambiguous signatures).
  • No public ABI/API changes outside the cuda_try template.
  • No behavior change on the failure path: same cuda_exception, same source location.

Test plan

  • Builds locally with cudax-cpp17 preset, CMAKE_CUDA_ARCHITECTURES=86.
  • Existing cuda_try unit tests (cuda_try1, cuda_try2) pass.
  • New negative compile-time test cuda_try_ambiguous.cu fails to compile with the expected static_assert message and is wired into the static-error CMake test list.
  • Full STF + cudax CI on this PR.

@andralex andralex requested a review from a team as a code owner May 10, 2026 22:56
@andralex andralex requested a review from caugonnet May 10, 2026 22:56
@github-project-automation github-project-automation Bot moved this to Todo in CCCL May 10, 2026
@copy-pr-bot
Copy link
Copy Markdown
Contributor

copy-pr-bot Bot commented May 10, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@cccl-authenticator-app cccl-authenticator-app Bot moved this from Todo to In Review in CCCL May 10, 2026
Generalize the templated `cuda_try<fun>(args...)` helper to synthesize
the output parameter at either the first or the last position, reject
ambiguous cases at compile time, and document the three call shapes.

Also migrate a handful of safe call sites in __stf and __places to the
new shorthand.

Co-authored-by: Cursor <cursoragent@cursor.com>
@andralex andralex force-pushed the stf-extend-cuda-try-output-inference branch from 60b56d9 to 2e2e305 Compare May 11, 2026 00:30
@andralex andralex requested a review from a team as a code owner May 11, 2026 00:30
@andralex andralex requested a review from bernhardmgruber May 11, 2026 00:30
@andralex andralex changed the title STF: extend cuda_try output inference STF: extend cuda_try output-parameter inference (first/last + ambiguity rejection) May 11, 2026
@andralex
Copy link
Copy Markdown
Contributor Author

/ok to test 2e2e305

@github-actions

This comment has been minimized.

@caugonnet
Copy link
Copy Markdown
Contributor

I guess this belongs to CCCL utilities at some point, but it's nice !

@caugonnet
Copy link
Copy Markdown
Contributor

/ok to test e21c8d8

@github-actions

This comment has been minimized.

@caugonnet caugonnet assigned caugonnet and unassigned caugonnet May 11, 2026
@caugonnet caugonnet added the stf Sequential Task Flow programming model label May 12, 2026
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 20, 2026

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: c7536937-4f84-4bf9-bb9c-031941ef3531

📥 Commits

Reviewing files that changed from the base of the PR and between 1f9cbd0 and cebc63c.

📒 Files selected for processing (1)
  • cudax/include/cuda/experimental/__stf/graph/graph_ctx.cuh

📝 Walkthrough

Summary by CodeRabbit

  • Refactor

    • Simplified and standardized CUDA API call patterns to use direct return-style invocations and improved handling of functions that produce outputs via either first- or last-parameter conventions.
    • Improved internal utilities to reliably deduce synthesized output parameters and reject ambiguous usages.
  • Tests

    • Added unit and static-error tests to validate new parameter-deduction behavior and to catch ambiguous call scenarios.

suggestion:

Walkthrough

The PR extends cuda_try to also synthesize outputs from last-parameter output pointers, adds type-traits to deduce last parameters, migrates several call sites to the value-returning form, and adds a static compile-time test that ensures ambiguous invocations are rejected.

Changes

Value-returning cuda_try with last-parameter support

Layer / File(s) Summary
Type-trait infrastructure for last-parameter deduction
cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh
Adds last_param, last_param_impl, function_last_param_impl templates to deduce a callable's last parameter type, introduces dependent_false for compile-time assertions, and expands unit tests to validate both first and last parameter type deduction across multiple function signatures.
Enhanced cuda_try with three-form dispatch and ambiguity detection
cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh
Reworks the cuda_try<fun> overload to dispatch among direct invocation, first-parameter output synthesis, and last-parameter output synthesis; adds a static_assert to reject ambiguous calls when both first and last forms apply for non-zero arguments; extends UNITTEST("cuda_try2") with explicit helpers validating both synthesis paths.
Call site migrations to value-returning pattern
cudax/include/cuda/experimental/__places/exec/green_context.cuh, cudax/include/cuda/experimental/__places/localized_array.cuh, cudax/include/cuda/experimental/__places/stream_pool.cuh, cudax/include/cuda/experimental/__stf/graph/graph_ctx.cuh
Updates get_cuda_context_id, peer-access capability checks, get_device_from_stream, get_stream_id, and uncached_graph_allocator to use value-returning cuda_try patterns instead of out-parameter style.
Ambiguity rejection test and CMake integration
cudax/test/stf/static_error_checks/cuda_try_ambiguous.cu, cudax/test/stf/static_error_checks/CMakeLists.txt
Adds a static-error-check test that verifies cuda_try rejects ambiguous cudaMemGetInfo invocations when a single output pointer is provided, and integrates the test into CMake to generate a compile-check ctest.

Suggested reviewers

  • bernhardmgruber
  • caugonnet
  • pciolkosz
  • alliepiper

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1


ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 9007af20-0f6c-4f91-b517-43493d544042

📥 Commits

Reviewing files that changed from the base of the PR and between 459e81a and 1f9cbd0.

📒 Files selected for processing (7)
  • cudax/include/cuda/experimental/__places/exec/green_context.cuh
  • cudax/include/cuda/experimental/__places/localized_array.cuh
  • cudax/include/cuda/experimental/__places/stream_pool.cuh
  • cudax/include/cuda/experimental/__stf/graph/graph_ctx.cuh
  • cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh
  • cudax/test/stf/static_error_checks/CMakeLists.txt
  • cudax/test/stf/static_error_checks/cuda_try_ambiguous.cu

Comment on lines +445 to +455
constexpr bool direct_form = ::std::is_invocable_v<decltype(fun), Ps...>;

constexpr bool first_output_form =
::std::is_pointer_v<reserved::first_param<fun>>
&& !::std::is_const_v<::std::remove_pointer_t<reserved::first_param<fun>>>
&& ::std::is_invocable_v<decltype(fun), ::std::remove_pointer_t<reserved::first_param<fun>>*, Ps...>;

constexpr bool last_output_form =
::std::is_pointer_v<reserved::last_param<fun>>
&& !::std::is_const_v<::std::remove_pointer_t<reserved::last_param<fun>>>
&& ::std::is_invocable_v<decltype(fun), Ps..., ::std::remove_pointer_t<reserved::last_param<fun>>*>;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Identify zero-argument cuda_try<...>() call sites that are most likely to hit eager param-trait instantiation.
# Expected: inspect matches and confirm whether called functions have zero parameters.
rg -nP 'cuda_try<[^>]+>\(\s*\)' --type=cu --type=cuh --type=h --type=hpp --type=cpp --type=cxx

Repository: NVIDIA/cccl

Length of output: 82


🏁 Script executed:

# First, let's understand the file and examine the problematic code section
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '440,460p'

Repository: NVIDIA/cccl

Length of output: 1266


🏁 Script executed:

# Search for cuda_try call sites with corrected pattern (use glob instead of file type)
rg 'cuda_try<[^>]+>\s*\(' --glob='*.{cu,cuh,h,hpp,cpp,cxx}' -n

Repository: NVIDIA/cccl

Length of output: 9845


🏁 Script executed:

# Examine the reserved::first_param and reserved::last_param trait implementations
rg -n 'struct first_param|struct last_param|template.*first_param|template.*last_param' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B2 -A8

Repository: NVIDIA/cccl

Length of output: 695


🏁 Script executed:

# Find the full trait and alias definitions
rg -n 'first_param\s*=' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -A2 -B2

Repository: NVIDIA/cccl

Length of output: 166


🏁 Script executed:

# Also search for last_param alias definition
rg -n 'last_param\s*=' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -A2 -B2

Repository: NVIDIA/cccl

Length of output: 193


🏁 Script executed:

# Check if there are any detection traits or SFINAE machinery
rg -n 'void_t|is_detected|_v<.*first_param|_v<.*last_param' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B2 -A2

Repository: NVIDIA/cccl

Length of output: 1333


🏁 Script executed:

# Look at the test file to see if zero-arg callables are tested
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '500,510p'

Repository: NVIDIA/cccl

Length of output: 531


🏁 Script executed:

# Check if there's a primary template for first_param_impl
rg -n 'struct first_param_impl' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -A5

Repository: NVIDIA/cccl

Length of output: 255


🏁 Script executed:

# Check if there's a primary template for last_param_impl
rg -n 'struct last_param_impl' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -A8

Repository: NVIDIA/cccl

Length of output: 465


🏁 Script executed:

# Check the test functions to see their signatures
rg -n 'test_first_output_param|test_last_output_param' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B5 -A5

Repository: NVIDIA/cccl

Length of output: 897


🏁 Script executed:

# Search for zero-parameter functions or void-returning functions that might trigger this
rg 'extern\s+\w+\s+\w+\s*\(\s*\)' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh

Repository: NVIDIA/cccl

Length of output: 37


🏁 Script executed:

# Check the function_last_param_impl specialization to see if it handles empty Ps
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '259,280p'

Repository: NVIDIA/cccl

Length of output: 781


🏁 Script executed:

# Try to understand if there's a fallback when a primary template isn't specialized
# Look for any SFINAE-like pattern or detection helpers
rg -n 'dependent_false|detected|void_t' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B3 -A3

Repository: NVIDIA/cccl

Length of output: 437


🏁 Script executed:

# Check if there are any truly zero-parameter functions being used with cuda_try
# Look at test helpers defined in the file
rg -n 'inline.*\(\s*\)' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B2 -A6

Repository: NVIDIA/cccl

Length of output: 37


🏁 Script executed:

# Try to compile a test case with a zero-param function to verify the issue
# First, let's see the full context of lines 486-496 for the test functions
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '475,510p'

Repository: NVIDIA/cccl

Length of output: 1236


🏁 Script executed:

# Let's check the actual signatures of the CUDA functions being called
# by examining if there are any comments or documentation about them
rg -n 'cudaGetDevice|cudaGetDeviceCount|cudaMemGetInfo' cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh -B2 -A2

Repository: NVIDIA/cccl

Length of output: 1572


🏁 Script executed:

# Let me create a minimal test to understand the actual behavior
# Check if there are any compile tests or error expectations
fd 'cuda_try.*test' --type f --glob '*.cu' --glob '*.cuh' | head -20

Repository: NVIDIA/cccl

Length of output: 199


🏁 Script executed:

# Look for the actual CUDA function declarations or their signatures
rg -n 'extern.*cudaGetDevice\|extern.*cudaGetDeviceCount' cudax/include -R --type=cu --type=h --type=cuh --type=hpp

Repository: NVIDIA/cccl

Length of output: 460


🏁 Script executed:

# Try a different approach: look at what happens if first_param or last_param can't be resolved
# Check if there's a primary (unspecialized) template and what error it would give
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '236,248p'

Repository: NVIDIA/cccl

Length of output: 357


🏁 Script executed:

# Let me check the actual function signatures of the CUDA API being called
# by looking at the documentation or examining if cudaGetDevice actually takes parameters
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '414,430p'

Repository: NVIDIA/cccl

Length of output: 1275


🏁 Script executed:

# Let's verify: does cudaGetDevice take a parameter? Check the call site at line 501
# The code says: int dev = cuda_try<cudaGetDevice>();
# This suggests it's being called in "output form" which means the function takes a pointer parameter
# Let's check if there's any documentation about this
cat -n cudax/include/cuda/experimental/__stf/utility/cuda_safe_call.cuh | sed -n '410,440p'

Repository: NVIDIA/cccl

Length of output: 2189


🏁 Script executed:

# Key question: are there ANY functions with zero parameters?
# Let's examine the comment that mentions ambiguity - it shows cudaMemGetInfo which takes 2 parameters
# Let me check if there's any real zero-parameter function in the CUDA API or test code
# by looking at what happens when last_param_impl<> is instantiated
python3 <<'EOF'
# Check: does last_param_impl have a primary template?
# If not, trying to instantiate last_param_impl<> (empty pack) would fail with:
# "error: implicit instantiation of undefined template 'last_param_impl<>'"

# This happens at line 453 when evaluating:
# ::std::is_pointer_v<reserved::last_param<fun>>
# for a zero-parameter function

# The issue is REAL: the trait instantiation happens eagerly in constexpr context
# before the condition can be short-circuited or reported as SFINAE failure

print("Analysis of the issue:")
print("1. For zero-parameter function signatures R (*)():")
print("   - first_param_impl<R (*)()> is undefined → hard error")
print("   - last_param_impl<> is undefined → hard error")
print()
print("2. Current code evaluates constexpr bool conditions eagerly:")
print("   - constexpr bool first_output_form = ::std::is_pointer_v<reserved::first_param<fun>>")
print("   - This instantiates first_param<fun>, which fails hard if first_param_impl is not specialized")
print()
print("3. The fallback static_assert at line 481 would never be reached")
print()
print("4. However, in the current codebase:")
print("   - All CUDA functions used have at least one parameter")
print("   - The test functions also have at least one parameter")
print("   - So this is a latent issue, not currently triggered")
EOF

Repository: NVIDIA/cccl

Length of output: 756


important: first_output_form and last_output_form eagerly instantiate reserved::first_param<fun> and reserved::last_param<fun>. For functions without parameters (e.g., zero-argument callables), these traits have no specialization and cause hard template errors before reaching the intended fallback static_assert at line 481. Make these checks SFINAE-safe using detection traits that resolve to false rather than hard_error on ill-formed parameter extraction, so diagnostics remain deterministic and the code can handle future zero-parameter function signatures.

@caugonnet
Copy link
Copy Markdown
Contributor

/ok to test cebc63c

@github-actions
Copy link
Copy Markdown
Contributor

🥳 CI Workflow Results

🟩 Finished in 1h 04m: Pass: 100%/55 | Total: 1d 02h | Max: 1h 04m | Hits: 10%/195214

See results here.

@caugonnet caugonnet merged commit eaba087 into main May 21, 2026
73 of 76 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stf Sequential Task Flow programming model

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

3 participants