Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix bug with intDiv for decimal arguments #59243

Merged
merged 13 commits into from Feb 19, 2024
134 changes: 114 additions & 20 deletions src/Functions/FunctionBinaryArithmetic.h
Expand Up @@ -146,10 +146,25 @@ struct BinaryOperationTraits

public:
static constexpr bool allow_decimal = IsOperation<Operation>::allow_decimal;
static constexpr bool only_integer = IsOperation<Operation>::div_int || IsOperation<Operation>::div_int_or_zero;

/// Appropriate result type for binary operator on numeric types. "Date" can also mean
/// DateTime, but if both operands are Dates, their type must be the same (e.g. Date - DateTime is invalid).
using ResultDataType = Switch<
/// Result must be Integer
Case<
only_integer && (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>),
Switch<
Case<
IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>,
Avogar marked this conversation as resolved.
Show resolved Hide resolved
Switch<
Case<IsIntegralOrExtended<LeftDataType>, LeftDataType>,
Case<IsIntegralOrExtended<RightDataType>, RightDataType>,
Avogar marked this conversation as resolved.
Show resolved Hide resolved
Case<std::is_same_v<LeftDataType, DataTypeDecimal256> || std::is_same_v<RightDataType, DataTypeDecimal256>, DataTypeInt256>,
Case<std::is_same_v<LeftDataType, DataTypeDecimal128> || std::is_same_v<RightDataType, DataTypeDecimal128>, DataTypeInt128>,
Case<std::is_same_v<LeftDataType, DataTypeDecimal64> || std::is_same_v<RightDataType, DataTypeDecimal64>, DataTypeInt64>,
Case<std::is_same_v<LeftDataType, DataTypeDecimal32> || std::is_same_v<RightDataType, DataTypeDecimal32>, DataTypeInt32>>>>>,

