Skip to content

Commit cc11668

Browse files
authored
[FLINK-37789][planner] Support to validate ML_PREDICT expression
1 parent 7db22fb commit cc11668

File tree

11 files changed

+494
-2
lines changed

11 files changed

+494
-2
lines changed

flink-table/flink-sql-parser/src/main/codegen/data/Parser.tdd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,7 @@
574574
"JSON_EXECUTION_PLAN"
575575
"PLAN_ADVICE"
576576
"METADATA"
577+
"MODEL"
577578
"OVERWRITE"
578579
"OVERWRITING"
579580
"PARTITIONED"

flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlExplicitModelCall.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,8 @@ public SqlExplicitModelCall(
3232
@Nullable SqlLiteral functionQualifier) {
3333
super(operator, operandList, pos, functionQualifier);
3434
}
35+
36+
public SqlIdentifier getModelIdentifier() {
37+
return operand(0);
38+
}
3539
}

flink-table/flink-sql-parser/src/main/java/org/apache/calcite/sql/SqlExplicitModelOperator.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,15 @@
1717

1818
package org.apache.calcite.sql;
1919

20+
import org.apache.calcite.rel.type.RelDataType;
2021
import org.apache.calcite.sql.parser.SqlParserPos;
22+
import org.apache.calcite.sql.validate.SqlValidator;
23+
import org.apache.calcite.sql.validate.SqlValidatorScope;
2124
import org.apache.calcite.util.ImmutableNullableList;
2225
import org.checkerframework.checker.nullness.qual.Nullable;
2326

27+
import static org.apache.calcite.util.Static.RESOURCE;
28+
2429
/** SqlExplicitModelOperator is a SQL operator that represents an explicit model. */
2530
public class SqlExplicitModelOperator extends SqlPrefixOperator {
2631

@@ -37,4 +42,20 @@ public SqlCall createCall(
3742
return new SqlExplicitModelCall(
3843
this, ImmutableNullableList.copyOf(operands), pos, functionQualifier);
3944
}
45+
46+
@Override
47+
public void validateCall(
48+
SqlCall call,
49+
SqlValidator validator,
50+
SqlValidatorScope scope,
51+
SqlValidatorScope operandScope) {
52+
throw SqlUtil.newContextException(
53+
call.pos, RESOURCE.objectNotFound(call.operand(0).toString()));
54+
}
55+
56+
@Override
57+
public RelDataType deriveType(SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
58+
throw SqlUtil.newContextException(
59+
call.pos, RESOURCE.objectNotFound(call.operand(0).toString()));
60+
}
4061
}

flink-table/flink-sql-parser/src/test/java/org/apache/flink/sql/parser/FlinkSqlParserImplTest.java

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3376,6 +3376,14 @@ void testModelInFunctionWithoutTable() {
33763376
+ "FROM TABLE(`FUNC`((TABLE `MY_TABLE`), MODEL `CAT`.`DB`.`MY_MODEL`))");
33773377
}
33783378

3379+
@Test
3380+
void testModelInFunctionNamedArgs() {
3381+
sql("select * from table(ml_predict(INPUT => TABLE my_table, model => MODEL my_model))")
3382+
.ok(
3383+
"SELECT *\n"
3384+
+ "FROM TABLE(`ML_PREDICT`(`INPUT` => (TABLE `MY_TABLE`), `MODEL` => (MODEL `MY_MODEL`)))");
3385+
}
3386+
33793387
/*
33803388
* This test was backported from Calcite 1.38 (CALCITE-6266).
33813389
* Remove it together with upgrade to Calcite 1.38.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
package org.apache.calcite.sql;
20+
21+
import org.apache.flink.table.planner.catalog.CatalogSchemaModel;
22+
23+
import org.apache.calcite.rel.type.RelDataType;
24+
import org.apache.calcite.sql.validate.SqlValidator;
25+
import org.apache.calcite.sql.validate.SqlValidatorScope;
26+
27+
import static java.util.Objects.requireNonNull;
28+
29+
/** SqlModelCall to fetch and reference model based on identifier. */
30+
public class SqlModelCall extends SqlBasicCall {
31+
32+
private final CatalogSchemaModel model;
33+
34+
public SqlModelCall(SqlExplicitModelCall modelCall, CatalogSchemaModel model) {
35+
super(
36+
new SqlModelOperator(model),
37+
modelCall.getOperandList(),
38+
modelCall.getParserPosition(),
39+
modelCall.getFunctionQuantifier());
40+
this.model = requireNonNull(model);
41+
}
42+
43+
@Override
44+
public void validate(SqlValidator validator, SqlValidatorScope scope) {
45+
// Do nothing here, override to avoid identifier validation which will be treated as column
46+
}
47+
48+
public RelDataType getInputType(SqlValidator validator) {
49+
return model.getInputRowType(validator.getTypeFactory());
50+
}
51+
52+
public RelDataType getOutputType(SqlValidator validator) {
53+
return model.getOutputRowType(validator.getTypeFactory());
54+
}
55+
56+
/**
57+
* A custom SqlOperator to handle model identifier.
58+
*
59+
* <p>It is used to derive the type of the model based on the identifier.
60+
*/
61+
private static class SqlModelOperator extends SqlPrefixOperator {
62+
63+
CatalogSchemaModel model;
64+
65+
private SqlModelOperator(CatalogSchemaModel model) {
66+
super("MODEL", SqlKind.OTHER_FUNCTION, 2, null, null, null);
67+
this.model = model;
68+
}
69+
70+
@Override
71+
public RelDataType deriveType(
72+
SqlValidator validator, SqlValidatorScope scope, SqlCall call) {
73+
return model.getOutputRowType(validator.getTypeFactory());
74+
}
75+
}
76+
}

