diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java index 213d12587310..af0614df4b8c 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionInvoker.java @@ -111,10 +111,9 @@ public void convertTypes(Object[] arguments) { } PinotDataType parameterType = _parameterTypes[i]; - PinotDataType argumentType = FunctionUtils.getArgumentType(argumentClass); - Preconditions.checkArgument(parameterType != null && argumentType != null, + Preconditions.checkArgument(parameterType != null, "Cannot convert value from class: %s to class: %s", argumentClass, parameterClass); - arguments[i] = parameterType.convert(argument, argumentType); + arguments[i] = parameterType.convert(argument, FunctionUtils.getArgumentType(argument)); } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java index 60ee07a76ab9..ab96dfd1b613 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/FunctionUtils.java @@ -29,7 +29,6 @@ import org.apache.calcite.sql.type.SqlTypeName; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.PinotDataType; -import org.apache.pinot.spi.data.FieldSpec.DataType; public class FunctionUtils { @@ -57,65 +56,14 @@ private FunctionUtils() { put(float[].class, PinotDataType.PRIMITIVE_FLOAT_ARRAY); put(double[].class, PinotDataType.PRIMITIVE_DOUBLE_ARRAY); put(BigDecimal[].class, PinotDataType.BIG_DECIMAL_ARRAY); + put(boolean[].class, PinotDataType.PRIMITIVE_BOOLEAN_ARRAY); + put(Timestamp[].class, PinotDataType.TIMESTAMP_ARRAY); put(String[].class, PinotDataType.STRING_ARRAY); put(byte[][].class, PinotDataType.BYTES_ARRAY); put(Map.class, PinotDataType.MAP); put(Object.class, PinotDataType.OBJECT); }}; - // Types allowed as the function argument (actual value passed into the function) for type conversion - private static final Map, PinotDataType> ARGUMENT_TYPE_MAP = new HashMap<>() {{ - put(Byte.class, PinotDataType.BYTE); - put(Boolean.class, PinotDataType.BOOLEAN); - put(Character.class, PinotDataType.CHARACTER); - put(Short.class, PinotDataType.SHORT); - put(Integer.class, PinotDataType.INTEGER); - put(Long.class, PinotDataType.LONG); - put(Float.class, PinotDataType.FLOAT); - put(Double.class, PinotDataType.DOUBLE); - put(BigDecimal.class, PinotDataType.BIG_DECIMAL); - put(Timestamp.class, PinotDataType.TIMESTAMP); - put(String.class, PinotDataType.STRING); - put(byte[].class, PinotDataType.BYTES); - put(int[].class, PinotDataType.PRIMITIVE_INT_ARRAY); - put(Integer[].class, PinotDataType.INTEGER_ARRAY); - put(long[].class, PinotDataType.PRIMITIVE_LONG_ARRAY); - put(Long[].class, PinotDataType.LONG_ARRAY); - put(float[].class, PinotDataType.PRIMITIVE_FLOAT_ARRAY); - put(Float[].class, PinotDataType.FLOAT_ARRAY); - put(double[].class, PinotDataType.PRIMITIVE_DOUBLE_ARRAY); - put(Double[].class, PinotDataType.DOUBLE_ARRAY); - put(BigDecimal[].class, PinotDataType.BIG_DECIMAL_ARRAY); - put(String[].class, PinotDataType.STRING_ARRAY); - put(byte[][].class, PinotDataType.BYTES_ARRAY); - put(Object.class, PinotDataType.OBJECT); - put(Object[].class, PinotDataType.OBJECT_ARRAY); - }}; - - private static final Map, DataType> DATA_TYPE_MAP = new HashMap<>() {{ - put(int.class, DataType.INT); - put(Integer.class, DataType.INT); - put(long.class, DataType.LONG); - put(Long.class, DataType.LONG); - put(float.class, DataType.FLOAT); - put(Float.class, DataType.FLOAT); - put(double.class, DataType.DOUBLE); - put(Double.class, DataType.DOUBLE); - put(BigDecimal.class, DataType.BIG_DECIMAL); - put(boolean.class, DataType.BOOLEAN); - put(Boolean.class, DataType.BOOLEAN); - put(Timestamp.class, DataType.TIMESTAMP); - put(String.class, DataType.STRING); - put(byte[].class, DataType.BYTES); - put(int[].class, DataType.INT); - put(long[].class, DataType.LONG); - put(float[].class, DataType.FLOAT); - put(double[].class, DataType.DOUBLE); - put(BigDecimal[].class, DataType.BIG_DECIMAL); - put(String[].class, DataType.STRING); - put(byte[][].class, DataType.BYTES); - }}; - private static final Map, ColumnDataType> COLUMN_DATA_TYPE_MAP = new HashMap<>() {{ put(int.class, ColumnDataType.INT); put(Integer.class, ColumnDataType.INT); @@ -136,6 +84,8 @@ private FunctionUtils() { put(float[].class, ColumnDataType.FLOAT_ARRAY); put(double[].class, ColumnDataType.DOUBLE_ARRAY); put(BigDecimal[].class, ColumnDataType.BIG_DECIMAL_ARRAY); + put(boolean[].class, ColumnDataType.BOOLEAN_ARRAY); + put(Timestamp[].class, ColumnDataType.TIMESTAMP_ARRAY); put(String[].class, ColumnDataType.STRING_ARRAY); put(byte[][].class, ColumnDataType.BYTES_ARRAY); put(Object.class, ColumnDataType.OBJECT); @@ -149,26 +99,55 @@ public static PinotDataType getParameterType(Class clazz) { return PARAMETER_TYPE_MAP.get(clazz); } - /** - * Returns the corresponding PinotDataType for the given argument class, or {@code null} if there is no one matching. - */ - @Nullable - public static PinotDataType getArgumentType(Class clazz) { - if (Collection.class.isAssignableFrom(clazz)) { - return PinotDataType.COLLECTION; + /// Returns the corresponding [PinotDataType] for the given argument value (the actual value passed into + /// the function). Returns [PinotDataType#OBJECT] / [PinotDataType#OBJECT_ARRAY] for unrecognized types, + /// matching [PinotDataType#getSingleValueType]'s best-effort fallback. Subclasses of non-final types + /// (e.g. vendor `Timestamp` subclasses returned by JDBC drivers) are matched by their parent type. + /// + /// Dispatch (single-value first since it's the dominant case for function arguments): + /// - Single values → delegated to [PinotDataType#getSingleValueType] (covers all scalar types + /// including `byte[]` → [PinotDataType#BYTES]). + /// - Reference arrays (`Object[]` and subtypes including `byte[][]`) → first non-null element is + /// sampled and [PinotDataType#getMultiValueType] is consulted. Empty / all-null reference arrays + /// fall back to [PinotDataType#OBJECT_ARRAY] since the element type is undeterminable. + /// - Primitive arrays (`int[]` / `long[]` / `float[]` / `double[]` / `boolean[]`) → handled here, since + /// they can't be element-sampled into a boxed type. + /// - [PinotDataType#COLLECTION] for any [Collection]; otherwise falls back to [PinotDataType#OBJECT]. + public static PinotDataType getArgumentType(Object value) { + PinotDataType singleValueType = PinotDataType.getSingleValueType(value); + if (singleValueType != PinotDataType.OBJECT) { + return singleValueType; } - if (Map.class.isAssignableFrom(clazz)) { - return PinotDataType.MAP; + if (value instanceof Object[]) { + Object[] array = (Object[]) value; + for (Object element : array) { + if (element == null) { + continue; + } + return PinotDataType.getMultiValueType(element); + } + // Empty or all-null reference array — element type undeterminable. + return PinotDataType.OBJECT_ARRAY; } - return ARGUMENT_TYPE_MAP.get(clazz); - } - - /** - * Returns the corresponding DataType for the given class, or {@code null} if there is no one matching. - */ - @Nullable - public static DataType getDataType(Class clazz) { - return DATA_TYPE_MAP.get(clazz); + if (value instanceof int[]) { + return PinotDataType.PRIMITIVE_INT_ARRAY; + } + if (value instanceof long[]) { + return PinotDataType.PRIMITIVE_LONG_ARRAY; + } + if (value instanceof float[]) { + return PinotDataType.PRIMITIVE_FLOAT_ARRAY; + } + if (value instanceof double[]) { + return PinotDataType.PRIMITIVE_DOUBLE_ARRAY; + } + if (value instanceof boolean[]) { + return PinotDataType.PRIMITIVE_BOOLEAN_ARRAY; + } + if (value instanceof Collection) { + return PinotDataType.COLLECTION; + } + return PinotDataType.OBJECT; } /** diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java index 43fbead43ffb..17f8e82d0769 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java @@ -41,8 +41,8 @@ private DataTypeConversionFunctions() { public static Object cast(Object value, String targetTypeLiteral) { Class clazz = value.getClass(); // TODO: Support cast for MV - Preconditions.checkArgument(!clazz.isArray() | clazz == byte[].class, "%s must not be an array type", clazz); - PinotDataType sourceType = PinotDataType.getSingleValueType(clazz); + Preconditions.checkArgument(!clazz.isArray() || clazz == byte[].class, "%s must not be an array type", clazz); + PinotDataType sourceType = PinotDataType.getSingleValueType(value); String transformed = targetTypeLiteral.toUpperCase(); PinotDataType targetDataType; switch (transformed) { diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java index dd90787b0344..f18afec273c9 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/PinotDataType.java @@ -1212,10 +1212,10 @@ public BigDecimal[] convert(Object value, PinotDataType sourceType) { } }, - BOOLEAN_ARRAY { + PRIMITIVE_BOOLEAN_ARRAY { @Override public boolean[] convert(Object value, PinotDataType sourceType) { - return sourceType.toBooleanArray(value); + return sourceType.toPrimitiveBooleanArray(value); } @Override @@ -1230,6 +1230,24 @@ public Integer[] toInternal(Object value) { } }, + BOOLEAN_ARRAY { + @Override + public Boolean[] convert(Object value, PinotDataType sourceType) { + return sourceType.toBooleanArray(value); + } + + @Override + public Integer[] toInternal(Object value) { + Boolean[] booleanArray = (Boolean[]) value; + int length = booleanArray.length; + Integer[] intArray = new Integer[length]; + for (int i = 0; i < length; i++) { + intArray[i] = booleanArray[i] != null ? (booleanArray[i] ? 1 : 0) : null; + } + return intArray; + } + }, + TIMESTAMP_ARRAY { @Override public Object convert(Object value, PinotDataType sourceType) { @@ -1563,117 +1581,111 @@ public Double[] toDoubleArray(Object value) { } } - public String[] toStringArray(Object value) { - if (value instanceof String[]) { - return (String[]) value; + public BigDecimal[] toBigDecimalArray(Object value) { + if (value instanceof BigDecimal[]) { + return (BigDecimal[]) value; } if (isSingleValue()) { - return new String[]{toString(value)}; + return new BigDecimal[]{toBigDecimal(value)}; } else { Object[] valueArray = toObjectArray(value); int length = valueArray.length; - String[] stringArray = new String[length]; + BigDecimal[] bigDecimalArray = new BigDecimal[length]; PinotDataType singleValueType = getSingleValueType(); for (int i = 0; i < length; i++) { - stringArray[i] = singleValueType.toString(valueArray[i]); + bigDecimalArray[i] = singleValueType.toBigDecimal(valueArray[i]); } - return stringArray; + return bigDecimalArray; } } - public byte[][] toBytesArray(Object value) { - if (value instanceof byte[][]) { - return (byte[][]) value; + public boolean[] toPrimitiveBooleanArray(Object value) { + if (value instanceof boolean[]) { + return (boolean[]) value; } if (isSingleValue()) { - return new byte[][]{toBytes(value)}; + return new boolean[]{toBoolean(value)}; } else { Object[] valueArray = toObjectArray(value); int length = valueArray.length; - byte[][] bytesArray = new byte[length][]; + boolean[] booleanArray = new boolean[length]; PinotDataType singleValueType = getSingleValueType(); for (int i = 0; i < length; i++) { - bytesArray[i] = singleValueType.toBytes(valueArray[i]); + booleanArray[i] = singleValueType.toBoolean(valueArray[i]); } - return bytesArray; + return booleanArray; } } - public BigDecimal[] toBigDecimalArray(Object value) { - if (value instanceof BigDecimal[]) { - return (BigDecimal[]) value; + public Boolean[] toBooleanArray(Object value) { + if (value instanceof Boolean[]) { + return (Boolean[]) value; } if (isSingleValue()) { - return new BigDecimal[]{toBigDecimal(value)}; + return new Boolean[]{toBoolean(value)}; } else { Object[] valueArray = toObjectArray(value); int length = valueArray.length; - BigDecimal[] bigDecimalArray = new BigDecimal[length]; + Boolean[] booleanArray = new Boolean[length]; PinotDataType singleValueType = getSingleValueType(); for (int i = 0; i < length; i++) { - bigDecimalArray[i] = singleValueType.toBigDecimal(valueArray[i]); + booleanArray[i] = singleValueType.toBoolean(valueArray[i]); } - return bigDecimalArray; + return booleanArray; } } - private static Object[] toObjectArray(Object array) { - if (array instanceof Collection) { - return ((Collection) array).toArray(); + public Timestamp[] toTimestampArray(Object value) { + if (value instanceof Timestamp[]) { + return (Timestamp[]) value; } - Class componentType = array.getClass().getComponentType(); - if (componentType.isPrimitive()) { - if (componentType == int.class) { - return ArrayUtils.toObject((int[]) array); - } - if (componentType == long.class) { - return ArrayUtils.toObject((long[]) array); - } - if (componentType == float.class) { - return ArrayUtils.toObject((float[]) array); - } - if (componentType == double.class) { - return ArrayUtils.toObject((double[]) array); - } - throw new UnsupportedOperationException("Unsupported primitive array type: " + componentType); + if (isSingleValue()) { + return new Timestamp[]{toTimestamp(value)}; } else { - return (Object[]) array; + Object[] valueArray = toObjectArray(value); + int length = valueArray.length; + Timestamp[] timestampArray = new Timestamp[length]; + PinotDataType singleValueType = getSingleValueType(); + for (int i = 0; i < length; i++) { + timestampArray[i] = singleValueType.toTimestamp(valueArray[i]); + } + return timestampArray; } } - public boolean[] toBooleanArray(Object value) { - if (value instanceof boolean[]) { - return (boolean[]) value; + public String[] toStringArray(Object value) { + if (value instanceof String[]) { + return (String[]) value; } if (isSingleValue()) { - return new boolean[]{toBoolean(value)}; + return new String[]{toString(value)}; } else { Object[] valueArray = toObjectArray(value); int length = valueArray.length; - boolean[] booleanArray = new boolean[length]; + String[] stringArray = new String[length]; PinotDataType singleValueType = getSingleValueType(); for (int i = 0; i < length; i++) { - booleanArray[i] = singleValueType.toBoolean(valueArray[i]); + stringArray[i] = singleValueType.toString(valueArray[i]); } - return booleanArray; + return stringArray; } } - public Timestamp[] toTimestampArray(Object value) { - if (value instanceof Timestamp[]) { - return (Timestamp[]) value; + public byte[][] toBytesArray(Object value) { + if (value instanceof byte[][]) { + return (byte[][]) value; } if (isSingleValue()) { - return new Timestamp[]{toTimestamp(value)}; + return new byte[][]{toBytes(value)}; } else { Object[] valueArray = toObjectArray(value); int length = valueArray.length; - Timestamp[] timestampArray = new Timestamp[length]; + byte[][] bytesArray = new byte[length][]; PinotDataType singleValueType = getSingleValueType(); for (int i = 0; i < length; i++) { - timestampArray[i] = singleValueType.toTimestamp(valueArray[i]); + bytesArray[i] = singleValueType.toBytes(valueArray[i]); } - return timestampArray; + return bytesArray; } } @@ -1731,19 +1743,42 @@ public UUID[] toUuidArray(Object value) { } } + private static Object[] toObjectArray(Object array) { + if (array instanceof Collection) { + return ((Collection) array).toArray(); + } + Class componentType = array.getClass().getComponentType(); + if (componentType.isPrimitive()) { + if (componentType == int.class) { + return ArrayUtils.toObject((int[]) array); + } + if (componentType == long.class) { + return ArrayUtils.toObject((long[]) array); + } + if (componentType == float.class) { + return ArrayUtils.toObject((float[]) array); + } + if (componentType == double.class) { + return ArrayUtils.toObject((double[]) array); + } + if (componentType == boolean.class) { + return ArrayUtils.toObject((boolean[]) array); + } + throw new UnsupportedOperationException("Unsupported primitive array type: " + componentType); + } else { + return (Object[]) array; + } + } + public Object convert(Object value, PinotDataType sourceType) { throw new UnsupportedOperationException("Cannot convert value from " + sourceType + " to " + this); } - /** - * Converts to the internal representation of the value. - *
    - *
  • BOOLEAN -> int
  • - *
  • TIMESTAMP -> long
  • - *
  • BOOLEAN_ARRAY -> int[]
  • - *
  • TIMESTAMP_ARRAY -> long[]
  • - *
