Skip to content

Commit

Permalink
Merge pull request #14726 from ibzib/BEAM-11990
Browse files Browse the repository at this point in the history
[BEAM-11990] [zetasql] Enable UDF with DATE arguments/return type.
  • Loading branch information
ibzib committed May 6, 2021
2 parents 5076c71 + de741ee commit 7ea721d
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.beam.sdk.extensions.sql.provider;

import com.google.auto.service.AutoService;
import java.sql.Date;
import java.util.Map;
import org.apache.beam.sdk.extensions.sql.udf.AggregateFn;
import org.apache.beam.sdk.extensions.sql.udf.ScalarFn;
Expand All @@ -37,7 +38,9 @@ public Map<String, ScalarFn> userDefinedScalarFunctions() {
"increment",
new IncrementFn(),
"isNull",
new IsNullFn());
new IsNullFn(),
"dateIncrementAll",
new DateIncrementAllFn());
}

@Override
Expand Down Expand Up @@ -105,4 +108,11 @@ public Long extractOutput(Long mutableAccumulator) {
return mutableAccumulator;
}
}

public static class DateIncrementAllFn extends ScalarFn {
@ApplyMethod
public Date incrementAll(Date date) {
return new Date(date.getYear() + 1, date.getMonth() + 1, date.getDate() + 1);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -133,21 +133,21 @@ protected RelNode makeRel(
private boolean supportsType(RelDataType type) {
switch (type.getSqlTypeName()) {
case BIGINT:
case BINARY:
case BOOLEAN:
case CHAR:
case DATE:
case DECIMAL:
case DOUBLE:
case NULL:
case TIMESTAMP:
case VARBINARY:
case VARCHAR:
case CHAR:
case BINARY:
case NULL:
return true;
case ARRAY:
return supportsType(type.getComponentType());
case ROW:
return type.getFieldList().stream().allMatch((field) -> supportsType(field.getType()));
case DATE: // BEAM-11990
case TIME: // BEAM-12086
case TIMESTAMP_WITH_LOCAL_TIME_ZONE: // BEAM-12087
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,18 +192,21 @@ void validateJavaUdf(ResolvedNodes.ResolvedCreateFunctionStmt createFunctionStmt
/**
* Throws {@link UnsupportedOperationException} if ZetaSQL type is not supported in Java UDF.
* Supported types are a subset of the types supported by {@link BeamJavaUdfCalcRule}.
*
* <p>Supported types should be kept in sync with {@link
* #validateJavaUdfCalciteType(RelDataType)}.
*/
void validateJavaUdfZetaSqlType(Type type) {
switch (type.getKind()) {
case TYPE_INT64:
case TYPE_DOUBLE:
case TYPE_BOOL:
case TYPE_STRING:
case TYPE_BYTES:
case TYPE_DATE:
case TYPE_DOUBLE:
case TYPE_INT64:
case TYPE_STRING:
// These types are supported.
break;
case TYPE_NUMERIC:
case TYPE_DATE:
case TYPE_TIME:
case TYPE_DATETIME:
case TYPE_TIMESTAMP:
Expand Down Expand Up @@ -422,18 +425,20 @@ private void validateScalarFunctionImpl(ScalarFunctionImpl scalarFunction) {
* Throws {@link UnsupportedOperationException} if Calcite type is not supported in Java UDF.
* Supported types are a subset of the corresponding Calcite types supported by {@link
* BeamJavaUdfCalcRule}.
*
* <p>Supported types should be kept in sync with {@link #validateJavaUdfZetaSqlType(Type)}.
*/
private void validateJavaUdfCalciteType(RelDataType type) {
switch (type.getSqlTypeName()) {
case BIGINT:
case DATE:
case DOUBLE:
case BOOLEAN:
case VARCHAR:
case VARBINARY:
// These types are supported.
break;
case DECIMAL:
case DATE:
case TIME:
case TIMESTAMP_WITH_LOCAL_TIME_ZONE:
case TIMESTAMP:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import com.google.zetasql.SqlException;
import java.lang.reflect.Method;
import java.time.LocalDate;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.extensions.sql.BeamSqlUdf;
import org.apache.beam.sdk.extensions.sql.SqlTransform;
Expand All @@ -36,6 +37,7 @@
import org.apache.beam.sdk.extensions.sql.meta.provider.ReadOnlyTableProvider;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.transforms.Sum;
Expand Down Expand Up @@ -434,4 +436,22 @@ public void testRegisterUdaf() {
PAssert.that(stream).containsInAnyOrder(Row.withSchema(singleField).addValues(6L).build());
pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
}

@Test
public void testDateUdf() {
String sql =
String.format(
"CREATE FUNCTION dateIncrementAll(d DATE) RETURNS DATE LANGUAGE java "
+ "OPTIONS (path='%s'); "
+ "SELECT dateIncrementAll('2020-04-04');",
jarPath);
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(sql);
PCollection<Row> stream = BeamSqlRelUtils.toPCollection(pipeline, beamRelNode);
Schema singleField = Schema.builder().addLogicalTypeField("field1", SqlTypes.DATE).build();
PAssert.that(stream)
.containsInAnyOrder(
Row.withSchema(singleField).addValues(LocalDate.of(2021, 5, 5)).build());
pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.beam.sdk.extensions.sql.zetasql;

import java.sql.Date;
import java.time.LocalDate;
import org.apache.beam.sdk.extensions.sql.BeamSqlUdf;
import org.apache.beam.sdk.extensions.sql.impl.JdbcConnection;
import org.apache.beam.sdk.extensions.sql.impl.JdbcDriver;
Expand All @@ -27,6 +29,7 @@
import org.apache.beam.sdk.extensions.sql.meta.provider.test.TestBoundedTable;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.logicaltypes.SqlTypes;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.values.PCollection;
Expand Down Expand Up @@ -73,6 +76,7 @@ public class ZetaSqlJavaUdfTypeTest extends ZetaSqlTestBase {
.addDoubleField("float64_inf")
.addDoubleField("float64_neg_inf")
.addDoubleField("float64_nan")
.addLogicalTypeField("f_date", SqlTypes.DATE)
.build())
.addRows(
true /* boolean_true */,
Expand All @@ -96,7 +100,8 @@ public class ZetaSqlJavaUdfTypeTest extends ZetaSqlTestBase {
2.2250738585072014e-308 /* float64_min_pos */,
Double.POSITIVE_INFINITY /* float64_inf */,
Double.NEGATIVE_INFINITY /* float64_neg_inf */,
Double.NaN /* float64_nan */);
Double.NaN /* float64_nan */,
LocalDate.of(2021, 4, 26) /* f_date */);

@Before
public void setUp() throws NoSuchMethodException {
Expand Down Expand Up @@ -125,6 +130,8 @@ public void setUp() throws NoSuchMethodException {
schema.add(
"test_float64",
ScalarFunctionImpl.create(DoubleIdentityFn.class.getMethod("eval", Double.class)));
schema.add(
"test_date", ScalarFunctionImpl.create(DateIdentityFn.class.getMethod("eval", Date.class)));

this.config = Frameworks.newConfigBuilder(config).defaultSchema(schema).build();
}
Expand Down Expand Up @@ -159,12 +166,26 @@ public Double eval(Double input) {
}
}

public static class DateIdentityFn implements BeamSqlUdf {
public Date eval(Date input) {
return input;
}
}

private void runUdfTypeTest(String query, Object result, Schema.TypeName typeName) {
runUdfTypeTest(query, result, Schema.FieldType.of(typeName));
}

private void runUdfTypeTest(String query, Object result, Schema.LogicalType<?, ?> logicalType) {
runUdfTypeTest(query, result, Schema.FieldType.logicalType(logicalType));
}

private void runUdfTypeTest(String query, Object result, Schema.FieldType fieldType) {
ZetaSQLQueryPlanner zetaSQLQueryPlanner = new ZetaSQLQueryPlanner(config);
BeamRelNode beamRelNode = zetaSQLQueryPlanner.convertToBeamRel(query);
PCollection<Row> stream = BeamSqlRelUtils.toPCollection(pipeline, beamRelNode);

Schema outputSchema = Schema.builder().addField("res", Schema.FieldType.of(typeName)).build();
Schema outputSchema = Schema.builder().addField("res", fieldType).build();
PAssert.that(stream).containsInAnyOrder(Row.withSchema(outputSchema).addValues(result).build());
pipeline.run().waitUntilFinish(Duration.standardMinutes(PIPELINE_EXECUTION_WAITTIME_MINUTES));
}
Expand Down Expand Up @@ -428,4 +449,16 @@ public void testNaNFloat64Input() {
runUdfTypeTest(
"SELECT test_float64(float64_nan) FROM table;", Double.NaN, Schema.TypeName.DOUBLE);
}

@Test
public void testDateLiteral() {
runUdfTypeTest(
"SELECT test_date('2021-04-26') FROM table;", LocalDate.of(2021, 4, 26), SqlTypes.DATE);
}

@Test
public void testDateInput() {
runUdfTypeTest(
"SELECT test_date(f_date) FROM table;", LocalDate.of(2021, 4, 26), SqlTypes.DATE);
}
}

0 comments on commit 7ea721d

Please sign in to comment.