Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support function registry with arg types #9397

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,31 +18,77 @@
*/
package org.apache.pinot.common.function;

import com.google.common.base.Preconditions;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.commons.lang3.StringUtils;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.FunctionContext;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.spi.annotations.ScalarFunction;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.utils.PinotReflectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;


/**
* Registry for scalar functions.
*
* The registry registers functions using canonical functionName and argument type as DataType.
* It doesn't differentiate function name in different canonical forms or argument types whose DataType is the same such
* as primitive numerical type and its wrapper class.
*
* To be backward compatible, the registry falls back functionName + param number matching when there are no type match
* for parameters.
* <p>TODO: Merge FunctionRegistry and FunctionDefinitionRegistry to provide one single registry for all functions.
*/
public class FunctionRegistry {
private FunctionRegistry() {
}

private static final Logger LOGGER = LoggerFactory.getLogger(FunctionRegistry.class);

// Map from function name to parameter types to function info.
private static final Map<String, Map<List<FieldSpec.DataType>, FunctionInfo>> FUNC_PARAM_INFO_MAP = new HashMap<>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest making the value a separate class, which can be used to look up the FunctionInfo with a list of argument types. That way we don't need to maintain 2 maps, and can support type matching (match int argument to long parameter function) in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to use ColumnDataType here which contains the SV/MV info


// FUNCTION_INFO_MAP is still used to be backward compatible.
// When we support all sql data types, we can deprecate this.
@Deprecated
private static final Map<String, Map<Integer, FunctionInfo>> FUNCTION_INFO_MAP = new HashMap<>();

private static List<FieldSpec.DataType> getParamTypes(Class<?>[] types) {
List<FieldSpec.DataType> paramTypes = new ArrayList<>();
for (Class<?> t : types) {
paramTypes.add(FunctionUtils.getDataType(t));
}
return paramTypes;
}

/**
* Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null}
* if there is no matching method. This method should be called after the FunctionRegistry is initialized and all
* methods are already registered.
*
* Assuming functionName is canonicalized.
*/
@Nullable
private static FunctionInfo getFunctionInfo(String functionName, int numParameters) {
Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.get(functionName);
return functionInfoMap != null ? functionInfoMap.get(numParameters) : null;
}

private static String canonicalize(String functionName) {
return StringUtils.remove(functionName, '_').toLowerCase();
}

/**
* Registers the scalar functions via reflection.
* NOTE: In order to plugin methods using reflection, the methods should be inside a class that includes ".function."
Expand Down Expand Up @@ -92,11 +138,17 @@ public static void registerFunction(Method method, boolean nullableParameters) {
* Registers a method with the given function name.
*/
public static void registerFunction(String functionName, Method method, boolean nullableParameters) {
FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters);
String canonicalName = canonicalize(functionName);
final FunctionInfo functionInfo = new FunctionInfo(method, method.getDeclaringClass(), nullableParameters);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(convention) We don't usually put final for local variables

final String canonicalName = canonicalize(functionName);
Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>());
Preconditions.checkState(functionInfoMap.put(method.getParameterCount(), functionInfo) == null,
"Function: %s with %s parameters is already registered", functionName, method.getParameterCount());
// Only put one default implementation for # of params.
if (!functionInfoMap.containsKey(method.getParameterCount())) {
functionInfoMap.put(method.getParameterCount(), functionInfo);
}
List<FieldSpec.DataType> paramTypes = getParamTypes(method.getParameterTypes());
Map<List<FieldSpec.DataType>, FunctionInfo> functionParamInfoMap =
FUNC_PARAM_INFO_MAP.computeIfAbsent(canonicalName, k -> new HashMap<>());
functionParamInfoMap.put(paramTypes, functionInfo);
}

