Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Continue optimizing branch miss of if function when result type is float*/decimal*/int* #59148

Merged
merged 14 commits into from Jan 31, 2024
172 changes: 133 additions & 39 deletions src/Functions/if.cpp
Expand Up @@ -42,12 +42,32 @@ using namespace GatherUtils;
/** Selection function by condition: if(cond, then, else).
* cond - UInt8
* then, else - numeric types for which there is a general type, or dates, datetimes, or strings, or arrays of these types.
* For better performance, try to use branch free code for numeric types(i.e. cond ? a : b --> !!cond * a + !cond * b), except floating point types because of Inf or NaN.
* For better performance, try to use branch free code for numeric types(i.e. cond ? a : b --> !!cond * a + !cond * b)
*/

template <typename ResultType>
taiyang-li marked this conversation as resolved.
Show resolved Hide resolved
concept is_native_int_or_decimal_v
= std::is_integral_v<ResultType> || (is_decimal<ResultType> && sizeof(ResultType) <= 8);

// This macro performs a branch-free conditional assignment for floating point types.
// It uses bitwise operations to avoid branching, which can be beneficial for performance.
#define BRANCHFREE_IF_FLOAT(TYPE, vc, va, vb, vr) \
using UIntType = typename NumberTraits::Construct<false, false, sizeof(TYPE)>::Type; \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interestingly, we have a similar approach in AggregateFunctionSum:

static_assert(sizeof(Value) == 4 || sizeof(Value) == 8);
using equivalent_integer = typename std::conditional_t<sizeof(Value) == 4, UInt32, UInt64>;
constexpr size_t unroll_count = 128 / sizeof(T);
T partial_sums[unroll_count]{};
const auto * unrolled_end = ptr + (count / unroll_count * unroll_count);
while (ptr < unrolled_end)
{
for (size_t i = 0; i < unroll_count; ++i)
{
equivalent_integer value;
std::memcpy(&value, &ptr[i], sizeof(Value));
value &= (!condition_map[i] != add_if_zero) - 1;
Value d;
std::memcpy(&d, &value, sizeof(Value));

using IntType = typename NumberTraits::Construct<true, false, sizeof(TYPE)>::Type; \
auto mask = static_cast<UIntType>(static_cast<IntType>(vc) - 1); \
auto new_a = static_cast<ResultType>(va); \
auto new_b = static_cast<ResultType>(vb); \
UIntType uint_a; \
std::memcpy(&uint_a, &new_a, sizeof(UIntType)); \
UIntType uint_b; \
std::memcpy(&uint_b, &new_b, sizeof(UIntType)); \
UIntType tmp = (~mask & uint_a) | (mask & uint_b); \
(vr) = *(reinterpret_cast<ResultType *>(&tmp));

template <typename ArrayCond, typename ArrayA, typename ArrayB, typename ArrayResult, typename ResultType>
inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const ArrayB & b, ArrayResult & res)
{

size_t size = cond.size();
bool a_is_short = a.size() < size;
bool b_is_short = b.size() < size;
Expand All @@ -57,47 +77,68 @@ inline void fillVectorVector(const ArrayCond & cond, const ArrayA & a, const Arr
size_t a_index = 0, b_index = 0;
for (size_t i = 0; i < size; ++i)
{
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
a_index += !!cond[i];
b_index += !cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[b_index], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[b_index++]);
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b[b_index]);

a_index += !!cond[i];
b_index += !cond[i];
}
}
else if (a_is_short)
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b[i]);
a_index += !!cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b[i], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b[i]);
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b[i]);

a_index += !!cond[i];
}
}
else if (b_is_short)
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[b_index], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[b_index++]);
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[b_index]);

b_index += !cond[i];
}
}
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b[i]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b[i], res[i])
}
else
{
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b[i]);
}
}
}
}

Expand All @@ -110,21 +151,32 @@ inline void fillVectorConstant(const ArrayCond & cond, const ArrayA & a, B b, Ar
{
size_t a_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[a_index]) + (!cond[i]) * static_cast<ResultType>(b);
a_index += !!cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[a_index], b, res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[a_index++]) : static_cast<ResultType>(b);
res[i] = cond[i] ? static_cast<ResultType>(a[a_index]) : static_cast<ResultType>(b);

a_index += !!cond[i];
}
}
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a[i]) + (!cond[i]) * static_cast<ResultType>(b);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a[i], b, res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a[i]) : static_cast<ResultType>(b);
}
}
}

Expand All @@ -137,21 +189,68 @@ inline void fillConstantVector(const ArrayCond & cond, A a, const ArrayB & b, Ar
{
size_t b_index = 0;
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[b_index]);
b_index += !cond[i];
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[b_index], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[b_index++]);
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[b_index]);

b_index += !cond[i];
}
}
else
{
for (size_t i = 0; i < size; ++i)
if constexpr (std::is_integral_v<ResultType>)
{
if constexpr (is_native_int_or_decimal_v<ResultType>)
res[i] = !!cond[i] * static_cast<ResultType>(a) + (!cond[i]) * static_cast<ResultType>(b[i]);
else if constexpr (std::is_floating_point_v<ResultType>)
{
BRANCHFREE_IF_FLOAT(ResultType, cond[i], a, b[i], res[i])
}
else
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b[i]);
}
}
}

