Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GLUTEN-4772][VL] Support empty map/array literal #4771

Merged
merged 14 commits into from
Mar 6, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -1086,4 +1086,29 @@ class TestOperator extends VeloxWholeStageTransformerSuite with AdaptiveSparkPla
}
}

test("test array literal") {
withTable("array_table") {
sql("create table array_table(a array<bigint>) using parquet")
sql("insert into table array_table select array(1)")
runQueryAndCompare("select size(coalesce(a, array())) from array_table") {
df =>
{
assert(getExecutedPlan(df).count(_.isInstanceOf[ProjectExecTransformer]) == 1)
}
}
}
}

test("test map literal") {
withTable("map_table") {
sql("create table map_table(a map<bigint, string>) using parquet")
sql("insert into table map_table select map(1, 'hello')")
runQueryAndCompare("select size(coalesce(a, map())) from map_table") {
df =>
{
assert(getExecutedPlan(df).count(_.isInstanceOf[ProjectExecTransformer]) == 1)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite {
}

test("Array Literal") {
validateOffloadResult("SELECT array()")
validateOffloadResult("SELECT array(array())")
validateOffloadResult("SELECT array(map())")
validateOffloadResult("SELECT array('Spark', '5')")
validateOffloadResult("SELECT array(5, 1, -1)")
validateOffloadResult("SELECT array(5S, 1S, -1S)")
Expand All @@ -93,6 +96,9 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite {
}

test("Map Literal") {
validateOffloadResult("SELECT map()")
validateOffloadResult("SELECT map(1, array())")
validateOffloadResult("SELECT map(1, map())")
validateOffloadResult("SELECT map('b', 'a', 'e', 'e')")
validateOffloadResult("SELECT map(1D, 'a', 2D, 'e')")
validateOffloadResult("SELECT map(1.0, map(1, 2, 3, 4))")
Expand Down Expand Up @@ -126,12 +132,6 @@ class VeloxLiteralSuite extends VeloxWholeStageTransformerSuite {
}

test("Literal Fallback") {
validateFallbackResult("SELECT array()")
validateFallbackResult("SELECT array(array())")
validateFallbackResult("SELECT array(map())")
validateFallbackResult("SELECT map()")
validateFallbackResult("SELECT map(1, array())")
validateFallbackResult("SELECT map(1, map())")
validateFallbackResult("SELECT array(null)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know why does it fallback ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it's not related to empty literal, but NullType support.
I'll fix it in another PR, otherwise this PR is too big.

validateFallbackResult("SELECT array(cast(null as int))")
validateFallbackResult("SELECT map(1, null)")
Expand Down
10 changes: 10 additions & 0 deletions cpp-ch/local-engine/Parser/SerializedPlanParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,6 +1463,11 @@ std::pair<DataTypePtr, Field> SerializedPlanParser::parseLiteral(const substrait
field = std::move(array);
break;
}
case substrait::Expression_Literal::kEmptyList: {
type = std::make_shared<DataTypeArray>(std::make_shared<DataTypeNothing>());
field = Array();
break;
}
case substrait::Expression_Literal::kMap: {
const auto & key_values = literal.map().key_values();
if (key_values.empty())
Expand Down Expand Up @@ -1508,6 +1513,11 @@ std::pair<DataTypePtr, Field> SerializedPlanParser::parseLiteral(const substrait
field = std::move(map);
break;
}
case substrait::Expression_Literal::kEmptyMap: {
type = std::make_shared<DataTypeMap>(std::make_shared<DataTypeNothing>(), std::make_shared<DataTypeNothing>());
field = Map();
break;
}
case substrait::Expression_Literal::kStruct: {
const auto & fields = literal.struct_().fields();

Expand Down
47 changes: 34 additions & 13 deletions cpp/velox/substrait/SubstraitToVeloxExpr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,16 @@ RowVectorPtr makeRowVector(const std::vector<VectorPtr>& children) {
return std::make_shared<RowVector>(children[0]->pool(), rowType, BufferPtr(nullptr), vectorSize, children);
}

ArrayVectorPtr makeEmptyArrayVector(memory::MemoryPool* pool) {
ArrayVectorPtr makeEmptyArrayVector(memory::MemoryPool* pool, const TypePtr& elementType) {
BufferPtr offsets = allocateOffsets(1, pool);
BufferPtr sizes = allocateOffsets(1, pool);
return std::make_shared<ArrayVector>(pool, ARRAY(UNKNOWN()), nullptr, 1, offsets, sizes, nullptr);
return std::make_shared<ArrayVector>(pool, ARRAY(elementType), nullptr, 1, offsets, sizes, nullptr);
}

MapVectorPtr makeEmptyMapVector(memory::MemoryPool* pool) {
MapVectorPtr makeEmptyMapVector(memory::MemoryPool* pool, const TypePtr& keyType, const TypePtr& valueType) {
BufferPtr offsets = allocateOffsets(1, pool);
BufferPtr sizes = allocateOffsets(1, pool);
return std::make_shared<MapVector>(pool, MAP(UNKNOWN(), UNKNOWN()), nullptr, 1, offsets, sizes, nullptr, nullptr);
return std::make_shared<MapVector>(pool, MAP(keyType, valueType), nullptr, 1, offsets, sizes, nullptr, nullptr);
}

RowVectorPtr makeEmptyRowVector(memory::MemoryPool* pool) {
Expand Down Expand Up @@ -351,10 +351,21 @@ std::shared_ptr<const core::ConstantTypedExpr> SubstraitVeloxExprConverter::toVe
auto constantVector = BaseVector::wrapInConstant(1, 0, literalsToArrayVector(substraitLit));
return std::make_shared<const core::ConstantTypedExpr>(constantVector);
}
case ::substrait::Expression_Literal::LiteralTypeCase::kEmptyList: {
auto elementType = SubstraitParser::parseType(substraitLit.empty_list().type());
auto constantVector = BaseVector::wrapInConstant(1, 0, makeEmptyArrayVector(pool_, elementType));
return std::make_shared<const core::ConstantTypedExpr>(constantVector);
}
case ::substrait::Expression_Literal::LiteralTypeCase::kMap: {
auto constantVector = BaseVector::wrapInConstant(1, 0, literalsToMapVector(substraitLit));
return std::make_shared<const core::ConstantTypedExpr>(constantVector);
}
case ::substrait::Expression_Literal::LiteralTypeCase::kEmptyMap: {
auto keyType = SubstraitParser::parseType(substraitLit.empty_map().key());
auto valueType = SubstraitParser::parseType(substraitLit.empty_map().value());
auto constantVector = BaseVector::wrapInConstant(1, 0, makeEmptyMapVector(pool_, keyType, valueType));
return std::make_shared<const core::ConstantTypedExpr>(constantVector);
}
case ::substrait::Expression_Literal::LiteralTypeCase::kStruct: {
auto constantVector = BaseVector::wrapInConstant(1, 0, literalsToRowVector(substraitLit));
return std::make_shared<const core::ConstantTypedExpr>(constantVector);
Expand Down Expand Up @@ -382,33 +393,34 @@ std::shared_ptr<const core::ConstantTypedExpr> SubstraitVeloxExprConverter::toVe
ArrayVectorPtr SubstraitVeloxExprConverter::literalsToArrayVector(const ::substrait::Expression::Literal& literal) {
auto childSize = literal.list().values().size();
if (childSize == 0) {
return makeEmptyArrayVector(pool_);
return makeEmptyArrayVector(pool_, UNKNOWN());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it's better to throw for unexpected behavior.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1, if the child size is 0 that means it is a empty list, we should not go into this method. @WangGuangxin can you address this comment ?

}
auto childTypeCase = literal.list().values(0).literal_type_case();
auto childLiteral = literal.list().values(0);
auto elementAtFunc = [&](vector_size_t idx) { return literal.list().values(idx); };
auto childVector = literalsToVector(childTypeCase, childSize, literal, elementAtFunc);
auto childVector = literalsToVector(childLiteral, childSize, literal, elementAtFunc);
return makeArrayVector(childVector);
}

MapVectorPtr SubstraitVeloxExprConverter::literalsToMapVector(const ::substrait::Expression::Literal& literal) {
auto childSize = literal.map().key_values().size();
if (childSize == 0) {
return makeEmptyMapVector(pool_);
return makeEmptyMapVector(pool_, UNKNOWN(), UNKNOWN());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same.

}
auto keyTypeCase = literal.map().key_values(0).key().literal_type_case();
auto valueTypeCase = literal.map().key_values(0).value().literal_type_case();
auto& keyLiteral = literal.map().key_values(0).key();
auto& valueLiteral = literal.map().key_values(0).value();
auto keyAtFunc = [&](vector_size_t idx) { return literal.map().key_values(idx).key(); };
auto valueAtFunc = [&](vector_size_t idx) { return literal.map().key_values(idx).value(); };
auto keyVector = literalsToVector(keyTypeCase, childSize, literal, keyAtFunc);
auto valueVector = literalsToVector(valueTypeCase, childSize, literal, valueAtFunc);
auto keyVector = literalsToVector(keyLiteral, childSize, literal, keyAtFunc);
auto valueVector = literalsToVector(valueLiteral, childSize, literal, valueAtFunc);
return makeMapVector(keyVector, valueVector);
}

VectorPtr SubstraitVeloxExprConverter::literalsToVector(
::substrait::Expression_Literal::LiteralTypeCase childTypeCase,
const ::substrait::Expression::Literal& childLiteral,
vector_size_t childSize,
const ::substrait::Expression::Literal& literal,
std::function<::substrait::Expression::Literal(vector_size_t /* idx */)> elementAtFunc) {
auto childTypeCase = childLiteral.literal_type_case();
switch (childTypeCase) {
case ::substrait::Expression_Literal::LiteralTypeCase::kNull: {
auto veloxType = SubstraitParser::parseType(literal.null());
Expand Down Expand Up @@ -456,6 +468,15 @@ VectorPtr SubstraitVeloxExprConverter::literalsToVector(
}
return rowVector;
}
case ::substrait::Expression_Literal::LiteralTypeCase::kEmptyList: {
auto elementType = SubstraitParser::parseType(childLiteral.empty_list().type());
return BaseVector::wrapInConstant(1, 0, makeEmptyArrayVector(pool_, elementType));
}
case ::substrait::Expression_Literal::LiteralTypeCase::kEmptyMap: {
auto keyType = SubstraitParser::parseType(childLiteral.empty_map().key());
auto valueType = SubstraitParser::parseType(childLiteral.empty_map().value());
return BaseVector::wrapInConstant(1, 0, makeEmptyMapVector(pool_, keyType, valueType));
}
default:
auto veloxType = getScalarType(elementAtFunc(0));
if (veloxType) {
Expand Down
2 changes: 1 addition & 1 deletion cpp/velox/substrait/SubstraitToVeloxExpr.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class SubstraitVeloxExprConverter {
/// Convert map literal to MapVector.
MapVectorPtr literalsToMapVector(const ::substrait::Expression::Literal& literal);
VectorPtr literalsToVector(
::substrait::Expression_Literal::LiteralTypeCase childTypeCase,
const ::substrait::Expression::Literal& childLiteral,
vector_size_t childSize,
const ::substrait::Expression::Literal& literal,
std::function<::substrait::Expression::Literal(vector_size_t /* idx */)> elementAtFunc);
Expand Down
26 changes: 8 additions & 18 deletions cpp/velox/substrait/SubstraitToVeloxPlanValidator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -243,27 +243,17 @@ bool SubstraitToVeloxPlanValidator::validateLiteral(
const ::substrait::Expression_Literal& literal,
const RowTypePtr& inputType) {
if (literal.has_list()) {
if (literal.list().values_size() == 0) {
LOG_VALIDATION_MSG("Literal is a list but has no value.");
return false;
} else {
for (auto child : literal.list().values()) {
if (!validateLiteral(child, inputType)) {
// the error msg has been set, so do not need to set it again.
return false;
}
for (auto child : literal.list().values()) {
if (!validateLiteral(child, inputType)) {
// the error msg has been set, so do not need to set it again.
return false;
}
}
} else if (literal.has_map()) {
if (literal.map().key_values().empty()) {
LOG_VALIDATION_MSG("Literal is a map but has no value.");
return false;
} else {
for (auto child : literal.map().key_values()) {
if (!validateLiteral(child.key(), inputType) || !validateLiteral(child.value(), inputType)) {
// the error msg has been set, so do not need to set it again.
return false;
}
for (auto child : literal.map().key_values()) {
if (!validateLiteral(child.key(), inputType) || !validateLiteral(child.value(), inputType)) {
// the error msg has been set, so do not need to set it again.
return false;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import io.substrait.proto.Expression;
import io.substrait.proto.Expression.Literal.Builder;
import io.substrait.proto.Type;
import org.apache.spark.sql.catalyst.util.ArrayData;

public class ListLiteralNode extends LiteralNodeWithValue<ArrayData> {
Expand All @@ -33,13 +34,18 @@ protected void updateLiteralBuilder(Builder literalBuilder, ArrayData array) {
Object[] elements = array.array();
TypeNode elementType = ((ListNode) getTypeNode()).getNestedType();

Expression.Literal.List.Builder listBuilder = Expression.Literal.List.newBuilder();
for (Object element : elements) {
LiteralNode elementNode = ExpressionBuilder.makeLiteral(element, elementType);
Expression.Literal elementExpr = elementNode.getLiteral();
listBuilder.addValues(elementExpr);
if (elements.length > 0) {
Expression.Literal.List.Builder listBuilder = Expression.Literal.List.newBuilder();
for (Object element : elements) {
LiteralNode elementNode = ExpressionBuilder.makeLiteral(element, elementType);
Expression.Literal elementExpr = elementNode.getLiteral();
listBuilder.addValues(elementExpr);
}
literalBuilder.setList(listBuilder.build());
} else {
Type.List.Builder listTypeBuilder = Type.List.newBuilder();
listTypeBuilder.setType(elementType.toProtobuf());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if if element type is NullType, we will still fallback it right ? e.g., select array();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's related to #2996. It will not fallback once NullType is supported

literalBuilder.setEmptyList(listTypeBuilder.build());
}

literalBuilder.setList(listBuilder.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import io.substrait.proto.Expression;
import io.substrait.proto.Expression.Literal.Builder;
import io.substrait.proto.Type;
import org.apache.spark.sql.catalyst.util.MapData;

public class MapLiteralNode extends LiteralNodeWithValue<MapData> {
Expand All @@ -32,21 +33,28 @@ public MapLiteralNode(MapData map, TypeNode typeNode) {
protected void updateLiteralBuilder(Builder literalBuilder, MapData map) {
Object[] keys = map.keyArray().array();
Object[] values = map.valueArray().array();
TypeNode mapType = ((MapNode) getTypeNode()).getKeyType();
TypeNode valueType = ((MapNode) getTypeNode()).getValueType();

Expression.Literal.Map.Builder mapBuilder = Expression.Literal.Map.newBuilder();
for (int i = 0; i < keys.length; ++i) {
LiteralNode keyNode =
ExpressionBuilder.makeLiteral(keys[i], ((MapNode) getTypeNode()).getKeyType());
LiteralNode valueNode =
ExpressionBuilder.makeLiteral(values[i], ((MapNode) getTypeNode()).getValueType());
if (keys.length > 0) {
Expression.Literal.Map.Builder mapBuilder = Expression.Literal.Map.newBuilder();
for (int i = 0; i < keys.length; ++i) {
LiteralNode keyNode = ExpressionBuilder.makeLiteral(keys[i], mapType);
LiteralNode valueNode = ExpressionBuilder.makeLiteral(values[i], valueType);

Expression.Literal.Map.KeyValue.Builder kvBuilder =
Expression.Literal.Map.KeyValue.newBuilder();
kvBuilder.setKey(keyNode.getLiteral());
kvBuilder.setValue(valueNode.getLiteral());
mapBuilder.addKeyValues(kvBuilder.build());
}
Expression.Literal.Map.KeyValue.Builder kvBuilder =
Expression.Literal.Map.KeyValue.newBuilder();
kvBuilder.setKey(keyNode.getLiteral());
kvBuilder.setValue(valueNode.getLiteral());
mapBuilder.addKeyValues(kvBuilder.build());
}

literalBuilder.setMap(mapBuilder.build());
literalBuilder.setMap(mapBuilder.build());
} else {
Type.Map.Builder mapTypeBuilder = Type.Map.newBuilder();
mapTypeBuilder.setKey(mapType.toProtobuf());
mapTypeBuilder.setValue(valueType.toProtobuf());
literalBuilder.setEmptyMap(mapTypeBuilder.build());
}
}
}
Loading