/**
Expand All @@ -107,17 +159,48 @@ public static boolean containsFunction(String functionName) {
}

/**
* Returns the {@link FunctionInfo} associated with the given function name and number of parameters, or {@code null}
* if there is no matching method. This method should be called after the FunctionRegistry is initialized and all
* methods are already registered.
* All functions should be directly or indirectly registered via this call to ensure function name is canonical.
*/
@Nullable
public static FunctionInfo getFunctionInfo(String functionName, int numParameters) {
Map<Integer, FunctionInfo> functionInfoMap = FUNCTION_INFO_MAP.get(canonicalize(functionName));
return functionInfoMap != null ? functionInfoMap.get(numParameters) : null;
public static FunctionInfo getFunctionInfo(String functionName, List<FieldSpec.DataType> dataTypes) {
String canonicalFunctionName = canonicalize(functionName);
Map<List<FieldSpec.DataType>, FunctionInfo> paramMaps =
FUNC_PARAM_INFO_MAP.getOrDefault(canonicalFunctionName, null);
FunctionInfo info = paramMaps.getOrDefault(dataTypes, null);
if (info == null) {
return getFunctionInfo(canonicalFunctionName, dataTypes.size());
}
return info;
}

private static String canonicalize(String functionName) {
return StringUtils.remove(functionName, '_').toLowerCase();
@Nullable
public static FunctionInfo getFunctionInfo(Function function) {
String functionName = function.getOperator();
List<Expression> operands = function.getOperands();
List<FieldSpec.DataType> args = new ArrayList<>();
for (Expression exp : operands) {
ExpressionContext ctx = ExpressionContext.forLiteralContext(exp.getLiteral());
args.add(ctx.getLiteralContext().getType());
}
return getFunctionInfo(functionName, args);
}

@Nullable
public static FunctionInfo getFunctionInfo(FunctionContext function) {
List<ExpressionContext> args = function.getArguments();
List<FieldSpec.DataType> argTypes = new ArrayList<>();
for (ExpressionContext exp : args) {
argTypes.add(exp.getLiteralContext().getType());
}
return getFunctionInfo(function.getFunctionName(), argTypes);
}

@Nullable
public static FunctionInfo getFunctionInfo(String functionName, DataSchema.ColumnDataType[] argTypes) {
List<FieldSpec.DataType> paramTypes = new ArrayList<>();
for (DataSchema.ColumnDataType type : argTypes) {
paramTypes.add(FunctionUtils.getDataType(type));
}
return getFunctionInfo(functionName, paramTypes);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,19 @@ private FunctionUtils() {
put(Object.class, ColumnDataType.OBJECT);
}};

private static final Map<ColumnDataType, DataType> DATA_TYPE_COLUMN_MAP = new HashMap<ColumnDataType, DataType>() {{
put(ColumnDataType.INT, DataType.INT);
put(ColumnDataType.LONG, DataType.LONG);
put(ColumnDataType.FLOAT, DataType.FLOAT);
put(ColumnDataType.DOUBLE, DataType.DOUBLE);
put(ColumnDataType.BIG_DECIMAL, DataType.BIG_DECIMAL);
put(ColumnDataType.BOOLEAN, DataType.BOOLEAN);
put(ColumnDataType.TIMESTAMP, DataType.TIMESTAMP);
put(ColumnDataType.STRING, DataType.STRING);
put(ColumnDataType.BYTES, DataType.BYTES);
// TODO: figure out the rest of type mapping.
}};

/**
* Returns the corresponding PinotDataType for the given parameter class, or {@code null} if there is no one matching.
*/
Expand Down Expand Up @@ -158,4 +171,12 @@ public static DataType getDataType(Class<?> clazz) {
public static ColumnDataType getColumnDataType(Class<?> clazz) {
return COLUMN_DATA_TYPE_MAP.get(clazz);
}

/**
* Returns the corresponding DataType for the ColumnDataType, or {@code null} if there is no one matching.
*/
@Nullable
public static DataType getDataType(ColumnDataType columnType) {
return DATA_TYPE_COLUMN_MAP.get(columnType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

import java.util.Objects;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.spi.data.FieldSpec;


/**
Expand All @@ -31,37 +34,53 @@
*/
public class ExpressionContext {
public enum Type {
LITERAL, IDENTIFIER, FUNCTION
LITERAL, IDENTIFIER, FUNCTION,
}

private final Type _type;
private final String _value;
private final FunctionContext _function;
// Only set when the _type is LITERAL
@Nullable
private final LiteralContext _literal;

public static ExpressionContext forLiteral(String literal) {
return new ExpressionContext(Type.LITERAL, literal, null);
public static ExpressionContext forLiteralContext(Literal literal){
return new ExpressionContext(Type.LITERAL, null, null, new LiteralContext(literal));
}
public static ExpressionContext forLiteralContext(FieldSpec.DataType type, Object val){
return new ExpressionContext(Type.LITERAL, null, null, new LiteralContext(type, val));
}

public static ExpressionContext forIdentifier(String identifier) {
return new ExpressionContext(Type.IDENTIFIER, identifier, null);
return new ExpressionContext(Type.IDENTIFIER, identifier, null, null);
}

public static ExpressionContext forFunction(FunctionContext function) {
return new ExpressionContext(Type.FUNCTION, null, function);
return new ExpressionContext(Type.FUNCTION, null, function, null);
}

private ExpressionContext(Type type, String value, FunctionContext function) {
private ExpressionContext(Type type, String value, FunctionContext function, LiteralContext literal) {
_type = type;
_value = value;
_function = function;
_literal = literal;
}

@Deprecated
public String getLiteralString() {
if (_literal == null || _literal.getValue() == null) {
return "";
}
return _literal.getValue().toString();
}

public Type getType() {
return _type;
}

public String getLiteral() {
return _value;
@Nullable
public LiteralContext getLiteralContext(){
return _literal;
}

public String getIdentifier() {
Expand Down Expand Up @@ -94,7 +113,7 @@ public boolean equals(Object o) {
return false;
}
ExpressionContext that = (ExpressionContext) o;
return _type == that._type && Objects.equals(_value, that._value) && Objects.equals(_function, that._function);
return _type.equals(that._type) && Objects.equals(_value, that._value) && Objects.equals(_function, that._function) && Objects.equals(_literal, that._literal);
}

@Override
Expand All @@ -103,14 +122,17 @@ public int hashCode() {
if (_type == Type.FUNCTION) {
return hash + _function.hashCode();
}
if (_type == Type.LITERAL) {
return hash + _literal.hashCode();
}
return hash + 31 * _value.hashCode();
}

@Override
public String toString() {
switch (_type) {
case LITERAL:
return '\'' + _value + '\'';
return _literal.toString();
case IDENTIFIER:
return _value;
case FUNCTION:
Expand Down