/// Decimal cases
Case<!allow_decimal && (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>), InvalidType>,
Case<
Expand Down Expand Up @@ -1667,31 +1682,102 @@ class FunctionBinaryArithmetic : public IFunction
{
if constexpr (IsDataTypeDecimal<LeftDataType> && IsDataTypeDecimal<RightDataType>)
{
if constexpr (is_division)
if constexpr (is_div_int || is_div_int_or_zero)
{
if constexpr (std::is_same_v<LeftDataType, DataTypeDecimal256> || std::is_same_v<RightDataType, DataTypeDecimal256>)
type_res = std::make_shared<DataTypeInt256>();
else if constexpr (std::is_same_v<LeftDataType, DataTypeDecimal128> || std::is_same_v<RightDataType, DataTypeDecimal128>)
type_res = std::make_shared<DataTypeInt128>();
else if constexpr (std::is_same_v<LeftDataType, DataTypeDecimal64> || std::is_same_v<RightDataType, DataTypeDecimal64>)
type_res = std::make_shared<DataTypeInt64>();
else
type_res = std::make_shared<DataTypeInt32>();
}
else
{
if (context->getSettingsRef().decimal_check_overflow)
if constexpr (is_division)
{
/// Check overflow by using operands scale (based on big decimal division implementation details):
/// big decimal arithmetic is based on big integers, decimal operands are converted to big integers
/// i.e. int_operand = decimal_operand*10^scale
/// For division, left operand will be scaled by right operand scale also to do big integer division,
/// BigInt result = left*10^(left_scale + right_scale) / right * 10^right_scale
/// So, we can check upfront possible overflow just by checking max scale used for left operand
/// Note: it doesn't detect all possible overflow during big decimal division
if (left.getScale() + right.getScale() > ResultDataType::maxPrecision())
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Overflow during decimal division");
if (context->getSettingsRef().decimal_check_overflow)
{
/// Check overflow by using operands scale (based on big decimal division implementation details):
/// big decimal arithmetic is based on big integers, decimal operands are converted to big integers
/// i.e. int_operand = decimal_operand*10^scale
/// For division, left operand will be scaled by right operand scale also to do big integer division,
/// BigInt result = left*10^(left_scale + right_scale) / right * 10^right_scale
/// So, we can check upfront possible overflow just by checking max scale used for left operand
/// Note: it doesn't detect all possible overflow during big decimal division
if (left.getScale() + right.getScale() > ResultDataType::maxPrecision())
throw Exception(ErrorCodes::DECIMAL_OVERFLOW, "Overflow during decimal division");
}
}
ResultDataType result_type = decimalResultType<is_multiply, is_division>(left, right);
type_res = std::make_shared<ResultDataType>(result_type.getPrecision(), result_type.getScale());
Avogar marked this conversation as resolved.
Show resolved Hide resolved
}
ResultDataType result_type = decimalResultType<is_multiply, is_division>(left, right);
type_res = std::make_shared<ResultDataType>(result_type.getPrecision(), result_type.getScale());
}
else if constexpr ((IsDataTypeDecimal<LeftDataType> && IsFloatingPoint<RightDataType>) ||
(IsDataTypeDecimal<RightDataType> && IsFloatingPoint<LeftDataType>))
type_res = std::make_shared<DataTypeFloat64>();
else if constexpr (((IsDataTypeDecimal<LeftDataType> && IsFloatingPoint<RightDataType>) ||
(IsDataTypeDecimal<RightDataType> && IsFloatingPoint<LeftDataType>)) && !(is_div_int || is_div_int_or_zero))
{
if constexpr ((is_div_int || is_div_int_or_zero) && IsDataTypeDecimal<LeftDataType>)
{
if constexpr (std::is_same_v<LeftDataType, DataTypeDecimal256>)
type_res = std::make_shared<DataTypeInt256>();
else if constexpr (std::is_same_v<LeftDataType, DataTypeDecimal128>)
type_res = std::make_shared<DataTypeInt128>();
else if constexpr (std::is_same_v<LeftDataType, DataTypeDecimal64> || std::is_same_v<RightDataType, DataTypeFloat64>)
type_res = std::make_shared<DataTypeInt64>();
else
type_res = std::make_shared<DataTypeInt32>();
}
else if constexpr (is_div_int || is_div_int_or_zero)
{
if constexpr (std::is_same_v<RightDataType, DataTypeDecimal256>)
type_res = std::make_shared<DataTypeInt256>();
else if constexpr (std::is_same_v<RightDataType, DataTypeDecimal128>)
type_res = std::make_shared<DataTypeInt128>();
else if constexpr (std::is_same_v<RightDataType, DataTypeDecimal64> || std::is_same_v<LeftDataType, DataTypeFloat64>)
type_res = std::make_shared<DataTypeInt64>();
else
type_res = std::make_shared<DataTypeInt32>();
}
else
type_res = std::make_shared<DataTypeFloat64>();
}
else if constexpr (IsDataTypeDecimal<LeftDataType>)
type_res = std::make_shared<LeftDataType>(left.getPrecision(), left.getScale());
{
if constexpr ((is_div_int || is_div_int_or_zero) && IsIntegralOrExtended<RightDataType>)
type_res = std::make_shared<RightDataType>();
else if constexpr (is_div_int || is_div_int_or_zero)
{
if constexpr (std::is_same_v<LeftDataType, DataTypeDecimal256>)
type_res = std::make_shared<DataTypeInt256>();
else if constexpr (std::is_same_v<LeftDataType, DataTypeDecimal128>)
type_res = std::make_shared<DataTypeInt128>();
else if constexpr (std::is_same_v<LeftDataType, DataTypeDecimal64>)
type_res = std::make_shared<DataTypeInt64>();
else
type_res = std::make_shared<DataTypeInt32>();
}
else
type_res = std::make_shared<LeftDataType>(left.getPrecision(), left.getScale());
}
else if constexpr (IsDataTypeDecimal<RightDataType>)
type_res = std::make_shared<RightDataType>(right.getPrecision(), right.getScale());
{
if constexpr ((is_div_int || is_div_int_or_zero) && IsIntegral<LeftDataType>)
type_res = std::make_shared<LeftDataType>();
else if constexpr (is_div_int || is_div_int_or_zero)
{
if constexpr (std::is_same_v<RightDataType, DataTypeDecimal256>)
type_res = std::make_shared<DataTypeInt256>();
else if constexpr (std::is_same_v<RightDataType, DataTypeDecimal128>)
type_res = std::make_shared<DataTypeInt128>();
else if constexpr (std::is_same_v<RightDataType, DataTypeDecimal64>)
type_res = std::make_shared<DataTypeInt64>();
else
type_res = std::make_shared<DataTypeInt32>();
}
else
type_res = std::make_shared<RightDataType>(right.getPrecision(), right.getScale());
}
else if constexpr (std::is_same_v<ResultDataType, DataTypeDateTime>)
{
// Special case for DateTime: binary OPS should reuse timezone
Expand Down Expand Up @@ -2009,8 +2095,10 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
constexpr bool decimal_with_float = (IsDataTypeDecimal<LeftDataType> && IsFloatingPoint<RightDataType>)
|| (IsFloatingPoint<LeftDataType> && IsDataTypeDecimal<RightDataType>);

using T0 = std::conditional_t<decimal_with_float, Float64, typename LeftDataType::FieldType>;
using T1 = std::conditional_t<decimal_with_float, Float64, typename RightDataType::FieldType>;
constexpr bool is_div_int_with_decimal = (is_div_int || is_div_int_or_zero) && (IsDataTypeDecimal<LeftDataType> || IsDataTypeDecimal<RightDataType>);

using T0 = std::conditional_t<decimal_with_float, Float64, std::conditional_t<is_div_int_with_decimal, Int64, typename LeftDataType::FieldType>>;
using T1 = std::conditional_t<decimal_with_float, Float64, std::conditional_t<is_div_int_with_decimal, Int64, typename RightDataType::FieldType>>;
using ResultType = typename ResultDataType::FieldType;
using ColVecT0 = ColumnVectorOrDecimal<T0>;
using ColVecT1 = ColumnVectorOrDecimal<T1>;
Expand All @@ -2026,6 +2114,12 @@ ColumnPtr executeStringInteger(const ColumnsWithTypeAndName & arguments, const A
left_col = castColumn(arguments[0], converted_type);
right_col = castColumn(arguments[1], converted_type);
}
else if constexpr (is_div_int_with_decimal)
{
const auto converted_type = std::make_shared<DataTypeInt64>();
left_col = castColumn(arguments[0], converted_type);
right_col = castColumn(arguments[1], converted_type);
}
Avogar marked this conversation as resolved.
Show resolved Hide resolved
else
{
left_col = arguments[0].column;
Expand Down
2 changes: 1 addition & 1 deletion src/Functions/IsOperation.h
Expand Up @@ -61,7 +61,7 @@ struct IsOperation
static constexpr bool bit_hamming_distance = IsSameOperation<Op, BitHammingDistanceImpl>::value;

static constexpr bool division = div_floating || div_int || div_int_or_zero || modulo;

// NOTE: allow_decimal should not fully contain `division` because of divInt
static constexpr bool allow_decimal = plus || minus || multiply || division || least || greatest;
};

Expand Down
14 changes: 7 additions & 7 deletions tests/queries/0_stateless/00700_decimal_arithm.reference
Expand Up @@ -10,18 +10,18 @@
63 21 -42 882 -882 2 0 2 0
63 21 -42 882 -882 2 0 2 0
1.00305798474369219219752355409390731264 -0.16305798474369219219752355409390731264 1.490591730234615865843651857942052864 -1.38847100762815390390123822295304634368 1.38847100762815390390123822295304634368 0.02 0.005
63.42 21.42 -41.58 890.82 -890.82 2.02 0.505 2.02 0.505
63.42 21.42 -41.58 890.82 -890.82 2.02 0.505 2.02 0.505
63.42 21.42 -41.58 890.82 -890.82 2.02 0.505 2.02 0.505
63.42 21.42 -41.58 890.82 -890.82 2.02 0.5 2.02 0.5
63.42 21.42 -41.58 890.82 -890.82 2.02 0.505 2 0
63.42 21.42 -41.58 890.82 -890.82 2.02 0.505 2 0
63.42 21.42 -41.58 890.82 -890.82 2.02 0.505 2 0
63.42 21.42 -41.58 890.82 -890.82 2.02 0.5 2 0
63 -21 42 882 -882 0 2 0 2
63 -21 42 882 -882 0 2 0 2
63 -21 42 882 -882 0 2 0 2
1.00305798474369219219752355409390731264 0.16305798474369219219752355409390731264 -1.490591730234615865843651857942052864 -1.38847100762815390390123822295304634368 1.38847100762815390390123822295304634368 -0.00000000000000000000000000000000000001 0.00000000000000000000000000000000000001
63.42 -21.42 41.58 890.82 -890.82 0.495 1.98 0.495 1.98
63.42 -21.42 41.58 890.82 -890.82 0.495 1.98 0 2
63.42 -21.42 41.58 890.82 -890.82
63.42 -21.42 41.58 890.82 -890.82 0.495049504950495049 1.980198019801980198 0.495049504950495049 1.980198019801980198
63.42 -21.42 41.58 890.82 -890.82 0.49 1.98 0.49 1.98
63.42 -21.42 41.58 890.82 -890.82 0.495049504950495049 1.980198019801980198 0 2
63.42 -21.42 41.58 890.82 -890.82 0.49 1.98 0 2
-42 42 42 42 0.42 0.42 0.42 42.42 42.42 42.42
0 0 0 0 0 0 0 0 0 0
42 -42 -42 -42 -0.42 -0.42 -0.42 -42.42 -42.42 -42.42
Expand Down
@@ -1,2 +1,2 @@
SELECT intDiv(9223372036854775807, 0.9998999834060669); -- { serverError 153 }
SELECT intDiv(9223372036854775807, 1.); -- { serverError 153 }
SELECT intDiv(18446744073709551615, 0.9998999834060669); -- { serverError 153 }
SELECT intDiv(18446744073709551615, 1.); -- { serverError 153 }
68 changes: 68 additions & 0 deletions tests/queries/0_stateless/02975_intdiv_with_decimal.reference
@@ -0,0 +1,68 @@
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
1
1
1
1
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
2
1
1
1
1
2
2
2
2
70 changes: 70 additions & 0 deletions tests/queries/0_stateless/02975_intdiv_with_decimal.sql
@@ -0,0 +1,70 @@
--intDiv--
SELECT intDiv(4,2);
SELECT intDiv(toDecimal32(4.4, 2), 2);
SELECT intDiv(4, toDecimal32(2.2, 2));
SELECT intDiv(toDecimal32(4.4, 2), 2);
SELECT intDiv(toDecimal32(4.4, 2), toDecimal32(2.2, 2));
SELECT intDiv(toDecimal64(4.4, 3), 2);
SELECT intDiv(toDecimal64(4.4, 3), toDecimal32(2.2, 2));
SELECT intDiv(toDecimal128(4.4, 4), 2);
SELECT intDiv(toDecimal128(4.4, 4), toDecimal32(2.2, 2));
SELECT intDiv(toDecimal256(4.4, 5), 2);
SELECT intDiv(toDecimal256(4.4, 5), toDecimal32(2.2, 2));
SELECT intDiv(4, toDecimal64(2.2, 2));
SELECT intDiv(toDecimal32(4.4, 2), toDecimal64(2.2, 2));
SELECT intDiv(4, toDecimal128(2.2, 3));
SELECT intDiv(toDecimal32(4.4, 2), toDecimal128(2.2, 3));
SELECT intDiv(4, toDecimal256(2.2, 4));
SELECT intDiv(toDecimal32(4.4, 2), toDecimal256(2.2, 4));
SELECT intDiv(toDecimal64(4.4, 2), toDecimal64(2.2, 2));
SELECT intDiv(toDecimal128(4.4, 2), toDecimal64(2.2, 2));
SELECT intDiv(toDecimal256(4.4, 2), toDecimal64(2.2, 2));
SELECT intDiv(toDecimal64(4.4, 2), toDecimal128(2.2, 2));
SELECT intDiv(toDecimal128(4.4, 2), toDecimal128(2.2, 2));
SELECT intDiv(toDecimal256(4.4, 2), toDecimal128(2.2, 2));
SELECT intDiv(toDecimal64(4.4, 2), toDecimal256(2.2, 2));
SELECT intDiv(toDecimal128(4.4, 2), toDecimal256(2.2, 2));
SELECT intDiv(toDecimal256(4.4, 2), toDecimal256(2.2, 2));
SELECT intDiv(4.2, toDecimal32(2.2, 2));
SELECT intDiv(4.2, toDecimal64(2.2, 2));
SELECT intDiv(4.2, toDecimal128(2.2, 2));
SELECT intDiv(4.2, toDecimal256(2.2, 2));
SELECT intDiv(toDecimal32(4.4, 2), 2.2);
SELECT intDiv(toDecimal64(4.4, 2), 2.2);
SELECT intDiv(toDecimal128(4.4, 2), 2.2);
SELECT intDiv(toDecimal256(4.4, 2), 2.2);
--intDivOrZero--
SELECT intDivOrZero(4,2);
SELECT intDivOrZero(toDecimal32(4.4, 2), 2);
SELECT intDivOrZero(4, toDecimal32(2.2, 2));
SELECT intDivOrZero(toDecimal32(4.4, 2), 2);
SELECT intDivOrZero(toDecimal32(4.4, 2), toDecimal32(2.2, 2));
SELECT intDivOrZero(toDecimal64(4.4, 3), 2);
SELECT intDivOrZero(toDecimal64(4.4, 3), toDecimal32(2.2, 2));
SELECT intDivOrZero(toDecimal128(4.4, 4), 2);
SELECT intDivOrZero(toDecimal128(4.4, 4), toDecimal32(2.2, 2));
SELECT intDivOrZero(toDecimal256(4.4, 5), 2);
SELECT intDivOrZero(toDecimal256(4.4, 5), toDecimal32(2.2, 2));
SELECT intDivOrZero(4, toDecimal64(2.2, 2));
SELECT intDivOrZero(toDecimal32(4.4, 2), toDecimal64(2.2, 2));
SELECT intDivOrZero(4, toDecimal128(2.2, 3));
SELECT intDivOrZero(toDecimal32(4.4, 2), toDecimal128(2.2, 3));
SELECT intDivOrZero(4, toDecimal256(2.2, 4));
SELECT intDivOrZero(toDecimal32(4.4, 2), toDecimal256(2.2, 4));
SELECT intDivOrZero(toDecimal64(4.4, 2), toDecimal64(2.2, 2));
SELECT intDivOrZero(toDecimal128(4.4, 2), toDecimal64(2.2, 2));
SELECT intDivOrZero(toDecimal256(4.4, 2), toDecimal64(2.2, 2));
SELECT intDivOrZero(toDecimal64(4.4, 2), toDecimal128(2.2, 2));
SELECT intDivOrZero(toDecimal128(4.4, 2), toDecimal128(2.2, 2));
SELECT intDivOrZero(toDecimal256(4.4, 2), toDecimal128(2.2, 2));
SELECT intDivOrZero(toDecimal64(4.4, 2), toDecimal256(2.2, 2));
SELECT intDivOrZero(toDecimal128(4.4, 2), toDecimal256(2.2, 2));
SELECT intDivOrZero(toDecimal256(4.4, 2), toDecimal256(2.2, 2));
SELECT intDivOrZero(4.2, toDecimal32(2.2, 2));
SELECT intDivOrZero(4.2, toDecimal64(2.2, 2));
SELECT intDivOrZero(4.2, toDecimal128(2.2, 2));
SELECT intDivOrZero(4.2, toDecimal256(2.2, 2));
SELECT intDivOrZero(toDecimal32(4.4, 2), 2.2);
SELECT intDivOrZero(toDecimal64(4.4, 2), 2.2);
SELECT intDivOrZero(toDecimal128(4.4, 2), 2.2);
SELECT intDivOrZero(toDecimal256(4.4, 2), 2.2);