Skip to content

Commit

Permalink
ARROW-7335: [C++][Gandiva] Add day_time_interval functions: castBIGIN…
Browse files Browse the repository at this point in the history
…T, extractDay

Function Signatures:
```
int64 castBIGINT(day_time_interval)
int64 extractDay(day_time_interval)
```

Closes #5980 from pprudhvi/dayinterval and squashes the following commits:

0ee77bb <Prudhvi Porandla> tabs to spaces in proto
47fa185 <Prudhvi Porandla> make day_time_interval signed
1c9f63a <Prudhvi Porandla> handle negatives
c7d2ec3 <Prudhvi Porandla> fix switch-case, inline
12f6b5e <Prudhvi Porandla> add projector test
9651403 <Prudhvi Porandla> interval day functions

Authored-by: Prudhvi Porandla <prudhvi.porandla@icloud.com>
Signed-off-by: Praveen <praveen@dremio.com>
  • Loading branch information
pprudhvi authored and praveenbingo committed Mar 9, 2020
1 parent af45b92 commit 116672f
Show file tree
Hide file tree
Showing 13 changed files with 193 additions and 5 deletions.
4 changes: 4 additions & 0 deletions cpp/src/gandiva/expression_registry.cc
Expand Up @@ -162,6 +162,10 @@ void ExpressionRegistry::AddArrowTypesToVector(arrow::Type::type& type,
case arrow::Type::type::DECIMAL:
vector.push_back(arrow::decimal(38, 0));
break;
case arrow::Type::type::INTERVAL:
vector.push_back(arrow::day_time_interval());
vector.push_back(arrow::month_interval());
break;
default:
// Unsupported types. test ensures that
// when one of these are added build breaks.
Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/function_registry_common.h
Expand Up @@ -36,6 +36,7 @@ namespace gandiva {
using arrow::binary;
using arrow::boolean;
using arrow::date64;
using arrow::day_time_interval;
using arrow::float32;
using arrow::float64;
using arrow::int16;
Expand Down
9 changes: 8 additions & 1 deletion cpp/src/gandiva/function_registry_datetime.cc
Expand Up @@ -67,7 +67,14 @@ std::vector<NativeFunction> GetDateTimeFunctionRegistry() {
NativeFunction::kNeedsFunctionHolder |
NativeFunction::kCanReturnErrors),
NativeFunction("castTIMESTAMP", {}, DataTypeVector{date64()}, timestamp(),
kResultNullIfNull, "castTIMESTAMP_date64")};
kResultNullIfNull, "castTIMESTAMP_date64"),

NativeFunction("castBIGINT", {}, DataTypeVector{day_time_interval()}, int64(),
kResultNullIfNull, "castBIGINT_daytimeinterval"),

NativeFunction("extractDay", {}, DataTypeVector{day_time_interval()}, int64(),
kResultNullIfNull, "extractDay_daytimeinterval"),
};

return date_time_fn_registry_;
}
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/gandiva/jni/expression_registry_helper.cc
Expand Up @@ -41,6 +41,17 @@ types::TimeUnit MapTimeUnit(arrow::TimeUnit::type& unit) {
return types::TimeUnit::SEC;
}

types::IntervalType MapIntervalType(arrow::IntervalType::type& type) {
switch (type) {
case arrow::IntervalType::MONTHS:
return types::IntervalType::YEAR_MONTH;
case arrow::IntervalType::DAY_TIME:
return types::IntervalType::DAY_TIME;
}
// satifsy gcc. should be unreachable.
return types::IntervalType::DAY_TIME;
}

void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type) {
switch (type->id()) {
case arrow::Type::type::BOOL:
Expand Down Expand Up @@ -127,6 +138,15 @@ void ArrowToProtobuf(DataTypePtr type, types::ExtGandivaType* gandiva_data_type)
gandiva_data_type->set_scale(0);
break;
}
case arrow::Type::type::INTERVAL: {
gandiva_data_type->set_type(types::GandivaType::INTERVAL);
std::shared_ptr<arrow::IntervalType> cast_interval_type =
std::dynamic_pointer_cast<arrow::IntervalType>(type);
arrow::IntervalType::type type = cast_interval_type->interval_type();
types::IntervalType interval_type = MapIntervalType(type);
gandiva_data_type->set_intervaltype(interval_type);
break;
}
default:
// un-supported types. test ensures that
// when one of these are added build breaks.
Expand Down
16 changes: 14 additions & 2 deletions cpp/src/gandiva/jni/jni_common.cc
Expand Up @@ -169,6 +169,18 @@ DataTypePtr ProtoTypeToTimestamp(const types::ExtGandivaType& ext_type) {
}
}

DataTypePtr ProtoTypeToInterval(const types::ExtGandivaType& ext_type) {
switch (ext_type.intervaltype()) {
case types::YEAR_MONTH:
return arrow::month_interval();
case types::DAY_TIME:
return arrow::day_time_interval();
default:
std::cerr << "Unknown interval type: " << ext_type.intervaltype() << "\n";
return nullptr;
}
}

DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) {
switch (ext_type.type()) {
case types::NONE:
Expand Down Expand Up @@ -214,9 +226,9 @@ DataTypePtr ProtoTypeToDataType(const types::ExtGandivaType& ext_type) {
return ProtoTypeToTime64(ext_type);
case types::TIMESTAMP:
return ProtoTypeToTimestamp(ext_type);

case types::FIXED_SIZE_BINARY:
case types::INTERVAL:
return ProtoTypeToInterval(ext_type);
case types::FIXED_SIZE_BINARY:
case types::LIST:
case types::STRUCT:
case types::UNION:
Expand Down
1 change: 1 addition & 0 deletions cpp/src/gandiva/llvm_types.cc
Expand Up @@ -41,6 +41,7 @@ LLVMTypes::LLVMTypes(llvm::LLVMContext& context) : context_(context) {
{arrow::Type::type::STRING, i8_ptr_type()},
{arrow::Type::type::BINARY, i8_ptr_type()},
{arrow::Type::type::DECIMAL, i128_type()},
{arrow::Type::type::INTERVAL, i64_type()},
};
}

Expand Down
19 changes: 19 additions & 0 deletions cpp/src/gandiva/precompiled/time.cc
Expand Up @@ -738,4 +738,23 @@ const char* castVARCHAR_timestamp_int64(gdv_int64 context, gdv_timestamp in,
memcpy(ret, char_buffer, *out_len);
return ret;
}

FORCE_INLINE
gdv_int64 extractDay_daytimeinterval(gdv_day_time_interval in) {
gdv_int32 days = static_cast<gdv_int32>(in & 0x00000000FFFFFFFF);
return static_cast<gdv_int64>(days);
}

FORCE_INLINE
gdv_int64 extractMillis_daytimeinterval(gdv_day_time_interval in) {
gdv_int32 millis = static_cast<gdv_int32>((in & 0xFFFFFFFF00000000) >> 32);
return static_cast<gdv_int64>(millis);
}

FORCE_INLINE
gdv_int64 castBIGINT_daytimeinterval(gdv_day_time_interval in) {
return extractMillis_daytimeinterval(in) +
extractDay_daytimeinterval(in) * MILLIS_IN_DAY;
}

} // extern "C"
1 change: 1 addition & 0 deletions cpp/src/gandiva/precompiled/types.h
Expand Up @@ -38,6 +38,7 @@ using gdv_time32 = int32_t;
using gdv_timestamp = int64_t;
using gdv_utf8 = char*;
using gdv_binary = char*;
using gdv_day_time_interval = int64_t;

#ifdef GANDIVA_UNIT_TEST
// unit tests may be compiled without O2, so inlining may not happen.
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/gandiva/proto/Types.proto
Expand Up @@ -65,6 +65,11 @@ enum TimeUnit {
NANOSEC = 3;
}

enum IntervalType {
YEAR_MONTH = 0;
DAY_TIME = 1;
}

enum SelectionVectorType {
SV_NONE = 0;
SV_INT16 = 1;
Expand All @@ -79,6 +84,7 @@ message ExtGandivaType {
optional DateUnit dateUnit = 5; // used by DATE32/DATE64
optional TimeUnit timeUnit = 6; // used by TIME32/TIME64
optional string timeZone = 7; // used by TIMESTAMP
optional IntervalType intervalType = 8; // used by INTERVAL
}

message Field {
Expand Down
Expand Up @@ -28,6 +28,7 @@
import org.apache.arrow.gandiva.ipc.GandivaTypes.GandivaType;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.FloatingPointPrecision;
import org.apache.arrow.vector.types.IntervalUnit;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;

Expand Down Expand Up @@ -175,9 +176,10 @@ private static ArrowType getArrowType(ExtGandivaType type) {
return new ArrowType.Null();
case GandivaType.DECIMAL_VALUE:
return new ArrowType.Decimal(0,0);
case GandivaType.INTERVAL_VALUE:
return new ArrowType.Interval(mapArrowIntervalUnit(type.getIntervalType()));
case GandivaType.FIXED_SIZE_BINARY_VALUE:
case GandivaType.MAP_VALUE:
case GandivaType.INTERVAL_VALUE:
case GandivaType.DICTIONARY_VALUE:
case GandivaType.LIST_VALUE:
case GandivaType.STRUCT_VALUE:
Expand All @@ -202,5 +204,17 @@ private static TimeUnit mapArrowTimeUnit(GandivaTypes.TimeUnit timeUnit) {
return null;
}
}

private static IntervalUnit mapArrowIntervalUnit(GandivaTypes.IntervalType intervalType) {
switch (intervalType.getNumber()) {
case GandivaTypes.IntervalType.YEAR_MONTH_VALUE:
return IntervalUnit.YEAR_MONTH;
case GandivaTypes.IntervalType.DAY_TIME_VALUE:
return IntervalUnit.DAY_TIME;
default:
return null;
}
}

}

Expand Up @@ -18,6 +18,7 @@
package org.apache.arrow.gandiva.expression;

import org.apache.arrow.flatbuf.DateUnit;
import org.apache.arrow.flatbuf.IntervalUnit;
import org.apache.arrow.flatbuf.TimeUnit;
import org.apache.arrow.flatbuf.Type;
import org.apache.arrow.gandiva.exceptions.GandivaException;
Expand Down Expand Up @@ -202,6 +203,26 @@ private static void initArrowTypeTimestamp(ArrowType.Timestamp timestampType,
}
}

private static void initArrowTypeInterval(ArrowType.Interval interval,
GandivaTypes.ExtGandivaType.Builder builder) {
short intervalUnit = interval.getUnit().getFlatbufID();
switch (intervalUnit) {
case IntervalUnit.YEAR_MONTH: {
builder.setType(GandivaTypes.GandivaType.INTERVAL);
builder.setIntervalType(GandivaTypes.IntervalType.YEAR_MONTH);
break;
}
case IntervalUnit.DAY_TIME: {
builder.setType(GandivaTypes.GandivaType.INTERVAL);
builder.setIntervalType(GandivaTypes.IntervalType.DAY_TIME);
break;
}
default: {
// not supported
}
}
}

/**
* Converts an arrow type into a protobuf.
*
Expand Down Expand Up @@ -259,6 +280,7 @@ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowTyp
break;
}
case Type.Interval: { // 11
ArrowTypeHelper.initArrowTypeInterval((ArrowType.Interval) arrowType, builder);
break;
}
case Type.List: { // 12
Expand Down Expand Up @@ -287,7 +309,7 @@ public static GandivaTypes.ExtGandivaType arrowTypeToProtobuf(ArrowType arrowTyp
if (!builder.hasType()) {
// type has not been set
// throw an exception
throw new UnsupportedTypeException("Unsupported type" + arrowType.toString());
throw new UnsupportedTypeException("Unsupported type " + arrowType.toString());
}

return builder.build();
Expand Down
Expand Up @@ -282,6 +282,15 @@ ArrowBuf stringToMillis(String[] dates) {
return buffer;
}

ArrowBuf stringToDayInterval(String[] values) {
ArrowBuf buffer = allocator.buffer(values.length * 8);
for (int i = 0; i < values.length; i++) {
buffer.writeInt(Integer.parseInt(values[i].split(" ")[0])); // days
buffer.writeInt(Integer.parseInt(values[i].split(" ")[1])); // millis
}
return buffer;
}

void releaseRecordBatch(ArrowRecordBatch recordBatch) {
// There are 2 references to the buffers
// One in the recordBatch - release that by calling close()
Expand Down
Expand Up @@ -44,6 +44,7 @@
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
import org.apache.arrow.vector.types.DateUnit;
import org.apache.arrow.vector.types.IntervalUnit;
import org.apache.arrow.vector.types.TimeUnit;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
Expand Down Expand Up @@ -1525,4 +1526,75 @@ public void testCastTimestampToString() throws Exception {
releaseRecordBatch(batch);
releaseValueVectors(output);
}

@Test
public void testCastDayIntervalToBigInt() throws Exception {
ArrowType dayIntervalType = new ArrowType.Interval(IntervalUnit.DAY_TIME);

Field dayIntervalField = Field.nullable("dayInterval", dayIntervalType);

TreeNode intervalNode = TreeBuilder.makeField(dayIntervalField);

TreeNode intervalToBigint = TreeBuilder.makeFunction("castBIGINT", Lists.newArrayList(intervalNode), int64);

Field resultField = Field.nullable("result", int64);
List<ExpressionTree> exprs =
Lists.newArrayList(
TreeBuilder.makeExpression(intervalToBigint, resultField));

Schema schema = new Schema(Lists.newArrayList(dayIntervalField));
Projector eval = Projector.make(schema, exprs);

int numRows = 5;
byte[] validity = new byte[]{(byte) 255};
String[] values =
new String[]{
"1 0", // "days millis"
"2 0",
"1 1",
"10 5000",
"11 86400001",
};

Long[] expValues =
new Long[]{
86400000L,
2 * 86400000L,
86400000L + 1L,
10 * 86400000L + 5000L,
11 * 86400000L + 86400001L
};

ArrowBuf bufValidity = buf(validity);
ArrowBuf intervalsData = stringToDayInterval(values);

ArrowFieldNode fieldNode = new ArrowFieldNode(numRows, 0);
ArrowRecordBatch batch =
new ArrowRecordBatch(
numRows,
Lists.newArrayList(fieldNode, fieldNode),
Lists.newArrayList(bufValidity, intervalsData));

List<ValueVector> output = new ArrayList<>();
for (int i = 0; i < exprs.size(); i++) {
BigIntVector bigIntVector = new BigIntVector(EMPTY_SCHEMA_PATH, allocator);
bigIntVector.allocateNew(numRows);
output.add(bigIntVector);
}
eval.evaluate(batch, output);
eval.close();

for (ValueVector valueVector : output) {
BigIntVector bigintVector = (BigIntVector) valueVector;

for (int j = 0; j < numRows; j++) {
assertFalse(bigintVector.isNull(j));
assertEquals(expValues[j], Long.valueOf(bigintVector.get(j)));
}
}

releaseRecordBatch(batch);
releaseValueVectors(output);
}

}

0 comments on commit 116672f

Please sign in to comment.