From 66a813ee7c34c60aa3a6b6289a831a4f869d981e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 2 Jul 2015 00:51:29 -0700 Subject: [PATCH] Prefix comparators for float and double --- .../unsafe/sort/PrefixComparators.java | 28 +++++++++ .../codegen/GenerateExpression.scala | 59 +++++++++++++++++++ .../expressions/CodeGenerationSuite.scala | 4 ++ .../spark/sql/execution/SortPrefixUtils.scala | 8 ++- 4 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java index e71a31eb94487..c10ab26c1bd12 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java @@ -25,6 +25,8 @@ private PrefixComparators() {} public static final IntPrefixComparator INTEGER = new IntPrefixComparator(); public static final LongPrefixComparator LONG = new LongPrefixComparator(); + public static final FloatPrefixComparator FLOAT = new FloatPrefixComparator(); + public static final DoublePrefixComparator DOUBLE = new DoublePrefixComparator(); public static final class IntPrefixComparator extends PrefixComparator { @Override @@ -45,4 +47,30 @@ public int compare(long a, long b) { return (a < b) ? -1 : (a > b) ? 1 : 0; } } + + public static final class FloatPrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + float a = Float.intBitsToFloat((int) aPrefix); + float b = Float.intBitsToFloat((int) bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(float value) { + return Float.floatToIntBits(value) & 0xffffffffL; + } + } + + public static final class DoublePrefixComparator extends PrefixComparator { + @Override + public int compare(long aPrefix, long bPrefix) { + double a = Double.longBitsToDouble(aPrefix); + double b = Double.longBitsToDouble(bPrefix); + return (a < b) ? -1 : (a > b) ? 1 : 0; + } + + public long computePrefix(double value) { + return Double.doubleToLongBits(value); + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala new file mode 100644 index 0000000000000..cb1480eb4aaf2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateExpression.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression} + +import scala.runtime.AbstractFunction1 + +object GenerateExpression extends CodeGenerator[Expression, InternalRow => Any] { + + override protected def canonicalize(in: Expression): Expression = { + ExpressionCanonicalizer.execute(in) + } + + override protected def bind(in: Expression, inputSchema: Seq[Attribute]): Expression = { + BindReferences.bindReference(in, inputSchema) + } + + override protected def create(expr: Expression): InternalRow => Any = { + val ctx = newCodeGenContext() + val eval = expr.gen(ctx) + val code = + s""" + |class SpecificExpression extends + | ${classOf[AbstractFunction1[InternalRow, Any]].getName}<${classOf[InternalRow].getName}, Object> { + | + | @Override + | public SpecificExpression generate($exprType[] expr) { + | return new SpecificExpression(expr); + | } + | + | @Override + | public Object apply(InternalRow i) { + | ${eval.code} + | return ${eval.isNull} ? null : ${eval.primitive}; + | } + | } + """.stripMargin + logDebug(s"Generated expression '$expr':\n$code") + println(code) + compile(code).generate(ctx.references.toArray).asInstanceOf[InternalRow => Any] + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 481b335d15dfd..f8bc5c560e154 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -26,6 +26,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ */ class CodeGenerationSuite extends SparkFunSuite { + test("generate expression") { + GenerateExpression.generate(Add(Literal(1), Literal(1))) + } + test("multithreaded eval") { import scala.concurrent._ import ExecutionContext.Implicits.global diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index 3fc1d1986fb05..53718690f6431 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SortOrder -import org.apache.spark.sql.types.{LongType, IntegerType} +import org.apache.spark.sql.types.{DoubleType, FloatType, LongType, IntegerType} import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} @@ -38,6 +38,8 @@ object SortPrefixUtils { sortOrder.dataType match { case IntegerType => PrefixComparators.INTEGER case LongType => PrefixComparators.LONG + case FloatType => PrefixComparators.FLOAT + case DoubleType => PrefixComparators.DOUBLE case _ => NoOpPrefixComparator } } @@ -47,6 +49,10 @@ object SortPrefixUtils { case IntegerType => (row: InternalRow) => PrefixComparators.INTEGER.computePrefix(sortOrder.child.eval(row).asInstanceOf[Int]) case LongType => (row: InternalRow) => sortOrder.child.eval(row).asInstanceOf[Long] + case FloatType => (row: InternalRow) => + PrefixComparators.FLOAT.computePrefix(sortOrder.child.eval(row).asInstanceOf[Float]) + case DoubleType => (row: InternalRow) => + PrefixComparators.DOUBLE.computePrefix(sortOrder.child.eval(row).asInstanceOf[Double]) case _ => (row: InternalRow) => 0L } }