Skip to content

Commit

Permalink
ARROW-6778: [C++] Support cast for DurationType
Browse files Browse the repository at this point in the history
https://issues.apache.org/jira/browse/ARROW-6778

Closes #5578 from jorisvandenbossche/ARROW-6778-cast-duration and squashes the following commits:

21bdacb <Joris Van den Bossche> add divide tests
2a9f398 <Joris Van den Bossche> enable test
920af3c <Joris Van den Bossche> DurationType is primitive
5cfefc2 <Joris Van den Bossche> add duration <-> duration cast
519f309 <Joris Van den Bossche> ARROW-6778:  Support zero copy cast for DurationType

Authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
jorisvandenbossche authored and pitrou committed Oct 9, 2019
1 parent d80899b commit 1b02af6
Show file tree
Hide file tree
Showing 6 changed files with 104 additions and 7 deletions.
19 changes: 14 additions & 5 deletions cpp/src/arrow/compute/kernels/cast.cc
Expand Up @@ -478,13 +478,19 @@ const std::pair<bool, int64_t> kTimeConversionTable[4][4] = {

} // namespace

template <>
struct CastFunctor<TimestampType, TimestampType> {
// <TimestampType, TimestampType> and <DurationType, DurationType>
template <typename O, typename I>
struct CastFunctor<
O, I,
typename std::enable_if<(std::is_base_of<O, TimestampType>::value &&
std::is_base_of<I, TimestampType>::value) ||
(std::is_base_of<O, DurationType>::value &&
std::is_base_of<I, DurationType>::value)>::type> {
void operator()(FunctionContext* ctx, const CastOptions& options,
const ArrayData& input, ArrayData* output) {
// If units are the same, zero copy, otherwise convert
const auto& in_type = checked_cast<const TimestampType&>(*input.type);
const auto& out_type = checked_cast<const TimestampType&>(*output->type);
const auto& in_type = checked_cast<const I&>(*input.type);
const auto& out_type = checked_cast<const O&>(*output->type);

if (in_type.unit() == out_type.unit()) {
ZeroCopyData(input, output);
Expand Down Expand Up @@ -1195,6 +1201,7 @@ GET_CAST_FUNCTION(DATE64_CASES, Date64Type)
GET_CAST_FUNCTION(TIME32_CASES, Time32Type)
GET_CAST_FUNCTION(TIME64_CASES, Time64Type)
GET_CAST_FUNCTION(TIMESTAMP_CASES, TimestampType)
GET_CAST_FUNCTION(DURATION_CASES, DurationType)
GET_CAST_FUNCTION(BINARY_CASES, BinaryType)
GET_CAST_FUNCTION(STRING_CASES, StringType)
GET_CAST_FUNCTION(LARGEBINARY_CASES, LargeBinaryType)
Expand Down Expand Up @@ -1233,13 +1240,14 @@ inline bool IsZeroCopyCast(Type::type in_type, Type::type out_type) {
return (out_type == Type::DATE32) || (out_type == Type::TIME32);
case Type::INT64:
return ((out_type == Type::DATE64) || (out_type == Type::TIME64) ||
(out_type == Type::TIMESTAMP));
(out_type == Type::TIMESTAMP) || (out_type == Type::DURATION));
case Type::DATE32:
case Type::TIME32:
return out_type == Type::INT32;
case Type::DATE64:
case Type::TIME64:
case Type::TIMESTAMP:
case Type::DURATION:
return out_type == Type::INT64;
default:
break;
Expand Down Expand Up @@ -1281,6 +1289,7 @@ Status GetCastFunction(const DataType& in_type, std::shared_ptr<DataType> out_ty
CAST_FUNCTION_CASE(Time32Type);
CAST_FUNCTION_CASE(Time64Type);
CAST_FUNCTION_CASE(TimestampType);
CAST_FUNCTION_CASE(DurationType);
CAST_FUNCTION_CASE(BinaryType);
CAST_FUNCTION_CASE(StringType);
CAST_FUNCTION_CASE(LargeBinaryType);
Expand Down
81 changes: 81 additions & 0 deletions cpp/src/arrow/compute/kernels/cast_test.cc
Expand Up @@ -886,6 +886,86 @@ TEST_F(TestCast, DateToCompatible) {
CheckFails<Date64Type>(date64(), v8, is_valid, date32(), options);
}

TEST_F(TestCast, DurationToCompatible) {
CastOptions options;

auto CheckDurationCast =
[this](const CastOptions& options, TimeUnit::type from_unit, TimeUnit::type to_unit,
const std::vector<int64_t>& from_values,
const std::vector<int64_t>& to_values, const std::vector<bool>& is_valid) {
CheckCase<DurationType, int64_t, DurationType, int64_t>(
duration(from_unit), from_values, is_valid, duration(to_unit), to_values,
options);
};

std::vector<bool> is_valid = {true, false, true, true, true};

// Multiply promotions
std::vector<int64_t> v1 = {0, 100, 200, 1, 2};
std::vector<int64_t> e1 = {0, 100000, 200000, 1000, 2000};
CheckDurationCast(options, TimeUnit::SECOND, TimeUnit::MILLI, v1, e1, is_valid);

std::vector<int64_t> v2 = {0, 100, 200, 1, 2};
std::vector<int64_t> e2 = {0, 100000000L, 200000000L, 1000000, 2000000};
CheckDurationCast(options, TimeUnit::SECOND, TimeUnit::MICRO, v2, e2, is_valid);

std::vector<int64_t> v3 = {0, 100, 200, 1, 2};
std::vector<int64_t> e3 = {0, 100000000000L, 200000000000L, 1000000000L, 2000000000L};
CheckDurationCast(options, TimeUnit::SECOND, TimeUnit::NANO, v3, e3, is_valid);

std::vector<int64_t> v4 = {0, 100, 200, 1, 2};
std::vector<int64_t> e4 = {0, 100000, 200000, 1000, 2000};
CheckDurationCast(options, TimeUnit::MILLI, TimeUnit::MICRO, v4, e4, is_valid);

std::vector<int64_t> v5 = {0, 100, 200, 1, 2};
std::vector<int64_t> e5 = {0, 100000000L, 200000000L, 1000000, 2000000};
CheckDurationCast(options, TimeUnit::MILLI, TimeUnit::NANO, v5, e5, is_valid);

std::vector<int64_t> v6 = {0, 100, 200, 1, 2};
std::vector<int64_t> e6 = {0, 100000, 200000, 1000, 2000};
CheckDurationCast(options, TimeUnit::MICRO, TimeUnit::NANO, v6, e6, is_valid);

// Zero copy
std::vector<int64_t> v7 = {0, 70000, 2000, 1000, 0};
std::shared_ptr<Array> arr;
ArrayFromVector<DurationType, int64_t>(duration(TimeUnit::SECOND), is_valid, v7, &arr);
CheckZeroCopy(*arr, duration(TimeUnit::SECOND));
CheckZeroCopy(*arr, int64());

// Divide, truncate
std::vector<int64_t> v8 = {0, 100123, 200456, 1123, 2456};
std::vector<int64_t> e8 = {0, 100, 200, 1, 2};

options.allow_time_truncate = true;
CheckDurationCast(options, TimeUnit::MILLI, TimeUnit::SECOND, v8, e8, is_valid);
CheckDurationCast(options, TimeUnit::MICRO, TimeUnit::MILLI, v8, e8, is_valid);
CheckDurationCast(options, TimeUnit::NANO, TimeUnit::MICRO, v8, e8, is_valid);

std::vector<int64_t> v9 = {0, 100123000, 200456000, 1123000, 2456000};
std::vector<int64_t> e9 = {0, 100, 200, 1, 2};
CheckDurationCast(options, TimeUnit::MICRO, TimeUnit::SECOND, v9, e9, is_valid);
CheckDurationCast(options, TimeUnit::NANO, TimeUnit::MILLI, v9, e9, is_valid);

std::vector<int64_t> v10 = {0, 100123000000L, 200456000000L, 1123000000L, 2456000000};
std::vector<int64_t> e10 = {0, 100, 200, 1, 2};
CheckDurationCast(options, TimeUnit::NANO, TimeUnit::SECOND, v10, e10, is_valid);

// Disallow truncate, failures
options.allow_time_truncate = false;
CheckFails<DurationType>(duration(TimeUnit::MILLI), v8, is_valid,
duration(TimeUnit::SECOND), options);
CheckFails<DurationType>(duration(TimeUnit::MICRO), v8, is_valid,
duration(TimeUnit::MILLI), options);
CheckFails<DurationType>(duration(TimeUnit::NANO), v8, is_valid,
duration(TimeUnit::MICRO), options);
CheckFails<DurationType>(duration(TimeUnit::MICRO), v9, is_valid,
duration(TimeUnit::SECOND), options);
CheckFails<DurationType>(duration(TimeUnit::NANO), v9, is_valid,
duration(TimeUnit::MILLI), options);
CheckFails<DurationType>(duration(TimeUnit::NANO), v10, is_valid,
duration(TimeUnit::SECOND), options);
}

TEST_F(TestCast, ToDouble) {
CastOptions options;
std::vector<bool> is_valid = {true, false, true, true, true};
Expand Down Expand Up @@ -968,6 +1048,7 @@ TEST_F(TestCast, DateTimeZeroCopy) {
CheckZeroCopy(*arr, time64(TimeUnit::MICRO));
CheckZeroCopy(*arr, date64());
CheckZeroCopy(*arr, timestamp(TimeUnit::NANO));
CheckZeroCopy(*arr, duration(TimeUnit::MILLI));
}

TEST_F(TestCast, PreallocatedMemory) {
Expand Down
Expand Up @@ -190,6 +190,9 @@
TEMPLATE(TimestampType, Date64Type) \
TEMPLATE(TimestampType, TimestampType)

#define DURATION_CASES(TEMPLATE) \
TEMPLATE(DurationType, DurationType)

#define BINARY_CASES(TEMPLATE) \
TEMPLATE(BinaryType, StringType)

Expand Down Expand Up @@ -240,6 +243,7 @@
TEMPLATE(DictionaryType, Time32Type) \
TEMPLATE(DictionaryType, Time64Type) \
TEMPLATE(DictionaryType, TimestampType) \
TEMPLATE(DictionaryType, DurationType) \
TEMPLATE(DictionaryType, NullType) \
TEMPLATE(DictionaryType, BinaryType) \
TEMPLATE(DictionaryType, FixedSizeBinaryType) \
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/arrow/compute/kernels/generated/codegen.py
Expand Up @@ -30,7 +30,8 @@
NUMERIC_TYPES = ['Boolean'] + INTEGER_TYPES + FLOATING_TYPES
STRING_TYPES = ['String', 'LargeString']

DATE_TIME_TYPES = ['Date32', 'Date64', 'Time32', 'Time64', 'Timestamp']
DATE_TIME_TYPES = ['Date32', 'Date64', 'Time32', 'Time64', 'Timestamp',
'Duration']


def _format_type(name):
Expand Down Expand Up @@ -85,6 +86,7 @@ def generate(self):
parametric=True),
CastCodeGenerator('Timestamp', ['Date32', 'Date64', 'Timestamp'],
parametric=True),
CastCodeGenerator('Duration', ['Duration'], parametric=True),
CastCodeGenerator('Binary', ['String']),
CastCodeGenerator('LargeBinary', ['LargeString']),
CastCodeGenerator('String', NUMERIC_TYPES + ['Timestamp']),
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/type_traits.h
Expand Up @@ -154,6 +154,7 @@ struct TypeTraits<DurationType> {
using ArrayType = DurationArray;
using BuilderType = DurationBuilder;
using ScalarType = DurationScalar;
using CType = DurationType::c_type;

static constexpr int64_t bytes_required(int64_t elements) {
return elements * static_cast<int64_t>(sizeof(int64_t));
Expand Down Expand Up @@ -623,6 +624,7 @@ static inline bool is_primitive(Type::type type_id) {
case Type::TIME32:
case Type::TIME64:
case Type::TIMESTAMP:
case Type::DURATION:
case Type::INTERVAL:
return true;
default:
Expand Down
1 change: 0 additions & 1 deletion python/pyarrow/tests/test_array.py
Expand Up @@ -947,7 +947,6 @@ def test_cast_date32_to_int():
assert result2.equals(arr)


@pytest.mark.xfail(strict=True) # TODO implement duration cast
def test_cast_duration_to_int():
arr = pa.array(np.array([0, 1, 2], dtype='int64'),
type=pa.duration('us'))
Expand Down

0 comments on commit 1b02af6

Please sign in to comment.