Skip to content

Commit

Permalink
ARROW-13617: [C++] Make Decimal representations consistent
Browse files Browse the repository at this point in the history
Factor out some basics of decimal representation and implementation in a generic base class.

Closes #12134 from pitrou/ARROW-13617-decimal-common

Authored-by: Antoine Pitrou <antoine@python.org>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
pitrou committed Jan 13, 2022
1 parent ab86daf commit d67a210
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 277 deletions.
36 changes: 13 additions & 23 deletions cpp/src/arrow/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -380,43 +380,33 @@ struct ARROW_EXPORT DurationScalar : public TemporalScalar<DurationType> {
: DurationScalar(std::move(value), duration(unit)) {}
};

struct ARROW_EXPORT Decimal128Scalar : public internal::PrimitiveScalarBase {
template <typename TYPE_CLASS, typename VALUE_TYPE>
struct ARROW_EXPORT DecimalScalar : public internal::PrimitiveScalarBase {
using internal::PrimitiveScalarBase::PrimitiveScalarBase;
using TypeClass = Decimal128Type;
using ValueType = Decimal128;
using TypeClass = TYPE_CLASS;
using ValueType = VALUE_TYPE;

Decimal128Scalar(Decimal128 value, std::shared_ptr<DataType> type)
DecimalScalar(ValueType value, std::shared_ptr<DataType> type)
: internal::PrimitiveScalarBase(std::move(type), true), value(value) {}

void* mutable_data() override {
return reinterpret_cast<void*>(value.mutable_native_endian_bytes());
}

util::string_view view() const override {
return util::string_view(reinterpret_cast<const char*>(value.native_endian_bytes()),
16);
ValueType::kByteWidth);
}

Decimal128 value;
ValueType value;
};

