-
Notifications
You must be signed in to change notification settings - Fork 333
Support fancy iterators in vectorized transform and port thrust::tabulate to it #6012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
6a14dbf
9644377
5c48e92
e1d6ed9
ce02b53
6543c30
2fbf093
c13c974
c2cb534
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -282,21 +282,45 @@ _CCCL_HOST_DEVICE constexpr int arch_to_min_bytes_in_flight(int sm_arch) | |
| return 12 * 1024; // V100 and below | ||
| } | ||
|
|
||
| template <typename T, typename... Ts> | ||
| _CCCL_HOST_DEVICE constexpr bool all_equal([[maybe_unused]] T head, Ts... tail) | ||
| template <typename H, typename... Ts> | ||
| _CCCL_HOST_DEVICE constexpr bool all_nonzero_equal(H head, Ts... values) | ||
| { | ||
| return ((head == tail) && ...); | ||
| size_t first = 0; | ||
| for (size_t v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...}) | ||
| { | ||
| if (v == 0) | ||
| { | ||
| continue; | ||
| } | ||
| if (first == 0) | ||
| { | ||
| first = v; | ||
| } | ||
| else if (v != first) | ||
| { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| } | ||
|
|
||
| _CCCL_HOST_DEVICE constexpr bool all_equal() | ||
| _CCCL_HOST_DEVICE constexpr bool all_nonzero_equal() | ||
| { | ||
| return true; | ||
| } | ||
|
|
||
| template <typename T, typename... Ts> | ||
| _CCCL_HOST_DEVICE constexpr auto first_item(T head, Ts...) -> T | ||
| template <typename H, typename... Ts> | ||
| _CCCL_HOST_DEVICE constexpr auto first_nonzero_value(H head, Ts... values) | ||
| { | ||
| return head; | ||
| for (auto v : ::cuda::std::array<H, 1 + sizeof...(Ts)>{head, values...}) | ||
| { | ||
| if (v != 0) | ||
| { | ||
| return v; | ||
| } | ||
| } | ||
| // we only reach here when all input are not contiguous and the output has a void value type | ||
| return H{1}; | ||
| } | ||
|
|
||
| template <typename T> | ||
|
|
@@ -336,25 +360,36 @@ struct policy_hub<RequiresStableAddress, | |
| (THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn> && ...); | ||
| static constexpr bool all_input_values_trivially_reloc = | ||
| (THRUST_NS_QUALIFIER::is_trivially_relocatable_v<it_value_t<RandomAccessIteratorsIn>> && ...); | ||
| static constexpr bool can_memcpy_inputs = all_inputs_contiguous && all_input_values_trivially_reloc; | ||
| static constexpr bool can_memcpy_all_inputs = all_inputs_contiguous && all_input_values_trivially_reloc; | ||
| // the vectorized kernel supports mixing contiguous and non-contiguous iterators | ||
| static constexpr bool can_memcpy_contiguous_inputs = | ||
| ((!THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn> | ||
| || THRUST_NS_QUALIFIER::is_trivially_relocatable_v<it_value_t<RandomAccessIteratorsIn>>) | ||
| && ...); | ||
|
Comment on lines
+365
to
+368
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should have a trait for that, I guess it will come up more often |
||
|
|
||
| // for vectorized policy: | ||
| static constexpr bool all_input_values_same_size = all_equal(sizeof(it_value_t<RandomAccessIteratorsIn>)...); | ||
| static constexpr int load_store_word_size = 8; // TODO(bgruber): make this 16, and 32 on Blackwell+ | ||
| // if there are no inputs, we take the size of the output value | ||
| static constexpr int value_type_size = | ||
| first_item(int{sizeof(it_value_t<RandomAccessIteratorsIn>)}..., int{size_of<it_value_t<RandomAccessIteratorOut>>}); | ||
| static constexpr bool all_contiguous_input_values_same_size = all_nonzero_equal( | ||
| (sizeof(it_value_t<RandomAccessIteratorsIn>) | ||
| * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ...); | ||
| static constexpr int load_store_word_size = 8; // TODO(bgruber): make this 16, and 32 on Blackwell+ | ||
| // find the value type size of the first contiguous iterator. if there are no inputs, we take the size of the output | ||
| // value type | ||
| static constexpr int contiguous_value_type_size = first_nonzero_value( | ||
| (int{sizeof(it_value_t<RandomAccessIteratorsIn>)} | ||
| * THRUST_NS_QUALIFIER::is_contiguous_iterator_v<RandomAccessIteratorsIn>) ..., | ||
| int{size_of<it_value_t<RandomAccessIteratorOut>>}); | ||
|
Comment on lines
+378
to
+380
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🙀 I believe those warrant a slightly more elaborate comment
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this needs a bigger refactoring in general |
||
| static constexpr bool value_type_divides_load_store_size = | ||
| load_store_word_size % value_type_size == 0; // implicitly checks that value_type_size <= load_store_word_size | ||
| load_store_word_size % contiguous_value_type_size == 0; // implicitly checks that value_type_size <= | ||
| // load_store_word_size | ||
| static constexpr int target_bytes_per_thread = | ||
| no_input_streams ? 16 /* by experiment on RTX 5090 */ : 32 /* guestimate by gevtushenko for loading */; | ||
| static constexpr int items_per_thread_vec = | ||
| ::cuda::round_up(target_bytes_per_thread, load_store_word_size) / value_type_size; | ||
| ::cuda::round_up(target_bytes_per_thread, load_store_word_size) / contiguous_value_type_size; | ||
| using default_vectorized_policy_t = vectorized_policy_t<256, items_per_thread_vec, load_store_word_size>; | ||
|
|
||
| static constexpr bool fallback_to_prefetch = | ||
| RequiresStableAddress || !can_memcpy_inputs || !all_input_values_same_size || !value_type_divides_load_store_size | ||
| || !DenseOutput; | ||
| RequiresStableAddress || !can_memcpy_contiguous_inputs || !all_contiguous_input_values_same_size | ||
| || !value_type_divides_load_store_size || !DenseOutput; | ||
|
|
||
| // TODO(bgruber): consider a separate kernel for just filling | ||
|
|
||
|
|
@@ -380,7 +415,7 @@ struct policy_hub<RequiresStableAddress, | |
| block_threads* async_policy::min_items_per_thread, | ||
| ldgsts_size_and_align) | ||
| > int{max_smem_per_block}; | ||
| static constexpr bool fallback_to_vectorized = exhaust_smem || no_input_streams; | ||
| static constexpr bool fallback_to_vectorized = exhaust_smem || no_input_streams || !can_memcpy_all_inputs; | ||
|
|
||
| public: | ||
| static constexpr int min_bif = arch_to_min_bytes_in_flight(800); | ||
|
|
@@ -421,7 +456,8 @@ struct policy_hub<RequiresStableAddress, | |
| (((int{sizeof(it_value_t<RandomAccessIteratorsIn>)} * AsyncBlockSize) % max_alignment == 0) && ...); | ||
| static constexpr bool enough_threads_for_peeling = AsyncBlockSize >= alignment; // head and tail bytes | ||
| static constexpr bool fallback_to_vectorized = | ||
| exhaust_smem || !tile_sizes_retain_alignment || !enough_threads_for_peeling || no_input_streams; | ||
| exhaust_smem || !tile_sizes_retain_alignment || !enough_threads_for_peeling || no_input_streams | ||
| || !can_memcpy_all_inputs; | ||
|
|
||
| public: | ||
| static constexpr int min_bif = arch_to_min_bytes_in_flight(PtxVersion); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should really pull that out into a function
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am strongly considering to refactor the entire mess, so let's postpone any small fixes for now.