Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
/**
* term of the [[ProcessFunction]]'s context, can be changed when needed
*/
var contextTerm = "ctx"
var contextTerm: String = ExprCodeGenerator.PROCESS_FUNCTION_DEFAULT_CONTEXT_TERM

/**
* information of the first input
Expand Down Expand Up @@ -443,7 +443,8 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
case (o@_, _) => o.accept(this)
}

generateCallExpression(ctx, call.getOperator, operands, resultType)
ExprCodeGenerator.generateCallExpression(
ctx, call.getOperator, operands, resultType, contextTerm)
}

override def visitOver(over: RexOver): GeneratedExpression =
Expand All @@ -454,14 +455,18 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)

override def visitPatternFieldRef(fieldRef: RexPatternFieldRef): GeneratedExpression =
throw new CodeGenException("Pattern field references are not supported yet.")
}

object ExprCodeGenerator {

// ----------------------------------------------------------------------------------------
val PROCESS_FUNCTION_DEFAULT_CONTEXT_TERM = "ctx"

private def generateCallExpression(
def generateCallExpression(
ctx: CodeGeneratorContext,
operator: SqlOperator,
operands: Seq[GeneratedExpression],
resultType: LogicalType): GeneratedExpression = {
resultType: LogicalType,
contextTerm: String = PROCESS_FUNCTION_DEFAULT_CONTEXT_TERM): GeneratedExpression = {
operator match {
// arithmetic
case PLUS if isNumeric(resultType) =>
Expand Down Expand Up @@ -740,9 +745,9 @@ class ExprCodeGenerator(ctx: CodeGeneratorContext, nullableInput: Boolean)
resultType)
.getOrElse(
throw new CodeGenException(s"Unsupported call: " +
s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" +
s"If you think this function should be supported, " +
s"you can create an issue and start a discussion for it."))
s"$sqlOperator(${operands.map(_.resultType).mkString(", ")}) \n" +
s"If you think this function should be supported, " +
s"you can create an issue and start a discussion for it."))
.generate(ctx, operands, resultType)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.codegen

import org.apache.flink.table.dataformat.{BaseRow, BinaryRow, BinaryRowWriter, BinaryString, GenericRow}
import org.junit.Assert.{assertFalse, assertTrue}
import org.junit.Test
import java.lang.{Integer => JInt, Long => JLong}

import org.apache.flink.api.common.ExecutionConfig
import org.apache.flink.table.api.DataTypes
import org.apache.flink.table.types.logical.LogicalType
import org.apache.flink.table.typeutils.BaseRowSerializer

/**
* Test for [[EqualiserCodeGenerator]].
*/
class EqualiserCodeGeneratorTest {

@Test
def testEqualiser(): Unit = {
val types = Array[LogicalType](
DataTypes.INT.getLogicalType,
DataTypes.ROW(
DataTypes.FIELD("f0", DataTypes.INT),
DataTypes.FIELD("f1", DataTypes.STRING),
DataTypes.FIELD("f2", DataTypes.ROW(
DataTypes.FIELD("f20", DataTypes.STRING),
DataTypes.FIELD("f21", DataTypes.STRING)
))).getLogicalType,
DataTypes.BIGINT.getLogicalType)
val generator = new EqualiserCodeGenerator(types)
val recordEqualiser = generator.generateRecordEqualiser("recordEqualiser")
.newInstance(Thread.currentThread().getContextClassLoader)

val rowLeft: GenericRow = GenericRow.of(
1: JInt,
GenericRow.of(
2: JInt,
BinaryString.fromString("3"),
GenericRow.of(
BinaryString.fromString("4"), BinaryString.fromString("5"))), 6L: JLong)
val rowRight: BaseRow = newBinaryRow(1: JInt, 2, "3", "4", "5", 6L)
assertTrue(recordEqualiser.equals(rowLeft, rowRight))

rowLeft.setHeader(1)
rowRight.setHeader(0)
assertFalse(recordEqualiser.equals(rowLeft, rowRight))
assertTrue(recordEqualiser.equalsWithoutHeader(rowLeft, rowRight))
}

def newBinaryRow(
c1: Int, c21: Int, c22: String, c231: String, c232: String, c3: Long): BinaryRow = {
val c23 = new BinaryRow(2)
var writer = new BinaryRowWriter(c23)
writer.writeString(0, BinaryString.fromString(c231))
writer.writeString(1, BinaryString.fromString(c232))
writer.complete()

val c2Row = new BinaryRow(2)
val c2Serializer = new BaseRowSerializer(
new ExecutionConfig(), DataTypes.STRING.getLogicalType, DataTypes.STRING.getLogicalType)
writer = new BinaryRowWriter(c2Row)
writer.writeInt(0, c21)
writer.writeString(1, BinaryString.fromString(c22))
writer.writeRow(2, c23, c2Serializer)
writer.complete()

val row = new BinaryRow(3)
writer = new BinaryRowWriter(row)
val serializer = new BaseRowSerializer(
new ExecutionConfig(),
DataTypes.INT.getLogicalType,
DataTypes.STRING.getLogicalType,
DataTypes.ROW(
DataTypes.FIELD("f20", DataTypes.STRING),
DataTypes.FIELD("f21", DataTypes.STRING)).getLogicalType)
writer.writeInt(0, c1)
writer.writeRow(1, c2Row, serializer)
writer.writeLong(2, c3)
writer.complete()
row
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.codegen.calls

import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.sql.fun.SqlStdOperatorTable._
import org.apache.flink.table.api.{DataTypes, TableConfig}
import org.apache.flink.table.codegen.CodeGenUtils.{BINARY_STRING, newName}
import org.apache.flink.table.codegen.ExprCodeGenerator.generateCallExpression
import org.apache.flink.table.codegen.{CodeGeneratorContext, GeneratedExpression}
import org.apache.flink.table.functions.sql.FlinkSqlOperatorTable
import org.apache.flink.table.generated.CompileUtils.compile
import org.apache.flink.table.types.logical.LogicalType
import org.junit.Assert._
import org.junit.Test

class StringCallGenTest {

def invoke(operator: SqlOperator, operands: Seq[GeneratedExpression], tpe: LogicalType): Any = {
val config = new TableConfig()
val ctx = new CodeGeneratorContext(config)
val expr = generateCallExpression(ctx, operator, operands, tpe)
compileAndInvoke(ctx, expr)
}

def compileAndInvoke(ctx: CodeGeneratorContext, expr: GeneratedExpression): Any = {
val name = newName("StringCallGenTest")
val abstractClass = classOf[Func].getCanonicalName
val code =
s"""
|public class $name extends $abstractClass {
|
| ${ctx.reuseMemberCode()}
|
| @Override
| public Object apply() {
| ${ctx.reuseLocalVariableCode()}
| ${expr.code}
| if (${expr.nullTerm}) {
| return null;
| } else {
| return ${expr.resultTerm};
| }
| }
|}
""".stripMargin
val func = compile(Thread.currentThread().getContextClassLoader, name, code)
.newInstance().asInstanceOf[Func]
func()
}

def toBinary(term: String): String = s"$BINARY_STRING.fromString($term)"

def str(term: String): GeneratedExpression =
newOperand(toBinary("\"" + term + "\""), DataTypes.STRING.getLogicalType)

def int(term: String): GeneratedExpression =
newOperand(term, DataTypes.INT.getLogicalType)

def newOperand(resultTerm: String, resultType: LogicalType): GeneratedExpression =
GeneratedExpression(resultTerm, "false", "", resultType)

@Test
def testEquals(): Unit = {
assertFalse(invoke(EQUALS, Seq(str("haha"), str("hehe")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
assertTrue(invoke(EQUALS, Seq(str("haha"), str("haha")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])

assertTrue(invoke(NOT_EQUALS, Seq(str("haha"), str("hehe")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
assertFalse(invoke(NOT_EQUALS, Seq(str("haha"), str("haha")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
}

@Test
def testLike(): Unit = {
assertFalse(invoke(SqlStdOperatorTable.LIKE, Seq(str("haha"), str("hehe")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
assertTrue(invoke(SqlStdOperatorTable.LIKE, Seq(str("haha"), str("haha")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])

assertTrue(invoke(SqlStdOperatorTable.NOT_LIKE, Seq(str("haha"), str("hehe")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
assertFalse(invoke(SqlStdOperatorTable.NOT_LIKE, Seq(str("haha"), str("haha")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
}

@Test
def testCharLength(): Unit = {
assertEquals(4, invoke(CHAR_LENGTH, Seq(str("haha")), DataTypes.INT.getLogicalType))
assertEquals(4, invoke(CHARACTER_LENGTH, Seq(str("haha")), DataTypes.INT.getLogicalType))
}

@Test
def testSqlTime(): Unit = {
assertEquals(1453438905L,
invoke(FlinkSqlOperatorTable.UNIX_TIMESTAMP,
Seq(str("2016-01-22 05:01:45")), DataTypes.TIMESTAMP.getLogicalType))

assertEquals(-120,
invoke(FlinkSqlOperatorTable.DATEDIFF,
Seq(str("2016-01-22"), str("2016-05-21")), DataTypes.DATE.getLogicalType))
}

@Test
def testSimilarTo(): Unit = {
assertFalse(invoke(SIMILAR_TO, Seq(str("haha"), str("hehe")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
assertTrue(invoke(SIMILAR_TO, Seq(str("haha"), str("haha")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])

assertTrue(invoke(NOT_SIMILAR_TO, Seq(str("haha"), str("hehe")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
assertFalse(invoke(NOT_SIMILAR_TO, Seq(str("haha"), str("haha")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
}

@Test
def testIsXxx(): Unit = {
assertTrue(invoke(FlinkSqlOperatorTable.IS_DECIMAL, Seq(str("1234134")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
assertTrue(invoke(FlinkSqlOperatorTable.IS_DIGIT, Seq(str("1234134")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
assertTrue(invoke(FlinkSqlOperatorTable.IS_ALPHA, Seq(str("adb")),
DataTypes.BOOLEAN.getLogicalType).asInstanceOf[Boolean])
}

@Test
def testPosition(): Unit = {
assertEquals(5,
invoke(POSITION, Seq(str("d"), str("aaaadfg")),
DataTypes.INT.getLogicalType))
}

@Test
def testHash(): Unit = {
assertEquals(1236857883,
invoke(FlinkSqlOperatorTable.HASH_CODE, Seq(str("aaaadfg")),
DataTypes.INT.getLogicalType))
}
}

abstract class Func {
def apply(): Any
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.table.runtime.hashtable;

import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.core.memory.MemorySegmentFactory;
import org.apache.flink.core.memory.SeekableDataInputView;
Expand Down Expand Up @@ -245,9 +246,10 @@ MatchIterator valueIter(long address) {
return iterator;
}

// public MatchIterator get(long key) {
// return get(key, hashLong(key, recursionLevel));
// }
@VisibleForTesting
public MatchIterator get(long key) {
return get(key, hashLong(key, recursionLevel));
}

/**
* Returns an iterator for all the values for the given key, or null if no value found.
Expand Down Expand Up @@ -653,6 +655,11 @@ void releaseBuckets() {
}
}

@VisibleForTesting
public void append(long key, BinaryRow row) throws IOException {
insertIntoTable(key, hashLong(key, recursionLevel), row);
}

// ------------------ PagedInputView for read end --------------------

/**
Expand Down
Loading