Skip to content

Commit

Permalink
Implement ARRAYLENGTH UDF for multi-valued columns (#5301)
Browse files Browse the repository at this point in the history
* Implement LENGTH UDF for multi-valued columns

* Add integration test for LENGTH UDF

* Fix a typo

* Rename Length UDF to ArrayLength
  • Loading branch information
bozhang2820 authored Apr 27, 2020
1 parent 90c9060 commit 1fe22b5
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 0 deletions.
3 changes: 3 additions & 0 deletions docs/pql_examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ Supported transform functions
expressed as hours since UTC epoch (note that the output is not Los Angeles
timezone)

``ARRAYLENGTH``
Takes a multi-valued column and returns the length of the column

``VALUEIN``
Takes at least 2 arguments, where the first argument is a multi-valued column, and the following arguments are constant values.
The transform function will filter the value from the multi-valued column with the given constant values.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public enum TransformFunctionType {
TIMECONVERT("timeConvert"),
DATETIMECONVERT("dateTimeConvert"),
DATETRUNC("dateTrunc"),
ARRAYLENGTH("arrayLength"),
VALUEIN("valueIn"),
MAPVALUE("mapValue");

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.operator.transform.function;

import java.util.List;
import java.util.Map;
import org.apache.pinot.core.common.DataSource;
import org.apache.pinot.core.operator.blocks.ProjectionBlock;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.core.plan.DocIdSetPlanNode;


/**
* The ArrayLengthTransformFunction class implements arrayLength function for multi-valued columns
*
* Sample queries:
* SELECT COUNT(*) FROM table WHERE arrayLength(mvColumn) > 2
* SELECT COUNT(*) FROM table GROUP BY arrayLength(mvColumn)
* SELECT MAX(arrayLength(mvColumn)) FROM table
*/
public class ArrayLengthTransformFunction extends BaseTransformFunction {
public static final String FUNCTION_NAME = "arrayLength";

private int[] _results;
private TransformFunction _argument;

@Override
public String getName() {
return FUNCTION_NAME;
}

@Override
public void init(List<TransformFunction> arguments, Map<String, DataSource> dataSourceMap) {
// Check that there is only 1 argument
if (arguments.size() != 1) {
throw new IllegalArgumentException("Exactly 1 argument is required for ARRAYLENGTH transform function");
}

// Check that the argument is a multi-valued column or transform function
TransformFunction firstArgument = arguments.get(0);
if (firstArgument instanceof LiteralTransformFunction || firstArgument.getResultMetadata().isSingleValue()) {
throw new IllegalArgumentException(
"The argument of ARRAYLENGTH transform function must be a multi-valued column or a transform function");
}
_argument = firstArgument;
}

@Override
public TransformResultMetadata getResultMetadata() {
return INT_SV_NO_DICTIONARY_METADATA;
}

@Override
public int[] transformToIntValuesSV(ProjectionBlock projectionBlock) {
if (_results == null) {
_results = new int[DocIdSetPlanNode.MAX_DOC_PER_CALL];
}

int numDocs = projectionBlock.getNumDocs();
switch (_argument.getResultMetadata().getDataType()) {
case INT:
int[][] intValuesMV = _argument.transformToIntValuesMV(projectionBlock);
for (int i = 0; i < numDocs; i++) {
_results[i] = intValuesMV[i].length;
}
break;
case LONG:
long[][] longValuesMV = _argument.transformToLongValuesMV(projectionBlock);
for (int i = 0; i < numDocs; i++) {
_results[i] = longValuesMV[i].length;
}
break;
case FLOAT:
float[][] floatValuesMV = _argument.transformToFloatValuesMV(projectionBlock);
for (int i = 0; i < numDocs; i++) {
_results[i] = floatValuesMV[i].length;
}
break;
case DOUBLE:
double[][] doubleValuesMV = _argument.transformToDoubleValuesMV(projectionBlock);
for (int i = 0; i < numDocs; i++) {
_results[i] = doubleValuesMV[i].length;
}
break;
case STRING:
String[][] stringValuesMV = _argument.transformToStringValuesMV(projectionBlock);
for (int i = 0; i < numDocs; i++) {
_results[i] = stringValuesMV[i].length;
}
break;
default:
throw new IllegalStateException();
}
return _results;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
* Base class for transform function providing the default implementation for all data types.
*/
public abstract class BaseTransformFunction implements TransformFunction {
protected static final TransformResultMetadata INT_SV_NO_DICTIONARY_METADATA =
new TransformResultMetadata(DataType.INT, true, false);
protected static final TransformResultMetadata LONG_SV_NO_DICTIONARY_METADATA =
new TransformResultMetadata(DataType.LONG, true, false);
protected static final TransformResultMetadata DOUBLE_SV_NO_DICTIONARY_METADATA =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ private TransformFunctionFactory() {
put(TransformFunctionType.TIMECONVERT.getName().toLowerCase(), TimeConversionTransformFunction.class);
put(TransformFunctionType.DATETIMECONVERT.getName().toLowerCase(), DateTimeConversionTransformFunction.class);
put(TransformFunctionType.DATETRUNC.getName().toLowerCase(), DateTruncTransformFunction.class);
put(TransformFunctionType.ARRAYLENGTH.getName().toLowerCase(), ArrayLengthTransformFunction.class);
put(TransformFunctionType.VALUEIN.getName().toLowerCase(), ValueInTransformFunction.class);
put(TransformFunctionType.MAPVALUE.getName().toLowerCase(), MapValueTransformFunction.class);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.pinot.core.operator.transform.function;

import org.apache.pinot.common.request.transform.TransformExpressionTree;
import org.apache.pinot.core.query.exception.BadQueryRequestException;
import org.apache.pinot.spi.data.FieldSpec;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;


public class ArrayLengthTransformFunctionTest extends BaseTransformFunctionTest {

@Test
public void testLengthTransformFunction() {
TransformExpressionTree expression =
TransformExpressionTree.compileToExpressionTree(String.format("arrayLength(%s)", INT_MV_COLUMN));
TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
Assert.assertTrue(transformFunction instanceof ArrayLengthTransformFunction);
Assert.assertEquals(transformFunction.getName(), ArrayLengthTransformFunction.FUNCTION_NAME);
Assert.assertEquals(transformFunction.getResultMetadata().getDataType(), FieldSpec.DataType.INT);
Assert.assertTrue(transformFunction.getResultMetadata().isSingleValue());
Assert.assertFalse(transformFunction.getResultMetadata().hasDictionary());

int[] results = transformFunction.transformToIntValuesSV(_projectionBlock);
for (int i = 0; i < NUM_ROWS; i++) {
Assert.assertEquals(results[i], _intMVValues[i].length);
}
}

@Test(dataProvider = "testIllegalArguments", expectedExceptions = {BadQueryRequestException.class})
public void testIllegalArguments(String expressionStr) {
TransformExpressionTree expression = TransformExpressionTree.compileToExpressionTree(expressionStr);
TransformFunctionFactory.get(expression, _dataSourceMap);
}

@DataProvider(name = "testIllegalArguments")
public Object[][] testIllegalArguments() {
return new Object[][]{
new Object[]{String.format("arrayLength(%s,1)", INT_MV_COLUMN)},
new Object[]{"arrayLength(2)"},
new Object[]{String.format("arrayLength(%s)", LONG_SV_COLUMN)}};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,28 @@ public void testGroupByUDF()
assertEquals(groupByEntry.get("group").get(0).asDouble(), 16138.0 / 2);
assertEquals(groupByResult.get("groupByColumns").get(0).asText(), "div(DaysSinceEpoch,'2')");

pqlQuery = "SELECT COUNT(*) FROM mytable GROUP BY arrayLength(DivAirports)";
response = postQuery(pqlQuery);
groupByResult = response.get("aggregationResults").get(0);
groupByEntry = groupByResult.get("groupByResult").get(0);
assertEquals(groupByEntry.get("value").asDouble(), 115545.0);
assertEquals(groupByEntry.get("group").get(0).asText(), "5");
assertEquals(groupByResult.get("groupByColumns").get(0).asText(), "arraylength(DivAirports)");

pqlQuery = "SELECT COUNT(*) FROM mytable GROUP BY arrayLength(valueIn(DivAirports,'DFW','ORD'))";
response = postQuery(pqlQuery);
groupByResult = response.get("aggregationResults").get(0);
groupByEntry = groupByResult.get("groupByResult").get(0);
assertEquals(groupByEntry.get("value").asDouble(), 114895.0);
assertEquals(groupByEntry.get("group").get(0).asText(), "0");
groupByEntry = groupByResult.get("groupByResult").get(1);
assertEquals(groupByEntry.get("value").asDouble(), 648.0);
assertEquals(groupByEntry.get("group").get(0).asText(), "1");
groupByEntry = groupByResult.get("groupByResult").get(2);
assertEquals(groupByEntry.get("value").asDouble(), 2.0);
assertEquals(groupByEntry.get("group").get(0).asText(), "2");
assertEquals(groupByResult.get("groupByColumns").get(0).asText(), "arraylength(valuein(DivAirports,'DFW','ORD'))");

pqlQuery = "SELECT COUNT(*) FROM mytable GROUP BY valueIn(DivAirports,'DFW','ORD')";
response = postQuery(pqlQuery);
groupByResult = response.get("aggregationResults").get(0);
Expand Down

0 comments on commit 1fe22b5

Please sign in to comment.