Skip to content
Permalink
Browse files
Adding vectorized time_shift (#12254)
* Adding vectorized time_shift

* Vectorize time shift, addressing review comments

* Remove an unused import
  • Loading branch information
somu-imply committed Feb 11, 2022
1 parent c61b19d commit 033989eb1d8f4f91268b2d7d4d3dc73af7bf2c3f
Showing 5 changed files with 159 additions and 5 deletions.
@@ -188,7 +188,19 @@ public String getFormatString()
// 30: logical and operator
"SELECT CAST(long1 as BOOLEAN) AND CAST (long2 as BOOLEAN), COUNT(*) FROM foo GROUP BY 1 ORDER BY 2",
// 31: isnull, notnull
"SELECT long5 IS NULL, long3 IS NOT NULL, count(*) FROM foo GROUP BY 1,2 ORDER BY 3"
"SELECT long5 IS NULL, long3 IS NOT NULL, count(*) FROM foo GROUP BY 1,2 ORDER BY 3",
// 32: time shift, non-expr col + reg agg, regular
"SELECT TIME_SHIFT(__time, 'PT1H', 3), string2, SUM(double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 33: time shift, non-expr col + expr agg, sequential low cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long1), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 34: time shift + non-expr agg (timeseries) (non-expression reference), zipf distribution low cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long2), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 35: time shift + expr agg (timeseries), zipf distribution high cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long3), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 36: time shift + non-expr agg (group by), uniform distribution low cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long4), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
// 37: time shift + expr agg (group by), uniform distribution high cardinality
"SELECT TIME_SHIFT(MILLIS_TO_TIMESTAMP(long5), 'PT1H', 1), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3"
);

@Param({"5000000"})
@@ -234,7 +246,13 @@ public String getFormatString()
"28",
"29",
"30",
"31"
"31",
"32",
"33",
"34",
"35",
"36",
"37"
})
private String query;

@@ -294,7 +294,7 @@ static void testExpression(String expr, Map<String, ExpressionType> types)
testExpressionWithBindings(expr, parsed, bindings);
}

private static void testExpressionWithBindings(
public static void testExpressionWithBindings(
String expr,
Expr parsed,
NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> bindings
@@ -321,7 +321,7 @@ private static void testExpressionWithBindings(
}
}

static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeRandomizedBindings(
public static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeRandomizedBindings(
int vectorSize,
Map<String, ExpressionType> types
)
@@ -338,7 +338,7 @@ static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeRandomized
);
}

static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeSequentialBinding(
public static NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> makeSequentialBinding(
int vectorSize,
Map<String, ExpressionType> types
)
@@ -26,6 +26,9 @@
import org.apache.druid.math.expr.ExprMacroTable;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.InputBindings;
import org.apache.druid.math.expr.vector.CastToTypeVectorProcessor;
import org.apache.druid.math.expr.vector.ExprVectorProcessor;
import org.apache.druid.math.expr.vector.LongOutLongInFunctionVectorValueProcessor;
import org.joda.time.Chronology;
import org.joda.time.Period;
import org.joda.time.chrono.ISOChronology;
@@ -112,6 +115,31 @@ public Expr visit(Shuttle shuttle)
return shuttle.visit(new TimestampShiftExpr(shuttle.visitAll(args)));
}

@Override
public boolean canVectorize(InputBindingInspector inspector)
{
return args.get(0).canVectorize(inspector);
}

@Override
public <T> ExprVectorProcessor<T> buildVectorized(VectorInputBindingInspector inspector)
{
ExprVectorProcessor<?> processor;
processor = new LongOutLongInFunctionVectorValueProcessor(
CastToTypeVectorProcessor.cast(args.get(0).buildVectorized(inspector), ExpressionType.LONG),
inspector.getMaxVectorSize()
)
{
@Override
public long apply(long input)
{
return chronology.add(period, input, step);
}
};

return (ExprVectorProcessor<T>) processor;
}

@Nullable
@Override
public ExpressionType getOutputType(InputBindingInspector inspector)
@@ -153,5 +181,6 @@ public ExpressionType getOutputType(InputBindingInspector inspector)
{
return ExpressionType.LONG;
}

}
}
@@ -0,0 +1,102 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.druid.query.expression;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.java.util.common.DateTimes;
import org.apache.druid.java.util.common.NonnullPair;
import org.apache.druid.java.util.common.logger.Logger;
import org.apache.druid.math.expr.Expr;
import org.apache.druid.math.expr.ExprEval;
import org.apache.druid.math.expr.ExpressionProcessing;
import org.apache.druid.math.expr.ExpressionType;
import org.apache.druid.math.expr.VectorExprSanityTest;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.joda.time.DateTime;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;

