diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Database.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Database.java index e54c21bb153605..963d72973ef3c9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Database.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Database.java @@ -769,16 +769,74 @@ public void setName(String name) { } public synchronized void addFunction(Function function, boolean ifNotExists) throws UserException { - function.checkWritable(); - if (FunctionUtil.addFunctionImpl(function, ifNotExists, false, name2Function)) { - Env.getCurrentEnv().getEditLog().logAddFunction(function); - try { + addFunctions(ImmutableList.of(function), ifNotExists, false); + } + + public synchronized void addTableFunction(Function function, boolean ifNotExists) throws UserException { + // Doris table functions are registered as two functions: the normal function and its outer variant. + Function outerFunction = function.clone(); + FunctionName name = outerFunction.getFunctionName(); + name.setFn(name.getFunction() + "_outer"); + if (hasSameTableFunctionPair(function, outerFunction, ifNotExists)) { + return; + } + addFunctions(ImmutableList.of(function, outerFunction), false, true); + } + + private boolean hasSameTableFunctionPair(Function function, Function outerFunction, boolean ifNotExists) + throws UserException { + Function existingFunction = getExistingFunction(function); + Function existingOuterFunction = getExistingFunction(outerFunction); + if (existingFunction == null && existingOuterFunction == null) { + return false; + } + if (ifNotExists && existingFunction != null && existingOuterFunction != null + && existingFunction.isUDTFunction() && existingOuterFunction.isUDTFunction()) { + return true; + } + throw new UserException("function already exists"); + } + + private Function getExistingFunction(Function function) { + try { + return getFunction(getFunctionSearchDesc(function)); + } catch (AnalysisException e) { + return null; + } + } + + private void addFunctions(List functions, boolean ifNotExists, boolean logAsBatch) throws UserException { + List addedFunctions = Lists.newArrayList(); + try { + for (Function function : functions) { + function.checkWritable(); + if (FunctionUtil.addFunctionImpl(function, ifNotExists, false, name2Function)) { + addedFunctions.add(function); + } + } + for (Function function : addedFunctions) { FunctionUtil.translateToNereidsThrows(this.getFullName(), function); - } catch (Exception e) { - name2Function.remove(function.getFunctionName().getFunction()); - throw e; } + } catch (Exception e) { + for (Function function : addedFunctions) { + FunctionUtil.dropFromNereids(this.getFullName(), getFunctionSearchDesc(function)); + } + for (int i = addedFunctions.size() - 1; i >= 0; i--) { + FunctionUtil.removeFunctionImpl(addedFunctions.get(i), name2Function); + } + throw e; } + if (logAsBatch) { + Env.getCurrentEnv().getEditLog().logAddFunctions(addedFunctions); + } else { + for (Function function : addedFunctions) { + Env.getCurrentEnv().getEditLog().logAddFunction(function); + } + } + } + + private FunctionSearchDesc getFunctionSearchDesc(Function function) { + return new FunctionSearchDesc(function.getFunctionName(), function.getArgs(), function.hasVarArgs()); } public synchronized void replayAddFunction(Function function) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java index b6de5fa9e06f42..90d1a82e498618 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java @@ -214,6 +214,7 @@ import org.apache.doris.persist.BinlogGcInfo; import org.apache.doris.persist.CleanQueryStatsInfo; import org.apache.doris.persist.CreateDbInfo; +import org.apache.doris.persist.CreateFunctionInfo; import org.apache.doris.persist.DropDbInfo; import org.apache.doris.persist.DropPartitionInfo; import org.apache.doris.persist.EditLog; @@ -6774,6 +6775,12 @@ public void replayCreateFunction(Function function) throws MetaNotFoundException db.replayAddFunction(function); } + public void replayCreateFunctions(CreateFunctionInfo info) throws MetaNotFoundException { + for (Function function : info.getFunctions()) { + replayCreateFunction(function); + } + } + public void replayCreateGlobalFunction(Function function) { globalFunctionMgr.replayAddFunction(function); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java index 1a8402520ef01b..f736fc30f3aa0f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java @@ -306,14 +306,18 @@ public void addUdf(String dbName, String name, UdfBuilder builder) { } } - public void dropUdf(String dbName, String name, List argTypes) { + public void dropUdf(String dbName, String name, List argTypes, boolean hasVarArgs) { if (dbName == null) { dbName = GLOBAL_FUNCTION; } synchronized (name2UdfBuilders) { Map> builders = name2UdfBuilders.getOrDefault(dbName, ImmutableMap.of()); builders.getOrDefault(name, Lists.newArrayList()) - .removeIf(builder -> ((UdfBuilder) builder).getArgTypes().equals(argTypes)); + .removeIf(builder -> { + UdfBuilder udfBuilder = (UdfBuilder) builder; + return udfBuilder.getArgTypes().equals(argTypes) + && udfBuilder.hasVarArguments() == hasVarArgs; + }); // the name will be used when show functions, so remove the name when it's dropped if (builders.getOrDefault(name, Lists.newArrayList()).isEmpty()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionUtil.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionUtil.java index c9ef0e81f14857..ff2b9612616345 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionUtil.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionUtil.java @@ -21,6 +21,7 @@ import org.apache.doris.common.AnalysisException; import org.apache.doris.common.Config; import org.apache.doris.common.UserException; +import org.apache.doris.common.util.DebugPointUtil; import org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdf; import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf; import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf; @@ -136,6 +137,13 @@ public static boolean addFunctionImpl(Function function, boolean ifNotExists, bo return true; } + public static boolean removeFunctionImpl(Function function, + ConcurrentMap> name2Function) throws UserException { + FunctionSearchDesc functionSearchDesc = new FunctionSearchDesc(function.getFunctionName(), function.getArgs(), + function.hasVarArgs()); + return dropFunctionImpl(functionSearchDesc, false, name2Function); + } + public static Function getFunction(FunctionSearchDesc function, ConcurrentMap> name2Function) throws AnalysisException { String functionName = function.getName().getFunction(); @@ -179,6 +187,9 @@ public static boolean translateToNereids(String dbName, Function function) { } public static boolean translateToNereidsThrows(String dbName, Function function) { + if (DebugPointUtil.isEnable("FunctionUtil.translateToNereidsThrows.exception")) { + throw new RuntimeException("debug point FunctionUtil.translateToNereidsThrows.exception"); + } try { translateToNereidsImpl(dbName, function); } catch (Exception e) { @@ -220,7 +231,7 @@ public static boolean dropFromNereids(String dbName, FunctionSearchDesc function String fnName = function.getName().getFunction(); List argTypes = Arrays.stream(function.getArgTypes()).map(DataType::fromCatalogType) .collect(Collectors.toList()); - Env.getCurrentEnv().getFunctionRegistry().dropUdf(dbName, fnName, argTypes); + Env.getCurrentEnv().getFunctionRegistry().dropUdf(dbName, fnName, argTypes, function.isVariadic()); } catch (Exception e) { LOG.warn("Nereids drop function {}:{} failed, caused by: {}", dbName == null ? "_global_" : dbName, function.getName(), e); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/GlobalFunctionMgr.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/GlobalFunctionMgr.java index c2f4fb3c1c033f..24711da0f1b4a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/GlobalFunctionMgr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/GlobalFunctionMgr.java @@ -74,14 +74,14 @@ public synchronized void addFunction(Function function, boolean ifNotExists) thr function.setGlobal(true); function.checkWritable(); if (FunctionUtil.addFunctionImpl(function, ifNotExists, false, name2Function)) { - Env.getCurrentEnv().getEditLog().logAddGlobalFunction(function); try { FunctionUtil.translateToNereidsThrows(null, function); } catch (Exception e) { LOG.warn("Nereids add function failed", e); - name2Function.remove(function.getFunctionName().getFunction()); + FunctionUtil.removeFunctionImpl(function, name2Function); throw e; } + Env.getCurrentEnv().getEditLog().logAddGlobalFunction(function); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/journal/JournalEntity.java b/fe/fe-core/src/main/java/org/apache/doris/journal/JournalEntity.java index bc0be815609942..2c92775bc3f1cf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/journal/JournalEntity.java +++ b/fe/fe-core/src/main/java/org/apache/doris/journal/JournalEntity.java @@ -87,6 +87,7 @@ import org.apache.doris.persist.ConsistencyCheckInfo; import org.apache.doris.persist.CreateDbInfo; import org.apache.doris.persist.CreateDictionaryPersistInfo; +import org.apache.doris.persist.CreateFunctionInfo; import org.apache.doris.persist.CreateTableInfo; import org.apache.doris.persist.DatabaseInfo; import org.apache.doris.persist.DictionaryDecreaseVersionInfo; @@ -499,6 +500,11 @@ public void readFields(DataInput in) throws IOException { isRead = true; break; } + case OperationType.OP_ADD_FUNCTIONS: { + data = CreateFunctionInfo.read(in); + isRead = true; + break; + } case OperationType.OP_DROP_FUNCTION: { data = FunctionSearchDesc.read(in); isRead = true; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdf.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdf.java index 60833ea70ccc6b..ef14127432d199 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdf.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdf.java @@ -46,6 +46,7 @@ public class AliasUdf extends ScalarFunction implements ExplicitlyCastableSignat private final Expression unboundFunction; private final List parameters; private final List argTypes; + private final boolean hasVarArguments; private final Map sessionVariables; /** @@ -55,6 +56,20 @@ public AliasUdf(String name, List argTypes, Expression unboundFunction List parameters, Map sessionVariables, Expression... arguments) { super(name, arguments); this.argTypes = argTypes; + this.hasVarArguments = false; + this.unboundFunction = unboundFunction; + this.parameters = parameters; + this.sessionVariables = sessionVariables; + } + + /** + * constructor with session variables. + */ + public AliasUdf(String name, List argTypes, boolean hasVarArguments, Expression unboundFunction, + List parameters, Map sessionVariables, Expression... arguments) { + super(name, arguments); + this.argTypes = argTypes; + this.hasVarArguments = hasVarArguments; this.unboundFunction = unboundFunction; this.parameters = parameters; this.sessionVariables = sessionVariables; @@ -62,7 +77,7 @@ public AliasUdf(String name, List argTypes, Expression unboundFunction @Override public List getSignatures() { - return ImmutableList.of(FunctionSignature.of(NullType.INSTANCE, argTypes)); + return ImmutableList.of(FunctionSignature.of(NullType.INSTANCE, hasVarArguments, argTypes)); } public List getParameters() { @@ -101,6 +116,7 @@ public static void translateToNereidsFunction(String dbName, AliasFunction funct AliasUdf aliasUdf = new AliasUdf( function.functionName(), Arrays.stream(function.getArgs()).map(DataType::fromCatalogType).collect(Collectors.toList()), + function.hasVarArgs(), parsedFunction, function.getParameters(), sessionVariables); @@ -116,7 +132,7 @@ public int arity() { @Override public Expression withChildren(List children) { - return new AliasUdf(getName(), argTypes, unboundFunction, parameters, sessionVariables, + return new AliasUdf(getName(), argTypes, hasVarArguments, unboundFunction, parameters, sessionVariables, children.toArray(new Expression[0])); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/UdfBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/UdfBuilder.java index 2c57cfad1beba9..6337908bc7b79b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/UdfBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/UdfBuilder.java @@ -30,4 +30,8 @@ public abstract class UdfBuilder extends FunctionBuilder { public abstract List getArgTypes(); public abstract List getSignatures(); + + public boolean hasVarArguments() { + return getSignatures().get(0).hasVarArgs; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java index 0cbfb30d59fe47..d5f63c565ebaf1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/CreateFunctionCommand.java @@ -233,15 +233,10 @@ public void run(ConnectContext ctx, StmtExecutor executor) throws Exception { functionName.setDb(dbName); } Database db = Env.getCurrentInternalCatalog().getDbOrDdlException(dbName); - db.addFunction(function, ifNotExists); if (function.isUDTFunction()) { - // all of the table function in doris will have two function - // one is the noraml, and another is outer, the different of them is deal with - // empty: whether need to insert NULL result value - Function outerFunction = function.clone(); - FunctionName name = outerFunction.getFunctionName(); - name.setFn(name.getFunction() + "_outer"); - db.addFunction(outerFunction, ifNotExists); + db.addTableFunction(function, ifNotExists); + } else { + db.addFunction(function, ifNotExists); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/persist/CreateFunctionInfo.java b/fe/fe-core/src/main/java/org/apache/doris/persist/CreateFunctionInfo.java new file mode 100644 index 00000000000000..d303486840371c --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/persist/CreateFunctionInfo.java @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.persist; + +import org.apache.doris.catalog.Function; +import org.apache.doris.common.io.Writable; + +import com.google.common.collect.ImmutableList; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.List; + +public class CreateFunctionInfo implements Writable { + private final List functions; + + public CreateFunctionInfo(List functions) { + this.functions = ImmutableList.copyOf(functions); + } + + public List getFunctions() { + return functions; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(functions.size()); + for (Function function : functions) { + function.write(out); + } + } + + public static CreateFunctionInfo read(DataInput in) throws IOException { + ImmutableList.Builder builder = ImmutableList.builder(); + int functionSize = in.readInt(); + for (int i = 0; i < functionSize; i++) { + builder.add(Function.read(in)); + } + return new CreateFunctionInfo(builder.build()); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/persist/EditLog.java b/fe/fe-core/src/main/java/org/apache/doris/persist/EditLog.java index 378b9de46b9f15..baf2615bdbde38 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/persist/EditLog.java +++ b/fe/fe-core/src/main/java/org/apache/doris/persist/EditLog.java @@ -831,6 +831,11 @@ public static void loadJournal(Env env, Long logId, JournalEntity journal) { Env.getCurrentEnv().replayCreateFunction(function); break; } + case OperationType.OP_ADD_FUNCTIONS: { + final CreateFunctionInfo info = (CreateFunctionInfo) journal.getData(); + Env.getCurrentEnv().replayCreateFunctions(info); + break; + } case OperationType.OP_DROP_FUNCTION: { FunctionSearchDesc function = (FunctionSearchDesc) journal.getData(); Env.getCurrentEnv().replayDropFunction(function); @@ -2135,6 +2140,10 @@ public void logAddFunction(Function function) { logEdit(OperationType.OP_ADD_FUNCTION, function); } + public void logAddFunctions(List functions) { + logEdit(OperationType.OP_ADD_FUNCTIONS, new CreateFunctionInfo(functions)); + } + public void logAddGlobalFunction(Function function) { logEdit(OperationType.OP_ADD_GLOBAL_FUNCTION, function); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/persist/OperationType.java b/fe/fe-core/src/main/java/org/apache/doris/persist/OperationType.java index d4e57a780edc9c..8161a7a20e997c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/persist/OperationType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/persist/OperationType.java @@ -215,6 +215,7 @@ public class OperationType { public static final short OP_DROP_FUNCTION = 131; public static final short OP_ADD_GLOBAL_FUNCTION = 132; public static final short OP_DROP_GLOBAL_FUNCTION = 133; + public static final short OP_ADD_FUNCTIONS = 134; // modify database/table/tablet/replica meta public static final short OP_SET_REPLICA_VERSION = 141; diff --git a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java index 72d29a16f2bd6b..293559a95cf379 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/catalog/CreateFunctionTest.java @@ -20,14 +20,23 @@ import org.apache.doris.analysis.Expr; import org.apache.doris.analysis.FunctionCallExpr; import org.apache.doris.analysis.StringLiteral; +import org.apache.doris.common.AnalysisException; import org.apache.doris.common.FeConstants; +import org.apache.doris.common.UserException; import org.apache.doris.common.jmockit.Deencapsulation; +import org.apache.doris.common.util.URI; +import org.apache.doris.journal.JournalEntity; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder; import org.apache.doris.nereids.trees.plans.commands.CreateDatabaseCommand; import org.apache.doris.nereids.trees.plans.commands.CreateFunctionCommand; import org.apache.doris.nereids.trees.plans.commands.CreateTableCommand; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.persist.CreateFunctionInfo; +import org.apache.doris.persist.EditLog; +import org.apache.doris.persist.OperationType; import org.apache.doris.planner.PlanFragment; import org.apache.doris.planner.Planner; import org.apache.doris.planner.UnionNode; @@ -37,13 +46,21 @@ import org.apache.doris.utframe.DorisAssert; import org.apache.doris.utframe.UtFrameUtils; +import com.google.common.collect.ImmutableList; import org.junit.AfterClass; import org.junit.Assert; import org.junit.BeforeClass; import org.junit.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; import java.io.File; import java.util.List; +import java.util.Map; import java.util.UUID; /* @@ -170,6 +187,298 @@ public void testCreatePythonFunctionRejectsObjectTypes() throws Exception { "ARRAY unsupported sub-type: bitmap"); } + @Test + public void testCreateFunctionRollbackOnlyFailedOverload() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + createDatabase(ctx, "create database rollback_function_db;"); + Database db = Env.getCurrentInternalCatalog().getDbNullable("rollback_function_db"); + Assert.assertNotNull(db); + + EditLog editLog = Env.getCurrentEnv().getEditLog(); + EditLog spyEditLog = Mockito.spy(editLog); + Mockito.doNothing().when(spyEditLog).logAddFunction(Mockito.any(Function.class)); + Env.getCurrentEnv().setEditLog(spyEditLog); + try (MockedStatic mockedFunctionUtil = Mockito.mockStatic(FunctionUtil.class, + Mockito.CALLS_REAL_METHODS)) { + Function existingFunction = createJavaUdf("rollback_function_db", "rollback_fn", Type.INT); + db.addFunction(existingFunction, false); + Assert.assertSame(existingFunction, db.getFunction(searchDesc(existingFunction))); + + Mockito.clearInvocations(spyEditLog); + Function failedFunction = createJavaUdf("rollback_function_db", "rollback_fn", Type.BIGINT); + mockedFunctionUtil.when(() -> FunctionUtil.translateToNereidsThrows( + "rollback_function_db", failedFunction)).thenThrow(new RuntimeException("translate failed")); + + RuntimeException exception = Assert.assertThrows(RuntimeException.class, + () -> db.addFunction(failedFunction, false)); + Assert.assertEquals("translate failed", exception.getMessage()); + Mockito.verify(spyEditLog, Mockito.never()).logAddFunction(Mockito.any(Function.class)); + Assert.assertSame(existingFunction, db.getFunction(searchDesc(existingFunction))); + Assert.assertThrows(AnalysisException.class, () -> db.getFunction(searchDesc(failedFunction))); + } finally { + Env.getCurrentEnv().setEditLog(editLog); + } + } + + @Test + public void testCreateTableFunctionRollbackWhenOuterFunctionFails() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + createDatabase(ctx, "create database rollback_table_function_db;"); + Database db = Env.getCurrentInternalCatalog().getDbNullable("rollback_table_function_db"); + Assert.assertNotNull(db); + + EditLog editLog = Env.getCurrentEnv().getEditLog(); + EditLog spyEditLog = Mockito.spy(editLog); + Mockito.doNothing().when(spyEditLog).logAddFunction(Mockito.any(Function.class)); + Mockito.doNothing().when(spyEditLog).logAddFunctions(Mockito.anyList()); + Env.getCurrentEnv().setEditLog(spyEditLog); + try (MockedStatic mockedFunctionUtil = Mockito.mockStatic(FunctionUtil.class, + Mockito.CALLS_REAL_METHODS)) { + Function tableFunction = createJavaUdtf("rollback_table_function_db", "rollback_table_fn", Type.INT); + mockedFunctionUtil.when(() -> FunctionUtil.translateToNereidsThrows( + Mockito.eq("rollback_table_function_db"), Mockito.any(Function.class))) + .thenAnswer(invocation -> { + Function function = invocation.getArgument(1); + if ("rollback_table_fn_outer".equals(function.functionName())) { + throw new RuntimeException("outer translate failed"); + } + return true; + }); + + RuntimeException exception = Assert.assertThrows(RuntimeException.class, + () -> db.addTableFunction(tableFunction, false)); + Assert.assertEquals("outer translate failed", exception.getMessage()); + Mockito.verify(spyEditLog, Mockito.never()).logAddFunction(Mockito.any(Function.class)); + Mockito.verify(spyEditLog, Mockito.never()).logAddFunctions(Mockito.anyList()); + Assert.assertThrows(AnalysisException.class, () -> db.getFunction(searchDesc(tableFunction))); + Assert.assertThrows(AnalysisException.class, + () -> db.getFunction(searchDesc("rollback_table_function_db", "rollback_table_fn_outer", + Type.INT))); + mockedFunctionUtil.verify(() -> FunctionUtil.dropFromNereids(Mockito.eq("rollback_table_function_db"), + Mockito.argThat(function -> "rollback_table_fn".equals(function.getName().getFunction())))); + mockedFunctionUtil.verify(() -> FunctionUtil.dropFromNereids(Mockito.eq("rollback_table_function_db"), + Mockito.argThat(function -> "rollback_table_fn_outer".equals(function.getName().getFunction())))); + } finally { + Env.getCurrentEnv().setEditLog(editLog); + } + } + + @Test + public void testCreateTableFunctionRollbackKeepsVariadicNereidsOverload() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + createDatabase(ctx, "create database rollback_table_function_vararg_db;"); + ctx.setDatabase("rollback_table_function_vararg_db"); + Database db = Env.getCurrentInternalCatalog().getDbNullable("rollback_table_function_vararg_db"); + Assert.assertNotNull(db); + + EditLog editLog = Env.getCurrentEnv().getEditLog(); + EditLog spyEditLog = Mockito.spy(editLog); + Mockito.doNothing().when(spyEditLog).logAddFunction(Mockito.any(Function.class)); + Mockito.doNothing().when(spyEditLog).logAddFunctions(Mockito.anyList()); + Env.getCurrentEnv().setEditLog(spyEditLog); + try (MockedStatic mockedFunctionUtil = Mockito.mockStatic(FunctionUtil.class, + Mockito.CALLS_REAL_METHODS)) { + mockedFunctionUtil.when(() -> FunctionUtil.translateToNereidsThrows( + Mockito.eq("rollback_table_function_vararg_db"), + Mockito.argThat(function -> "rollback_vararg_table_fn_outer".equals(function.functionName())))) + .thenThrow(new RuntimeException("outer translate failed")); + + Function variadicFunction = createJavaUdtf( + "rollback_table_function_vararg_db", "rollback_vararg_table_fn", Type.INT); + variadicFunction.setHasVarArgs(true); + db.addFunction(variadicFunction, false); + Assert.assertSame(variadicFunction, db.getFunction(searchDesc(variadicFunction))); + assertSingleVariadicUdfBuilder("rollback_table_function_vararg_db", "rollback_vararg_table_fn"); + + Mockito.clearInvocations(spyEditLog); + Function tableFunction = createJavaUdtf( + "rollback_table_function_vararg_db", "rollback_vararg_table_fn", Type.INT); + RuntimeException exception = Assert.assertThrows(RuntimeException.class, + () -> db.addTableFunction(tableFunction, false)); + Assert.assertEquals("outer translate failed", exception.getMessage()); + Mockito.verify(spyEditLog, Mockito.never()).logAddFunction(Mockito.any(Function.class)); + Mockito.verify(spyEditLog, Mockito.never()).logAddFunctions(Mockito.anyList()); + Assert.assertThrows(AnalysisException.class, () -> db.getFunction(searchDesc(tableFunction))); + Assert.assertSame(variadicFunction, db.getFunction(searchDesc(variadicFunction))); + assertSingleVariadicUdfBuilder("rollback_table_function_vararg_db", "rollback_vararg_table_fn"); + } finally { + Env.getCurrentEnv().setEditLog(editLog); + } + } + + @Test + public void testCreateTableFunctionRollbackWhenOuterFunctionConflicts() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + createDatabase(ctx, "create database rollback_table_function_conflict_db;"); + Database db = Env.getCurrentInternalCatalog().getDbNullable("rollback_table_function_conflict_db"); + Assert.assertNotNull(db); + + EditLog editLog = Env.getCurrentEnv().getEditLog(); + EditLog spyEditLog = Mockito.spy(editLog); + Mockito.doNothing().when(spyEditLog).logAddFunction(Mockito.any(Function.class)); + Mockito.doNothing().when(spyEditLog).logAddFunctions(Mockito.anyList()); + Env.getCurrentEnv().setEditLog(spyEditLog); + try (MockedStatic mockedFunctionUtil = Mockito.mockStatic(FunctionUtil.class, + Mockito.CALLS_REAL_METHODS)) { + mockedFunctionUtil.when(() -> FunctionUtil.translateToNereidsThrows( + Mockito.eq("rollback_table_function_conflict_db"), Mockito.any(Function.class))) + .thenReturn(true); + Function existingOuterFunction = createJavaUdtf( + "rollback_table_function_conflict_db", "rollback_table_conflict_fn_outer", Type.INT); + db.addFunction(existingOuterFunction, false); + Assert.assertSame(existingOuterFunction, db.getFunction(searchDesc(existingOuterFunction))); + + Mockito.clearInvocations(spyEditLog); + Function tableFunction = createJavaUdtf( + "rollback_table_function_conflict_db", "rollback_table_conflict_fn", Type.INT); + UserException exception = Assert.assertThrows(UserException.class, + () -> db.addTableFunction(tableFunction, true)); + Assert.assertEquals("function already exists", exception.getDetailMessage()); + Mockito.verify(spyEditLog, Mockito.never()).logAddFunction(Mockito.any(Function.class)); + Mockito.verify(spyEditLog, Mockito.never()).logAddFunctions(Mockito.anyList()); + Assert.assertThrows(AnalysisException.class, () -> db.getFunction(searchDesc(tableFunction))); + Assert.assertSame(existingOuterFunction, db.getFunction(searchDesc(existingOuterFunction))); + } finally { + Env.getCurrentEnv().setEditLog(editLog); + } + } + + @Test + public void testCreateTableFunctionIfNotExistsSkipsExistingPair() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + createDatabase(ctx, "create database existing_table_function_pair_db;"); + Database db = Env.getCurrentInternalCatalog().getDbNullable("existing_table_function_pair_db"); + Assert.assertNotNull(db); + + EditLog editLog = Env.getCurrentEnv().getEditLog(); + EditLog spyEditLog = Mockito.spy(editLog); + Mockito.doNothing().when(spyEditLog).logAddFunction(Mockito.any(Function.class)); + Mockito.doNothing().when(spyEditLog).logAddFunctions(Mockito.anyList()); + Env.getCurrentEnv().setEditLog(spyEditLog); + try (MockedStatic mockedFunctionUtil = Mockito.mockStatic(FunctionUtil.class, + Mockito.CALLS_REAL_METHODS)) { + mockedFunctionUtil.when(() -> FunctionUtil.translateToNereidsThrows( + Mockito.eq("existing_table_function_pair_db"), Mockito.any(Function.class))) + .thenReturn(true); + Function existingFunction = createJavaUdtf( + "existing_table_function_pair_db", "existing_table_pair_fn", Type.INT); + Function existingOuterFunction = createJavaUdtf( + "existing_table_function_pair_db", "existing_table_pair_fn_outer", Type.INT); + db.addFunction(existingFunction, false); + db.addFunction(existingOuterFunction, false); + Assert.assertSame(existingFunction, db.getFunction(searchDesc(existingFunction))); + Assert.assertSame(existingOuterFunction, db.getFunction(searchDesc(existingOuterFunction))); + + Mockito.clearInvocations(spyEditLog); + Function tableFunction = createJavaUdtf( + "existing_table_function_pair_db", "existing_table_pair_fn", Type.INT); + db.addTableFunction(tableFunction, true); + + Mockito.verify(spyEditLog, Mockito.never()).logAddFunction(Mockito.any(Function.class)); + Mockito.verify(spyEditLog, Mockito.never()).logAddFunctions(Mockito.anyList()); + Assert.assertSame(existingFunction, db.getFunction(searchDesc(existingFunction))); + Assert.assertSame(existingOuterFunction, db.getFunction(searchDesc(existingOuterFunction))); + } finally { + Env.getCurrentEnv().setEditLog(editLog); + } + } + + @Test + public void testCreateTableFunctionLogsPairAtomically() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + createDatabase(ctx, "create database atomic_table_function_db;"); + Database db = Env.getCurrentInternalCatalog().getDbNullable("atomic_table_function_db"); + Assert.assertNotNull(db); + + EditLog editLog = Env.getCurrentEnv().getEditLog(); + EditLog spyEditLog = Mockito.spy(editLog); + Mockito.doNothing().when(spyEditLog).logAddFunction(Mockito.any(Function.class)); + Mockito.doNothing().when(spyEditLog).logAddFunctions(Mockito.anyList()); + Env.getCurrentEnv().setEditLog(spyEditLog); + try (MockedStatic mockedFunctionUtil = Mockito.mockStatic(FunctionUtil.class, + Mockito.CALLS_REAL_METHODS)) { + mockedFunctionUtil.when(() -> FunctionUtil.translateToNereidsThrows( + Mockito.eq("atomic_table_function_db"), Mockito.any(Function.class))) + .thenReturn(true); + Function tableFunction = createJavaUdtf("atomic_table_function_db", "atomic_table_fn", Type.INT); + + db.addTableFunction(tableFunction, false); + + Mockito.verify(spyEditLog, Mockito.never()).logAddFunction(Mockito.any(Function.class)); + Mockito.verify(spyEditLog).logAddFunctions(Mockito.argThat(functions -> + functions.size() == 2 + && "atomic_table_fn".equals(functions.get(0).functionName()) + && "atomic_table_fn_outer".equals(functions.get(1).functionName()))); + Assert.assertSame(tableFunction, db.getFunction(searchDesc(tableFunction))); + Assert.assertNotNull(db.getFunction(searchDesc("atomic_table_function_db", "atomic_table_fn_outer", + Type.INT))); + } finally { + Env.getCurrentEnv().setEditLog(editLog); + } + } + + @Test + public void testCreateTableFunctionJournalReplayRestoresPair() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + createDatabase(ctx, "create database replay_table_function_db;"); + Database db = Env.getCurrentInternalCatalog().getDbNullable("replay_table_function_db"); + Assert.assertNotNull(db); + + Function tableFunction = createJavaUdtf("replay_table_function_db", "replay_table_fn", Type.INT); + Function outerFunction = createOuterTableFunction(tableFunction); + JournalEntity journalEntity = new JournalEntity(); + journalEntity.setOpCode(OperationType.OP_ADD_FUNCTIONS); + journalEntity.setData(new CreateFunctionInfo(ImmutableList.of(tableFunction, outerFunction))); + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + journalEntity.write(new DataOutputStream(outputStream)); + JournalEntity replayJournalEntity = new JournalEntity(); + replayJournalEntity.readFields(new DataInputStream(new ByteArrayInputStream(outputStream.toByteArray()))); + + Assert.assertEquals(OperationType.OP_ADD_FUNCTIONS, replayJournalEntity.getOpCode()); + Assert.assertTrue(replayJournalEntity.getData() instanceof CreateFunctionInfo); + Assert.assertEquals(2, ((CreateFunctionInfo) replayJournalEntity.getData()).getFunctions().size()); + EditLog.loadJournal(Env.getCurrentEnv(), 0L, replayJournalEntity); + Assert.assertNotNull(db.getFunction(searchDesc(tableFunction))); + Assert.assertNotNull(db.getFunction(searchDesc(outerFunction))); + } + + @Test + public void testCreateGlobalFunctionRollbackOnlyFailedOverload() throws Exception { + GlobalFunctionMgr globalFunctionMgr = Env.getCurrentEnv().getGlobalFunctionMgr(); + FunctionSearchDesc existingFunctionDesc = searchDesc(null, "rollback_global_fn", Type.INT); + FunctionSearchDesc failedFunctionDesc = searchDesc(null, "rollback_global_fn", Type.BIGINT); + globalFunctionMgr.dropFunction(existingFunctionDesc, true); + globalFunctionMgr.dropFunction(failedFunctionDesc, true); + + EditLog editLog = Env.getCurrentEnv().getEditLog(); + EditLog spyEditLog = Mockito.spy(editLog); + Mockito.doNothing().when(spyEditLog).logAddGlobalFunction(Mockito.any(Function.class)); + Env.getCurrentEnv().setEditLog(spyEditLog); + try (MockedStatic mockedFunctionUtil = Mockito.mockStatic(FunctionUtil.class, + Mockito.CALLS_REAL_METHODS)) { + Function existingFunction = createJavaUdf(null, "rollback_global_fn", Type.INT); + globalFunctionMgr.addFunction(existingFunction, false); + Assert.assertSame(existingFunction, globalFunctionMgr.getFunction(existingFunctionDesc)); + + Mockito.clearInvocations(spyEditLog); + Function failedFunction = createJavaUdf(null, "rollback_global_fn", Type.BIGINT); + mockedFunctionUtil.when(() -> FunctionUtil.translateToNereidsThrows( + null, failedFunction)).thenThrow(new RuntimeException("global translate failed")); + + RuntimeException exception = Assert.assertThrows(RuntimeException.class, + () -> globalFunctionMgr.addFunction(failedFunction, false)); + Assert.assertEquals("global translate failed", exception.getMessage()); + Mockito.verify(spyEditLog, Mockito.never()).logAddGlobalFunction(Mockito.any(Function.class)); + Assert.assertSame(existingFunction, globalFunctionMgr.getFunction(existingFunctionDesc)); + Assert.assertThrows(AnalysisException.class, () -> globalFunctionMgr.getFunction(failedFunctionDesc)); + } finally { + Env.getCurrentEnv().setEditLog(editLog); + globalFunctionMgr.dropFunction(existingFunctionDesc, true); + globalFunctionMgr.dropFunction(failedFunctionDesc, true); + } + } + @Test public void testCreateGlobalFunction() throws Exception { ConnectContext ctx = UtFrameUtils.createDefaultCtx(); @@ -273,4 +582,40 @@ private Function findFunction(Database db, String functionName) { } throw new AssertionError("function not found: " + functionName); } + + private Function createJavaUdf(String dbName, String functionName, Type... argTypes) throws AnalysisException { + return ScalarFunction.createUdf(Function.BinaryType.JAVA_UDF, new FunctionName(dbName, functionName), argTypes, + Type.INT, false, URI.create("file:///tmp/" + functionName + ".jar"), "evaluate", null, null); + } + + private Function createJavaUdtf(String dbName, String functionName, Type... argTypes) throws AnalysisException { + Function function = createJavaUdf(dbName, functionName, argTypes); + function.setUDTFunction(true); + return function; + } + + private void assertSingleVariadicUdfBuilder(String dbName, String functionName) { + Map> builders = Env.getCurrentEnv().getFunctionRegistry() + .getName2UdfBuilders().get(dbName); + Assert.assertNotNull(builders); + List functionBuilders = builders.get(functionName); + Assert.assertNotNull(functionBuilders); + Assert.assertEquals(1, functionBuilders.size()); + Assert.assertTrue(((UdfBuilder) functionBuilders.get(0)).hasVarArguments()); + } + + private Function createOuterTableFunction(Function function) { + Function outerFunction = function.clone(); + FunctionName name = outerFunction.getFunctionName(); + name.setFn(name.getFunction() + "_outer"); + return outerFunction; + } + + private FunctionSearchDesc searchDesc(Function function) { + return new FunctionSearchDesc(function.getFunctionName(), function.getArgs(), function.hasVarArgs()); + } + + private FunctionSearchDesc searchDesc(String dbName, String functionName, Type... argTypes) { + return new FunctionSearchDesc(new FunctionName(dbName, functionName), argTypes, false); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/UdfTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/UdfTest.java index 73a44299f7375e..d262beded0756b 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/UdfTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/UdfTest.java @@ -223,7 +223,7 @@ public void testAliasFunctionWithIllegalExpressionsRejected() throws Exception { public void testReadFromStream() throws Exception { createFunction("create global alias function f8(int) with parameter(n) as hours_add(now(3), n)"); Env.getCurrentEnv().getFunctionRegistry().dropUdf(null, "f8", - ImmutableList.of(IntegerType.INSTANCE)); + ImmutableList.of(IntegerType.INSTANCE), false); ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); Env.getCurrentEnv().getGlobalFunctionMgr().write(new DataOutputStream(outputStream)); diff --git a/regression-test/suites/fault_injection_p0/test_create_function_rollback.groovy b/regression-test/suites/fault_injection_p0/test_create_function_rollback.groovy new file mode 100644 index 00000000000000..f412e33af50e94 --- /dev/null +++ b/regression-test/suites/fault_injection_p0/test_create_function_rollback.groovy @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import org.apache.doris.regression.suite.ClusterOptions + +suite("test_create_function_rollback", "docker") { + def options = new ClusterOptions() + options.feConfigs += [ + "enable_debug_points=true" + ] + options.cloudMode = false + + docker(options) { + sql """DROP FUNCTION IF EXISTS doris_25021_fn(INT)""" + sql """DROP FUNCTION IF EXISTS doris_25021_fn(BIGINT)""" + sql """CREATE ALIAS FUNCTION doris_25021_fn(INT) WITH PARAMETER(x) AS add(x, 1)""" + + try { + GetDebugPoint().enableDebugPointForAllFEs( + "FunctionUtil.translateToNereidsThrows.exception", [execute: 1]) + test { + sql """CREATE ALIAS FUNCTION doris_25021_fn(BIGINT) WITH PARAMETER(x) AS add(x, 1)""" + exception "debug point FunctionUtil.translateToNereidsThrows.exception" + } + } finally { + GetDebugPoint().disableDebugPointForAllFEs("FunctionUtil.translateToNereidsThrows.exception") + } + + test { + sql """CREATE ALIAS FUNCTION doris_25021_fn(INT) WITH PARAMETER(x) AS add(x, 1)""" + exception "function already exists" + } + test { + sql """DROP FUNCTION doris_25021_fn(BIGINT)""" + exception "function does not exist" + } + } +}