template <typename ArrayCond, typename A, typename B, typename ArrayResult, typename ResultType>
inline void fillConstantConstant(const ArrayCond & cond, A a, B b, ArrayResult & res)
{
size_t size = cond.size();

/// Int8(alias type of uint8_t) has special aliasing properties that prevents compiler from auto-vectorizing for below codes, refer to https://gist.github.com/alexei-zaripov/dcc14c78819c5f1354afe8b70932007c
///
/// for (size_t i = 0; i < size; ++i)
/// res[i] = cond[i] ? static_cast<Int8>(a) : static_cast<Int8>(b);
///
/// Therefore, we manually optimize it by avoiding branch miss when ResultType is Int8. Other types like (U)Int128|256 or Decimal128/256 also benefit from this optimization.
if constexpr (std::is_same_v<ResultType, Int8> || is_over_big_int<ResultType>)
taiyang-li marked this conversation as resolved.
Show resolved Hide resolved
{
alignas(64) const ResultType ab[2] = {static_cast<ResultType>(a), static_cast<ResultType>(b)};
for (size_t i = 0; i < size; ++i)
{
res[i] = ab[!cond[i]];
}
}
else if constexpr (std::is_same_v<ResultType, Decimal32> || std::is_same_v<ResultType, Decimal64>)
{
ResultType new_a = static_cast<ResultType>(a);
ResultType new_b = static_cast<ResultType>(b);
for (size_t i = 0; i < size; ++i)
{
/// Reuse new_a and new_b to achieve auto-vectorization
res[i] = cond[i] ? new_a : new_b;
}
}
else
{
for (size_t i = 0; i < size; ++i)
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
}
}

Expand Down Expand Up @@ -197,8 +296,7 @@ struct NumIfImpl
auto col_res = ColVecResult::create(size);
ArrayResult & res = col_res->getData();

for (size_t i = 0; i < size; ++i)
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
fillConstantConstant<ArrayCond, A, B, ArrayResult, ResultType>(cond, a, b, res);
return col_res;
}
};
Expand Down Expand Up @@ -247,8 +345,7 @@ struct NumIfImpl<Decimal<A>, Decimal<B>, Decimal<R>>
auto col_res = ColVecResult::create(size, scale);
ArrayResult & res = col_res->getData();

for (size_t i = 0; i < size; ++i)
res[i] = cond[i] ? static_cast<ResultType>(a) : static_cast<ResultType>(b);
fillConstantConstant<ArrayCond, A, B, ArrayResult, ResultType>(cond, a, b, res);
return col_res;
}
};
Expand Down Expand Up @@ -1112,17 +1209,12 @@ class FunctionIf : public FunctionIfBase

if (cond_const_col)
{
if (arg_then.type->equals(*arg_else.type))
{
return cond_const_col->getValue<UInt8>()
? arg_then.column
: arg_else.column;
}
UInt8 value = cond_const_col->getValue<UInt8>();
taiyang-li marked this conversation as resolved.
Show resolved Hide resolved
const ColumnWithTypeAndName & arg = value ? arg_then : arg_else;
if (arg.type->equals(*result_type))
return arg.column;
else
{
materialized_cond_col = cond_const_col->convertToFullColumn();
cond_col = typeid_cast<const ColumnUInt8 *>(&*materialized_cond_col);
}
return castColumn(arg, result_type);
}

if (!cond_col)
Expand Down Expand Up @@ -1159,6 +1251,8 @@ class FunctionIf : public FunctionIfBase
TypeIndex left_id = left_type->getTypeId();
TypeIndex right_id = right_type->getTypeId();

/// TODO optimize for map type
/// TODO optimize for nullable type
if (!(callOnBasicTypes<true, true, true, false>(left_id, right_id, call)
|| (res = executeTyped<UUID, UUID>(cond_col, arguments, result_type, input_rows_count))
|| (res = executeString(cond_col, arguments, result_type))
Expand Down
30 changes: 21 additions & 9 deletions tests/performance/if.xml
@@ -1,12 +1,24 @@
<test>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() > 42949673, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 3865470566, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 2147483647, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, 2)) ]]></query>

<!-- Tests when branches are both not constant -->
<query>with rand32() % 2 as x select if(x, materialize(1.234), materialize(2.456)) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1.234::Decimal64(3) as a, 2.456::Decimal64(3) as b select if(x, materialize(a), materialize(b)) from numbers(100000000) format Null</query>

<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() > 42949673, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 3865470566, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 2147483647, zero + 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, zero + 2)) ]]></query>

<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, zero + 1, 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, zero + 2)) ]]></query>
<query><![CDATA[ SELECT count() FROM zeros(1000000000) WHERE NOT ignore(if(rand32() < 42949673, 1, 2)) ]]></query>

<!-- Tests when branches are both constant -->
<query>with rand32() % 2 as x, 1::Int8 as a, -1::Int8 as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Int64 as a, -1::Int64 as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Int32 as a, -1::Int32 as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Decimal32(3) as a, -1::Decimal32(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Decimal64(3) as a, -1::Decimal64(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Decimal128(3) as a, -1::Decimal128(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Decimal256(3) as a, -1::Decimal256(3) as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Int128 as a, -1::Int128 as b select if(x, a, b) from numbers(100000000) format Null</query>
<query>with rand32() % 2 as x, 1::Int256 as a, -1::Int256 as b select if(x, a, b) from numbers(100000000) format Null</query>
</test>