Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions thrust/thrust/system/cuda/detail/partition.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# include <cub/util_device.cuh>
# include <cub/util_math.cuh>

# include <thrust/detail/raw_reference_cast.h>
# include <thrust/detail/temporary_array.h>
# include <thrust/partition.h>
# include <thrust/system/cuda/detail/cdp_dispatch.h>
Expand All @@ -30,9 +31,11 @@
# include <thrust/system/cuda/detail/uninitialized_copy.h>
# include <thrust/system/cuda/detail/util.h>

# include <cuda/__iterator/zip_iterator.h>
# include <cuda/std/__iterator/distance.h>
# include <cuda/std/__utility/pair.h>
# include <cuda/std/cstdint>
# include <cuda/std/tuple>

THRUST_NAMESPACE_BEGIN
namespace cuda_cub
Expand Down Expand Up @@ -365,13 +368,45 @@ stable_partition(execution_policy<Derived>& policy, Iterator first, Iterator las
return ret;
}

// Functor for the single-pass is_partitioned check.
// Returns true for an adjacent pair (a[i], a[i+1]) where pred(a[i]) is false
// and pred(a[i+1]) is true — i.e., a "false → true" transition that violates
// the partitioning invariant.
template <class Predicate>
struct __is_partitioned_fn
{
Predicate pred_;

template <class Tuple>
[[nodiscard]] _CCCL_HOST_DEVICE bool operator()(const Tuple& tuple) const
{
const bool lhs = pred_(thrust::raw_reference_cast(::cuda::std::get<0>(tuple)));
const bool rhs = pred_(thrust::raw_reference_cast(::cuda::std::get<1>(tuple)));
return !lhs && rhs;
}
};

// Single-pass implementation: zip adjacent elements and find any "false → true"
// transition. Two-pass (find_if_not + find_if) required two kernel launches;
// this approach uses one find_if over (a[i], a[i+1]) pairs, cutting kernel
// launch overhead roughly in half for typical inputs.
// See: https://github.com/NVIDIA/cccl/issues/8085
template <class Derived, class ItemsIt, class Predicate>
bool _CCCL_HOST_DEVICE
is_partitioned(execution_policy<Derived>& policy, ItemsIt first, ItemsIt last, Predicate predicate)
{
ItemsIt boundary = cuda_cub::find_if_not(policy, first, last, predicate);
ItemsIt end = cuda_cub::find_if(policy, boundary, last, predicate);
return end == last;
if (first == last)
{
return true;
}
// Build a range of adjacent pairs: (a[0],a[1]), (a[1],a[2]), ..., (a[n-2],a[n-1]).
// The distance of this zip range is min(n, n-1) = n-1 (via zip_iterator::operator-).
const auto first_zip = ::cuda::make_zip_iterator(first, first + 1);
const auto last_zip = ::cuda::make_zip_iterator(last, last);
const auto result = cuda_cub::find_if(policy, first_zip, last_zip, __is_partitioned_fn<Predicate>{predicate});
// Checking get<1>(result) == last (rather than result == last_zip) correctly
// handles the n==1 edge case where find_if_n returns first_zip (num_items==0).
return ::cuda::std::get<1>(result.__iterators()) == last;
}
} // namespace cuda_cub
THRUST_NAMESPACE_END
Expand Down