import java.util.Map;

/**
* randomize inputs to various vector expressions and make sure the results match nonvectorized expressions
* <p>
* this is not a replacement for correctness tests, but will ensure that vectorized and non-vectorized expression
* evaluation is at least self consistent...
*/
public class VectorExpressionsSanityTest extends InitializedNullHandlingTest
{
private static final Logger log = new Logger(VectorExpressionsSanityTest.class);
private static final int NUM_ITERATIONS = 10;
private static final int VECTOR_SIZE = 512;
private static final TimestampShiftExprMacro TIMESTAMP_SHIFT_EXPR_MACRO = new TimestampShiftExprMacro();
private static final DateTime DATE_TIME = DateTimes.of("2020-11-05T04:05:06");

final Map<String, ExpressionType> types = ImmutableMap.<String, ExpressionType>builder()
.put("l1", ExpressionType.LONG)
.put("l2", ExpressionType.LONG)
.put("d1", ExpressionType.DOUBLE)
.put("d2", ExpressionType.DOUBLE)
.put("s1", ExpressionType.STRING)
.put("s2", ExpressionType.STRING)
.put("boolString1", ExpressionType.STRING)
.put("boolString2", ExpressionType.STRING)
.build();

@BeforeClass
public static void setupTests()
{
ExpressionProcessing.initializeForStrictBooleansTests(true);
}

@AfterClass
public static void teardownTests()
{
ExpressionProcessing.initializeForTests(null);
}

static void testExpression(String expr, Expr parsed, Map<String, ExpressionType> types)
{
log.debug("[%s]", expr);
NonnullPair<Expr.ObjectBinding[], Expr.VectorInputBinding> bindings;
for (int iterations = 0; iterations < NUM_ITERATIONS; iterations++) {
bindings = VectorExprSanityTest.makeRandomizedBindings(VECTOR_SIZE, types);
VectorExprSanityTest.testExpressionWithBindings(expr, parsed, bindings);
}
bindings = VectorExprSanityTest.makeSequentialBinding(VECTOR_SIZE, types);
VectorExprSanityTest.testExpressionWithBindings(expr, parsed, bindings);
}

@Test
public void testTimeShiftFn()
{
int step = 1;
Expr parsed = TIMESTAMP_SHIFT_EXPR_MACRO.apply(
ImmutableList.of(
ExprEval.of(DATE_TIME.getMillis()).toExpr(),
ExprEval.of("P1M").toExpr(),
ExprEval.of(step).toExpr()
));
testExpression("time_shift(l1, 'P1M', 1)", parsed, types);
}
}

@@ -83,6 +83,11 @@ public class SqlVectorizedExpressionSanityTest extends InitializedNullHandlingTe
"SELECT TIME_FLOOR(__time, 'PT1H'), SUM(long1 * long4) FROM foo GROUP BY 1 ORDER BY 1",
"SELECT TIME_FLOOR(__time, 'PT1H'), SUM(long1 * long4) FROM foo GROUP BY 1 ORDER BY 2",
"SELECT TIME_FLOOR(TIMESTAMPADD(DAY, -1, __time), 'PT1H'), SUM(long1 * long4) FROM foo GROUP BY 1 ORDER BY 1",
"SELECT TIME_SHIFT(__time, 'PT1H', 3), string2, SUM(long1 * double4) FROM foo GROUP BY 1,2 ORDER BY 3",
"SELECT TIME_SHIFT(__time, 'PT1H', 4), string2, SUM(long1 * double4) FROM foo WHERE string2 = '10' GROUP BY 1,2 ORDER BY 3",
"SELECT TIME_SHIFT(__time, 'PT1H', 3), SUM(long1 * long4) FROM foo GROUP BY 1 ORDER BY 1",
"SELECT TIME_SHIFT(__time, 'PT1H', 4), SUM(long1 * long4) FROM foo GROUP BY 1 ORDER BY 2",
"SELECT TIME_SHIFT(TIMESTAMPADD(DAY, -1, __time), 'PT1H', 3), SUM(long1 * long4) FROM foo GROUP BY 1 ORDER BY 1",
"SELECT (long1 * long2), SUM(double1) FROM foo GROUP BY 1 ORDER BY 2",
"SELECT string2, SUM(long1 * long4) FROM foo GROUP BY 1 ORDER BY 2",
"SELECT string1 + string2, COUNT(*) FROM foo GROUP BY 1 ORDER BY 2",

0 comments on commit 033989e

Please sign in to comment.