Skip to content

Commit

Permalink
Allow replacing existing function if UDF already exists
Browse files Browse the repository at this point in the history
  • Loading branch information
digitalpoetry committed May 23, 2024
1 parent 484f493 commit fdf0059
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 40 deletions.
30 changes: 13 additions & 17 deletions fe/fe-core/src/main/java/com/starrocks/catalog/Database.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
package com.starrocks.catalog;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.gson.annotations.SerializedName;
Expand Down Expand Up @@ -654,14 +655,14 @@ public String getCatalogName() {
return catalogName;
}

public synchronized void addFunction(Function function) throws UserException {
addFunctionImpl(function, false);
public synchronized void addFunction(Function function, boolean allowExists) throws UserException {
addFunctionImpl(function, false, allowExists);
GlobalStateMgr.getCurrentState().getEditLog().logAddFunction(function);
}

public synchronized void replayAddFunction(Function function) {
try {
addFunctionImpl(function, true);
addFunctionImpl(function, true, false);
} catch (UserException e) {
Preconditions.checkArgument(false);
}
Expand All @@ -676,16 +677,13 @@ public static void replayCreateFunctionLog(Function function) {
db.replayAddFunction(function);
}

// return true if add success, false
private void addFunctionImpl(Function function, boolean isReplay) throws UserException {
private void addFunctionImpl(Function function, boolean isReplay, boolean allowExists) throws UserException {
String functionName = function.getFunctionName().getFunction();
List<Function> existFuncs = name2Function.get(functionName);
List<Function> existFuncs = name2Function.getOrDefault(functionName, ImmutableList.of());
if (!isReplay) {
if (existFuncs != null) {
for (Function existFunc : existFuncs) {
if (function.compare(existFunc, Function.CompareMode.IS_IDENTICAL)) {
throw new UserException("function already exists");
}
for (Function existFunc : existFuncs) {
if (!allowExists && function.compare(existFunc, Function.CompareMode.IS_IDENTICAL)) {
throw new UserException("function already exists");
}
}
// Get function id for this UDF, use CatalogIdGenerator. Only get function id
Expand All @@ -695,12 +693,10 @@ private void addFunctionImpl(Function function, boolean isReplay) throws UserExc
function.setFunctionId(-functionId);
}

List<Function> functions = new ArrayList<>();
if (existFuncs != null) {
functions.addAll(existFuncs);
}
functions.add(function);
name2Function.put(functionName, functions);
existFuncs = existFuncs.stream()
.filter(f -> !function.compare(f, Function.CompareMode.IS_IDENTICAL))
.collect(ImmutableList.toImmutableList());
name2Function.put(functionName, ImmutableList.<Function>builder().addAll(existFuncs).add(function).build());
}

public synchronized void dropFunction(FunctionSearchDesc function) throws UserException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,13 @@ public synchronized Function getFunction(FunctionSearchDesc function) {
return func;
}

private void addFunction(Function function, boolean isReplay) throws UserException {
private void addFunction(Function function, boolean isReplay, boolean allowExists) throws UserException {
String functionName = function.getFunctionName().getFunction();
List<Function> existFuncs = name2Function.get(functionName);
List<Function> existFuncs = name2Function.getOrDefault(functionName, ImmutableList.of());
if (!isReplay) {
if (existFuncs != null) {
for (Function existFunc : existFuncs) {
if (function.compare(existFunc, Function.CompareMode.IS_IDENTICAL)) {
throw new UserException("function already exists");
}
for (Function existFunc : existFuncs) {
if (!allowExists && function.compare(existFunc, Function.CompareMode.IS_IDENTICAL)) {
throw new UserException("function already exists");
}
}
// Get function id for this UDF, use CatalogIdGenerator. Only get function id
Expand All @@ -94,23 +92,20 @@ private void addFunction(Function function, boolean isReplay) throws UserExcepti
function.setFunctionId(-functionId);
}

com.google.common.collect.ImmutableList.Builder<Function> builder =
com.google.common.collect.ImmutableList.builder();
if (existFuncs != null) {
builder.addAll(existFuncs);
}
builder.add(function);
name2Function.put(functionName, builder.build());
existFuncs = existFuncs.stream()
.filter(f -> !function.compare(f, Function.CompareMode.IS_IDENTICAL))
.collect(ImmutableList.toImmutableList());
name2Function.put(functionName, ImmutableList.<Function>builder().addAll(existFuncs).add(function).build());
}

public synchronized void userAddFunction(Function f) throws UserException {
addFunction(f, false);
public synchronized void userAddFunction(Function f, boolean allowExists) throws UserException {
addFunction(f, false, allowExists);
GlobalStateMgr.getCurrentState().getEditLog().logAddFunction(f);
}

public synchronized void replayAddFunction(Function f) {
try {
addFunction(f, true);
addFunction(f, true, false);
} catch (UserException e) {
Preconditions.checkArgument(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,13 +238,15 @@ public ShowResultSet visitCreateFunctionStatement(CreateFunctionStmt stmt, Conne
ErrorReport.wrapWithRuntimeException(() -> {
FunctionName name = stmt.getFunctionName();
if (name.isGlobalFunction()) {
context.getGlobalStateMgr().getGlobalFunctionMgr().userAddFunction(stmt.getFunction());
context.getGlobalStateMgr()
.getGlobalFunctionMgr()
.userAddFunction(stmt.getFunction(), stmt.shouldReplaceIfExists());
} else {
Database db = context.getGlobalStateMgr().getDb(name.getDb());
if (db == null) {
ErrorReport.reportDdlException(ErrorCode.ERR_BAD_DB_ERROR, name.getDb());
}
db.addFunction(stmt.getFunction());
db.addFunction(stmt.getFunction(), stmt.shouldReplaceIfExists());
}
});
return null;
Expand Down
41 changes: 41 additions & 0 deletions fe/fe-core/src/test/java/com/starrocks/catalog/DatabaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
package com.starrocks.catalog;

import com.google.common.collect.Lists;
import com.starrocks.analysis.FunctionName;
import com.starrocks.catalog.MaterializedIndex.IndexState;
import com.starrocks.common.UserException;
import com.starrocks.common.jmockit.Deencapsulation;
import com.starrocks.common.util.concurrent.lock.LockManager;
import com.starrocks.persist.CreateTableInfo;
Expand Down Expand Up @@ -98,6 +100,10 @@ public void setup() {
globalStateMgr.getLockManager();
minTimes = 0;
result = new LockManager();

globalStateMgr.getNextId();
minTimes = 0;
result = 1L;
}
};
}
Expand Down Expand Up @@ -223,4 +229,39 @@ public void testGetUUID() {
db3.setCatalogName("hive");
Assert.assertEquals("hive.db3", db3.getUUID());
}

@Test
public void testCreateDatabaseUdfGivenUdfAlreadyExists() throws UserException {
Type[] argTypes = new Type[2];
argTypes[0] = Type.INT;
argTypes[1] = Type.INT;
FunctionName name = new FunctionName(null, "addIntInt");
name.setDb(db.getCatalogName());
Function f = new Function(name, argTypes, Type.INT, false);

// Add the UDF for the first time
db.addFunction(f, false);

// Attempt to add the same UDF again, expecting an exception
Assert.assertThrows(UserException.class, () -> db.addFunction(f, false));
}

@Test
public void testCreateGlobalUdfGivenUdfAlreadyExistsAllowExisting() throws UserException {
Type[] argTypes = new Type[2];
argTypes[0] = Type.INT;
argTypes[1] = Type.INT;
FunctionName name = new FunctionName(null, "addIntInt");
name.setDb(db.getCatalogName());
Function f = new Function(name, argTypes, Type.INT, false);

// Add the UDF for the first time
db.addFunction(f, true);
// Attempt to add the same UDF again
db.addFunction(f, true);

List<Function> functions = db.getFunctions();
Assert.assertEquals(functions.size(), 1);
Assert.assertTrue(functions.get(0).compare(f, Function.CompareMode.IS_IDENTICAL));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,18 @@

import com.starrocks.analysis.FunctionName;
import com.starrocks.common.UserException;
import com.starrocks.persist.EditLog;
import com.starrocks.server.GlobalStateMgr;
import mockit.Mock;
import mockit.MockUp;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

import java.util.List;

import static org.mockito.Mockito.mock;

public class GlobalFunctionMgrTest {
private GlobalFunctionMgr globalFunctionMgr;

Expand Down Expand Up @@ -54,4 +60,51 @@ public void testAddAndDropFunction() throws UserException {
Assert.assertEquals(functions.size(), 0);
}
}
}

@Test
public void testCreateGlobalUdfGivenUdfAlreadyExists() throws UserException {
Type[] argTypes = new Type[2];
argTypes[0] = Type.INT;
argTypes[1] = Type.INT;
FunctionName name = new FunctionName(null, "addIntInt");
name.setAsGlobalFunction();
Function f = new Function(name, argTypes, Type.INT, false);
new MockUp<GlobalStateMgr>() {
@Mock
public EditLog getEditLog() {
return mock();
}
};

// Add the UDF for the first time
globalFunctionMgr.userAddFunction(f, false);

// Attempt to add the same UDF again, expecting an exception
Assert.assertThrows(UserException.class, () -> globalFunctionMgr.userAddFunction(f, false));
}

@Test
public void testCreateGlobalUdfGivenUdfAlreadyExistsAllowExisting() throws UserException {
Type[] argTypes = new Type[2];
argTypes[0] = Type.INT;
argTypes[1] = Type.INT;
FunctionName name = new FunctionName(null, "addIntInt");
name.setAsGlobalFunction();
Function f = new Function(name, argTypes, Type.INT, false);
new MockUp<GlobalStateMgr>() {
@Mock
public EditLog getEditLog() {
return mock();
}
};

// Add the UDF for the first time
globalFunctionMgr.userAddFunction(f, true);
// Attempt to add the same UDF again
globalFunctionMgr.userAddFunction(f, true);

List<Function> functions = globalFunctionMgr.getFunctions();
Assert.assertEquals(functions.size(), 1);
Assert.assertTrue(functions.get(0).compare(f, Function.CompareMode.IS_IDENTICAL));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1674,7 +1674,7 @@ public void testDatabaseStmt() throws Exception {
FunctionName fn = FunctionName.createFnName("db1.my_udf_json_get");
Function function = new Function(fn, Arrays.asList(Type.STRING, Type.STRING), Type.STRING, false);
try {
db1.addFunction(function);
db1.addFunction(function, false);
} catch (Throwable e) {
// ignore
}
Expand Down Expand Up @@ -2787,7 +2787,7 @@ public void testFunc() throws Exception {
FunctionName fn = FunctionName.createFnName("db1.my_udf_json_get");
Function function = new Function(fn, Arrays.asList(Type.STRING, Type.STRING), Type.STRING, false);
try {
db1.addFunction(function);
db1.addFunction(function, false);
} catch (Throwable e) {
// ignore
}
Expand Down Expand Up @@ -2900,7 +2900,7 @@ public void testShowFunc() throws Exception {
FunctionName fn = FunctionName.createFnName("db1.my_udf_json_get");
Function function = new Function(fn, Arrays.asList(Type.STRING, Type.STRING), Type.STRING, false);
try {
db1.addFunction(function);
db1.addFunction(function, false);
} catch (Throwable e) {
// ignore
}
Expand Down

0 comments on commit fdf0059

Please sign in to comment.