struct ARROW_EXPORT Decimal256Scalar : public internal::PrimitiveScalarBase {
using internal::PrimitiveScalarBase::PrimitiveScalarBase;
using TypeClass = Decimal256Type;
using ValueType = Decimal256;

Decimal256Scalar(Decimal256 value, std::shared_ptr<DataType> type)
: internal::PrimitiveScalarBase(std::move(type), true), value(value) {}

void* mutable_data() override {
return reinterpret_cast<void*>(value.mutable_native_endian_bytes());
}
util::string_view view() const override {
const std::array<uint64_t, 4>& bytes = value.native_endian_array();
return util::string_view(reinterpret_cast<const char*>(bytes.data()),
bytes.size() * sizeof(uint64_t));
}
struct ARROW_EXPORT Decimal128Scalar : public DecimalScalar<Decimal128Type, Decimal128> {
using DecimalScalar::DecimalScalar;
};

Decimal256 value;
struct ARROW_EXPORT Decimal256Scalar : public DecimalScalar<Decimal256Type, Decimal256> {
using DecimalScalar::DecimalScalar;
};

struct ARROW_EXPORT BaseListScalar : public Scalar {
Expand Down
136 changes: 48 additions & 88 deletions cpp/src/arrow/util/basic_decimal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@

namespace arrow {

using internal::AddWithOverflow;
using internal::SafeLeftShift;
using internal::SafeSignedAdd;
using internal::SafeSignedSubtract;
using internal::SubtractWithOverflow;

static const BasicDecimal128 ScaleMultipliers[] = {
BasicDecimal128(1LL),
Expand Down Expand Up @@ -368,43 +371,16 @@ static constexpr uint64_t kInt32Mask = 0xFFFFFFFF;
static constexpr BasicDecimal128 kMaxValue =
BasicDecimal128(5421010862427522170LL, 687399551400673280ULL - 1);

#if ARROW_LITTLE_ENDIAN
BasicDecimal128::BasicDecimal128(const uint8_t* bytes)
: BasicDecimal128(reinterpret_cast<const int64_t*>(bytes)[1],
reinterpret_cast<const uint64_t*>(bytes)[0]) {}
#else
BasicDecimal128::BasicDecimal128(const uint8_t* bytes)
: BasicDecimal128(reinterpret_cast<const int64_t*>(bytes)[0],
reinterpret_cast<const uint64_t*>(bytes)[1]) {}
#endif

constexpr int BasicDecimal128::kBitWidth;
constexpr int BasicDecimal128::kMaxPrecision;
constexpr int BasicDecimal128::kMaxScale;

std::array<uint8_t, 16> BasicDecimal128::ToBytes() const {
std::array<uint8_t, 16> out{{0}};
ToBytes(out.data());
return out;
}

void BasicDecimal128::ToBytes(uint8_t* out) const {
DCHECK_NE(out, nullptr);
#if ARROW_LITTLE_ENDIAN
reinterpret_cast<uint64_t*>(out)[0] = low_bits_;
reinterpret_cast<int64_t*>(out)[1] = high_bits_;
#else
reinterpret_cast<int64_t*>(out)[0] = high_bits_;
reinterpret_cast<uint64_t*>(out)[1] = low_bits_;
#endif
}

BasicDecimal128& BasicDecimal128::Negate() {
low_bits_ = ~low_bits_ + 1;
high_bits_ = ~high_bits_;
if (low_bits_ == 0) {
high_bits_ = SafeSignedAdd<int64_t>(high_bits_, 1);
uint64_t result_lo = ~low_bits() + 1;
int64_t result_hi = ~high_bits();
if (result_lo == 0) {
result_hi = SafeSignedAdd<int64_t>(result_hi, 1);
}
*this = BasicDecimal128(result_hi, result_lo);
return *this;
}

Expand All @@ -422,22 +398,18 @@ bool BasicDecimal128::FitsInPrecision(int32_t precision) const {
}

BasicDecimal128& BasicDecimal128::operator+=(const BasicDecimal128& right) {
const uint64_t sum = low_bits_ + right.low_bits_;
high_bits_ = SafeSignedAdd<int64_t>(high_bits_, right.high_bits_);
if (sum < low_bits_) {
high_bits_ = SafeSignedAdd<int64_t>(high_bits_, 1);
}
low_bits_ = sum;
int64_t result_hi = SafeSignedAdd(high_bits(), right.high_bits());
uint64_t result_lo = low_bits() + right.low_bits();
result_hi = SafeSignedAdd<int64_t>(result_hi, result_lo < low_bits());
*this = BasicDecimal128(result_hi, result_lo);
return *this;
}

BasicDecimal128& BasicDecimal128::operator-=(const BasicDecimal128& right) {
const uint64_t diff = low_bits_ - right.low_bits_;
high_bits_ -= right.high_bits_;
if (diff > low_bits_) {
--high_bits_;
}
low_bits_ = diff;
int64_t result_hi = SafeSignedSubtract(high_bits(), right.high_bits());
uint64_t result_lo = low_bits() - right.low_bits();
result_hi = SafeSignedSubtract<int64_t>(result_hi, result_lo > low_bits());
*this = BasicDecimal128(result_hi, result_lo);
return *this;
}

Expand All @@ -449,47 +421,53 @@ BasicDecimal128& BasicDecimal128::operator/=(const BasicDecimal128& right) {
}

BasicDecimal128& BasicDecimal128::operator|=(const BasicDecimal128& right) {
low_bits_ |= right.low_bits_;
high_bits_ |= right.high_bits_;
array_[0] |= right.array_[0];
array_[1] |= right.array_[1];
return *this;
}

BasicDecimal128& BasicDecimal128::operator&=(const BasicDecimal128& right) {
low_bits_ &= right.low_bits_;
high_bits_ &= right.high_bits_;
array_[0] &= right.array_[0];
array_[1] &= right.array_[1];
return *this;
}

BasicDecimal128& BasicDecimal128::operator<<=(uint32_t bits) {
if (bits != 0) {
uint64_t result_lo;
int64_t result_hi;
if (bits < 64) {
high_bits_ = SafeLeftShift(high_bits_, bits);
high_bits_ |= (low_bits_ >> (64 - bits));
low_bits_ <<= bits;
result_hi = SafeLeftShift(high_bits(), bits);
result_hi |= (low_bits() >> (64 - bits));
result_lo = low_bits() << bits;
} else if (bits < 128) {
high_bits_ = static_cast<int64_t>(low_bits_) << (bits - 64);
low_bits_ = 0;
result_hi = static_cast<int64_t>(low_bits() << (bits - 64));
result_lo = 0;
} else {
high_bits_ = 0;
low_bits_ = 0;
result_hi = 0;
result_lo = 0;
}
*this = BasicDecimal128(result_hi, result_lo);
}
return *this;
}

BasicDecimal128& BasicDecimal128::operator>>=(uint32_t bits) {
if (bits != 0) {
uint64_t result_lo;
int64_t result_hi;
if (bits < 64) {
low_bits_ >>= bits;
low_bits_ |= static_cast<uint64_t>(high_bits_ << (64 - bits));
high_bits_ = static_cast<int64_t>(static_cast<uint64_t>(high_bits_) >> bits);
result_lo = low_bits() >> bits;
result_lo |= static_cast<uint64_t>(high_bits()) << (64 - bits);
result_hi = high_bits() >> bits;
} else if (bits < 128) {
low_bits_ = static_cast<uint64_t>(high_bits_ >> (bits - 64));
high_bits_ = static_cast<int64_t>(high_bits_ >= 0L ? 0L : -1L);
result_lo = static_cast<uint64_t>(high_bits() >> (bits - 64));
result_hi = high_bits() >> 63;
} else {
high_bits_ = static_cast<int64_t>(high_bits_ >= 0L ? 0L : -1L);
low_bits_ = static_cast<uint64_t>(high_bits_);
result_hi = high_bits() >> 63;
result_lo = static_cast<uint64_t>(result_hi);
}
*this = BasicDecimal128(result_hi, result_lo);
}
return *this;
}
Expand Down Expand Up @@ -633,8 +611,7 @@ BasicDecimal128& BasicDecimal128::operator*=(const BasicDecimal128& right) {
BasicDecimal128 y = BasicDecimal128::Abs(right);
uint128_t r(x);
r *= uint128_t{y};
high_bits_ = r.hi();
low_bits_ = r.lo();
*this = BasicDecimal128(static_cast<int64_t>(r.hi()), r.lo());
if (negate) {
Negate();
}
Expand Down Expand Up @@ -1158,20 +1135,13 @@ BasicDecimal128 BasicDecimal128::ReduceScaleBy(int32_t reduce_by, bool round) co
int32_t BasicDecimal128::CountLeadingBinaryZeros() const {
DCHECK_GE(*this, BasicDecimal128(0));

if (high_bits_ == 0) {
return bit_util::CountLeadingZeros(low_bits_) + 64;
if (high_bits() == 0) {
return bit_util::CountLeadingZeros(low_bits()) + 64;
} else {
return bit_util::CountLeadingZeros(static_cast<uint64_t>(high_bits_));
return bit_util::CountLeadingZeros(static_cast<uint64_t>(high_bits()));
}
}

BasicDecimal256::BasicDecimal256(const uint8_t* bytes)
: array_({reinterpret_cast<const uint64_t*>(bytes)[0],
reinterpret_cast<const uint64_t*>(bytes)[1],
reinterpret_cast<const uint64_t*>(bytes)[2],
reinterpret_cast<const uint64_t*>(bytes)[3]}) {}

constexpr int BasicDecimal256::kBitWidth;
constexpr int BasicDecimal256::kMaxPrecision;
constexpr int BasicDecimal256::kMaxScale;

Expand Down Expand Up @@ -1243,20 +1213,6 @@ BasicDecimal256& BasicDecimal256::operator<<=(uint32_t bits) {
return *this;
}

std::array<uint8_t, 32> BasicDecimal256::ToBytes() const {
std::array<uint8_t, 32> out{{0}};
ToBytes(out.data());
return out;
}

void BasicDecimal256::ToBytes(uint8_t* out) const {
DCHECK_NE(out, nullptr);
reinterpret_cast<uint64_t*>(out)[0] = array_[0];
reinterpret_cast<uint64_t*>(out)[1] = array_[1];
reinterpret_cast<uint64_t*>(out)[2] = array_[2];
reinterpret_cast<uint64_t*>(out)[3] = array_[3];
}

BasicDecimal256& BasicDecimal256::operator*=(const BasicDecimal256& right) {
// Since the max value of BasicDecimal256 is supposed to be 1e76 - 1 and the
// min the negation taking the absolute values here should always be safe.
Expand Down Expand Up @@ -1391,4 +1347,8 @@ BasicDecimal256 operator/(const BasicDecimal256& left, const BasicDecimal256& ri
return result;
}

// Explicitly instantiate template base class, for DLL linking on Windows
template class GenericBasicDecimal<BasicDecimal128, 128>;
template class GenericBasicDecimal<BasicDecimal256, 256>;

} // namespace arrow
Loading

0 comments on commit d67a210

Please sign in to comment.