Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarFunction;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.StructField;
import org.apache.doris.catalog.StructType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.AnalysisException;
Expand Down Expand Up @@ -346,6 +347,7 @@ private void analyzeCommon(ConnectContext ctx) throws AnalysisException {
}
if (binaryType == Function.BinaryType.JAVA_UDF) {
FunctionUtil.checkEnableJavaUdf();
checkUdfSupportedTypes();
if (!isAggregate && !isTableFunction) {
volatility = analyzeVolatility();
}
Expand All @@ -363,6 +365,7 @@ private void analyzeCommon(ConnectContext ctx) throws AnalysisException {
extractExpirationTime();
} else if (binaryType == Function.BinaryType.PYTHON_UDF) {
FunctionUtil.checkEnablePythonUdf();
checkUdfSupportedTypes();
if (!isAggregate && !isTableFunction) {
volatility = analyzeVolatility();
}
Expand Down Expand Up @@ -418,6 +421,36 @@ private static boolean validatePythonRuntimeVersion(String runtimeVersionString)
return runtimeVersionString != null && PYTHON_VERSION_PATTERN.matcher(runtimeVersionString).matches();
}

private void checkUdfSupportedTypes() throws AnalysisException {
Type[] argTypes = argsDef.getArgTypes();
for (int i = 0; i < argTypes.length; i++) {
checkUdfSupportedType(argTypes[i], "argument " + (i + 1));
}
checkUdfSupportedType(returnType.toCatalogDataType(), "return");
if (intermediateType != null) {
checkUdfSupportedType(intermediateType.toCatalogDataType(), "intermediate");
}
}

private void checkUdfSupportedType(Type type, String typePosition) throws AnalysisException {
// Reject bitmap/hll/quantile_state type
if (type.isObjectStored()) {
throw new AnalysisException(String.format(
"%s does not support %s type %s", binaryType, typePosition, type.toSql()));
}

if (type.isArrayType()) {
checkUdfSupportedType(((ArrayType) type).getItemType(), typePosition + " element");
} else if (type.isMapType()) {
checkUdfSupportedType(((MapType) type).getKeyType(), typePosition + " key");
checkUdfSupportedType(((MapType) type).getValueType(), typePosition + " value");
} else if (type.isStructType()) {
for (StructField field : ((StructType) type).getFields()) {
checkUdfSupportedType(field.getType(), typePosition + " field " + field.getName());
}
}
}

private Boolean parseBooleanFromProperties(String propertyString) throws AnalysisException {
String valueOfString = properties.get(propertyString);
if (valueOfString == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,35 @@ public void test() throws Exception {
Assert.assertEquals(FunctionVolatility.VOLATILE, findFunction(db, "py_default").getVolatility());
}

@Test
public void testCreatePythonFunctionRejectsObjectTypes() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
createDatabase(ctx, "create database py_obj_type_db;");
dorisAssert = new DorisAssert(ctx);
dorisAssert.useDatabase("py_obj_type_db");

assertCreateFunctionAnalysisException(ctx, "create function py_obj_type_db.py_bitmap_arg(bitmap) returns int "
+ "properties('type'='PYTHON_UDF', 'symbol'='evaluate', 'runtime_version'='3.10.2');",
"PYTHON_UDF does not support argument 1 type bitmap");
assertCreateFunctionAnalysisException(ctx, "create function py_obj_type_db.j_bitmap_arg(bitmap) returns int "
+ "properties('type'='JAVA_UDF', 'symbol'='evaluate');",
"JAVA_UDF does not support argument 1 type bitmap");
assertCreateFunctionAnalysisException(ctx, "create function py_obj_type_db.py_hll_ret(int) returns hll "
+ "properties('type'='PYTHON_UDF', 'symbol'='evaluate', 'runtime_version'='3.10.2');",
"PYTHON_UDF does not support return type hll");
assertCreateFunctionAnalysisException(ctx, "create aggregate function py_obj_type_db.py_quantile_arg"
+ "(quantile_state) returns int properties('type'='PYTHON_UDF', 'symbol'='Agg', "
+ "'runtime_version'='3.10.2');",
"PYTHON_UDF does not support argument 1 type quantile_state");
assertCreateFunctionAnalysisException(ctx, "create aggregate function py_obj_type_db.j_quantile_arg"
+ "(quantile_state) returns int properties('type'='JAVA_UDF', 'symbol'='Agg');",
"JAVA_UDF does not support argument 1 type quantile_state");
assertCreateFunctionAnalysisException(ctx, "create tables function py_obj_type_db.py_bitmap_table(int) "
+ "returns array<bitmap> properties('type'='PYTHON_UDF', 'symbol'='evaluate', "
+ "'runtime_version'='3.10.2');",
"ARRAY unsupported sub-type: bitmap");
}

@Test
public void testCreateGlobalFunction() throws Exception {
ConnectContext ctx = UtFrameUtils.createDefaultCtx();
Expand Down Expand Up @@ -215,6 +244,12 @@ private void createFunction(String sql, ConnectContext connectContext) throws Ex
}
}

private void assertCreateFunctionAnalysisException(ConnectContext ctx, String sql, String message) {
Exception exception = Assert.assertThrows(Exception.class, () -> createFunction(sql, ctx));
Assert.assertTrue("Expected error to contain: " + message + ", actual: " + exception.getMessage(),
exception.getMessage().contains(message));
}

private boolean containsIgnoreCase(String str, String sub) {
return str.toLowerCase().contains(sub.toLowerCase());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

suite("test_pythonudaf_object_types_inline") {
def runtime_version = getPythonUdfRuntimeVersion()

test {
sql """
CREATE AGGREGATE FUNCTION py_obj_udaf_bitmap_arg(bitmap)
RETURNS BIGINT
PROPERTIES (
"type" = "PYTHON_UDF",
"symbol" = "Agg",
"runtime_version" = "${runtime_version}"
)
AS \$\$
class Agg:
def __init__(self):
self.sum = 0
def accumulate(self, v):
pass
def merge(self, other):
pass
def finish(self):
return self.sum
@property
def aggregate_state(self):
return self.sum
\$\$;
"""
exception "does not support argument 1 type bitmap"
}

test {
sql """
CREATE AGGREGATE FUNCTION py_obj_udaf_hll_ret(int)
RETURNS HLL
PROPERTIES (
"type" = "PYTHON_UDF",
"symbol" = "Agg",
"runtime_version" = "${runtime_version}"
)
AS \$\$
class Agg:
def __init__(self):
self.state = None
def accumulate(self, v):
pass
def merge(self, other):
pass
def finish(self):
return self.state
@property
def aggregate_state(self):
return self.state
\$\$;
"""
exception "does not support return type hll"
}

test {
sql """
CREATE AGGREGATE FUNCTION py_obj_udaf_quantile_state(quantile_state)
RETURNS BIGINT
INTERMEDIATE BIGINT
PROPERTIES (
"type" = "PYTHON_UDF",
"symbol" = "Agg",
"runtime_version" = "${runtime_version}"
)
AS \$\$
class Agg:
def __init__(self):
self.state = 0
def accumulate(self, v):
pass
def merge(self, other):
pass
def finish(self):
return self.state
@property
def aggregate_state(self):
return self.state
\$\$;
"""
exception "does not support argument 1 type quantile_state"
}

test {
sql """
CREATE AGGREGATE FUNCTION py_obj_udaf_bitmap_intermediate(int)
RETURNS BIGINT
INTERMEDIATE BITMAP
PROPERTIES (
"type" = "PYTHON_UDF",
"symbol" = "Agg",
"runtime_version" = "${runtime_version}"
)
AS \$\$
class Agg:
def __init__(self):
self.state = 0
def accumulate(self, v):
pass
def merge(self, other):
pass
def finish(self):
return self.state
@property
def aggregate_state(self):
return self.state
\$\$;
"""
exception "does not support intermediate type bitmap"
}

test {
sql """
CREATE AGGREGATE FUNCTION py_obj_udaf_array_bitmap(int)
RETURNS ARRAY<BITMAP>
PROPERTIES (
"type" = "PYTHON_UDF",
"symbol" = "Agg",
"runtime_version" = "${runtime_version}"
)
AS \$\$
class Agg:
def __init__(self):
self.state = None
def accumulate(self, v):
pass
def merge(self, other):
pass
def finish(self):
return self.state
@property
def aggregate_state(self):
return self.state
\$\$;
"""
exception "ARRAY unsupported sub-type: bitmap"
}

test {
sql """
CREATE AGGREGATE FUNCTION py_obj_udaf_struct_bitmap(int)
RETURNS STRUCT<plain:INT, nested:MAP<INT, ARRAY<HLL>>>
PROPERTIES (
"type" = "PYTHON_UDF",
"symbol" = "Agg",
"runtime_version" = "${runtime_version}"
)
AS \$\$
class Agg:
def __init__(self):
self.state = None
def accumulate(self, v):
pass
def merge(self, other):
pass
def finish(self):
return self.state
@property
def aggregate_state(self):
return self.state
\$\$;
"""
exception "ARRAY unsupported sub-type: hll"
}
}
Loading
Loading