Skip to content

Commit

Permalink
Merge 2e0abdd into 25fea4e
Browse files Browse the repository at this point in the history
  • Loading branch information
xumingmin committed Jul 12, 2017
2 parents 25fea4e + 2e0abdd commit dfb705e
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 36 deletions.
114 changes: 81 additions & 33 deletions dsls/sql/src/main/java/org/apache/beam/dsls/sql/BeamSql.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
*/
package org.apache.beam.dsls.sql;

import com.google.auto.value.AutoValue;
import org.apache.beam.dsls.sql.rel.BeamRelNode;
import org.apache.beam.dsls.sql.schema.BeamPCollectionTable;
import org.apache.beam.dsls.sql.schema.BeamSqlRow;
import org.apache.beam.dsls.sql.schema.BeamSqlRowCoder;
import org.apache.beam.dsls.sql.schema.BeamSqlUdaf;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.values.PCollection;
Expand Down Expand Up @@ -51,7 +53,9 @@
//run a simple query, and register the output as a table in BeamSql;
String sql1 = "select MY_FUNC(c1), c2 from PCOLLECTION";
PCollection<BeamSqlRow> outputTableA = inputTableA.apply(BeamSql.simpleQuery(sql1));
PCollection<BeamSqlRow> outputTableA = inputTableA.apply(
BeamSql.simpleQuery(sql1)
.withUdf("MY_FUNC", MY_FUNC.class, "FUNC"));
//run a JOIN with one table from TextIO, and one table from another query
PCollection<BeamSqlRow> outputTableB = PCollectionTuple.of(
Expand All @@ -60,15 +64,14 @@
.apply(BeamSql.query("select * from TABLE_O_A JOIN TABLE_B where ..."));
//output the final result with TextIO
outputTableB.apply(BeamSql.toTextRow()).apply(TextIO.write().to("/my/output/path"));
outputTableB.apply(...).apply(TextIO.write().to("/my/output/path"));
p.run().waitUntilFinish();
* }
* </pre>
*/
@Experimental
public class BeamSql {

/**
* Transforms a SQL query into a {@link PTransform} representing an equivalent execution plan.
*
Expand All @@ -80,9 +83,11 @@ public class BeamSql {
* <p>It is an error to apply a {@link PCollectionTuple} missing any {@code table names}
* referenced within the query.
*/
public static PTransform<PCollectionTuple, PCollection<BeamSqlRow>> query(String sqlQuery) {
return new QueryTransform(sqlQuery);

public static QueryTransform query(String sqlQuery) {
return QueryTransform.builder()
.setSqlEnv(new BeamSqlEnv())
.setSqlQuery(sqlQuery)
.build();
}

/**
Expand All @@ -93,42 +98,62 @@ public static PTransform<PCollectionTuple, PCollection<BeamSqlRow>> query(String
*
* <p>Make sure to query it from a static table name <em>PCOLLECTION</em>.
*/
public static PTransform<PCollection<BeamSqlRow>, PCollection<BeamSqlRow>>
simpleQuery(String sqlQuery) throws Exception {
return new SimpleQueryTransform(sqlQuery);
public static SimpleQueryTransform simpleQuery(String sqlQuery) throws Exception {
return SimpleQueryTransform.builder()
.setSqlEnv(new BeamSqlEnv())
.setSqlQuery(sqlQuery)
.build();
}

/**
* A {@link PTransform} representing an execution plan for a SQL query.
*/
private static class QueryTransform extends
@AutoValue
public abstract static class QueryTransform extends
PTransform<PCollectionTuple, PCollection<BeamSqlRow>> {
private transient BeamSqlEnv sqlEnv;
private String sqlQuery;
abstract BeamSqlEnv getSqlEnv();
abstract String getSqlQuery();

public QueryTransform(String sqlQuery) {
this.sqlQuery = sqlQuery;
sqlEnv = new BeamSqlEnv();
static Builder builder() {
return new AutoValue_BeamSql_QueryTransform.Builder();
}

public QueryTransform(String sqlQuery, BeamSqlEnv sqlEnv) {
this.sqlQuery = sqlQuery;
this.sqlEnv = sqlEnv;
@AutoValue.Builder
abstract static class Builder {
abstract Builder setSqlQuery(String sqlQuery);
abstract Builder setSqlEnv(BeamSqlEnv sqlEnv);
abstract QueryTransform build();
}

/**
* register a UDF function used in this query.
*/
public QueryTransform withUdf(String functionName, Class<?> clazz, String methodName){
getSqlEnv().registerUdf(functionName, clazz, methodName);
return this;
}

/**
* register a UDAF function used in this query.
*/
public QueryTransform withUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz){
getSqlEnv().registerUdaf(functionName, clazz);
return this;
}

@Override
public PCollection<BeamSqlRow> expand(PCollectionTuple input) {
registerTables(input);

BeamRelNode beamRelNode = null;
try {
beamRelNode = sqlEnv.planner.convertToBeamRel(sqlQuery);
beamRelNode = getSqlEnv().planner.convertToBeamRel(getSqlQuery());
} catch (ValidationException | RelConversionException | SqlParseException e) {
throw new IllegalStateException(e);
}

try {
return beamRelNode.buildBeamPipeline(input, sqlEnv);
return beamRelNode.buildBeamPipeline(input, getSqlEnv());
} catch (Exception e) {
throw new IllegalStateException(e);
}
Expand All @@ -140,7 +165,7 @@ private void registerTables(PCollectionTuple input){
PCollection<BeamSqlRow> sourceStream = (PCollection<BeamSqlRow>) input.get(sourceTag);
BeamSqlRowCoder sourceCoder = (BeamSqlRowCoder) sourceStream.getCoder();

sqlEnv.registerTable(sourceTag.getId(),
getSqlEnv().registerTable(sourceTag.getId(),
new BeamPCollectionTable(sourceStream, sourceCoder.getTableSchema()));
}
}
Expand All @@ -150,26 +175,45 @@ private void registerTables(PCollectionTuple input){
* A {@link PTransform} representing an execution plan for a SQL query referencing
* a single table.
*/
private static class SimpleQueryTransform
@AutoValue
public abstract static class SimpleQueryTransform
extends PTransform<PCollection<BeamSqlRow>, PCollection<BeamSqlRow>> {
private static final String PCOLLECTION_TABLE_NAME = "PCOLLECTION";
private transient BeamSqlEnv sqlEnv = new BeamSqlEnv();
private String sqlQuery;
abstract BeamSqlEnv getSqlEnv();
abstract String getSqlQuery();

public SimpleQueryTransform(String sqlQuery) {
this.sqlQuery = sqlQuery;
validateQuery();
static Builder builder() {
return new AutoValue_BeamSql_SimpleQueryTransform.Builder();
}

// public SimpleQueryTransform withUdf(String udfName){
// throw new UnsupportedOperationException("Pending for UDF support");
// }
@AutoValue.Builder
abstract static class Builder {
abstract Builder setSqlQuery(String sqlQuery);
abstract Builder setSqlEnv(BeamSqlEnv sqlEnv);
abstract SimpleQueryTransform build();
}

/**
* register a UDF function used in this query.
*/
public SimpleQueryTransform withUdf(String functionName, Class<?> clazz, String methodName){
getSqlEnv().registerUdf(functionName, clazz, methodName);
return this;
}

/**
* register a UDAF function used in this query.
*/
public SimpleQueryTransform withUdaf(String functionName, Class<? extends BeamSqlUdaf> clazz){
getSqlEnv().registerUdaf(functionName, clazz);
return this;
}

private void validateQuery() {
SqlNode sqlNode;
try {
sqlNode = sqlEnv.planner.parseQuery(sqlQuery);
sqlEnv.planner.getPlanner().close();
sqlNode = getSqlEnv().planner.parseQuery(getSqlQuery());
getSqlEnv().planner.getPlanner().close();
} catch (SqlParseException e) {
throw new IllegalStateException(e);
}
Expand All @@ -188,8 +232,12 @@ private void validateQuery() {

@Override
public PCollection<BeamSqlRow> expand(PCollection<BeamSqlRow> input) {
validateQuery();
return PCollectionTuple.of(new TupleTag<BeamSqlRow>(PCOLLECTION_TABLE_NAME), input)
.apply(new QueryTransform(sqlQuery, sqlEnv));
.apply(QueryTransform.builder()
.setSqlEnv(getSqlEnv())
.setSqlQuery(getSqlQuery())
.build());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
* <p>It contains a {@link SchemaPlus} which holds the metadata of tables/UDF functions, and
* a {@link BeamQueryPlanner} which parse/validate/optimize/translate input SQL queries.
*/
public class BeamSqlEnv {
SchemaPlus schema;
BeamQueryPlanner planner;
public class BeamSqlEnv implements Serializable{
transient SchemaPlus schema;
transient BeamQueryPlanner planner;

public BeamSqlEnv() {
schema = Frameworks.createRootSchema(true);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.dsls.sql;

import java.sql.Types;
import java.util.Arrays;
import java.util.Iterator;
import org.apache.beam.dsls.sql.schema.BeamSqlRecordType;
import org.apache.beam.dsls.sql.schema.BeamSqlRow;
import org.apache.beam.dsls.sql.schema.BeamSqlUdaf;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
import org.junit.Test;

/**
* Tests for UDF/UDAF.
*/
public class BeamSqlDslUdfUdafTest extends BeamSqlDslBase {
/**
* GROUP-BY with UDAF.
*/
@Test
public void testUdaf() throws Exception {
BeamSqlRecordType resultType = BeamSqlRecordType.create(Arrays.asList("f_int2", "squaresum"),
Arrays.asList(Types.INTEGER, Types.INTEGER));

BeamSqlRow record = new BeamSqlRow(resultType);
record.addField("f_int2", 0);
record.addField("squaresum", 30);

String sql1 = "SELECT f_int2, squaresum1(f_int) AS `squaresum`"
+ " FROM PCOLLECTION GROUP BY f_int2";
PCollection<BeamSqlRow> result1 =
boundedInput1.apply("testUdaf1",
BeamSql.simpleQuery(sql1).withUdaf("squaresum1", SquareSum.class));
PAssert.that(result1).containsInAnyOrder(record);

String sql2 = "SELECT f_int2, squaresum2(f_int) AS `squaresum`"
+ " FROM PCOLLECTION GROUP BY f_int2";
PCollection<BeamSqlRow> result2 =
PCollectionTuple.of(new TupleTag<BeamSqlRow>("PCOLLECTION"), boundedInput1)
.apply("testUdaf2",
BeamSql.query(sql2).withUdaf("squaresum2", SquareSum.class));
PAssert.that(result2).containsInAnyOrder(record);

pipeline.run().waitUntilFinish();
}

/**
* test UDF.
*/
@Test
public void testUdf() throws Exception{
BeamSqlRecordType resultType = BeamSqlRecordType.create(Arrays.asList("f_int", "cubicvalue"),
Arrays.asList(Types.INTEGER, Types.INTEGER));

BeamSqlRow record = new BeamSqlRow(resultType);
record.addField("f_int", 2);
record.addField("cubicvalue", 8);

String sql1 = "SELECT f_int, cubic1(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2";
PCollection<BeamSqlRow> result1 =
boundedInput1.apply("testUdf1",
BeamSql.simpleQuery(sql1).withUdf("cubic1", CubicInteger.class, "cubic"));
PAssert.that(result1).containsInAnyOrder(record);

String sql2 = "SELECT f_int, cubic2(f_int) as cubicvalue FROM PCOLLECTION WHERE f_int = 2";
PCollection<BeamSqlRow> result2 =
PCollectionTuple.of(new TupleTag<BeamSqlRow>("PCOLLECTION"), boundedInput1)
.apply("testUdf2",
BeamSql.query(sql2).withUdf("cubic2", CubicInteger.class, "cubic"));
PAssert.that(result2).containsInAnyOrder(record);

pipeline.run().waitUntilFinish();
}

/**
* UDAF for test, which returns the sum of square.
*/
public static class SquareSum extends BeamSqlUdaf<Integer, Integer, Integer> {

public SquareSum() {
}

// @Override
public Integer init() {
return 0;
}

// @Override
public Integer add(Integer accumulator, Integer input) {
return accumulator + input * input;
}

// @Override
public Integer merge(Iterable<Integer> accumulators) {
int v = 0;
Iterator<Integer> ite = accumulators.iterator();
while (ite.hasNext()) {
v += ite.next();
}
return v;
}

// @Override
public Integer result(Integer accumulator) {
return accumulator;
}

}

/**
* A example UDF for test.
*/
public static class CubicInteger{
public static Integer cubic(Integer input){
return input * input * input;
}
}
}

0 comments on commit dfb705e

Please sign in to comment.