Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,8 @@ class ScalarFunctionSplitter(

private var fieldsRexCall: Map[Int, Int] = Map[Int, Int]()

private val extractedRexNodeRefs: mutable.HashSet[RexNode] = mutable.HashSet[RexNode]()

override def visitCall(call: RexCall): RexNode = {
if (needConvert(call)) {
getExtractedRexNode(call)
Expand All @@ -454,7 +456,9 @@ class ScalarFunctionSplitter(
new RexInputRef(field.getIndex, field.getType)
case _ =>
val newFieldAccess =
rexBuilder.makeFieldAccess(expr.accept(this), fieldAccess.getField.getIndex)
rexBuilder.makeFieldAccess(
convertInputRefToLocalRefIfNecessary(expr.accept(this)),
fieldAccess.getField.getIndex)
getExtractedRexNode(newFieldAccess)
}
} else {
Expand All @@ -468,9 +472,18 @@ class ScalarFunctionSplitter(

override def visitNode(rexNode: RexNode): RexNode = rexNode

private def convertInputRefToLocalRefIfNecessary(node: RexNode): RexNode = {
node match {
case inputRef: RexInputRef if extractedRexNodeRefs.contains(node) =>
new RexLocalRef(inputRef.getIndex, node.getType)
case _ => node
}
}

private def getExtractedRexNode(node: RexNode): RexNode = {
val newNode = new RexInputRef(extractedFunctionOffset + extractedRexNodes.length, node.getType)
extractedRexNodes.append(node)
extractedRexNodeRefs.add(newNode)
newNode
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ public void setup() {
+ " a int,\n"
+ " b bigint,\n"
+ " c string,\n"
+ " d ARRAY<INT NOT NULL>\n"
+ " d ARRAY<INT NOT NULL>,\n"
+ " e ROW<f ROW<h int, i double>, g string>"
+ ") WITH (\n"
+ " 'connector' = 'test-simple-table-source'\n"
+ ") ;");
Expand Down Expand Up @@ -182,6 +183,12 @@ public void testFieldAccessAfter() {
util.verifyRelPlan(sqlQuery);
}

@Test
public void testCompositeFieldAsInput() {
String sqlQuery = "SELECT func1(e.f.h) from MyTable";
util.verifyRelPlan(sqlQuery);
}

@Test
public void testFieldOperand() {
String sqlQuery = "SELECT func1(func5(a).f0) from MyTable";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ public void setup() {
+ " a int,\n"
+ " b bigint,\n"
+ " c string,\n"
+ " d ARRAY<INT NOT NULL>\n"
+ " d ARRAY<INT NOT NULL>,\n"
+ " e ROW<f ROW<h int, i double>, g string>\n"
+ ") WITH (\n"
+ " 'connector' = 'test-simple-table-source'\n"
+ ") ;");
Expand Down Expand Up @@ -110,6 +111,12 @@ public void testCorrelateWithCast() {
util.verifyRelPlan(sqlQuery);
}

@Test
public void testCorrelateWithCompositeFieldAsInput() {
String sqlQuery = "select * FROM MyTable, LATERAL TABLE(asyncTableFunc(e.f.h))";
util.verifyRelPlan(sqlQuery);
}

/** Test function. */
public static class AsyncFunc extends AsyncTableFunction<String> {

Expand Down
Loading