Skip to content

Commit

Permalink
Review feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
rok committed Apr 5, 2022
1 parent ab7ee6d commit 3171e2e
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 49 deletions.
24 changes: 16 additions & 8 deletions cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,45 +214,53 @@ std::shared_ptr<DataType> CommonNumeric(const ValueDescr* begin, size_t count) {
return int8();
}

TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count) {
TimeUnit::type finest_unit = TimeUnit::SECOND;
bool CommonTemporalResolution(const ValueDescr* begin, size_t count,
TimeUnit::type* finest_unit) {
bool is_time_unit = false;
*finest_unit = TimeUnit::SECOND;
const ValueDescr* end = begin + count;
for (auto it = begin; it != end; it++) {
auto id = it->type->id();
switch (id) {
case Type::DATE32: {
// Date32's unit is days, but the coarsest we have is seconds
is_time_unit = true;
continue;
}
case Type::DATE64: {
finest_unit = std::max(finest_unit, TimeUnit::MILLI);
*finest_unit = std::max(*finest_unit, TimeUnit::MILLI);
is_time_unit = true;
continue;
}
case Type::TIMESTAMP: {
const auto& ty = checked_cast<const TimestampType&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
*finest_unit = std::max(*finest_unit, ty.unit());
is_time_unit = true;
continue;
}
case Type::DURATION: {
const auto& ty = checked_cast<const DurationType&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
*finest_unit = std::max(*finest_unit, ty.unit());
is_time_unit = true;
continue;
}
case Type::TIME32: {
const auto& ty = checked_cast<const Time32Type&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
*finest_unit = std::max(*finest_unit, ty.unit());
is_time_unit = true;
continue;
}
case Type::TIME64: {
const auto& ty = checked_cast<const Time64Type&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
*finest_unit = std::max(*finest_unit, ty.unit());
is_time_unit = true;
continue;
}
default:
continue;
}
}
return finest_unit;
return is_time_unit;
}

std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count) {
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/compute/kernels/codegen_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -1394,7 +1394,8 @@ ARROW_EXPORT
std::shared_ptr<DataType> CommonTemporal(const ValueDescr* begin, size_t count);

ARROW_EXPORT
TimeUnit::type CommonTemporalResolution(const ValueDescr* begin, size_t count);
bool CommonTemporalResolution(const ValueDescr* begin, size_t count,
TimeUnit::type* finest_unit);

ARROW_EXPORT
std::shared_ptr<DataType> CommonBinary(const ValueDescr* begin, size_t count);
Expand Down
96 changes: 61 additions & 35 deletions cpp/src/arrow/compute/kernels/codegen_internal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,55 +163,80 @@ TEST(TestDispatchBest, CommonTemporal) {
TEST(TestDispatchBest, CommonTemporalResolution) {
std::vector<ValueDescr> args;
std::string tz = "Pacific/Marquesas";
TimeUnit::type ty;

args = {date32(), date32()};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::SECOND, ty);
args = {date32(), date64()};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MILLI, ty);
args = {date32(), timestamp(TimeUnit::SECOND)};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::SECOND, ty);
args = {time32(TimeUnit::MILLI), date32()};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MILLI, ty);
args = {time32(TimeUnit::MILLI), time32(TimeUnit::SECOND)};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MILLI, ty);
args = {time32(TimeUnit::MILLI), time64(TimeUnit::MICRO)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MICRO, ty);
args = {time64(TimeUnit::NANO), time64(TimeUnit::MICRO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::NANO, ty);
args = {duration(TimeUnit::MILLI), duration(TimeUnit::MICRO)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MICRO, ty);
args = {duration(TimeUnit::MILLI), date32()};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MILLI, ty);
args = {date64(), duration(TimeUnit::SECOND)};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MILLI, ty);
args = {duration(TimeUnit::SECOND), time32(TimeUnit::SECOND)};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::SECOND, ty);
args = {duration(TimeUnit::SECOND), time64(TimeUnit::NANO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::NANO, ty);
args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::NANO, ty);
args = {timestamp(TimeUnit::SECOND, tz), timestamp(TimeUnit::MICRO)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MICRO, ty);
args = {date32(), timestamp(TimeUnit::MICRO, tz)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MICRO, ty);
args = {timestamp(TimeUnit::MICRO, tz), date64()};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MICRO, ty);
args = {time32(TimeUnit::MILLI), timestamp(TimeUnit::MICRO, tz)};
ASSERT_EQ(TimeUnit::MICRO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MICRO, ty);
args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::NANO, ty);
args = {timestamp(TimeUnit::SECOND, tz), duration(TimeUnit::MILLI)};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MILLI, ty);
args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::SECOND, ty);
args = {time32(TimeUnit::MILLI), duration(TimeUnit::SECOND)};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MILLI, ty);
args = {time64(TimeUnit::MICRO), duration(TimeUnit::NANO)};
ASSERT_EQ(TimeUnit::NANO, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::NANO, ty);
args = {duration(TimeUnit::SECOND), int64()};
ASSERT_EQ(TimeUnit::SECOND, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::SECOND, ty);
args = {duration(TimeUnit::MILLI), timestamp(TimeUnit::SECOND, tz)};
ASSERT_EQ(TimeUnit::MILLI, CommonTemporalResolution(args.data(), args.size()));
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ASSERT_EQ(TimeUnit::MILLI, ty);
}