- */ + /// Converts to the internal representation of the value. + /// - `BOOLEAN` → `Integer` (0/1) + /// - `TIMESTAMP` → `Long` (epoch millis) + /// - `PRIMITIVE_BOOLEAN_ARRAY` / `BOOLEAN_ARRAY` → `Integer[]` (per-element 0/1) + /// - `TIMESTAMP_ARRAY` → `Long[]` (per-element epoch millis) public Object toInternal(Object value) { return value; } @@ -1774,6 +1809,7 @@ public PinotDataType getSingleValueType() { return DOUBLE; case BIG_DECIMAL_ARRAY: return BIG_DECIMAL; + case PRIMITIVE_BOOLEAN_ARRAY: case BOOLEAN_ARRAY: return BOOLEAN; case TIMESTAMP_ARRAY: @@ -1796,121 +1832,126 @@ public PinotDataType getSingleValueType() { } } - public static PinotDataType getSingleValueType(Class cls) { - if (cls == Integer.class) { + /// Returns the [PinotDataType] for the given single value, dispatched on the runtime class via + /// `instanceof`. Returns [#OBJECT] for any unrecognized type. Subclasses of non-final types + /// (e.g. vendor `Timestamp` subclasses returned by JDBC drivers) are matched by their parent type. + public static PinotDataType getSingleValueType(Object value) { + if (value instanceof Integer) { return INTEGER; } - if (cls == Long.class) { + if (value instanceof Long) { return LONG; } - if (cls == Float.class) { + if (value instanceof Float) { return FLOAT; } - if (cls == Double.class) { + if (value instanceof Double) { return DOUBLE; } - if (cls == BigDecimal.class) { + if (value instanceof BigDecimal) { return BIG_DECIMAL; } - if (cls == String.class) { - return STRING; - } - if (cls == byte[].class) { - return BYTES; - } - if (cls == UUID.class) { - return UUID; - } - if (cls == Boolean.class) { + if (value instanceof Boolean) { return BOOLEAN; } - if (cls == Timestamp.class) { + if (value instanceof Timestamp) { return TIMESTAMP; } - if (cls != null && Map.class.isAssignableFrom(cls)) { + if (value instanceof String) { + return STRING; + } + if (value instanceof byte[]) { + return BYTES; + } + if (value instanceof Map) { return MAP; } - if (cls == LocalDate.class) { + if (value instanceof LocalDate) { return DATE; } - if (cls == LocalTime.class) { + if (value instanceof LocalTime) { return TIME; } - if (cls == Byte.class) { + if (value instanceof UUID) { + return UUID; + } + if (value instanceof Byte) { return BYTE; } - if (cls == Character.class) { + if (value instanceof Character) { return CHARACTER; } - if (cls == Short.class) { + if (value instanceof Short) { return SHORT; } return OBJECT; } - public static PinotDataType getMultiValueType(Class cls) { - if (cls == Integer.class) { + /// Returns the multi-value [PinotDataType] for the given sample element, dispatched on the runtime class + /// via `instanceof`. Returns [#OBJECT_ARRAY] for any unrecognized type. + public static PinotDataType getMultiValueType(Object element) { + if (element instanceof Integer) { return INTEGER_ARRAY; } - if (cls == Long.class) { + if (element instanceof Long) { return LONG_ARRAY; } - if (cls == Float.class) { + if (element instanceof Float) { return FLOAT_ARRAY; } - if (cls == Double.class) { + if (element instanceof Double) { return DOUBLE_ARRAY; } - if (cls == BigDecimal.class) { + if (element instanceof BigDecimal) { return BIG_DECIMAL_ARRAY; } - if (cls == String.class) { - return STRING_ARRAY; - } - if (cls == byte[].class) { - return BYTES_ARRAY; - } - if (cls == Boolean.class) { + if (element instanceof Boolean) { return BOOLEAN_ARRAY; } - if (cls == Timestamp.class) { + if (element instanceof Timestamp) { return TIMESTAMP_ARRAY; } - if (cls == LocalDate.class) { + if (element instanceof String) { + return STRING_ARRAY; + } + if (element instanceof byte[]) { + return BYTES_ARRAY; + } + if (element instanceof LocalDate) { return DATE_ARRAY; } - if (cls == LocalTime.class) { + if (element instanceof LocalTime) { return TIME_ARRAY; } - if (cls == UUID.class) { + if (element instanceof UUID) { return UUID_ARRAY; } - if (cls == Byte.class) { + if (element instanceof Byte) { return BYTE_ARRAY; } - if (cls == Character.class) { + if (element instanceof Character) { return CHARACTER_ARRAY; } - if (cls == Short.class) { + if (element instanceof Short) { return SHORT_ARRAY; } return OBJECT_ARRAY; } private static int anyToInt(Object val) { - return getSingleValueType(val.getClass()).toInt(val); + return getSingleValueType(val).toInt(val); } private static long anyToLong(Object val) { - return getSingleValueType(val.getClass()).toLong(val); + return getSingleValueType(val).toLong(val); } private static float anyToFloat(Object val) { - return getSingleValueType(val.getClass()).toFloat(val); + return getSingleValueType(val).toFloat(val); } private static double anyToDouble(Object val) { - return getSingleValueType(val.getClass()).toDouble(val); + return getSingleValueType(val).toDouble(val); } /** diff --git a/pinot-common/src/test/java/org/apache/pinot/common/evaluator/InbuiltFunctionEvaluatorTest.java b/pinot-common/src/test/java/org/apache/pinot/common/evaluator/InbuiltFunctionEvaluatorTest.java index 7c22937bf59f..aa37a6a826ab 100644 --- a/pinot-common/src/test/java/org/apache/pinot/common/evaluator/InbuiltFunctionEvaluatorTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/common/evaluator/InbuiltFunctionEvaluatorTest.java @@ -27,7 +27,6 @@ import org.testng.annotations.Test; import static org.testng.Assert.assertEquals; -import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; import static org.testng.Assert.assertTrue; @@ -108,9 +107,7 @@ public void testNotWithNulls() { private void checkBooleanLiteralExpression(String expression, int value) { InbuiltFunctionEvaluator evaluator = new InbuiltFunctionEvaluator(expression); Object output = evaluator.evaluate(new GenericRow()); - Class outputValueClass = output.getClass(); - PinotDataType outputType = FunctionUtils.getArgumentType(outputValueClass); - assertNotNull(outputType); + PinotDataType outputType = FunctionUtils.getArgumentType(output); // as INT is the stored type for BOOLEAN assertEquals(outputType.toInt(output), value); } diff --git a/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionUtilsTest.java b/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionUtilsTest.java new file mode 100644 index 000000000000..7dc41f656b7a --- /dev/null +++ b/pinot-common/src/test/java/org/apache/pinot/common/function/FunctionUtilsTest.java @@ -0,0 +1,158 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.common.function; + +import java.math.BigDecimal; +import java.sql.Timestamp; +import java.time.LocalDate; +import java.time.LocalTime; +import java.util.HashMap; +import java.util.List; +import java.util.UUID; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.common.utils.PinotDataType; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + + +public class FunctionUtilsTest { + + @Test + public void testGetArgumentType() { + // Single values delegated to PinotDataType.getSingleValueType + assertEquals(FunctionUtils.getArgumentType(1), PinotDataType.INTEGER); + assertEquals(FunctionUtils.getArgumentType(1L), PinotDataType.LONG); + assertEquals(FunctionUtils.getArgumentType(1.0f), PinotDataType.FLOAT); + assertEquals(FunctionUtils.getArgumentType(1.0d), PinotDataType.DOUBLE); + assertEquals(FunctionUtils.getArgumentType(BigDecimal.ONE), PinotDataType.BIG_DECIMAL); + assertEquals(FunctionUtils.getArgumentType(Boolean.TRUE), PinotDataType.BOOLEAN); + assertEquals(FunctionUtils.getArgumentType(new Timestamp(0L)), PinotDataType.TIMESTAMP); + assertEquals(FunctionUtils.getArgumentType("foo"), PinotDataType.STRING); + assertEquals(FunctionUtils.getArgumentType(new byte[]{0}), PinotDataType.BYTES); + assertEquals(FunctionUtils.getArgumentType(new HashMap<>()), PinotDataType.MAP); + assertEquals(FunctionUtils.getArgumentType(LocalDate.EPOCH), PinotDataType.DATE); + assertEquals(FunctionUtils.getArgumentType(LocalTime.NOON), PinotDataType.TIME); + assertEquals(FunctionUtils.getArgumentType(UUID.randomUUID()), PinotDataType.UUID); + assertEquals(FunctionUtils.getArgumentType((byte) 1), PinotDataType.BYTE); + assertEquals(FunctionUtils.getArgumentType('a'), PinotDataType.CHARACTER); + assertEquals(FunctionUtils.getArgumentType((short) 1), PinotDataType.SHORT); + } + + @Test + public void testGetArgumentTypeForVendorTimestampSubclass() { + // Vendor JDBC drivers commonly return Timestamp subclasses (e.g. BigQuery Simba's TimestampTz). + // Subclasses must resolve to TIMESTAMP via the instanceof dispatch. + class VendorTimestamp extends Timestamp { + VendorTimestamp(long time) { + super(time); + } + } + assertEquals(FunctionUtils.getArgumentType(new VendorTimestamp(0L)), PinotDataType.TIMESTAMP); + } + + @Test + public void testGetArgumentTypeForPrimitiveArrays() { + assertEquals(FunctionUtils.getArgumentType(new int[]{1}), PinotDataType.PRIMITIVE_INT_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new long[]{1L}), PinotDataType.PRIMITIVE_LONG_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new float[]{1.0f}), PinotDataType.PRIMITIVE_FLOAT_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new double[]{1.0d}), PinotDataType.PRIMITIVE_DOUBLE_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new boolean[]{true}), PinotDataType.PRIMITIVE_BOOLEAN_ARRAY); + } + + @Test + public void testGetArgumentTypeForReferenceArrays() { + // Reference arrays sample first non-null element via PinotDataType.getMultiValueType + assertEquals(FunctionUtils.getArgumentType(new Integer[]{1}), PinotDataType.INTEGER_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new Long[]{1L}), PinotDataType.LONG_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new Float[]{1.0f}), PinotDataType.FLOAT_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new Double[]{1.0d}), PinotDataType.DOUBLE_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new BigDecimal[]{BigDecimal.ONE}), PinotDataType.BIG_DECIMAL_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new Boolean[]{true}), PinotDataType.BOOLEAN_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new Timestamp[]{new Timestamp(0L)}), PinotDataType.TIMESTAMP_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new String[]{"foo"}), PinotDataType.STRING_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new byte[][]{{0}}), PinotDataType.BYTES_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new LocalDate[]{LocalDate.EPOCH}), PinotDataType.DATE_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new LocalTime[]{LocalTime.NOON}), PinotDataType.TIME_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new UUID[]{UUID.randomUUID()}), PinotDataType.UUID_ARRAY); + } + + @Test + public void testGetArgumentTypeForEmptyOrAllNullReferenceArray() { + // Empty / all-null reference arrays fall back to OBJECT_ARRAY (element type undeterminable). + assertEquals(FunctionUtils.getArgumentType(new Object[0]), PinotDataType.OBJECT_ARRAY); + assertEquals(FunctionUtils.getArgumentType(new Object[]{null, null}), PinotDataType.OBJECT_ARRAY); + // Specific component type with no usable sample also falls back to OBJECT_ARRAY. + assertEquals(FunctionUtils.getArgumentType(new Integer[0]), PinotDataType.OBJECT_ARRAY); + } + + @Test + public void testGetArgumentTypeForCollection() { + assertEquals(FunctionUtils.getArgumentType(List.of(1, 2, 3)), PinotDataType.COLLECTION); + } + + @Test + public void testGetArgumentTypeForUnknown() { + // Unrecognized non-array types fall back to OBJECT (best-effort coercion sentinel). + assertEquals(FunctionUtils.getArgumentType(new Object()), PinotDataType.OBJECT); + } + + @Test + public void testGetParameterType() { + // Scalars + assertEquals(FunctionUtils.getParameterType(int.class), PinotDataType.INTEGER); + assertEquals(FunctionUtils.getParameterType(Integer.class), PinotDataType.INTEGER); + assertEquals(FunctionUtils.getParameterType(boolean.class), PinotDataType.BOOLEAN); + assertEquals(FunctionUtils.getParameterType(Boolean.class), PinotDataType.BOOLEAN); + assertEquals(FunctionUtils.getParameterType(Timestamp.class), PinotDataType.TIMESTAMP); + assertEquals(FunctionUtils.getParameterType(String.class), PinotDataType.STRING); + assertEquals(FunctionUtils.getParameterType(byte[].class), PinotDataType.BYTES); + // Arrays + assertEquals(FunctionUtils.getParameterType(int[].class), PinotDataType.PRIMITIVE_INT_ARRAY); + assertEquals(FunctionUtils.getParameterType(boolean[].class), PinotDataType.PRIMITIVE_BOOLEAN_ARRAY); + assertEquals(FunctionUtils.getParameterType(Timestamp[].class), PinotDataType.TIMESTAMP_ARRAY); + assertEquals(FunctionUtils.getParameterType(String[].class), PinotDataType.STRING_ARRAY); + assertEquals(FunctionUtils.getParameterType(byte[][].class), PinotDataType.BYTES_ARRAY); + // Boxed array forms not allowed as scalar function parameters + assertNull(FunctionUtils.getParameterType(Integer[].class)); + assertNull(FunctionUtils.getParameterType(Boolean[].class)); + // Unknown class + assertNull(FunctionUtils.getParameterType(LocalDate.class)); + } + + @Test + public void testGetColumnDataType() { + // Scalars + assertEquals(FunctionUtils.getColumnDataType(int.class), ColumnDataType.INT); + assertEquals(FunctionUtils.getColumnDataType(Integer.class), ColumnDataType.INT); + assertEquals(FunctionUtils.getColumnDataType(boolean.class), ColumnDataType.BOOLEAN); + assertEquals(FunctionUtils.getColumnDataType(Timestamp.class), ColumnDataType.TIMESTAMP); + assertEquals(FunctionUtils.getColumnDataType(byte[].class), ColumnDataType.BYTES); + // Arrays + assertEquals(FunctionUtils.getColumnDataType(int[].class), ColumnDataType.INT_ARRAY); + assertEquals(FunctionUtils.getColumnDataType(boolean[].class), ColumnDataType.BOOLEAN_ARRAY); + assertEquals(FunctionUtils.getColumnDataType(Timestamp[].class), ColumnDataType.TIMESTAMP_ARRAY); + assertEquals(FunctionUtils.getColumnDataType(byte[][].class), ColumnDataType.BYTES_ARRAY); + // Object + assertEquals(FunctionUtils.getColumnDataType(Object.class), ColumnDataType.OBJECT); + // Unknown class + assertNull(FunctionUtils.getColumnDataType(LocalDate.class)); + } +} diff --git a/pinot-common/src/test/java/org/apache/pinot/common/utils/PinotDataTypeTest.java b/pinot-common/src/test/java/org/apache/pinot/common/utils/PinotDataTypeTest.java index 27d96a5c85ab..537538b7dcea 100644 --- a/pinot-common/src/test/java/org/apache/pinot/common/utils/PinotDataTypeTest.java +++ b/pinot-common/src/test/java/org/apache/pinot/common/utils/PinotDataTypeTest.java @@ -131,7 +131,7 @@ public Object[][] numberFormatConversionErrors() { return new Object[][] { {INTEGER_ARRAY, LONG_ARRAY, new Object[]{"abc"}}, {INTEGER_ARRAY, INTEGER_ARRAY, new Object[]{"abc"}}, - {INTEGER_ARRAY, BOOLEAN_ARRAY, new Object[]{"abc"}} + {INTEGER_ARRAY, PRIMITIVE_BOOLEAN_ARRAY, new Object[]{"abc"}} }; } @@ -143,8 +143,15 @@ public void testNumberFormatConversionErrors(PinotDataType sourceType, PinotData @DataProvider public Object[][] conversions() { return new Object[][]{ - {STRING_ARRAY, BOOLEAN_ARRAY, new String[] {"true", "false"}, new boolean[] { true, false }}, - {BOOLEAN_ARRAY, BOOLEAN_ARRAY, new boolean[] { true, false }, new boolean[] { true, false }}, + {STRING_ARRAY, PRIMITIVE_BOOLEAN_ARRAY, new String[] {"true", "false"}, new boolean[] { true, false }}, + {PRIMITIVE_BOOLEAN_ARRAY, PRIMITIVE_BOOLEAN_ARRAY, new boolean[] { true, false }, + new boolean[] { true, false }}, + {STRING_ARRAY, BOOLEAN_ARRAY, new String[] {"true", "false"}, new Boolean[] { true, false }}, + {BOOLEAN_ARRAY, BOOLEAN_ARRAY, new Boolean[] { true, false }, new Boolean[] { true, false }}, + // Cross-form: PRIMITIVE_BOOLEAN_ARRAY -> BOOLEAN_ARRAY exercises the boolean[] path through + // toObjectArray (which now handles primitive boolean[]). + {PRIMITIVE_BOOLEAN_ARRAY, BOOLEAN_ARRAY, new boolean[] { true, false }, new Boolean[] { true, false }}, + {BOOLEAN_ARRAY, PRIMITIVE_BOOLEAN_ARRAY, new Boolean[] { true, false }, new boolean[] { true, false }}, {LONG_ARRAY, TIMESTAMP_ARRAY, new long[] {1000000L, 2000000L}, new Timestamp[] { new Timestamp(1000000L), new Timestamp(2000000L) }}, {TIMESTAMP_ARRAY, TIMESTAMP_ARRAY, new Timestamp[] { new Timestamp(1000000L), new Timestamp(2000000L) }, @@ -243,8 +250,8 @@ public void testDateArray() { assertEquals(DATE_ARRAY.convert(new Integer[]{19_096, 19_723}, INTEGER_ARRAY), dates); // DATE_ARRAY → STRING_ARRAY: each element via DATE.toString. assertEquals(STRING_ARRAY.convert(dates, DATE_ARRAY), new String[]{"2022-04-14", "2024-01-01"}); - // Class lookup: Object[] of LocalDate routes to DATE_ARRAY. - assertEquals(getMultiValueType(LocalDate.class), DATE_ARRAY); + // Lookup: Object[] of LocalDate routes to DATE_ARRAY. + assertEquals(getMultiValueType(dates[0]), DATE_ARRAY); // toInternal: Integer[] of epoch-days. assertEquals(DATE_ARRAY.toInternal(dates), new Integer[]{19_096, 19_723}); } @@ -273,7 +280,7 @@ public void testTimeArray() { // LONG_ARRAY → TIME_ARRAY: per-element millis-since-midnight decoding. assertEquals(TIME_ARRAY.convert(new Long[]{31_892_000L, 43_200_000L}, LONG_ARRAY), times); assertEquals(STRING_ARRAY.convert(times, TIME_ARRAY), new String[]{"08:51:32", "12:00"}); - assertEquals(getMultiValueType(LocalTime.class), TIME_ARRAY); + assertEquals(getMultiValueType(times[0]), TIME_ARRAY); // toInternal: Long[] of millis-since-midnight. assertEquals(TIME_ARRAY.toInternal(times), new Long[]{31_892_000L, 43_200_000L}); } @@ -315,8 +322,8 @@ public void testUuidArray() { assertEquals(bytesArray[0].length, 16); assertEquals(UUID.convert(bytesArray[0], BYTES), u1); assertEquals(UUID.convert(bytesArray[1], BYTES), u2); - // Class lookup: Object[] of UUID routes to UUID_ARRAY. - assertEquals(getMultiValueType(UUID.class), UUID_ARRAY); + // Lookup: Object[] of UUID routes to UUID_ARRAY. + assertEquals(getMultiValueType(u1), UUID_ARRAY); // toInternal: String[] of canonical form. assertEquals(UUID_ARRAY.toInternal(uuids), new String[]{"550e8400-e29b-41d4-a716-446655440000", "00000000-0000-0000-0000-000000000001"}); @@ -383,54 +390,61 @@ public void testObject() { @Test public void testGetSingleValueType() { - Map, PinotDataType> testCases = new HashMap<>(); - testCases.put(Boolean.class, BOOLEAN); - testCases.put(Byte.class, BYTE); - testCases.put(Character.class, CHARACTER); - testCases.put(Short.class, SHORT); - testCases.put(Integer.class, INTEGER); - testCases.put(Long.class, LONG); - testCases.put(Float.class, FLOAT); - testCases.put(Double.class, DOUBLE); - testCases.put(BigDecimal.class, BIG_DECIMAL); - testCases.put(Timestamp.class, TIMESTAMP); - testCases.put(LocalDate.class, DATE); - testCases.put(LocalTime.class, TIME); - testCases.put(String.class, STRING); - testCases.put(byte[].class, BYTES); - testCases.put(UUID.class, UUID); - - for (Map.Entry, PinotDataType> tc : testCases.entrySet()) { - assertEquals(getSingleValueType(tc.getKey()), tc.getValue()); + assertEquals(getSingleValueType(1), INTEGER); + assertEquals(getSingleValueType(1L), LONG); + assertEquals(getSingleValueType(1.0f), FLOAT); + assertEquals(getSingleValueType(1.0d), DOUBLE); + assertEquals(getSingleValueType(BigDecimal.ONE), BIG_DECIMAL); + assertEquals(getSingleValueType(Boolean.TRUE), BOOLEAN); + assertEquals(getSingleValueType(new Timestamp(0L)), TIMESTAMP); + assertEquals(getSingleValueType("foo"), STRING); + assertEquals(getSingleValueType(new byte[]{0}), BYTES); + assertEquals(getSingleValueType(new HashMap<>()), MAP); + assertEquals(getSingleValueType(LocalDate.EPOCH), DATE); + assertEquals(getSingleValueType(LocalTime.NOON), TIME); + assertEquals(getSingleValueType(java.util.UUID.randomUUID()), UUID); + assertEquals(getSingleValueType((byte) 1), BYTE); + assertEquals(getSingleValueType('a'), CHARACTER); + assertEquals(getSingleValueType((short) 1), SHORT); + assertEquals(getSingleValueType(new Object()), OBJECT); + + // Vendor JDBC drivers commonly return Timestamp subclasses (e.g. BigQuery Simba's TimestampTz). + // Subclasses must resolve to TIMESTAMP, not OBJECT. + class VendorTimestamp extends Timestamp { + VendorTimestamp(long time) { + super(time); + } } - assertEquals(getSingleValueType(Object.class), OBJECT); - assertEquals(getSingleValueType(Map.class), MAP); - assertEquals(getSingleValueType(null), OBJECT); + assertEquals(getSingleValueType(new VendorTimestamp(0L)), TIMESTAMP); } @Test public void testGetMultipleValueType() { - Map, PinotDataType> testCases = new HashMap<>(); - testCases.put(Byte.class, BYTE_ARRAY); - testCases.put(Character.class, CHARACTER_ARRAY); - testCases.put(Short.class, SHORT_ARRAY); - testCases.put(Integer.class, INTEGER_ARRAY); - testCases.put(Long.class, LONG_ARRAY); - testCases.put(Float.class, FLOAT_ARRAY); - testCases.put(Double.class, DOUBLE_ARRAY); - testCases.put(String.class, STRING_ARRAY); - testCases.put(Boolean.class, BOOLEAN_ARRAY); - testCases.put(Timestamp.class, TIMESTAMP_ARRAY); - testCases.put(LocalDate.class, DATE_ARRAY); - testCases.put(LocalTime.class, TIME_ARRAY); - testCases.put(byte[].class, BYTES_ARRAY); - testCases.put(UUID.class, UUID_ARRAY); - - for (Map.Entry, PinotDataType> tc : testCases.entrySet()) { - assertEquals(getMultiValueType(tc.getKey()), tc.getValue()); + assertEquals(getMultiValueType(1), INTEGER_ARRAY); + assertEquals(getMultiValueType(1L), LONG_ARRAY); + assertEquals(getMultiValueType(1.0f), FLOAT_ARRAY); + assertEquals(getMultiValueType(1.0d), DOUBLE_ARRAY); + assertEquals(getMultiValueType(BigDecimal.ONE), BIG_DECIMAL_ARRAY); + assertEquals(getMultiValueType(Boolean.TRUE), BOOLEAN_ARRAY); + assertEquals(getMultiValueType(new Timestamp(0L)), TIMESTAMP_ARRAY); + assertEquals(getMultiValueType("foo"), STRING_ARRAY); + assertEquals(getMultiValueType(new byte[]{0}), BYTES_ARRAY); + assertEquals(getMultiValueType(LocalDate.EPOCH), DATE_ARRAY); + assertEquals(getMultiValueType(LocalTime.NOON), TIME_ARRAY); + assertEquals(getMultiValueType(java.util.UUID.randomUUID()), UUID_ARRAY); + assertEquals(getMultiValueType((byte) 1), BYTE_ARRAY); + assertEquals(getMultiValueType('a'), CHARACTER_ARRAY); + assertEquals(getMultiValueType((short) 1), SHORT_ARRAY); + assertEquals(getMultiValueType(new Object()), OBJECT_ARRAY); + + // Vendor JDBC drivers commonly return Timestamp subclasses (e.g. BigQuery Simba's TimestampTz). + // Subclasses must resolve to TIMESTAMP_ARRAY, not OBJECT_ARRAY. + class VendorTimestamp extends Timestamp { + VendorTimestamp(long time) { + super(time); + } } - assertEquals(getMultiValueType(Object.class), OBJECT_ARRAY); - assertEquals(getMultiValueType(null), OBJECT_ARRAY); + assertEquals(getMultiValueType(new VendorTimestamp(0L)), TIMESTAMP_ARRAY); } private static Object getGenericTestObject() { diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java index 3d1abda8bffc..86a5e2d8c242 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapper.java @@ -439,6 +439,38 @@ private void getNonLiteralValues(ValueBlock valueBlock) { case BIG_DECIMAL_ARRAY: _nonLiteralValues[i] = transformFunction.transformToBigDecimalValuesMV(valueBlock); break; + case PRIMITIVE_BOOLEAN_ARRAY: { + int[][] intValuesMV = transformFunction.transformToIntValuesMV(valueBlock); + int numRows = intValuesMV.length; + boolean[][] booleanValuesMV = new boolean[numRows][]; + for (int j = 0; j < numRows; j++) { + int[] intValues = intValuesMV[j]; + int numValues = intValues.length; + boolean[] booleanValues = new boolean[numValues]; + for (int k = 0; k < numValues; k++) { + booleanValues[k] = intValues[k] == 1; + } + booleanValuesMV[j] = booleanValues; + } + _nonLiteralValues[i] = booleanValuesMV; + break; + } + case TIMESTAMP_ARRAY: { + long[][] longValuesMV = transformFunction.transformToLongValuesMV(valueBlock); + int numRows = longValuesMV.length; + Timestamp[][] timestampValuesMV = new Timestamp[numRows][]; + for (int j = 0; j < numRows; j++) { + long[] longValues = longValuesMV[j]; + int numValues = longValues.length; + Timestamp[] timestampValues = new Timestamp[numValues]; + for (int k = 0; k < numValues; k++) { + timestampValues[k] = new Timestamp(longValues[k]); + } + timestampValuesMV[j] = timestampValues; + } + _nonLiteralValues[i] = timestampValuesMV; + break; + } case STRING_ARRAY: _nonLiteralValues[i] = transformFunction.transformToStringValuesMV(valueBlock); break; diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java index e78d202b2395..8d8d951f3d97 100644 --- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java +++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/ScalarTransformFunctionWrapperTest.java @@ -21,6 +21,7 @@ import it.unimi.dsi.fastutil.ints.IntLinkedOpenHashSet; import it.unimi.dsi.fastutil.objects.ObjectLinkedOpenHashSet; import java.nio.charset.StandardCharsets; +import java.sql.Timestamp; import java.text.Normalizer; import java.util.Arrays; import java.util.Base64; @@ -34,6 +35,7 @@ import org.apache.pinot.common.function.scalar.StringFunctions; import org.apache.pinot.common.request.context.ExpressionContext; import org.apache.pinot.common.request.context.RequestContextUtils; +import org.apache.pinot.spi.annotations.ScalarFunction; import org.apache.pinot.spi.data.FieldSpec.DataType; import org.apache.pinot.spi.utils.ArrayCopyUtils; import org.apache.pinot.spi.utils.BigDecimalUtils; @@ -1570,4 +1572,79 @@ public void testIsValidASCIITransformFunction() { } testTransformFunction(transformFunction, expectedValues); } + + // Test-only scalar functions registered via FunctionRegistry's reflection scan (the test class is + // in a `.function.` package, matching the scan regex). They exist solely to exercise the + // `PRIMITIVE_BOOLEAN_ARRAY` and `TIMESTAMP_ARRAY` dispatch cases in + // ScalarTransformFunctionWrapper.getNonLiteralValues — no production scalar function in OSS Pinot + // currently declares a `boolean[]` or `Timestamp[]` parameter. + + @ScalarFunction + public static int countTrueBooleans(boolean[] booleans) { + int count = 0; + for (boolean b : booleans) { + if (b) { + count++; + } + } + return count; + } + + @ScalarFunction + public static long sumTimestampMillis(Timestamp[] timestamps) { + long sum = 0; + for (Timestamp t : timestamps) { + sum += t.getTime(); + } + return sum; + } + + @Test + public void testCountTrueBooleansTransformFunction() { + // Exercises the PRIMITIVE_BOOLEAN_ARRAY dispatch in getNonLiteralValues: the int MV column is + // read as int[][] via transformToIntValuesMV, then converted per-row to boolean[][] (intValue + // == 1 → true) before being passed to countTrueBooleans(boolean[]). + ExpressionContext expression = + RequestContextUtils.getExpression(String.format("countTrueBooleans(%s)", INT_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "countTrueBooleans"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.INT); + assertTrue(transformFunction.getResultMetadata().isSingleValue()); + int[] expectedValues = new int[NUM_ROWS]; + for (int i = 0; i < NUM_ROWS; i++) { + int count = 0; + for (int v : _intMVValues[i]) { + if (v == 1) { + count++; + } + } + expectedValues[i] = count; + } + testTransformFunction(transformFunction, expectedValues); + } + + @Test + public void testSumTimestampMillisTransformFunction() { + // Exercises the TIMESTAMP_ARRAY dispatch in getNonLiteralValues: the long MV column is read as + // long[][] via transformToLongValuesMV, then converted per-row to Timestamp[][] (new + // Timestamp(long)) before being passed to sumTimestampMillis(Timestamp[]). Each Timestamp's + // getTime() returns the original long, so the per-row sum equals the sum of input longs. + ExpressionContext expression = + RequestContextUtils.getExpression(String.format("sumTimestampMillis(%s)", LONG_MV_COLUMN)); + TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap); + assertTrue(transformFunction instanceof ScalarTransformFunctionWrapper); + assertEquals(transformFunction.getName(), "sumTimestampMillis"); + assertEquals(transformFunction.getResultMetadata().getDataType(), DataType.LONG); + assertTrue(transformFunction.getResultMetadata().isSingleValue()); + long[] expectedValues = new long[NUM_ROWS]; + for (int i = 0; i < NUM_ROWS; i++) { + long sum = 0; + for (long v : _longMVValues[i]) { + sum += v; + } + expectedValues[i] = sum; + } + testTransformFunction(transformFunction, expectedValues); + } } diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/creator/impl/stats/MapColumnPreIndexStatsCollector.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/creator/impl/stats/MapColumnPreIndexStatsCollector.java index 52adcd03f093..e2d720ad1fee 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/creator/impl/stats/MapColumnPreIndexStatsCollector.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/creator/impl/stats/MapColumnPreIndexStatsCollector.java @@ -283,7 +283,7 @@ public void seal() { */ private AbstractColumnStatisticsCollector createKeyStatsCollector(String key, Object value) { // Get the type of the value - PinotDataType type = PinotDataType.getSingleValueType(value.getClass()); + PinotDataType type = PinotDataType.getSingleValueType(value); return createKeyStatsCollector(key, convertToDataType(type)); } diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/loader/defaultcolumn/BaseDefaultColumnHandler.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/loader/defaultcolumn/BaseDefaultColumnHandler.java index 891cc7095744..3b7eaeba82f7 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/loader/defaultcolumn/BaseDefaultColumnHandler.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/segment/index/loader/defaultcolumn/BaseDefaultColumnHandler.java @@ -601,9 +601,7 @@ private void createDerivedColumnV1Indices(String column, FunctionEvaluator funct nullValueVectorCreator.setNull(i); } } else if (outputValueType == null) { - Class outputValueClass = outputValue.getClass(); - outputValueType = FunctionUtils.getArgumentType(outputValueClass); - Preconditions.checkState(outputValueType != null, "Unsupported output value class: %s", outputValueClass); + outputValueType = FunctionUtils.getArgumentType(outputValue); } outputValues[i] = outputValue; diff --git a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/DataTypeTransformerUtils.java b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/DataTypeTransformerUtils.java index 017f2d274e78..37abe6677e67 100644 --- a/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/DataTypeTransformerUtils.java +++ b/pinot-segment-local/src/main/java/org/apache/pinot/segment/local/utils/DataTypeTransformerUtils.java @@ -75,11 +75,11 @@ public static Object transformValue(String column, @Nullable Object value, Pinot if (destDataType == PinotDataType.JSON && values.length == 0) { sourceDataType = PinotDataType.JSON; } else { - sourceDataType = PinotDataType.getMultiValueType(values[0].getClass()); + sourceDataType = PinotDataType.getMultiValueType(values[0]); } } else { // Single-value column - sourceDataType = PinotDataType.getSingleValueType(value.getClass()); + sourceDataType = PinotDataType.getSingleValueType(value); } // Convert from source to destination type