Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix __half build for compute < 53
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed Oct 19, 2021
1 parent 7ba073c commit d03cd6b
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 109 deletions.
54 changes: 39 additions & 15 deletions cub/agent/agent_sub_warp_merge_sort.cuh
Expand Up @@ -104,6 +104,42 @@ template <bool IS_DESCENDING,
typename OffsetT>
class AgentSubWarpSort
{
struct BinaryOpT
{
template <typename T>
__device__ bool operator()(T lhs, T rhs)
{
if (IS_DESCENDING)
{
return lhs > rhs;
}
else
{
return lhs < rhs;
}
}

#if defined(__CUDA_FP16_TYPES_EXIST__) && (CUB_PTX_ARCH < 530)
__device__ bool operator()(__half lhs, __half rhs)
{
return (*this)(static_cast<float>(lhs), static_cast<float>(rhs));
}
#endif
};

#if defined(__CUDA_FP16_TYPES_EXIST__) && (CUB_PTX_ARCH < 530)
__device__ static bool equal(__half lhs, __half rhs)
{
return static_cast<float>(lhs) == static_cast<float>(rhs);
}
#endif

template <typename T>
__device__ static bool equal(T lhs, T rhs)
{
return lhs == rhs;
}

public:
static constexpr bool KEYS_ONLY = std::is_same<ValueT, cub::NullType>::value;

Expand Down Expand Up @@ -161,18 +197,6 @@ public:
ItemsLoadItT values_input,
ValueT *values_output)
{
auto binary_op = [] (KeyT lhs, KeyT rhs) -> bool
{
if (IS_DESCENDING)
{
return lhs > rhs;
}
else
{
return lhs < rhs;
}
};

WarpMergeSortT warp_merge_sort(storage.sort);

if (segment_size < 3)
Expand All @@ -183,7 +207,7 @@ public:
keys_output,
values_input,
values_output,
binary_op);
BinaryOpT{});
}
else
{
Expand Down Expand Up @@ -211,7 +235,7 @@ public:
WARP_SYNC(warp_merge_sort.get_member_mask());
}

warp_merge_sort.Sort(keys, values, binary_op, segment_size, oob_default);
warp_merge_sort.Sort(keys, values, BinaryOpT{}, segment_size, oob_default);
WARP_SYNC(warp_merge_sort.get_member_mask());

WarpStoreKeysT(storage.store_keys).Store(keys_output, keys, segment_size);
Expand Down Expand Up @@ -264,7 +288,7 @@ private:
KeyT lhs = keys_input[0];
KeyT rhs = keys_input[1];

if (lhs == rhs || binary_op(lhs, rhs))
if (equal(lhs, rhs) || binary_op(lhs, rhs))
{
keys_output[0] = lhs;
keys_output[1] = rhs;
Expand Down

0 comments on commit d03cd6b

Please sign in to comment.