TEST(TestDispatchBest, ReplaceTemporalTypes) {
Expand All @@ -220,67 +245,68 @@ TEST(TestDispatchBest, ReplaceTemporalTypes) {
TimeUnit::type ty;

args = {date32(), date32()};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND));

args = {date64(), time32(TimeUnit::SECOND)};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI));
AssertTypeEqual(args[1].type, time32(TimeUnit::MILLI));

args = {duration(TimeUnit::SECOND), date64()};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, duration(TimeUnit::MILLI));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI));

args = {timestamp(TimeUnit::MICRO, tz), timestamp(TimeUnit::NANO)};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::NANO));

args = {timestamp(TimeUnit::MICRO, tz), time64(TimeUnit::NANO)};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::NANO, tz));
AssertTypeEqual(args[1].type, time64(TimeUnit::NANO));

args = {timestamp(TimeUnit::SECOND, tz), date64()};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::MILLI, tz));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::MILLI));

args = {timestamp(TimeUnit::SECOND, "UTC"), timestamp(TimeUnit::SECOND, tz)};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, timestamp(TimeUnit::SECOND, "UTC"));
AssertTypeEqual(args[1].type, timestamp(TimeUnit::SECOND, tz));

args = {time32(TimeUnit::SECOND), duration(TimeUnit::SECOND)};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, time32(TimeUnit::SECOND));
AssertTypeEqual(args[1].type, duration(TimeUnit::SECOND));

args = {time64(TimeUnit::MICRO), duration(TimeUnit::SECOND)};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, time64(TimeUnit::MICRO));
AssertTypeEqual(args[1].type, duration(TimeUnit::MICRO));

args = {time32(TimeUnit::SECOND), duration(TimeUnit::NANO)};
ty = CommonTemporalResolution(args.data(), args.size());
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, time64(TimeUnit::NANO));
AssertTypeEqual(args[1].type, duration(TimeUnit::NANO));

args = {duration(TimeUnit::SECOND), int64()};
ReplaceTemporalTypes(CommonTemporalResolution(args.data(), args.size()), &args);
ASSERT_TRUE(CommonTemporalResolution(args.data(), args.size(), &ty));
ReplaceTemporalTypes(ty, &args);
AssertTypeEqual(args[0].type, duration(TimeUnit::SECOND));
AssertTypeEqual(args[1].type, int64());
}
Expand Down
8 changes: 3 additions & 5 deletions cpp/src/arrow/compute/kernels/scalar_arithmetic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1818,13 +1818,11 @@ struct ArithmeticFunction : ScalarFunction {
// Only promote types for binary functions
if (values->size() == 2) {
ReplaceNullWithOtherType(values);
auto type = CommonTemporalResolution(values->data(), values->size());
if (type) {
ReplaceTemporalTypes(type, values);
TimeUnit::type finest_unit;
if (CommonTemporalResolution(values->data(), values->size(), &finest_unit)) {
ReplaceTemporalTypes(finest_unit, values);
} else if (auto numeric_type = CommonNumeric(*values)) {
ReplaceTypes(numeric_type, values);
} else if (type == TimeUnit::SECOND) {
ReplaceTemporalTypes(type, values);
}
}

Expand Down

0 comments on commit 3171e2e

Please sign in to comment.