Skip to content

Commit

Permalink
[FLINK-1851] [tableAPI] Add support for casting in Table Expression P…
Browse files Browse the repository at this point in the history
…arser

- Also fix code generation for casting between primitives.
- Extends documentation for TableAPI expressions

This closes #592
  • Loading branch information
aljoscha authored and fhueske committed Sep 16, 2015
1 parent 31f6317 commit 7bea901
Show file tree
Hide file tree
Showing 9 changed files with 198 additions and 28 deletions.
55 changes: 53 additions & 2 deletions docs/libs/table.md
Expand Up @@ -123,7 +123,58 @@ DataSet<WC> result = tableEnv.toDataSet(filtered, WC.class);
When using Java, the embedded DSL for specifying expressions cannot be used. Only String expressions
are supported. They support exactly the same feature set as the expression DSL.

Please refer to the Javadoc for a full list of supported operations and a description of the
expression syntax.
## Expression Syntax

A `Table` supports to following operations: `select`, `where`, `groupBy`, `join` (Plus `filter` as
an alias for `where`.). These are also documented in the [Javadoc](http://flink.apache.org/docs/latest/api/java/org/apache/flink/api/table/Table.html)
of Table.

Some of these expect an expression. These can either be specified using an embedded Scala DSL or
a String expression. Please refer to the examples above to learn how expressions can be
formulated.

This is the complete EBNF grammar for expressions:

{% highlight ebnf %}

expression = single expression , { "," , single expression } ;

single expression = alias | logic ;

alias = logic | logic , "AS" , field reference ;

logic = comparison , [ ( "&&" | "||" ) , comparison ] ;

comparison = term , [ ( "=" | "!=" | ">" | ">=" | "<" | "<=" ) , term ] ;

term = product , [ ( "+" | "-" ) , product ] ;

product = binary bitwise , [ ( "*" | "/" | "%" ) , binary bitwise ] ;

binary bitwise = unary , [ ( "&" | "!" | "^" ) , unary ] ;

unary = [ "!" | "-" | "~" ] , suffix ;

suffix = atom | aggregation | cast | as | substring ;

aggregation = atom , [ ".sum" | ".min" | ".max" | ".count" | "avg" ] ;

cast = atom , ".cast(" , data type , ")" ;

data type = "BYTE" | "SHORT" | "INT" | "LONG" | "FLOAT" | "DOUBLE" | "BOOL" | "BOOLEAN" | "STRING" ;

as = atom , ".as(" , field reference , ")" ;

substring = atom , ".substring(" , substring start , ["," substring end] , ")" ;

substring start = single expression ;

substring end = single expression ;

atom = ( "(" , single expression , ")" ) | literal | field reference ;

{% endhighlight %}

Here, `literal` is a valid Java literal and `field reference` specifies a column in the data. The
column names follow Java identifier syntax.

Expand Up @@ -91,7 +91,7 @@ protected BasicTypeInfo(Class<T> clazz, Class<?>[] possibleCastTargetTypes, Type
* Returns whether this type should be automatically casted to
* the target type in an arithmetic operation.
*/
public boolean canCastTo(BasicTypeInfo<?> to) {
public boolean shouldAutocastTo(BasicTypeInfo<?> to) {
for (Class<?> possibleTo: possibleCastTargetTypes) {
if (possibleTo.equals(to.getTypeClass())) {
return true;
Expand Down
Expand Up @@ -41,6 +41,10 @@ import org.apache.flink.api.table.plan._
* val table2 = ...
* val set = table2.toDataSet[MyType]
* }}}
*
* Operations such as [[join]], [[select]], [[where]] and [[groupBy]] either take arguments
* in a Scala DSL or as an expression String. Please refer to the documentation for the expression
* syntax.
*/
case class Table(private[flink] val operation: PlanNode) {

Expand Down
Expand Up @@ -249,7 +249,7 @@ abstract class ExpressionCodeGenerator[R](
s"""
|boolean $nullTerm = ${childGen.nullTerm};
|$resultTpe $resultTerm;
|if ($nullTerm == null) {
|if ($nullTerm) {
| $resultTerm = null;
|} else {
| $resultTerm = "" + ${childGen.resultTerm};
Expand All @@ -262,8 +262,12 @@ abstract class ExpressionCodeGenerator[R](
}
childGen.code + castCode

case expressions.Cast(child: Expression, tpe: BasicTypeInfo[_]) =>
case expressions.Cast(child: Expression, tpe: BasicTypeInfo[_])
if child.typeInfo == BasicTypeInfo.STRING_TYPE_INFO =>
val childGen = generateExpression(child)
val fromTpe = typeTermForTypeInfoForCast(child.typeInfo)
val toTpe = typeTermForTypeInfoForCast(tpe)

val castCode = if (nullCheck) {
s"""
|boolean $nullTerm = ${childGen.nullTerm};
Expand All @@ -276,6 +280,27 @@ abstract class ExpressionCodeGenerator[R](
| ${tpe.getTypeClass.getCanonicalName}.valueOf(${childGen.resultTerm});
""".stripMargin
}

childGen.code + castCode

case expressions.Cast(child: Expression, tpe: BasicTypeInfo[_])
if child.typeInfo.isBasicType =>
val childGen = generateExpression(child)
val fromTpe = typeTermForTypeInfoForCast(child.typeInfo)
val toTpe = typeTermForTypeInfoForCast(tpe)
val castCode = if (nullCheck) {
s"""
|boolean $nullTerm = ${childGen.nullTerm};
|$resultTpe $resultTerm;
|if ($nullTerm) {
| $resultTerm = null;
|} else {
| $resultTerm = ($toTpe)($fromTpe) ${childGen.resultTerm};
|}
""".stripMargin
} else {
s"$resultTpe $resultTerm = ($toTpe)($fromTpe) ${childGen.resultTerm};\n"
}
childGen.code + castCode

case ResolvedFieldReference(fieldName, fieldTpe: TypeInformation[_]) =>
Expand Down Expand Up @@ -589,14 +614,38 @@ abstract class ExpressionCodeGenerator[R](

protected def typeTermForTypeInfo(tpe: TypeInformation[_]): String = tpe match {

// case BasicTypeInfo.INT_TYPE_INFO => "int"
// case BasicTypeInfo.LONG_TYPE_INFO => "long"
// case BasicTypeInfo.SHORT_TYPE_INFO => "short"
// case BasicTypeInfo.BYTE_TYPE_INFO => "byte"
// case BasicTypeInfo.FLOAT_TYPE_INFO => "float"
// case BasicTypeInfo.DOUBLE_TYPE_INFO => "double"
// case BasicTypeInfo.BOOLEAN_TYPE_INFO => "boolean"
// case BasicTypeInfo.CHAR_TYPE_INFO => "char"
// From PrimitiveArrayTypeInfo we would get class "int[]", scala reflections
// does not seem to like this, so we manually give the correct type here.
case PrimitiveArrayTypeInfo.INT_PRIMITIVE_ARRAY_TYPE_INFO => "int[]"
case PrimitiveArrayTypeInfo.LONG_PRIMITIVE_ARRAY_TYPE_INFO => "long[]"
case PrimitiveArrayTypeInfo.SHORT_PRIMITIVE_ARRAY_TYPE_INFO => "short[]"
case PrimitiveArrayTypeInfo.BYTE_PRIMITIVE_ARRAY_TYPE_INFO => "byte[]"
case PrimitiveArrayTypeInfo.FLOAT_PRIMITIVE_ARRAY_TYPE_INFO => "float[]"
case PrimitiveArrayTypeInfo.DOUBLE_PRIMITIVE_ARRAY_TYPE_INFO => "double[]"
case PrimitiveArrayTypeInfo.BOOLEAN_PRIMITIVE_ARRAY_TYPE_INFO => "boolean[]"
case PrimitiveArrayTypeInfo.CHAR_PRIMITIVE_ARRAY_TYPE_INFO => "char[]"

case _ =>
tpe.getTypeClass.getCanonicalName

}

// when casting we first need to unbox Primitives, for example,
// float a = 1.0f;
// byte b = (byte) a;
// works, but for boxed types we need this:
// Float a = 1.0f;
// Byte b = (byte)(float) a;
protected def typeTermForTypeInfoForCast(tpe: TypeInformation[_]): String = tpe match {

case BasicTypeInfo.INT_TYPE_INFO => "int"
case BasicTypeInfo.LONG_TYPE_INFO => "long"
case BasicTypeInfo.SHORT_TYPE_INFO => "short"
case BasicTypeInfo.BYTE_TYPE_INFO => "byte"
case BasicTypeInfo.FLOAT_TYPE_INFO => "float"
case BasicTypeInfo.DOUBLE_TYPE_INFO => "double"
case BasicTypeInfo.BOOLEAN_TYPE_INFO => "boolean"
case BasicTypeInfo.CHAR_TYPE_INFO => "char"

// From PrimitiveArrayTypeInfo we would get class "int[]", scala reflections
// does not seem to like this, so we manually give the correct type here.
Expand Down
Expand Up @@ -33,10 +33,10 @@ class InsertAutoCasts extends Rule[Expression] {
case plus@Plus(o1, o2) =>
// Plus is special case since we can cast anything to String for String concat
if (o1.typeInfo != o2.typeInfo && o1.typeInfo.isBasicType && o2.typeInfo.isBasicType) {
if (o1.typeInfo.asInstanceOf[BasicTypeInfo[_]].canCastTo(
if (o1.typeInfo.asInstanceOf[BasicTypeInfo[_]].shouldAutocastTo(
o2.typeInfo.asInstanceOf[BasicTypeInfo[_]])) {
Plus(Cast(o1, o2.typeInfo), o2)
} else if (o2.typeInfo.asInstanceOf[BasicTypeInfo[_]].canCastTo(
} else if (o2.typeInfo.asInstanceOf[BasicTypeInfo[_]].shouldAutocastTo(
o1.typeInfo.asInstanceOf[BasicTypeInfo[_]])) {
Plus(o1, Cast(o2, o1.typeInfo))
} else if (o1.typeInfo == BasicTypeInfo.STRING_TYPE_INFO) {
Expand All @@ -55,10 +55,10 @@ class InsertAutoCasts extends Rule[Expression] {
val o1 = ba.left
val o2 = ba.right
if (o1.typeInfo != o2.typeInfo && o1.typeInfo.isBasicType && o2.typeInfo.isBasicType) {
if (o1.typeInfo.asInstanceOf[BasicTypeInfo[_]].canCastTo(
if (o1.typeInfo.asInstanceOf[BasicTypeInfo[_]].shouldAutocastTo(
o2.typeInfo.asInstanceOf[BasicTypeInfo[_]])) {
ba.makeCopy(Seq(Cast(o1, o2.typeInfo), o2))
} else if (o2.typeInfo.asInstanceOf[BasicTypeInfo[_]].canCastTo(
} else if (o2.typeInfo.asInstanceOf[BasicTypeInfo[_]].shouldAutocastTo(
o1.typeInfo.asInstanceOf[BasicTypeInfo[_]])) {
ba.makeCopy(Seq(o1, Cast(o2, o1.typeInfo)))
} else {
Expand All @@ -73,10 +73,10 @@ class InsertAutoCasts extends Rule[Expression] {
val o2 = ba.right
if (o1.typeInfo != o2.typeInfo && o1.typeInfo.isInstanceOf[IntegerTypeInfo[_]] &&
o2.typeInfo.isInstanceOf[IntegerTypeInfo[_]]) {
if (o1.typeInfo.asInstanceOf[BasicTypeInfo[_]].canCastTo(
if (o1.typeInfo.asInstanceOf[BasicTypeInfo[_]].shouldAutocastTo(
o2.typeInfo.asInstanceOf[BasicTypeInfo[_]])) {
ba.makeCopy(Seq(Cast(o1, o2.typeInfo), o2))
} else if (o2.typeInfo.asInstanceOf[BasicTypeInfo[_]].canCastTo(
} else if (o2.typeInfo.asInstanceOf[BasicTypeInfo[_]].shouldAutocastTo(
o1.typeInfo.asInstanceOf[BasicTypeInfo[_]])) {
ba.makeCopy(Seq(o1, Cast(o2, o1.typeInfo)))
} else {
Expand Down
Expand Up @@ -17,8 +17,18 @@
*/
package org.apache.flink.api.table.expressions

import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.table.ExpressionException

case class Cast(child: Expression, tpe: TypeInformation[_]) extends UnaryExpression {
def typeInfo = tpe
def typeInfo = tpe match {
case BasicTypeInfo.STRING_TYPE_INFO => tpe

case b if b.isBasicType && child.typeInfo.isBasicType => tpe

case _ => throw new ExpressionException(
s"Invalid cast: $this. Casts are only valid betwixt primitive types.")
}

override def toString = s"$child.cast($tpe)"
}
Expand Up @@ -17,6 +17,7 @@
*/
package org.apache.flink.api.table.parser

import org.apache.flink.api.common.typeinfo.{BasicTypeInfo, TypeInformation}
import org.apache.flink.api.table.ExpressionException
import org.apache.flink.api.table.plan.As
import org.apache.flink.api.table.expressions._
Expand All @@ -27,8 +28,8 @@ import scala.util.parsing.combinator.{PackratParsers, JavaTokenParsers}
* Parser for expressions inside a String. This parses exactly the same expressions that
* would be accepted by the Scala Expression DSL.
*
* See [[org.apache.flink.api.scala.expressions.ImplicitExpressionConversions]] and
* [[org.apache.flink.api.scala.expressions.ImplicitExpressionOperations]] for the constructs
* See [[org.apache.flink.api.scala.table.ImplicitExpressionConversions]] and
* [[org.apache.flink.api.scala.table.ImplicitExpressionOperations]] for the constructs
* available in the Scala Expression DSL. This parser must be kept in sync with the Scala DSL
* lazy valined in the above files.
*/
Expand Down Expand Up @@ -107,6 +108,17 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {
lazy val avg: PackratParser[Expression] =
(atom <~ ".avg" ^^ { e => Avg(e) }) | (AVG ~ "(" ~> atom <~ ")" ^^ { e => Avg(e) })

lazy val cast: PackratParser[Expression] =
atom <~ ".cast(BYTE)" ^^ { e => Cast(e, BasicTypeInfo.BYTE_TYPE_INFO) } |
atom <~ ".cast(SHORT)" ^^ { e => Cast(e, BasicTypeInfo.SHORT_TYPE_INFO) } |
atom <~ ".cast(INT)" ^^ { e => Cast(e, BasicTypeInfo.INT_TYPE_INFO) } |
atom <~ ".cast(LONG)" ^^ { e => Cast(e, BasicTypeInfo.LONG_TYPE_INFO) } |
atom <~ ".cast(FLOAT)" ^^ { e => Cast(e, BasicTypeInfo.FLOAT_TYPE_INFO) } |
atom <~ ".cast(DOUBLE)" ^^ { e => Cast(e, BasicTypeInfo.DOUBLE_TYPE_INFO) } |
atom <~ ".cast(BOOL)" ^^ { e => Cast(e, BasicTypeInfo.BOOLEAN_TYPE_INFO) } |
atom <~ ".cast(BOOLEAN)" ^^ { e => Cast(e, BasicTypeInfo.BOOLEAN_TYPE_INFO) } |
atom <~ ".cast(STRING)" ^^ { e => Cast(e, BasicTypeInfo.STRING_TYPE_INFO) }

lazy val as: PackratParser[Expression] = atom ~ ".as(" ~ fieldReference ~ ")" ^^ {
case e ~ _ ~ as ~ _ => Naming(e, as.name)
}
Expand All @@ -125,7 +137,7 @@ object ExpressionParser extends JavaTokenParsers with PackratParsers {

lazy val suffix =
isNull | isNotNull |
abs | sum | min | max | count | avg |
abs | sum | min | max | count | avg | cast |
substring | substringWithoutEnd | atom


Expand Down
Expand Up @@ -18,6 +18,7 @@

package org.apache.flink.api.java.table.test;

import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.api.table.Table;
import org.apache.flink.api.table.Row;
import org.apache.flink.api.java.DataSet;
Expand Down Expand Up @@ -128,5 +129,27 @@ public void testNumericAutocastInComparison() throws Exception {

expected = "2,2,2,2,2.0,2.0,Hello";
}

@Test
public void testCastFromString() throws Exception {
ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
TableEnvironment tableEnv = new TableEnvironment();

DataSource<Tuple3<String, String, String>> input =
env.fromElements(new Tuple3<String, String, String>("1", "true", "2.0"));

Table table =
tableEnv.fromDataSet(input);

Table result = table.select(
"f0.cast(BYTE), f0.cast(SHORT), f0.cast(INT), f0.cast(LONG), f2.cast(DOUBLE), f2.cast(FLOAT), f1.cast(BOOL)");

DataSet<Row> ds = tableEnv.toDataSet(result, Row.class);
ds.writeAsText(resultPath, FileSystem.WriteMode.OVERWRITE);

env.execute();

expected = "1,1,1,1,2.0,2.0,true\n";
}
}

Expand Up @@ -18,15 +18,17 @@

package org.apache.flink.api.scala.table.test

import org.junit._
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

import org.apache.flink.api.common.typeinfo.BasicTypeInfo
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.table._
import org.apache.flink.core.fs.FileSystem.WriteMode
import org.apache.flink.test.util.{TestBaseUtils, MultipleProgramsTestBase}
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
import org.junit._
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.Parameterized

@RunWith(classOf[Parameterized])
class CastingITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode) {
Expand Down Expand Up @@ -89,4 +91,23 @@ class CastingITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mo
expected = "2,2,2,2,2.0,2.0"
}

@Test
def testCastFromString: Unit = {

val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(("1", "true", "2.0")).toTable
.select(
'_1.cast(BasicTypeInfo.BYTE_TYPE_INFO),
'_1.cast(BasicTypeInfo.SHORT_TYPE_INFO),
'_1.cast(BasicTypeInfo.INT_TYPE_INFO),
'_1.cast(BasicTypeInfo.LONG_TYPE_INFO),
'_3.cast(BasicTypeInfo.DOUBLE_TYPE_INFO),
'_3.cast(BasicTypeInfo.FLOAT_TYPE_INFO),
'_2.cast(BasicTypeInfo.BOOLEAN_TYPE_INFO))

ds.writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
expected = "1,1,1,1,2.0,2.0,true\n"
}

}

0 comments on commit 7bea901

Please sign in to comment.