flink-table/flink-table-planner/src/main/java/org/apache/calcite/sql/validate/ProcedureNamespace.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.apache.flink.annotation.Internal;
2020
import org.apache.flink.table.planner.calcite.FlinkSqlCallBinding;
21+
import org.apache.flink.table.planner.functions.sql.ml.SqlMLTableFunction;
2122

2223
import org.apache.calcite.rel.type.RelDataType;
2324
import org.apache.calcite.sql.SqlCall;
@@ -61,7 +62,7 @@ public RelDataType validateImpl(RelDataType targetRowType) {
6162
final SqlOperator operator = call.getOperator();
6263
final SqlCallBinding callBinding = new FlinkSqlCallBinding(validator, scope, call);
6364
final SqlCall permutedCall = callBinding.permutedCall();
64-
if (operator instanceof SqlWindowTableFunction) {
65+
if (operator instanceof SqlWindowTableFunction || operator instanceof SqlMLTableFunction) {
6566
permutedCall.validate(validator, scope);
6667
}
6768

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/calcite/FlinkCalciteSqlValidator.java

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
import org.apache.flink.table.catalog.Column;
2727
import org.apache.flink.table.catalog.ResolvedSchema;
2828
import org.apache.flink.table.data.TimestampData;
29+
import org.apache.flink.table.planner.catalog.CatalogSchemaModel;
2930
import org.apache.flink.table.planner.catalog.CatalogSchemaTable;
31+
import org.apache.flink.table.planner.functions.sql.ml.SqlMLTableFunction;
32+
import org.apache.flink.table.planner.plan.FlinkCalciteCatalogReader;
3033
import org.apache.flink.table.planner.plan.utils.FlinkRexUtil;
3134
import org.apache.flink.table.planner.utils.ShortcutUtils;
3235
import org.apache.flink.table.types.logical.DecimalType;
@@ -43,12 +46,14 @@
4346
import org.apache.calcite.sql.SqlAsOperator;
4447
import org.apache.calcite.sql.SqlBasicCall;
4548
import org.apache.calcite.sql.SqlCall;
49+
import org.apache.calcite.sql.SqlExplicitModelCall;
4650
import org.apache.calcite.sql.SqlFunction;
4751
import org.apache.calcite.sql.SqlFunctionCategory;
4852
import org.apache.calcite.sql.SqlIdentifier;
4953
import org.apache.calcite.sql.SqlJoin;
5054
import org.apache.calcite.sql.SqlKind;
5155
import org.apache.calcite.sql.SqlLiteral;
56+
import org.apache.calcite.sql.SqlModelCall;
5257
import org.apache.calcite.sql.SqlNode;
5358
import org.apache.calcite.sql.SqlNodeList;
5459
import org.apache.calcite.sql.SqlOperator;
@@ -371,7 +376,21 @@ protected void addToSelectList(
371376
final SqlBasicCall call = (SqlBasicCall) node;
372377
final SqlOperator operator = call.getOperator();
373378

374-
if (operator instanceof SqlWindowTableFunction) {
379+
if (node instanceof SqlExplicitModelCall) {
380+
// Convert it so that model can be accessed in planner. SqlExplicitModelCall
381+
// from parser can't access model.
382+
SqlExplicitModelCall modelCall = (SqlExplicitModelCall) node;
383+
SqlIdentifier modelIdentifier = modelCall.getModelIdentifier();
384+
FlinkCalciteCatalogReader catalogReader =
385+
(FlinkCalciteCatalogReader) getCatalogReader();
386+
CatalogSchemaModel model = catalogReader.getModel(modelIdentifier.names);
387+
if (model != null) {
388+
return new SqlModelCall(modelCall, model);
389+
}
390+
}
391+
392+
// TODO (FLINK-37819): add test for SqlMLTableFunction
393+
if (operator instanceof SqlWindowTableFunction || operator instanceof SqlMLTableFunction) {
375394
if (tableArgs.stream().allMatch(Objects::isNull)) {
376395
return rewritten;
377396
}

flink-table/flink-table-planner/src/main/java/org/apache/flink/table/planner/functions/sql/FlinkSqlOperatorTable.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import org.apache.flink.table.api.TableException;
2222
import org.apache.flink.table.planner.calcite.FlinkTypeFactory;
2323
import org.apache.flink.table.planner.functions.sql.internal.SqlAuxiliaryGroupAggFunction;
24+
import org.apache.flink.table.planner.functions.sql.ml.SqlMLPredictTableFunction;
2425
import org.apache.flink.table.planner.plan.type.FlinkReturnTypes;
2526
import org.apache.flink.table.planner.plan.type.NumericExceptFirstOperandChecker;
2627

@@ -1341,6 +1342,9 @@ public List<SqlGroupedWindowFunction> getAuxiliaryFunctions() {
13411342
public static final SqlFunction CUMULATE = new SqlCumulateTableFunction();
13421343
public static final SqlFunction SESSION = new SqlSessionTableFunction();
13431344

1345+
// MODEL TABLE FUNCTIONS
1346+
public static final SqlFunction ML_PREDICT = new SqlMLPredictTableFunction();
1347+
13441348
// Catalog Functions
13451349
public static final SqlFunction CURRENT_DATABASE =
13461350
BuiltInSqlFunction.newBuilder()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to you under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.flink.table.planner.functions.sql.ml;
19+
20+
import org.apache.calcite.rel.type.RelDataType;
21+
import org.apache.calcite.rel.type.RelDataTypeFactory;
22+
import org.apache.calcite.sql.SqlCallBinding;
23+
import org.apache.calcite.sql.SqlOperandCountRange;
24+
import org.apache.calcite.sql.SqlOperator;
25+
import org.apache.calcite.sql.SqlOperatorBinding;
26+
import org.apache.calcite.sql.type.SqlOperandCountRanges;
27+
import org.apache.calcite.sql.type.SqlOperandMetadata;
28+
import org.apache.calcite.sql.type.SqlTypeName;
29+
30+
import java.util.Collections;
31+
import java.util.List;
32+
33+
/**
34+
* SqlMlPredictTableFunction implements an operator for prediction.
35+
*
36+
* <p>It allows four parameters:
37+
*
38+
* <ol>
39+
* <li>a table
40+
* <li>a model name
41+
* <li>a descriptor to provide a column name from the input table
42+
* <li>an optional config map
43+
* </ol>
44+
*/
45+
public class SqlMLPredictTableFunction extends SqlMLTableFunction {
46+
47+
public SqlMLPredictTableFunction() {
48+
super("ML_PREDICT", new PredictOperandMetadata());
49+
}
50+
51+
/**
52+
* {@inheritDoc}
53+
*
54+
* <p>Overrides because the first parameter of table-value function windowing is an explicit
55+
* TABLE parameter, which is not scalar.
56+
*/
57+
@Override
58+
public boolean argumentMustBeScalar(int ordinal) {
59+
return ordinal != 0;
60+
}
61+
62+
@Override
63+
protected RelDataType inferRowType(SqlOperatorBinding opBinding) {
64+
// TODO: FLINK-37780 output type based on table schema and model output schema
65+
// model output schema to be available after integrated with SqlExplicitModelCall
66+
return opBinding.getOperandType(1);
67+
}
68+
69+
private static class PredictOperandMetadata implements SqlOperandMetadata {
70+
private static final List<String> PARAM_NAMES =
71+
List.of(PARAM_INPUT, PARAM_MODEL, PARAM_COLUMN, PARAM_CONFIG);
72+
private static final List<String> MANDATORY_PARAM_NAMES =
73+
List.of(PARAM_INPUT, PARAM_MODEL, PARAM_COLUMN);
74+
75+
PredictOperandMetadata() {}
76+
77+
@Override
78+
public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
79+
return Collections.nCopies(
80+
PARAM_NAMES.size(), typeFactory.createSqlType(SqlTypeName.ANY));
81+
}
82+
83+
@Override
84+
public List<String> paramNames() {
85+
return PARAM_NAMES;
86+
}
87+
88+
@Override
89+
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
90+
// TODO: FLINK-37780 Check operand types after integrated with SqlExplicitModelCall in
91+
// validator
92+
return false;
93+
}
94+
95+
@Override
96+
public SqlOperandCountRange getOperandCountRange() {
97+
return SqlOperandCountRanges.between(MANDATORY_PARAM_NAMES.size(), PARAM_NAMES.size());
98+
}
99+
100+
@Override
101+
public Consistency getConsistency() {
102+
return Consistency.NONE;
103+
}
104+
105+
@Override
106+
public boolean isOptional(int i) {
107+
return i > getOperandCountRange().getMin() && i <= getOperandCountRange().getMax();
108+
}
109+
110+
@Override
111+
public String getAllowedSignatures(SqlOperator op, String opName) {
112+
return opName
113+
+ "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]]";
114+
}
115+
}
116+
}

0 commit comments

Comments
 (0)